In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm
from sklearn.metrics import f1_score
import pandas as pd
import pickle
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
df_anno = pd.read_pickle('AnnoMI-full-with-audio-cleaned-text.pkl')

In [None]:
df_anno.shape

In [None]:
# Extract unique class names for client and therapist tasks
client_class_names = df_anno['client_talk_type'].dropna().unique().tolist()
therapist_class_names = df_anno['main_therapist_behaviour'].dropna().unique().tolist()

client_class_names, therapist_class_names

In [None]:
# Load saved embeddings and labels
with open('audio_embeddings_ast_client.pkl', 'rb') as file:
    audio_data_client = pickle.load(file)
    
with open('audio_embeddings_ast_therapist.pkl', 'rb') as file:
    audio_data_therapist = pickle.load(file)

with open('text_embeddings_roberta_client.pkl', 'rb') as file:
    text_data_client = pickle.load(file)
    
with open('text_embeddings_roberta_therapist.pkl', 'rb') as file:
    text_data_therapist = pickle.load(file)

# Concatenate audio and text embeddings for each role
embeddings_client = np.concatenate((audio_data_client["embeddings"], text_data_client["embeddings"]), axis=1)
embeddings_therapist = np.concatenate((audio_data_therapist["embeddings"], text_data_therapist["embeddings"]), axis=1)

labels_client = audio_data_client["labels"]
labels_therapist = audio_data_therapist["labels"]

embeddings_client.shape, embeddings_therapist.shape

In [None]:
class EmbeddingsDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "embeddings": torch.tensor(self.embeddings[idx], dtype=torch.float32),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [None]:
X_train_client, X_val_client, y_train_client, y_val_client = train_test_split(
    embeddings_client, labels_client, test_size=0.2, random_state=42)

X_train_therapist, X_val_therapist, y_train_therapist, y_val_therapist = train_test_split(
    embeddings_therapist, labels_therapist, test_size=0.2, random_state=42)

In [None]:
batch_size = 1024
# Creating DataLoaders
train_loader_client = DataLoader(EmbeddingsDataset(X_train_client, y_train_client), batch_size=batch_size, shuffle=True)
val_loader_client = DataLoader(EmbeddingsDataset(X_val_client, y_val_client), batch_size=batch_size, shuffle=False)

train_loader_therapist = DataLoader(EmbeddingsDataset(X_train_therapist, y_train_therapist), batch_size=batch_size, shuffle=True)
val_loader_therapist = DataLoader(EmbeddingsDataset(X_val_therapist, y_val_therapist), batch_size=batch_size, shuffle=False)

In [None]:
class MultiModalMTL(nn.Module):
    def __init__(self, input_dim, num_classes_client, num_classes_therapist, hidden_dim=512, dropout=0):
        super(MultiModalMTL, self).__init__()
        
        # Shared layer
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Task-specific heads
        self.client_classifier = nn.Linear(hidden_dim, num_classes_client)
        self.therapist_classifier = nn.Linear(hidden_dim, num_classes_therapist)
    
    def forward(self, embeddings, task_name=None):
        shared_output = self.shared(embeddings)
        
        # Route through the appropriate classifier
        if task_name == "client":
            return self.client_classifier(shared_output)
        elif task_name == "therapist":
            return self.therapist_classifier(shared_output)
        else:
            raise ValueError(f"Invalid task_name: {task_name}. Expected 'client' or 'therapist'.")
        

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size * 2),
            nn.ReLU(),
            nn.Linear(embed_size * 2, embed_size)
        )

    def forward(self, value, key, query, mask=None):
        attention_out, _ = self.attention(query, key, value, attn_mask=mask)
        x = self.norm1(attention_out + query)
        forward_out = self.feed_forward(x)
        out = self.norm2(forward_out + x)
        return out


