In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np

class DacneEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, embed_dim, num_layers=2):
        """ 
        Using an LSTM to process the modtion capture sequences.
        
        """
        super(DacneEncoder, self).__init_()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirection=True)
        self.fc = nn.Linear(hidden_size * 2, embed_dim)
    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        pooled = lstm_out.mean(dim=1)
        embed = self.fc(pooled)

        return F.normalize(embed, dim=-1)

class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers=1):
        """ 
        A simple text encoder that uses an embedding layer and a GRU to process text sequences.
       
        """
        
        super(TextEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, embed_dim)
        
    def forwad(self, x):
        x = self.embedding(x)
        output, _ =self.rnn(x)
        pooled = output.mean(dim=1)
        embed = self.fc(pooled)
        
        return F.normalize(embed, dim=-1)

def contrastive_loss(dance_embeds, text_embeds, temperature=0.07):
    """ 
    Compute the InfoNCE loss between the dance and text embeddings.
    
    """

    logits = torch.matmul(dance_embeds, text_embeds.t()) / temperature
    labels = torch.arange(dance_embeds.size(0)).to(dance_embeds.device)
    loss_d2t = F.cross_entropy(logits, labels)
    loss_t2d = F.cross_entropy(logits.t(), labels)

    return(loss_d2t + loss_t2d) / 2

def train_multimodal_model(dance_data, text_data, dance_encoder, text_encoder, optimizer, num_epochs=10):

    history = []
    dance_encoder.train()
    text_encoder.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for dance_batch, text_batch in zip(dance_data, text_data):
            optimizer.zero_grad()
            dance_embeds = dance_encoder(dance_batch)
            text_embeds = text_encoder(text_batch)

            loss = contrastive_loss(dance_embeds, text_embeds)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dance_data)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
        history.append(avg_loss)
    return history

def generate_dance_from_text(text_input, dance_encoder, text_encoder, holdout_dance_embeddings, holdout_dance_data):

     text_embed = text_encoder(text_input)
     similarities = torch.matmul(holdout_dance_embeddings, text_embed.t()).squeeze()
     best_index = torch.argmax(similarities).item()
     best_match = holdout_dance_data[best_index]
     
     return best_match

def generate_text_from_dance(dance_input, dance_encoder, text_encoder, holdout_text_embeddings, holdout_text_data):
    
    dance_embed = dance_encoder(dance_input)
    similarities = torch.matmul(holdout_text_embeddings, dance_embed.t()).squeeze(1)
    best_index = torch.argmax(similarities).item()
    best_match = holdout_text_data[best_index]

    return best_match



