# MUSE

## Imports

In [1]:
import tarfile
import os
from dataset import MELDDataset

## Data

In [2]:
# extract the data from a ".tar.gz" file. Put the content into data/name_of_file folder
def extract_data(file_name, data_folder):
    if not os.path.exists(data_folder):
        os.makedirs(data_folder)
    tar = tarfile.open(file_name, "r:gz")
    tar.extractall(data_folder)
    tar.close()

In [3]:
# extract_data("raw/train.tar.gz", "data")

In [4]:
# count the number of videos in the dataset
def count_videos(data_folder):
    return len(os.listdir(data_folder))

print(count_videos("data/train_splits/"))

9990


In [5]:
# setup the dataset
dataset = MELDDataset(
    csv_file="data/train_sent_emo.csv", 
    root_dir="./data",
    split_type="train"
)

In [6]:
example = dataset[0]

RuntimeError: Couldn't find appropriate backend to handle uri ./data/train_splits/audio/dia0_utt0.wav and format None.

In [None]:


############################################
# 1. Data Download & Alignment
############################################
def download_and_prepare_cmu_mosi(data_path="cmumosi/"):
    cmumosi_highlevel = mmdatasdk.mmdataset(mmdatasdk.cmu_mosi.highlevel, data_path)
    cmumosi_highlevel.add_computational_sequences(mmdatasdk.cmu_mosi.labels, data_path)
    cmumosi_highlevel.align('Opinion Segment Labels')
    print("Data download and alignment complete.")
    return cmumosi_highlevel

cmumosi_highlevel = download_and_prepare_cmu_mosi()

############################################
# 2. PyTorch Dataset
############################################
class CMUMOSIDataset(Dataset):
    def __init__(self, mmdataset_obj, split='train'):
        self.data = []
        
        # Example: retrieve from "Opinion Segment Labels"
        for vid in mmdataset_obj.computational_sequences["Opinion Segment Labels"].data.keys():
            segment_data = mmdataset_obj.computational_sequences["Opinion Segment Labels"].data[vid]
            for seg_key in segment_data["features"]:
                label = segment_data["features"][seg_key]
                text_data = self._get_feature(mmdataset_obj, "glove_vectors", vid, seg_key)
                audio_data = self._get_feature(mmdataset_obj, "COVAREP", vid, seg_key)
                visual_data = self._get_feature(mmdataset_obj, "FACET_4.2", vid, seg_key)
                
                self.data.append({
                    "text": text_data,
                    "audio": audio_data,
                    "visual": visual_data,
                    "label": label.squeeze() if label.shape[0] == 1 else label
                })

    def _get_feature(self, mmdataset_obj, feature_key, vid, seg_key):
        if feature_key not in mmdataset_obj.computational_sequences:
            return np.zeros((1,))
        feature_data = mmdataset_obj.computational_sequences[feature_key].data[vid]["features"]
        if seg_key not in feature_data:
            return np.zeros((1,))
        return feature_data[seg_key]

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        text_tensor = torch.tensor(sample["text"], dtype=torch.float)
        audio_tensor = torch.tensor(sample["audio"], dtype=torch.float)
        visual_tensor = torch.tensor(sample["visual"], dtype=torch.float)
        label_tensor = torch.tensor(sample["label"], dtype=torch.float)
        return text_tensor, audio_tensor, visual_tensor, label_tensor

train_dataset = CMUMOSIDataset(cmumosi_highlevel, split='train')
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
print("Total samples in train dataset:", len(train_dataset))

############################################
# 3. Model Definitions
############################################
# class TextEncoder(nn.Module):
#     def __init__(self, vocab_size=5000, embed_dim=300, hidden_dim=128):
#         super(TextEncoder, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, embed_dim)
#         self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
#         self.hidden_dim = hidden_dim

#     def forward(self, x):
#         if x.dim() == 2:
#             x = x.unsqueeze(1)
#         _, (h_n, _) = self.lstm(x)
#         h_n = h_n.squeeze(0)
#         return h_n
    


class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(TextEncoder, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, input_texts):
        # tokenize the input text
        inputs = self.tokenizer(
            input_texts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        )

        # get the BERT model output and sentence embeddings
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)

        return embeddings

class AudioEncoder(nn.Module):
    def __init__(self, input_dim=74, hidden_dim=128):
        super(AudioEncoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=64, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, hidden_dim)
        
    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(2)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.squeeze(-1)
        x = self.fc(x)
        return x

class VisualEncoder(nn.Module):
    def __init__(self, input_dim=35, hidden_dim=128):
        super(VisualEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, hidden_dim)
        
    def forward(self, x):
        if x.dim() > 2:
            x = x.mean(dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class MultimodalModel(nn.Module):
    def __init__(self, text_hidden_dim=128, audio_hidden_dim=128, visual_hidden_dim=128, 
                 fusion_output_dim=64, num_classes=1):
        super(MultimodalModel, self).__init__()
        self.text_encoder = TextEncoder(embedding_dim=text_hidden_dim)
        self.audio_encoder = AudioEncoder(hidden_dim=audio_hidden_dim)
        self.visual_encoder = VisualEncoder(hidden_dim=visual_hidden_dim)
        
        total_fusion_dim = text_hidden_dim + audio_hidden_dim + visual_hidden_dim
        
        self.fusion_fc1 = nn.Linear(total_fusion_dim, fusion_output_dim)
        self.relu = nn.ReLU()
        self.final_fc = nn.Linear(fusion_output_dim, num_classes)
        
    def forward(self, text_input, audio_input, visual_input):
        text_repr = self.text_encoder(text_input)
        audio_repr = self.audio_encoder(audio_input)
        visual_repr = self.visual_encoder(visual_input)
        
        fused = torch.cat([text_repr, audio_repr, visual_repr], dim=1)
        fused = self.fusion_fc1(fused)
        fused = self.relu(fused)
        
        out = self.final_fc(fused)
        return out.squeeze()

############################################
# 4. Training Loop (Simplified)
############################################
model = MultimodalModel(num_classes=1)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for i, (text_batch, audio_batch, visual_batch, label_batch) in enumerate(train_loader):
        text_batch = text_batch.to(device)
        audio_batch = audio_batch.to(device)
        visual_batch = visual_batch.to(device)
        label_batch = label_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(text_batch, audio_batch, visual_batch)
        loss = criterion(outputs, label_batch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

NameError: name 'mmdatasdk' is not defined