class EnhancedMultiModalMTL(nn.Module):
    def __init__(self, input_dim, num_classes_client, num_classes_therapist, heads=16):
        super(EnhancedMultiModalMTL, self).__init__()
        
        # Adjusting the embedding size
        self.embedding_adjust = nn.Linear(input_dim, 1024)  # Increase width here
        
        # More Transformer Blocks
        self.transformer_block1 = TransformerBlock(embed_size=1024, heads=heads)
        self.transformer_block2 = TransformerBlock(embed_size=1024, heads=heads)
        self.transformer_block3 = TransformerBlock(embed_size=1024, heads=heads)  # Added
        self.transformer_block4 = TransformerBlock(embed_size=1024, heads=heads)  # Added
        
        # Deeper task-specific heads
        self.client_classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0),
            nn.Linear(256, num_classes_therapist)
        )
        
        self.therapist_classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0),
            nn.Linear(256, num_classes_therapist)
        )

    def forward(self, embeddings, task_name=None):
        embeddings = self.embedding_adjust(embeddings)
        embeddings = embeddings.unsqueeze(0)
        
        shared_output = self.transformer_block1(embeddings, embeddings, embeddings)
        shared_output = self.transformer_block2(shared_output, shared_output, shared_output)
        shared_output = self.transformer_block3(shared_output, shared_output, shared_output)  # Added
        shared_output = self.transformer_block4(shared_output, shared_output, shared_output)  # Added
        
        shared_output = shared_output.squeeze(0)

        # Route through the appropriate classifier
        if task_name == "client":
            return self.client_classifier(shared_output)
        elif task_name == "therapist":
            return self.therapist_classifier(shared_output)
        else:
            raise ValueError(f"Invalid task_name: {task_name}. Expected 'client' or 'therapist'.")



In [None]:
input_dim = embeddings_client.shape[1]
num_classes_client = len(np.unique(labels_client))
num_classes_therapist = len(np.unique(labels_therapist))

model = EnhancedMultiModalMTL(input_dim, num_classes_client, num_classes_therapist).to(device)

In [None]:
input_dim

In [None]:
num_classes_client, num_classes_therapist

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss()
epochs = 100

In [None]:
# Lists to store metrics for plotting
training_losses = []
validation_losses = []
training_f1_scores_client = []
validation_f1_scores_client = []
training_f1_scores_therapist = []
validation_f1_scores_therapist = []

# To keep track of the best model based on validation F1 score
best_f1_score_client = 0.0
best_f1_score_therapist = 0.0

# Training loop
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    all_train_preds_client = []
    all_train_labels_client = []
    all_train_preds_therapist = []
    all_train_labels_therapist = []

    # Training phase for client task
    train_bar_client = tqdm(train_loader_client, desc=f"Training (Client) Epoch {epoch+1}")
    for batch in train_bar_client:
        optimizer.zero_grad()
        embeddings, labels = batch['embeddings'].to(device), batch['labels'].to(device)
        outputs = model(embeddings, task_name="client")
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        _, preds = torch.max(outputs, dim=1)
        all_train_preds_client.extend(preds.cpu().numpy())
        all_train_labels_client.extend(labels.cpu().numpy())

        train_bar_client.set_postfix(loss=loss.item())
    
    train_f1_client = f1_score(all_train_labels_client, all_train_preds_client, average='macro')
    training_f1_scores_client.append(train_f1_client)

    # Training phase for therapist task
    train_bar_therapist = tqdm(train_loader_therapist, desc=f"Training (Therapist) Epoch {epoch+1}")
    for batch in train_bar_therapist:
        optimizer.zero_grad()
        embeddings, labels = batch['embeddings'].to(device), batch['labels'].to(device)
        outputs = model(embeddings, task_name="therapist")
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        _, preds = torch.max(outputs, dim=1)
        all_train_preds_therapist.extend(preds.cpu().numpy())
        all_train_labels_therapist.extend(labels.cpu().numpy())

        train_bar_therapist.set_postfix(loss=loss.item())

    train_f1_therapist = f1_score(all_train_labels_therapist, all_train_preds_therapist, average='macro')
    training_f1_scores_therapist.append(train_f1_therapist)

    training_losses.append(total_train_loss)

    # Evaluation phase for client task
    model.eval()
    total_val_loss = 0
    all_val_preds_client = []
    all_val_labels_client = []

    val_bar_client = tqdm(val_loader_client, desc=f"Evaluating (Client) Epoch {epoch+1}")
    with torch.no_grad():
        for batch in val_bar_client:
            embeddings, labels = batch['embeddings'].to(device), batch['labels'].to(device)
            outputs = model(embeddings, task_name="client")
            loss = loss_fn(outputs, labels)
            total_val_loss += loss.item()

            _, preds = torch.max(outputs, dim=1)
            all_val_preds_client.extend(preds.cpu().numpy())
            all_val_labels_client.extend(labels.cpu().numpy())

    val_f1_client = f1_score(all_val_labels_client, all_val_preds_client, average='macro')
    validation_f1_scores_client.append(val_f1_client)

    # Evaluation phase for therapist task
    all_val_preds_therapist = []
    all_val_labels_therapist = []

    val_bar_therapist = tqdm(val_loader_therapist, desc=f"Evaluating (Therapist) Epoch {epoch+1}")
    with torch.no_grad():
        for batch in val_bar_therapist:
            embeddings, labels = batch['embeddings'].to(device), batch['labels'].to(device)
            outputs = model(embeddings, task_name="therapist")
            loss = loss_fn(outputs, labels)
            total_val_loss += loss.item()

            _, preds = torch.max(outputs, dim=1)
            all_val_preds_therapist.extend(preds.cpu().numpy())
            all_val_labels_therapist.extend(labels.cpu().numpy())

    val_f1_therapist = f1_score(all_val_labels_therapist, all_val_preds_therapist, average='macro')
    validation_f1_scores_therapist.append(val_f1_therapist)
    
    validation_losses.append(total_val_loss)

    # Save the best model
    if val_f1_client > best_f1_score_client and val_f1_therapist > best_f1_score_therapist:
        best_f1_score_client = val_f1_client
        best_f1_score_therapist = val_f1_therapist
        torch.save(model.state_dict(), "best_mtl_multi_modal_model_client_therapist.pth")
    elif val_f1_client > best_f1_score_client:
        best_f1_score_client = val_f1_client
        torch.save(model.state_dict(), "best_mtl_multi_modal_model_client.pth")
    elif val_f1_therapist > best_f1_score_therapist:
        best_f1_score_therapist = val_f1_therapist
        torch.save(model.state_dict(), "best_mtl_multi_modal_model_therapist.pth")
    
    print(f"Epoch {epoch+1}, Train Loss: {total_train_loss}, Val Loss: {total_val_loss}, \nTrain F1 Client: {train_f1_client:.4f}, Val F1 Client: {val_f1_client:.4f}, \nTrain F1 Therapist: {train_f1_therapist:.4f}, Val F1 Therapist: {val_f1_therapist:.4f}")

