In [None]:
import numpy as np
import pickle
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
df_anno = pd.read_pickle('AnnoMI-full-with-audio-cleaned-text.pkl')

In [None]:
# Extract unique class names for client and therapist tasks
client_class_names = df_anno['client_talk_type'].dropna().unique().tolist()
therapist_class_names = df_anno['main_therapist_behaviour'].dropna().unique().tolist()

client_class_names, therapist_class_names

In [None]:
# Load saved embeddings and labels
with open('audio_embeddings_ast_client.pkl', 'rb') as file:
    audio_data_client = pickle.load(file)
    
with open('audio_embeddings_ast_therapist.pkl', 'rb') as file:
    audio_data_therapist = pickle.load(file)

with open('text_embeddings_roberta_client.pkl', 'rb') as file:
    text_data_client = pickle.load(file)
    
with open('text_embeddings_roberta_therapist.pkl', 'rb') as file:
    text_data_therapist = pickle.load(file)

labels_client = audio_data_client["labels"]
labels_therapist = audio_data_therapist["labels"]


In [None]:
# Split data for audio and text separately
X_audio_train_client, X_audio_val_client, y_audio_train_client, y_audio_val_client = train_test_split(
    audio_data_client["embeddings"], labels_client, test_size=0.2, random_state=42)

X_audio_train_therapist, X_audio_val_therapist, y_audio_train_therapist, y_audio_val_therapist = train_test_split(
    audio_data_therapist["embeddings"], labels_therapist, test_size=0.2, random_state=42)

X_text_train_client, X_text_val_client, y_text_train_client, y_text_val_client = train_test_split(
    text_data_client["embeddings"], labels_client, test_size=0.2, random_state=42)

X_text_train_therapist, X_text_val_therapist, y_text_train_therapist, y_text_val_therapist = train_test_split(
    text_data_therapist["embeddings"], labels_therapist, test_size=0.2, random_state=42)


In [None]:
class MergedDataset(Dataset):
    def __init__(self, audio_embeddings, text_embeddings, labels):
        assert len(audio_embeddings) == len(text_embeddings) == len(labels), "Data dimensions don't match!"
        self.audio_embeddings = audio_embeddings
        self.text_embeddings = text_embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "audio_embeddings": torch.tensor(self.audio_embeddings[idx], dtype=torch.float32),
            "text_embeddings": torch.tensor(self.text_embeddings[idx], dtype=torch.float32),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [None]:
# Create merged datasets
train_dataset_client = MergedDataset(X_audio_train_client, X_text_train_client, y_audio_train_client)
val_dataset_client = MergedDataset(X_audio_val_client, X_text_val_client, y_audio_val_client)
train_dataset_therapist = MergedDataset(X_audio_train_therapist, X_text_train_therapist, y_audio_train_therapist)
val_dataset_therapist = MergedDataset(X_audio_val_therapist, X_text_val_therapist, y_audio_val_therapist)

# Create DataLoaders
batch_size = 32
train_loader_client = DataLoader(train_dataset_client, batch_size=batch_size, shuffle=True)
val_loader_client = DataLoader(val_dataset_client, batch_size=batch_size, shuffle=False)
train_loader_therapist = DataLoader(train_dataset_therapist, batch_size=batch_size, shuffle=True)
val_loader_therapist = DataLoader(val_dataset_therapist, batch_size=batch_size, shuffle=False)

In [None]:
class LateFusionMTL(nn.Module):
    def __init__(self, input_dim_audio, input_dim_text, num_classes_client, num_classes_therapist, fusion_dim=256, dropout=0):
        super(LateFusionMTL, self).__init__()

        # Audio Network
        self.audio_net = nn.Sequential(
            nn.Linear(input_dim_audio, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, fusion_dim)
        )

        # Text Network
        self.text_net = nn.Sequential(
            nn.Linear(input_dim_text, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, fusion_dim)
        )

        # Shared Layer after fusion
        self.shared = nn.Sequential(
            nn.Linear(fusion_dim * 2, 512),  # Assuming concatenation
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        # Client classifier
        self.client_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes_client)
        )

        # Therapist classifier
        self.therapist_classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes_therapist)
        )

    def forward(self, audio_embeddings, text_embeddings, task_name=None):
        audio_output = self.audio_net(audio_embeddings)
        text_output = self.text_net(text_embeddings)

        # Fusion - using concatenation
        fused_output = torch.cat((audio_output, text_output), dim=1)
        
        shared_output = self.shared(fused_output)

        # Route through the appropriate classifier
        if task_name == "client":
            return self.client_classifier(shared_output)
        elif task_name == "therapist":
            return self.therapist_classifier(shared_output)
        else:
            raise ValueError(f"Invalid task_name: {task_name}. Expected 'client' or 'therapist'.")

