In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, Trainer, TrainingArguments
import torch
from transformers import RobertaConfig, RobertaModel
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
import torch.nn.functional as F
from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model
from transformers import TrainerCallback, TrainerState, TrainerControl, AutoModelForAudioClassification
from transformers import DataCollator
from transformers import EvalPrediction
from torch import optim
from transformers import AutoFeatureExtractor
import evaluate

In [None]:
df_audio = pd.read_pickle('AnnoMI-ast-new.pkl')
df_text = pd.read_pickle('AnnoMI-full-with-audio-cleaned-text.pkl')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Audio

In [None]:
# make a dataset where intelocutor is client
df_client_audio = df_audio[df_audio['interlocutor'] == 'client']
df_client_audio = df_client_audio[['client_ast_emb', 'client_talk_type']]
df_client_audio.rename(columns={'client_ast_emb': 'inputs', 'client_talk_type': 'labels'}, inplace=True)
df_client_audio['labels'] = df_client_audio['labels'].astype("category").cat.codes

df_therapist_audio = df_audio[df_audio['interlocutor'] == 'therapist']
df_therapist_audio = df_therapist_audio[['therapist_ast_emb', 'main_therapist_behaviour']]
df_therapist_audio.rename(columns={'therapist_ast_emb': 'inputs', 'main_therapist_behaviour': 'labels'}, inplace=True)
df_therapist_audio['labels'] = df_therapist_audio['labels'].astype("category").cat.codes

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataframe):  # Add label2id as an argument
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = torch.tensor(self.data.iloc[idx]['labels'], dtype=torch.long)
        
        input_values = torch.tensor(self.data.iloc[idx]['inputs'], dtype=torch.float).squeeze(0)
        
        return {
            "input_values": input_values,
            "labels": label  # Use the encoded label
        }


In [None]:
client_dataset_audio = CustomDataset(df_client_audio)
therapist_dataset_audio = CustomDataset(df_therapist_audio)

In [None]:
client_dataloader_audio = DataLoader(client_dataset_audio, batch_size=8, shuffle=False)
therapist_dataloader_audio = DataLoader(therapist_dataset_audio, batch_size=8, shuffle=False)

In [None]:
class MTLASTAudioClassificaiton(nn.Module):
    def __init__(self, base_model_name, num_classes_client, num_classes_therapist):
        super(MTLASTAudioClassificaiton, self).__init__()

        # Shared layer
        self.base_model = AutoModelForAudioClassification.from_pretrained(base_model_name)

        # Client specific classifier
        self.client_classifier = nn.Linear(527, num_classes_client)  # Adjust the input dimension

        # Therapist specific classifier
        self.therapist_classifier = nn.Linear(527, num_classes_therapist)  # Adjust the input dimension

    def forward(self, input_values, task_name=None, return_embeddings=False):
        # Passing input_values through the shared layer
        shared_output = self.base_model(input_values=input_values).logits
        
        # If return_embeddings is True, return the shared_output directly
        if return_embeddings:
            return shared_output
        
        # pooled_output = torch.mean(shared_output, dim=1)  # Only if you want mean pooling
        pooled_output = shared_output  # Use directly if not pooling

        # Routing through the appropriate classifier
        if task_name == 'client':
            return self.client_classifier(pooled_output)
        elif task_name == 'therapist':
            return self.therapist_classifier(pooled_output)
        else:
            raise ValueError(f"Invalid task_name: {task_name}. Expected 'client' or 'therapist'.")


In [None]:
# Initialize the model
num_classes_client = 3
num_classes_therapist = 4
base_model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"  # This is an example name; adjust as needed

model_client_audio = MTLASTAudioClassificaiton(base_model_name, num_classes_client, num_classes_therapist).to(device)

# Load the saved weights (for demonstration, I'm using the client weights as an example)
model_path_client_audio = "best_mtl_model_audio_ast_client.pth"  # Update the path accordingly
model_client_audio.load_state_dict(torch.load(model_path_client_audio))


model_therapist_audio = MTLASTAudioClassificaiton(base_model_name, num_classes_client, num_classes_therapist).to(device)
model_path_therapist_audio = "best_mtl_model_audio_ast_therapist.pth"
model_therapist_audio.load_state_dict(torch.load(model_path_therapist_audio))

In [None]:
torch.cuda.empty_cache()

In [None]:
def extract_audio_embeddings_and_labels_from_dataloader(model, dataloader):
    all_embeddings = []
    all_labels = []
    model.eval()  # Set the model to evaluation mode
    progress_bar = tqdm(dataloader, desc="Extracting embeddings and labels")
    with torch.no_grad():
        for batch in progress_bar:
            input_values = batch["input_values"].to(device)
            embeddings = model(input_values=input_values, return_embeddings=True)
            del input_values
            all_embeddings.append(embeddings)
            all_labels.append(batch["labels"])
            del embeddings
    torch.cuda.empty_cache()
    return torch.cat(all_embeddings, dim=0), torch.cat(all_labels, dim=0)  # Concatenate embeddings and labels along the batch dimension


In [None]:
audio_embeddings_client, audio_labels_client = extract_audio_embeddings_and_labels_from_dataloader(model_client_audio, client_dataloader_audio)
audio_embeddings_therapist, audio_labels_therapist = extract_audio_embeddings_and_labels_from_dataloader(model_therapist_audio, therapist_dataloader_audio)

In [None]:
# Save the client audio embeddings and labels
import pickle

with open('audio_embeddings_ast_client.pkl', 'wb') as file:
    pickle.dump({
        "embeddings": audio_embeddings_client.cpu().numpy(),
        "labels": audio_labels_client.cpu().numpy()
    }, file)

