In [1]:
import logging
logging.getLogger().setLevel(logging.ERROR)

In [4]:
import os
import pandas as pd
import torch
from tqdm import tqdm

In [7]:
import os
import pandas as pd
import torch
from tqdm import tqdm

def prepare_split(split):
    prompt = r"""### INSTRUCTIONS ###
Continue the conversation by generating **only the next line** spoken by the indicated character.
Your response must be empathetic, showing understanding or emotional attunement to the preceding dialogue.

### EXAMPLE ###

=== DIALOGUE HISTORY ===
Rachel: Hey!
Ross: Hi!
Rachel: What are you doing here?
Ross: Ah y'know, this building is on my paper route so I...
Rachel: Oh.
Ross: Hi.
Rachel: Hi.
Ross: How'd did it go?
Rachel: Oh well, the woman I interviewed with was pretty tough, but y'know thank God Mark coached me, because once I started talking about the fall line, she got all happy and wouldn't shut up.
Ross:

=== RESPONSE ===
That sounds like a huge relief.

### TASK ###

=== DIALOGUE HISTORY ===
{dialogue_hist}

=== RESPONSE ===
"""

    dataset_dir = '/project/msoleyma_1026/EmpatheticResponseGeneration'
    dialogues_df = pd.read_csv(f'{dataset_dir}/MELD.Raw/{split}_sent_emo.csv').groupby('Dialogue_ID')
    targets_df = pd.read_csv(f'{dataset_dir}/Targets/{split}_structured.csv')

    data = []
    for d_id, dialogue in tqdm(dialogues_df, total=len(dialogues_df)):
        dialogue_hist = ""

        for _, row in dialogue.iterrows():
            u_id = row['Utterance_ID']

            multimodal_embed_path = f'{dataset_dir}/ImagebindEmbeds/{split}/dia{d_id}_utt{u_id}.pt'

            if not os.path.isfile(multimodal_embed_path):
                continue
            
            dialogue_hist += f"{row['Speaker']}: {row['Utterance']}\n"
            multimodal_embed = torch.load(f'{dataset_dir}/ImagebindEmbeds/{split}/dia{d_id}_utt{u_id}.pt')
            target_response = targets_df[(targets_df['Dialogue_ID'] == d_id) & (targets_df['Utterance_ID'] == u_id)]['Response'].values[0]

            next_speaker = dialogue[dialogue['Utterance_ID'] == u_id + 1].iloc[0]['Speaker']
            data.append({
                'Multimodal_Embed': multimodal_embed,
                'Prompt': prompt.format(dialogue_hist=f"{dialogue_hist}{next_speaker}:"),
                'Target_Response': target_response
            })
    
    return data

train_data = prepare_split('train')

100%|██████████| 1038/1038 [02:06<00:00,  8.23it/s]


In [8]:
from torch.utils.data import Dataset, DataLoader

class MultimodalMELD(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        u = self.data[idx]

        return {
            'multimodal_embed': u['Multimodal_Embed'],
            'prompt': u['Prompt'],
            'target_response': u['Target_Response']
        }

train_dataset = MultimodalMELD(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [9]:
from huggingface_hub import HfFolder
HfFolder.save_token("HUGGING_FACE_TOKEN")

In [5]:
!nvidia-smi

Sat May  3 17:00:50 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:81:00.0 Off |                    0 |
| N/A   26C    P0             42W /  300W |       1MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

class EmpatheticMLLM(nn.Module):
    def __init__(self):
        super(EmpatheticMLLM, self).__init__()

        # quantized Mistral-7B-Instruct-v0.3
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="float16"
        )
        
        model_id = 'meta-llama/Meta-Llama-3-8B-instruct'
        
        self.llm = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # project from ImageBind embedding space to Mistral embedding space
        self.projector = nn.Linear(1024, self.llm.config.hidden_size) # 1024 -> 4096

        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            task_type='CAUSAL_LM'
        )

        self.llm = get_peft_model(self.llm, lora_config)

    def forward(self, x, training=False):
        device = next(self.parameters()).device

        # project to LLM embedding space
        multimodal_embed = self.projector(x['multimodal_embed']).to(self.llm.dtype)

        # tokenize prompt
        prompt_tokenized = self.tokenizer(x['prompt'], return_tensors="pt", padding=True, truncation=True, max_length=32768).to(device)
        prompt_ids = prompt_tokenized.input_ids
        prompt_attention_mask = prompt_tokenized.attention_mask

        # embed prompt tokens
        prompt_embeds = self.llm.model.model.embed_tokens(prompt_ids)

        # concatenate: [multimodal ImageBind embedding] + [prompt embeddings]
        inputs_embeds = torch.cat([multimodal_embed, prompt_embeds], dim=1)
        attention_mask = torch.cat([torch.ones(multimodal_embed.size(0), 1, device=device), prompt_attention_mask], dim=1)

        # if training, provide labels for supervised learning
        if training:
            # tokenize target response
            target_tokenized = self.tokenizer(x['target_response'], return_tensors="pt", padding=True, truncation=True).to(device)
            target_ids = target_tokenized.input_ids

            # embed target response tokens
            target_embeds = self.llm.model.model.embed_tokens(target_ids)
            
            # concatenate target embeddings
            inputs_embeds = torch.cat([inputs_embeds, target_embeds], dim=1)
            
            # ignore multimodal token + prompt tokens in loss calculation
            bs = multimodal_embed.size(0)
            mask_len = multimodal_embed.size(1) + prompt_embeds.size(1)
            labels = torch.cat([torch.full((bs, mask_len), -100, device=device), target_ids], dim=1)

            outputs = self.llm(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels
            )
            return outputs

        # otherwise, just generate output
        else:
            outputs = self.llm.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                max_new_tokens=25
            )
            return outputs

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EmpatheticMLLM().to(device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
from torch.optim import AdamW
from tqdm import tqdm

projector_params = []
for name, param in model.named_parameters():
    if "projector" in name:
        projector_params.append(param)

qlora_params = []
for name, param in model.llm.named_parameters():
    if "lora" in name:
        qlora_params.append(param)
    
optimizer = AdamW(projector_params + qlora_params, lr=1e-5)

epochs = 5
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        optimizer.zero_grad()

        x = {
            'multimodal_embed': batch['multimodal_embed'].to(device),
            'prompt': batch['prompt'],
            'target_response': batch['target_response']
        }

        output = model(x, training=True)
        loss = output.loss
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")

torch.save(model.state_dict(), "finetuned_mllm_llama_3.pth")
print("Final model saved successfully!")

Epoch 1/5:   0%|          | 0/8572 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/5: 100%|██████████| 8572/8572 [22:44<00:00,  6.28it/s]


Epoch 1/5, Loss: 1.1461


Epoch 2/5: 100%|██████████| 8572/8572 [22:46<00:00,  6.27it/s]


Epoch 2/5, Loss: 0.9438


Epoch 3/5: 100%|██████████| 8572/8572 [22:45<00:00,  6.28it/s]


Epoch 3/5, Loss: 0.8639


Epoch 4/5: 100%|██████████| 8572/8572 [22:46<00:00,  6.27it/s]


Epoch 4/5, Loss: 0.7889


Epoch 5/5: 100%|██████████| 8572/8572 [22:47<00:00,  6.27it/s]


Epoch 5/5, Loss: 0.7128
Final model saved successfully!