In [None]:
input_dim_audio = X_audio_train_client.shape[1]
input_dim_text = X_text_train_client.shape[1]
num_classes_client = len(np.unique(labels_client))
num_classes_therapist = len(np.unique(labels_therapist))

model = LateFusionMTL(input_dim_audio, input_dim_text, num_classes_client, num_classes_therapist).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:
# Lists to store metrics for plotting
training_losses = []
validation_losses = []
training_f1_scores_client = []
validation_f1_scores_client = []
training_f1_scores_therapist = []
validation_f1_scores_therapist = []

# To keep track of the best model based on validation F1 score
best_f1_score_client = 0.0
best_f1_score_therapist = 0.0

epochs = 100

for epoch in range(epochs):
    
    # Training Phase
    model.train()
    total_train_loss = 0
    all_train_preds_client = []
    all_train_labels_client = []
    all_train_preds_therapist = []
    all_train_labels_therapist = []

    # Training phase for client task
    train_bar_client = tqdm(train_loader_client, desc=f"Training (Client) Epoch {epoch+1}", position=0, leave=True)
    for batch in train_bar_client:
        optimizer.zero_grad()
        audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
        outputs = model(audio_embeddings, text_embeddings, task_name="client")
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        all_train_preds_client.extend(preds.cpu().numpy())
        all_train_labels_client.extend(labels.cpu().numpy())

    # Training phase for therapist task
    train_bar_therapist = tqdm(train_loader_therapist, desc=f"Training (Therapist) Epoch {epoch+1}", position=0, leave=True)
    for batch in train_bar_therapist:
        optimizer.zero_grad()
        audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
        outputs = model(audio_embeddings, text_embeddings, task_name="therapist")
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        all_train_preds_therapist.extend(preds.cpu().numpy())
        all_train_labels_therapist.extend(labels.cpu().numpy())

    # Validation Phase
    model.eval()
    total_val_loss = 0
    all_val_preds_client = []
    all_val_labels_client = []
    all_val_preds_therapist = []
    all_val_labels_therapist = []
    
    with torch.no_grad():
        # Validation phase for client task
        val_bar_client = tqdm(val_loader_client, desc=f"Validating (Client) Epoch {epoch+1}", position=0, leave=True)
        for batch in val_bar_client:
            audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
            outputs = model(audio_embeddings, text_embeddings, task_name="client")
            loss = loss_fn(outputs, labels)
            total_val_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            all_val_preds_client.extend(preds.cpu().numpy())
            all_val_labels_client.extend(labels.cpu().numpy())

        # Validation phase for therapist task
        val_bar_therapist = tqdm(val_loader_therapist, desc=f"Validating (Therapist) Epoch {epoch+1}", position=0, leave=True)
        for batch in val_bar_therapist:
            audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
            outputs = model(audio_embeddings, text_embeddings, task_name="therapist")
            loss = loss_fn(outputs, labels)
            total_val_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            all_val_preds_therapist.extend(preds.cpu().numpy())
            all_val_labels_therapist.extend(labels.cpu().numpy())
    
    # Calculate training and validation F1 scores
    train_f1_client = f1_score(all_train_labels_client, all_train_preds_client, average='macro')
    val_f1_client = f1_score(all_val_labels_client, all_val_preds_client, average='macro')
    train_f1_therapist = f1_score(all_train_labels_therapist, all_train_preds_therapist, average='macro')
    val_f1_therapist = f1_score(all_val_labels_therapist, all_val_preds_therapist, average='macro')

    # Update lists for plotting
    training_losses.append(total_train_loss)
    validation_losses.append(total_val_loss)
    training_f1_scores_client.append(train_f1_client)
    validation_f1_scores_client.append(val_f1_client)
    training_f1_scores_therapist.append(train_f1_therapist)
    validation_f1_scores_therapist.append(val_f1_therapist)

    # Save the best model based on validation F1 score
    if val_f1_client > best_f1_score_client and val_f1_therapist > best_f1_score_therapist:
        torch.save(model.state_dict(), "best_model_multi_modal_mtl_client.pth")
        torch.save(model.state_dict(), "best_model_multi_modal_mtl_therapist.pth")
        best_f1_score_therapist = val_f1_therapist
        best_f1_score_client = val_f1_client
    elif val_f1_client > best_f1_score_client:
        torch.save(model.state_dict(), "best_model_multi_modal_mtl_client.pth")
        best_f1_score_client = val_f1_client
    elif val_f1_therapist > best_f1_score_therapist:
        torch.save(model.state_dict(), "best_model_multi_modal_mtl_therapist.pth")
        best_f1_score_therapist = val_f1_therapist
    
    # Print results for the current epoch
    print(f"\nEpoch {epoch+1}/{epochs}")
    print("-" * 30)
    print(f"Train Loss: {total_train_loss:.4f}, Val Loss: {total_val_loss:.4f}")
    print(f"Train F1 Client: {train_f1_client:.4f}, Val F1 Client: {val_f1_client:.4f}")
    print(f"Train F1 Therapist: {train_f1_therapist:.4f}, Val F1 Therapist: {val_f1_therapist:.4f}\n")