# Save the therapist audio embeddings and labels
with open('audio_embeddings_ast_therapist.pkl', 'wb') as file:
    pickle.dump({
        "embeddings": audio_embeddings_therapist.cpu().numpy(),
        "labels": audio_labels_therapist.cpu().numpy()
    }, file)


In [None]:
# Free up GPU memory
del model_client_audio, model_therapist_audio, audio_embeddings_client, audio_labels_client, audio_embeddings_therapist, audio_labels_therapist
torch.cuda.empty_cache()

# Text

In [None]:
tokenizer_roberta_large = AutoTokenizer.from_pretrained("roberta-large")

In [None]:
# Tokenizer
tokenizer = tokenizer_roberta_large

# Tokenize the utterances for both tasks
def tokenize_data(texts):
    return tokenizer(texts.tolist(), truncation=True, padding=True, return_tensors="pt")

# Tokenize the client utterances
client_texts = df_text[df_text['interlocutor'] == 'client']['utterance_text']
client_labels = df_text[df_text['interlocutor'] == 'client']['client_talk_type'].astype("category").cat.codes
client_encodings = tokenize_data(client_texts)

# Tokenize the therapist utterances
therapist_texts = df_text[df_text['interlocutor'] == 'therapist']['utterance_text']
therapist_labels = df_text[df_text['interlocutor'] == 'therapist']['main_therapist_behaviour'].astype("category").cat.codes
therapist_encodings = tokenize_data(therapist_texts)

In [None]:
# PyTorch Dataset
class MTLDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
client_dataset = MTLDataset(client_encodings, client_labels.to_numpy())
therapist_dataset = MTLDataset(therapist_encodings, therapist_labels.to_numpy())

In [None]:
class MTLModel(nn.Module):
    def __init__(self, base_model_name, num_classes_client, num_classes_therapist):
        super(MTLModel, self).__init__()
        
        # Shared layers using Roberta
        self.shared = RobertaModel.from_pretrained(base_model_name)
        
        # Task-specific heads
        self.client_classifier = nn.Linear(self.shared.config.hidden_size, num_classes_client)
        self.therapist_classifier = nn.Linear(self.shared.config.hidden_size, num_classes_therapist)
    
    def forward(self, input_ids, attention_mask, task_name=None, return_embeddings=False):
        shared_output = self.shared(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = shared_output[0][:, 0, :]
        
        # If return_embeddings is True, return the pooled_output directly
        if return_embeddings:
            return pooled_output
        
        # Route through the appropriate classifier
        if task_name == "client":
            return self.client_classifier(pooled_output)
        elif task_name == "therapist":
            return self.therapist_classifier(pooled_output)
        else:
            raise ValueError(f"Invalid task_name: {task_name}. Expected 'client' or 'therapist'.")


In [None]:
# Number of unique labels for each task
num_classes_client = 3
num_classes_therapist = 4

# Initialize the model for client task
base_model_name = "roberta-large"
model_client_text = MTLModel(base_model_name, num_classes_client, num_classes_therapist).to(device)

# Load the saved weights
model_path_client_text = "best_mtl_model_roberta_client.pth"
model_client_text.load_state_dict(torch.load(model_path_client_text))


# Initialize the model for therapist task
model_therapist_text = MTLModel(base_model_name, num_classes_client, num_classes_therapist).to(device)

# Load the saved weights
model_path_therapist_text = "best_mtl_model_roberta_therapist.pth"
model_therapist_text.load_state_dict(torch.load(model_path_therapist_text))

In [None]:
torch.cuda.empty_cache()

In [None]:
# Data Loaders
batch_size = 32

client_dataloader_text = DataLoader(client_dataset, batch_size=batch_size, shuffle=False)
therapist_dataloader_text = DataLoader(therapist_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def extract_text_embeddings_and_labels_from_dataloader(model, dataloader):
    all_embeddings = []
    all_labels = []
    model.eval()  # Set the model to evaluation mode
    progress_bar = tqdm(dataloader, desc="Extracting text embeddings and labels")
    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            embeddings = model(input_ids=input_ids, attention_mask=attention_mask, return_embeddings=True)
            del input_ids, attention_mask
            all_embeddings.append(embeddings)
            all_labels.append(batch["labels"])
            del embeddings
    torch.cuda.empty_cache()
    return torch.cat(all_embeddings, dim=0), torch.cat(all_labels, dim=0)  # Concatenate embeddings and labels along the batch dimension


In [None]:
text_embeddings_client, text_labels_client = extract_text_embeddings_and_labels_from_dataloader(model_client_text, client_dataloader_text)
text_embeddings_therapist, text_labels_therapist = extract_text_embeddings_and_labels_from_dataloader(model_therapist_text, therapist_dataloader_text)

In [None]:
import pickle

# Save the client text embeddings and labels
with open('text_embeddings_roberta_client.pkl', 'wb') as file:
    pickle.dump({
        "embeddings": text_embeddings_client.cpu().numpy(),
        "labels": text_labels_client.cpu().numpy()
    }, file)

# Save the therapist text embeddings and labels
with open('text_embeddings_roberta_therapist.pkl', 'wb') as file:
    pickle.dump({
        "embeddings": text_embeddings_therapist.cpu().numpy(),
        "labels": text_labels_therapist.cpu().numpy()
    }, file)


In [None]:
# Free up GPU memory
del model_client_text, model_therapist_text, text_embeddings_client, text_labels_client, text_embeddings_therapist, text_labels_therapist
torch.cuda.empty_cache()