In [None]:
epochs_range = range(epochs)

plt.figure(figsize=(20, 12))

# Plot training and validation loss
plt.subplot(2, 2, 1)
plt.plot(epochs_range, training_losses, label='Training Loss')
plt.plot(epochs_range, validation_losses, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

# Plot training and validation F1 scores for Client
plt.subplot(2, 2, 2)
plt.plot(epochs_range, training_f1_scores_client, label='Training F1 Client')
plt.plot(epochs_range, validation_f1_scores_client, label='Validation F1 Client')
plt.legend(loc='lower right')
plt.title('Training and Validation F1 Score (Client)')

# Plot training and validation F1 scores for Therapist
plt.subplot(2, 2, 3)
plt.plot(epochs_range, training_f1_scores_therapist, label='Training F1 Therapist')
plt.plot(epochs_range, validation_f1_scores_therapist, label='Validation F1 Therapist')
plt.legend(loc='lower right')
plt.title('Training and Validation F1 Score (Therapist)')

plt.tight_layout()
plt.show()

In [None]:
# Load the best model
model_client = model
model_client.load_state_dict(torch.load("best_mtl_multi_modal_model_client.pth"))
model_client.eval()

model_therapist = model
model_therapist.load_state_dict(torch.load("best_mtl_multi_modal_model_therapist.pth"))
model_therapist.eval()

# Final evaluation function
def evaluate_model(model, dataloader, task_name):
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids, labels = batch['embeddings'].to(device), batch['labels'].to(device)
            outputs = model(input_ids, task_name=task_name)
            _, preds = torch.max(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return all_labels, all_preds

# Evaluate for both tasks
labels_client, preds_client = evaluate_model(model_client, val_loader_client, "client")
labels_therapist, preds_therapist = evaluate_model(model_therapist, val_loader_therapist, "therapist")

# For client task:
print("\nFinal Classification Report (Client):")
print(classification_report(labels_client, preds_client, target_names=client_class_names))

# For therapist task:
print("\nFinal Classification Report (Therapist):")
print(classification_report(labels_therapist, preds_therapist, target_names=therapist_class_names))