In [None]:
# Plot training and validation losses
epochs_range = range(epochs)

plt.figure(figsize=(20, 12))

# Plot training and validation loss
plt.subplot(2, 2, 1)
plt.plot(epochs_range, training_losses, label='Training Loss')
plt.plot(epochs_range, validation_losses, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

# Plot training and validation F1 scores for Client
plt.subplot(2, 2, 2)
plt.plot(epochs_range, training_f1_scores_client, label='Training F1 Client')
plt.plot(epochs_range, validation_f1_scores_client, label='Validation F1 Client')
plt.legend(loc='lower right')
plt.title('Training and Validation F1 Score (Client)')

# Plot training and validation F1 scores for Therapist
plt.subplot(2, 2, 3)
plt.plot(epochs_range, training_f1_scores_therapist, label='Training F1 Therapist')
plt.plot(epochs_range, validation_f1_scores_therapist, label='Validation F1 Therapist')
plt.legend(loc='lower right')
plt.title('Training and Validation F1 Score (Therapist)')

plt.tight_layout()
plt.show()

In [None]:
# Final evaluation on the test set 
# Client Best Model
model.load_state_dict(torch.load("best_model_multi_modal_mtl_client.pth"))
model.eval()
all_test_preds_client = []
all_test_labels_client = []

with torch.no_grad():
    test_bar_client = tqdm(val_loader_client, desc=f"Testing (Client)")
    for batch in test_bar_client:
        audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
        outputs = model(audio_embeddings, text_embeddings, task_name="client")
        _, preds = torch.max(outputs, dim=1)
        all_test_preds_client.extend(preds.cpu().numpy())
        all_test_labels_client.extend(labels.cpu().numpy())

print(classification_report(all_test_labels_client, all_test_preds_client, target_names=client_class_names))

# Therapist Best Model
model.load_state_dict(torch.load("best_model_multi_modal_mtl_therapist.pth"))
model.eval()
all_test_preds_therapist = []
all_test_labels_therapist = []

with torch.no_grad():
    test_bar_therapist = tqdm(val_loader_therapist, desc=f"Testing (Therapist)")
    for batch in test_bar_therapist:
        audio_embeddings, text_embeddings, labels = batch['audio_embeddings'].to(device), batch['text_embeddings'].to(device), batch['labels'].to(device)
        outputs = model(audio_embeddings, text_embeddings, task_name="therapist")
        _, preds = torch.max(outputs, dim=1)
        all_test_preds_therapist.extend(preds.cpu().numpy())
        all_test_labels_therapist.extend(labels.cpu().numpy())

print(classification_report(all_test_labels_therapist, all_test_preds_therapist, target_names=therapist_class_names))