In [1]:
import pandas as pd
import numpy as np
import pickle as pkl
import random
import torchaudio

random.seed(0)
np.random.seed(0)

In [2]:
!nvidia-smi

Wed Jul 26 01:22:49 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off | 00000000:01:00.0  On |                  N/A |
| 32%   50C    P3             141W / 350W |    520MiB / 24576MiB |      9%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         

## Load GoEmotions and General Audio Datasets (CREMA, TESS,  RAVDASS, ETC)

In [3]:
train_audio = pkl.load(open('/home/vmachado/Documents/c4ai_clip_audio_text/data/c4ai_clip/train_audio.pkl', "rb"))[['path', 'label']]
test_audio = pkl.load(open('/home/vmachado/Documents/c4ai_clip_audio_text/data/c4ai_clip/test_audio.pkl', "rb"))[['path', 'label']]
train_text = pkl.load(open('/home/vmachado/Documents/c4ai_clip_audio_text/data/c4ai_clip/train_text.pkl', "rb"))[['text', 'grouped_label']]
test_text = pkl.load(open('/home/vmachado/Documents/c4ai_clip_audio_text/data/c4ai_clip/test_text.pkl', "rb"))[['text', 'grouped_label']]

In [4]:
go_emotions = pd.concat([train_text, test_text])

In [5]:
go_emotions.groupby("grouped_label").count()

Unnamed: 0_level_0,text
grouped_label,Unnamed: 1_level_1
anger,6039
disgust,664
fear,705
joy,19002
neutral,14429
sadness,2936
surprise,5062


In [6]:
anger = go_emotions[go_emotions["grouped_label"] == "anger"].sample(5000, replace=False, random_state=0)
disgust = go_emotions[go_emotions["grouped_label"] == "disgust"].sample(4000, replace=True, random_state=0)
fear = go_emotions[go_emotions["grouped_label"] == "fear"].sample(4000, replace=True, random_state=0)
joy = go_emotions[go_emotions["grouped_label"] == "joy"].sample(5000, replace=False, random_state=0)
neutral = go_emotions[go_emotions["grouped_label"] == "neutral"].sample(5000, replace=False, random_state=0)
sadness = go_emotions[go_emotions["grouped_label"] == "sadness"].sample(2000, replace=True, random_state=0)

In [7]:
go_emotions = go_emotions[go_emotions["grouped_label"] != "anger"]
go_emotions = go_emotions[go_emotions["grouped_label"] != "joy"]
go_emotions = go_emotions[go_emotions["grouped_label"] != "neutral"]
go_emotions = pd.concat([go_emotions, anger, disgust, fear, joy, neutral, sadness])
go_emotions.groupby("grouped_label").count()

Unnamed: 0_level_0,text
grouped_label,Unnamed: 1_level_1
anger,5000
disgust,4664
fear,4705
joy,5000
neutral,5000
sadness,4936
surprise,5062


In [8]:
def norm_labels(x):
    if x == "afraid":
        return "fear"
    elif x == "angry":
        return "anger"
    elif x == "disgusted":
        return "disgust"
    elif x == "sad":
        return "sadness"
    else:
        return x

In [9]:
train_audio["label"] = train_audio["label"].apply(norm_labels)
test_audio["label"] = test_audio["label"].apply(norm_labels)

In [10]:
pd.concat([train_audio,test_audio]).groupby("label").count()

Unnamed: 0_level_0,path
label,Unnamed: 1_level_1
anger,1863
disgust,1863
fear,1863
joy,2055
neutral,1583
sadness,1863
surprise,592


In [11]:
audio_datasets = pd.concat([train_audio,test_audio]).reset_index(drop=True)

In [12]:
surprise = audio_datasets[audio_datasets["label"] == "surprise"].sample(2000, replace=True, random_state=0)
audio_datasets = audio_datasets[audio_datasets["label"] != "surprise"]
audio_datasets = pd.concat([audio_datasets, surprise]).reset_index(drop=True)
#audio_datasets

In [13]:
audio_datasets.groupby("label").count()

Unnamed: 0_level_0,path
label,Unnamed: 1_level_1
anger,1863
disgust,1863
fear,1863
joy,2055
neutral,1583
sadness,1863
surprise,2000


## Load Meld and IEMOCAP

In [14]:
train_df_erc = pd.read_csv("train_text_df.csv", index_col=0).rename(columns={"utterance":"text"})
train_df_erc["path"] = train_df_erc["path"].apply(lambda x: '/home/vmachado/Documents/' + x)
train_df_erc

Unnamed: 0,text,label,path
0,The only one I know still love his parents. [B...,joy,/home/vmachado/Documents/multimodal-datasets/I...
1,The only one I know still love his parents. Ye...,neutral,/home/vmachado/Documents/multimodal-datasets/I...
2,Oh it's not bad thing it's good thing. You kno...,joy,/home/vmachado/Documents/multimodal-datasets/I...
3,"You know it's nice here, the air is sweet. You...",sadness,/home/vmachado/Documents/multimodal-datasets/I...
4,"You're not sorry you came? Not sorry, no. I c...",sadness,/home/vmachado/Documents/multimodal-datasets/I...
...,...,...,...
13723,That would be no. Come on. It doesn't taste ba...,neutral,/home/vmachado/Documents/multimodal-datasets/M...
13724,"Come on. It doesn't taste bad. Yeah, it's kind...",joy,/home/vmachado/Documents/multimodal-datasets/M...
13725,"Yeah, it's kinda sweet, sorta like, uh... Cant...",neutral,/home/vmachado/Documents/multimodal-datasets/M...
13726,Cantaloupe juice. Exactly. [BFR] You've tasted...,surprise,/home/vmachado/Documents/multimodal-datasets/M...


In [15]:
test_df_erc = pd.read_csv("test_text_df.csv", index_col=0).rename(columns={"utterance":"text"})
test_df_erc["path"] = test_df_erc["path"].apply(lambda x: '/home/vmachado/Documents/' + x)
test_df_erc

Unnamed: 0,text,label,path
0,"[BFR] Brian, I need help. [AFT] Babe, I don't...",sadness,/home/vmachado/Documents/multimodal-datasets/I...
1,"Brian, I need help. [BFR] Babe, I don't know w...",neutral,/home/vmachado/Documents/multimodal-datasets/I...
2,"Babe, I don't know what to tell you. Don't gi...",neutral,/home/vmachado/Documents/multimodal-datasets/I...
3,"I wish I had some answers for you, babe. I me...",neutral,/home/vmachado/Documents/multimodal-datasets/I...
4,I went to school and I got my degree. And I g...,neutral,/home/vmachado/Documents/multimodal-datasets/I...
...,...,...,...
3846,"Oh, it is. It isn't. [BFR] It is. [AFT] Isn't!",neutral,/home/vmachado/Documents/multimodal-datasets/M...
3847,It isn't. It is. [BFR] Isn't! [AFT],anger,/home/vmachado/Documents/multimodal-datasets/M...
3848,[BFR] Yeah baby! [AFT] I’m really glad you gu...,joy,/home/vmachado/Documents/multimodal-datasets/M...
3849,Yeah baby! [BFR] I’m really glad you guys are ...,neutral,/home/vmachado/Documents/multimodal-datasets/M...


In [16]:
test_df_erc["source"] = test_df_erc["path"].apply(lambda x: "meld" if "MELD" in x else "iemocap")
test_df_erc

Unnamed: 0,text,label,path,source
0,"[BFR] Brian, I need help. [AFT] Babe, I don't...",sadness,/home/vmachado/Documents/multimodal-datasets/I...,iemocap
1,"Brian, I need help. [BFR] Babe, I don't know w...",neutral,/home/vmachado/Documents/multimodal-datasets/I...,iemocap
2,"Babe, I don't know what to tell you. Don't gi...",neutral,/home/vmachado/Documents/multimodal-datasets/I...,iemocap
3,"I wish I had some answers for you, babe. I me...",neutral,/home/vmachado/Documents/multimodal-datasets/I...,iemocap
4,I went to school and I got my degree. And I g...,neutral,/home/vmachado/Documents/multimodal-datasets/I...,iemocap
...,...,...,...,...
3846,"Oh, it is. It isn't. [BFR] It is. [AFT] Isn't!",neutral,/home/vmachado/Documents/multimodal-datasets/M...,meld
3847,It isn't. It is. [BFR] Isn't! [AFT],anger,/home/vmachado/Documents/multimodal-datasets/M...,meld
3848,[BFR] Yeah baby! [AFT] I’m really glad you gu...,joy,/home/vmachado/Documents/multimodal-datasets/M...,meld
3849,Yeah baby! [BFR] I’m really glad you guys are ...,neutral,/home/vmachado/Documents/multimodal-datasets/M...,meld


In [17]:
test_df_erc.groupby("source").count()

Unnamed: 0_level_0,text,label,path
source,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
iemocap,1241,1241,1241
meld,2610,2610,2610


## Join datasets

In [18]:
train_df_erc.groupby("label").count()

Unnamed: 0_level_0,text,path
label,Unnamed: 1_level_1,Unnamed: 2_level_1
anger,1954,1954
disgust,258,258
fear,266,266
joy,2783,2783
neutral,5804,5804
sadness,1451,1451
surprise,1212,1212


In [19]:
ang = train_df_erc[train_df_erc["label"] == "anger"].sample(3000, replace=True, random_state=0)
disg = train_df_erc[train_df_erc["label"] == "disgust"].sample(4700, replace=True, random_state=0)
fear = train_df_erc[train_df_erc["label"] == "fear"].sample(4700, replace=True, random_state=0)
joy = train_df_erc[train_df_erc["label"] == "joy"].sample(2300, replace=True, random_state=0)
sadness = train_df_erc[train_df_erc["label"] == "sadness"].sample(3500, replace=True, random_state=0)
surprise = train_df_erc[train_df_erc["label"] == "surprise"].sample(3800, replace=True, random_state=0)

#excited = train_df_erc[train_df_erc["label"] == "excited"].sample(4300, replace=True, random_state=0)
#frustration = train_df_erc[train_df_erc["label"] == "frustration"].sample(3600, replace=True, random_state=0)

In [20]:
#train_df_erc_resampled = pd.concat([train_df_erc, joy, ang, disg, fear, surprise, sadness, excited, frustration]).reset_index(drop=True)
train_df_erc_resampled = pd.concat([train_df_erc, joy, ang, disg, fear, surprise, sadness]).reset_index(drop=True)

In [21]:
#train_df_erc_resampled = train_df_erc

In [22]:
train_df_erc_resampled.groupby("label").count()

Unnamed: 0_level_0,text,path
label,Unnamed: 1_level_1,Unnamed: 2_level_1
anger,4954,4954
disgust,4958,4958
fear,4966,4966
joy,5083,5083
neutral,5804,5804
sadness,4951,4951
surprise,5012,5012


## VoxPopuli + VoxCeleb

In [23]:
df_vox = pd.read_csv("voxceleb.csv").drop(columns="Unnamed: 0")[["path", "text", "sentiment_label"]]
df_vox

Unnamed: 0,path,text,sentiment_label
0,/home/vmachado/.cache/huggingface/datasets/dow...,and i i don't believe in god no religion says ...,Neutral
1,/home/vmachado/.cache/huggingface/datasets/dow...,the question because of my mother till i was f...,Neutral
2,/home/vmachado/.cache/huggingface/datasets/dow...,from my own culture things changed i i think a...,Neutral
3,/home/vmachado/.cache/huggingface/datasets/dow...,of god what is a creator the almighty that uh,Neutral
4,/home/vmachado/.cache/huggingface/datasets/dow...,i don't wanna pinpoint what exactly god is i i...,Neutral
...,...,...,...
7161,/home/vmachado/.cache/huggingface/datasets/dow...,the movie while he's solving this mystery exce...,Neutral
7162,/home/vmachado/.cache/huggingface/datasets/dow...,in my backstory you know that i actually uh hi...,Neutral
7163,/home/vmachado/.cache/huggingface/datasets/dow...,and it's just high action uh uh you want you,Neutral
7164,/home/vmachado/.cache/huggingface/datasets/dow...,you you can't stop thinking and and wondering ...,Neutral


In [24]:
df_ls = pd.read_csv("voxceleb.csv").drop(columns="Unnamed: 0")[["path", "text", "sentiment_label"]] #pd.read_csv("df_ls.csv")
df_ls

Unnamed: 0,path,text,sentiment_label
0,/home/vmachado/.cache/huggingface/datasets/dow...,and i i don't believe in god no religion says ...,Neutral
1,/home/vmachado/.cache/huggingface/datasets/dow...,the question because of my mother till i was f...,Neutral
2,/home/vmachado/.cache/huggingface/datasets/dow...,from my own culture things changed i i think a...,Neutral
3,/home/vmachado/.cache/huggingface/datasets/dow...,of god what is a creator the almighty that uh,Neutral
4,/home/vmachado/.cache/huggingface/datasets/dow...,i don't wanna pinpoint what exactly god is i i...,Neutral
...,...,...,...
7161,/home/vmachado/.cache/huggingface/datasets/dow...,the movie while he's solving this mystery exce...,Neutral
7162,/home/vmachado/.cache/huggingface/datasets/dow...,in my backstory you know that i actually uh hi...,Neutral
7163,/home/vmachado/.cache/huggingface/datasets/dow...,and it's just high action uh uh you want you,Neutral
7164,/home/vmachado/.cache/huggingface/datasets/dow...,you you can't stop thinking and and wondering ...,Neutral


## Join all datasets

In [25]:
#df_train = pd.concat([go_emotions.rename(columns={"grouped_label":"label"}).assign(path=[None for _ in range(len(go_emotions))]), audio_datasets.assign(text=[None for _ in range(len(audio_datasets))]), train_df_erc_resampled, df_ls]).reset_index(drop=True) #.drop(columns="path")
#df_train = pd.concat([audio_datasets.assign(text=[None for _ in range(len(audio_datasets))]), train_df_erc_resampled,train_df_erc_resampled, df_ls]).reset_index(drop=True) #.drop(columns="path")
df_train = pd.concat([go_emotions.rename(columns={"grouped_label":"label"}).assign(path=[None for _ in range(len(go_emotions))]), train_df_erc_resampled, df_ls, audio_datasets]).reset_index(drop=True)
df_train

Unnamed: 0,text,label,path,sentiment_label
0,To make her feel threatened,fear,,
1,OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe...,surprise,,
2,Demographics? I don’t know anybody under 35 wh...,surprise,,
3,Maybe that’s what happened to the great white ...,surprise,,
4,"I never thought it was at the same moment, but...",surprise,,
...,...,...,...,...
90346,,surprise,./audio/audio_emo/tess.woman.surprised.351.wav,
90347,,surprise,./audio/audio_emo/ravdass.man.surprise.63.wav,
90348,,surprise,./audio/audio_emo/tess.woman.surprised.26.wav,
90349,,surprise,./audio/audio_emo/tess.woman.surprised.67.wav,


In [26]:
import math

def label_to_sentiment(x):
    if x == None:
        return x
    if x in ["joy", "surprise", "excited"]:
        return "Positive"
    elif x in ["fear", "anger", "disgust", "sadness", "frustration"]:
        return "Negative"
    else:
        return "Neutral"

In [27]:
df_train["sentiment_label"] = df_train["label"].apply(label_to_sentiment)

In [28]:
df_train

Unnamed: 0,text,label,path,sentiment_label
0,To make her feel threatened,fear,,Negative
1,OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe...,surprise,,Positive
2,Demographics? I don’t know anybody under 35 wh...,surprise,,Positive
3,Maybe that’s what happened to the great white ...,surprise,,Positive
4,"I never thought it was at the same moment, but...",surprise,,Positive
...,...,...,...,...
90346,,surprise,./audio/audio_emo/tess.woman.surprised.351.wav,Positive
90347,,surprise,./audio/audio_emo/ravdass.man.surprise.63.wav,Positive
90348,,surprise,./audio/audio_emo/tess.woman.surprised.26.wav,Positive
90349,,surprise,./audio/audio_emo/tess.woman.surprised.67.wav,Positive


In [29]:
df_train["sentiment_label"].value_counts()

Negative    46586
Positive    24212
Neutral     19553
Name: sentiment_label, dtype: int64

In [30]:
from sklearn.preprocessing import LabelEncoder

lab_encoder = LabelEncoder()
lab_encoder.fit(df_train['label'].unique())

lab_encoder_senti = LabelEncoder()
lab_encoder_senti.fit(df_train['sentiment_label'].unique())

In [31]:
len(df_train)

90351

In [32]:
len(test_df_erc)

3851

In [33]:
import numpy as np
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from sklearn.metrics import f1_score, accuracy_score
import gc
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
gc.enable()

In [34]:
gc.collect()

0

In [35]:
import numpy as np
import faiss

class FaissKNeighbors:
    def __init__(self, k=5):
        self.index = None
        self.y = None
        self.k = k

    def fit(self, X, y):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def predict(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]
        predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
        return predictions

In [36]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

class Scheduler(_LRScheduler):
    def __init__(self, 
                 optimizer: Optimizer,
                 dim_embed: int,
                 warmup_steps: int,
                 last_epoch: int=-1,
                 verbose: bool=False) -> None:

        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)
        
    def get_lr(self) -> float:
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups

global PREVIOUS_LR
PREVIOUS_LR = -9999
def calc_lr(step, dim_embed, warmup_steps):
    #if step > warmup_steps:
    #    return 5e-5
    global PREVIOUS_LR
    lr = dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))
    #return lr
    if lr < 2e-4:
        PREVIOUS_LR = lr
        return lr
    else:
        #lr = dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))
        while lr >= PREVIOUS_LR:
            step += 1.
            lr = dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))
        PREVIOUS_LR = lr
        return lr

In [37]:
from src.modeling.speech_encoder import *

dim_embed = 768
N_VECTORS = 512
MAX_LEN = 256

audio_encoder = AudioEncoderMFCCHU(
    N_VECTORS, 
    emb_dim=dim_embed, 
    n_layers=1, 
    max_length=MAX_LEN, 
    nheads=12,
    dropout=0.1
)

#audio_encoder = torch.load(f'/home/vmachado/Documents/c4ai_clip_audio_text/audio_encoder_best/audio_encoder.bin')
#audio_encoder.load_state_dict(torch.load(f'/home/vmachado/Documents/c4ai_clip_audio_text/audio_encoder_pre_trained_reformed_5_FIM_6_layer_continue/audio_best.bin'))

In [38]:
from src.modeling.text_encoder import *

MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
#MODEL_NAME = 'sentence-transformers/all-MiniLM-L12-v2'
text_encoder = TextEncoder(MODEL_NAME, max_len=128, extra_tokens=['[NAME]', '[RELIGION]', '[LAUGHTER]', '[BFR]', '[AFT]'])
#text_encoder.load_state_dict(torch.load('/home/vmachado/Documents/c4ai_clip_audio_text/text_encoder_only_meld/dabest_text.bin'))
#text_encoder.load_state_dict(torch.load(f'text_encoder_ready_L2_test2/best_text_encoder.bin'))
#text_encoder.load_state_dict(torch.load(f'text_encoder_ready_L2_test2/pytorch_model_AudioTextCLIP_epoch_22.bin'))

In [39]:
for param in list(text_encoder.parameters()):
    param.requires_grad = True
    
#for idx_l, l in enumerate(text_encoder.encoder.encoder.layer):
#    if idx_l >= 11:
#        for param in list(l.parameters()):
#            param.requires_grad = True

In [40]:
df_train

Unnamed: 0,text,label,path,sentiment_label
0,To make her feel threatened,fear,,Negative
1,OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe...,surprise,,Positive
2,Demographics? I don’t know anybody under 35 wh...,surprise,,Positive
3,Maybe that’s what happened to the great white ...,surprise,,Positive
4,"I never thought it was at the same moment, but...",surprise,,Positive
...,...,...,...,...
90346,,surprise,./audio/audio_emo/tess.woman.surprised.351.wav,Positive
90347,,surprise,./audio/audio_emo/ravdass.man.surprise.63.wav,Positive
90348,,surprise,./audio/audio_emo/tess.woman.surprised.26.wav,Positive
90349,,surprise,./audio/audio_emo/tess.woman.surprised.67.wav,Positive


In [41]:
list(df_train[~df_train["path"].isna()]["path"])

['/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_M000.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F001.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_M001.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F003.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_M002.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F004.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F005.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F006.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_M006.wav',
 '/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/train/Ses01M_script01_3_F

In [42]:
from src.utils.speech_processing import *

audio_tokenizer = AudioEncoderMFCCHUTokenizer(max_length=MAX_LEN, cache_path='/home/vmachado/Documents/c4ai_clip_audio_text/new_speech_features/')
#_, lens = audio_tokenizer.cache_dataset(paths=list(df_train[~df_train["path"].isna()].sample(frac=0.25, random_state=0)["path"]))
audio_tokenizer.mean = torch.Tensor([-1.1745e+03,  2.2787e+00, -1.3436e+02,  4.2281e+01, -1.5886e+02,
        -4.2503e+01, -1.6652e+02, -6.0679e+01, -1.0516e+02, -3.1940e+01,
        -9.4301e+01, -2.8196e+01, -6.6406e+01,  1.3214e-01,  2.4173e-01,
         2.1768e-01,  1.9153e-01,  1.5862e-01,  1.3224e-01,  1.1713e-01,
         1.0544e-01,  9.4641e-02,  8.5885e-02,  7.7632e-02,  7.0888e-02,
         6.3016e-02,  5.7492e-02,  5.2733e-02,  4.8752e-02,  4.4427e-02,
         4.0783e-02,  3.6457e-02,  3.3371e-02,  3.0200e-02,  2.7740e-02,
         2.4798e-02,  2.2748e-02,  2.0403e-02,  1.8943e-02,  1.7292e-02,
         1.5884e-02,  1.4463e-02,  1.3473e-02,  1.2383e-02,  1.1698e-02,
         1.0816e-02,  1.0040e-02,  9.2574e-03,  8.7220e-03,  8.0053e-03,
         7.3996e-03,  6.8473e-03,  6.4102e-03,  5.8573e-03,  5.4198e-03,
         4.9593e-03,  4.6220e-03,  4.2688e-03,  4.0196e-03,  3.7833e-03,
         3.5817e-03,  3.3637e-03,  3.2097e-03,  3.0329e-03,  2.9024e-03,
         2.7216e-03,  2.5992e-03,  2.4578e-03,  2.3219e-03,  2.2085e-03,
         2.1044e-03,  1.9795e-03,  1.8923e-03,  1.7711e-03,  1.6752e-03,
         1.5630e-03,  1.5008e-03,  1.4268e-03,  1.3523e-03,  1.2771e-03,
         1.2057e-03,  1.1288e-03,  1.0627e-03,  1.0068e-03,  9.3696e-04,
         8.8309e-04,  8.2027e-04,  7.7626e-04,  7.2621e-04,  6.8244e-04,
         6.2516e-04,  5.7361e-04,  5.4662e-04,  5.0724e-04,  4.7727e-04,
         4.4075e-04,  4.0421e-04,  3.6961e-04,  3.3435e-04,  3.0521e-04,
         2.8062e-04,  2.6401e-04,  2.3755e-04,  2.1533e-04,  1.9300e-04,
         1.7130e-04,  1.5586e-04,  1.4436e-04,  1.2694e-04,  1.1913e-04,
         1.0666e-04,  9.7559e-05,  8.8004e-05,  8.0929e-05,  7.5713e-05,
         6.5684e-05,  6.2502e-05,  5.5034e-05,  4.8409e-05,  4.3191e-05,
         3.9435e-05,  3.5874e-05,  3.3941e-05,  3.0259e-05,  2.7158e-05,
         2.3242e-05,  1.9884e-05,  1.6042e-05,  8.2821e-02,  2.2356e-01,
        -2.8710e-01,  7.9194e-02, -2.2688e-01, -4.4583e-03, -3.1010e-01,
        -6.7791e-02, -1.5429e-01, -7.0310e-02, -2.2113e-01, -4.9786e-02,
        -1.5399e-01,  2.5206e-04,  4.2298e-04,  4.2356e-04,  3.6397e-04,
         3.1303e-04,  2.6955e-04,  2.2263e-04,  2.2512e-04,  2.0992e-04,
         1.8997e-04,  1.7311e-04,  1.4846e-04,  1.2123e-04,  1.1412e-04,
         1.0339e-04,  1.1267e-04,  1.0309e-04,  8.4783e-05,  7.3514e-05,
         7.2023e-05,  6.4760e-05,  6.4285e-05,  5.7985e-05,  5.2177e-05,
         4.6418e-05,  4.7871e-05,  4.3763e-05,  4.0681e-05,  3.7222e-05,
         3.4295e-05,  3.1561e-05,  2.9967e-05,  2.7680e-05,  2.7623e-05,
         2.6451e-05,  2.6466e-05,  2.2476e-05,  2.0828e-05,  1.9829e-05,
         1.7483e-05,  1.6071e-05,  1.6325e-05,  1.3802e-05,  1.4725e-05,
         1.4710e-05,  1.3431e-05,  1.1905e-05,  1.1559e-05,  1.1739e-05,
         1.1275e-05,  1.1417e-05,  1.2092e-05,  1.2754e-05,  1.2024e-05,
         1.0824e-05,  1.0315e-05,  9.1208e-06,  8.8483e-06,  7.4808e-06,
         7.2708e-06,  8.2807e-06,  7.3811e-06,  7.4967e-06,  6.9180e-06,
         7.2903e-06,  6.8355e-06,  5.9609e-06,  5.7615e-06,  5.5132e-06,
         5.2766e-06,  4.7307e-06,  4.5198e-06,  3.9687e-06,  3.8545e-06,
         3.4617e-06,  3.2478e-06,  2.8526e-06,  3.0319e-06,  2.5946e-06,
         2.5531e-06,  2.3268e-06,  2.1778e-06,  2.0740e-06,  1.9709e-06,
         1.7363e-06,  1.3544e-06,  1.4028e-06,  1.3077e-06,  1.1276e-06,
         1.2016e-06,  1.1983e-06,  1.0208e-06,  8.1538e-07,  7.8428e-07,
         6.6318e-07,  7.0989e-07,  9.2021e-07,  7.2013e-07,  8.1444e-07,
         8.0237e-07,  9.0007e-07,  6.5921e-07,  6.9934e-07,  3.8494e-07,
         4.5313e-07,  4.1900e-07,  3.3735e-07,  2.8976e-07,  3.8101e-07,
         2.2933e-07,  2.1375e-07,  1.9388e-07,  1.4123e-07,  6.7112e-08,
         5.7247e-08, -4.3037e-01,  6.1246e-02,  5.4988e-02, -3.4926e-02,
         1.2448e-01, -1.3948e-03,  1.2025e-01,  4.4583e-02,  9.7096e-02,
         1.3381e-02,  8.2258e-02,  2.6150e-02,  7.2660e-02, -1.4877e-04,
        -2.4823e-04, -2.3552e-04, -2.2068e-04, -1.6322e-04, -1.2564e-04,
        -1.1424e-04, -8.9999e-05, -8.5837e-05, -7.8019e-05, -6.8423e-05,
        -6.0470e-05, -5.6134e-05, -5.3471e-05, -4.9815e-05, -4.4360e-05,
        -4.5985e-05, -4.3064e-05, -3.7208e-05, -3.6196e-05, -3.3151e-05,
        -3.1243e-05, -2.7004e-05, -2.4864e-05, -2.1479e-05, -1.8415e-05,
        -1.7921e-05, -1.8757e-05, -1.6076e-05, -1.5099e-05, -1.4192e-05,
        -1.2519e-05, -1.1340e-05, -1.0300e-05, -9.5987e-06, -8.8470e-06,
        -9.0991e-06, -7.4181e-06, -6.4367e-06, -6.6805e-06, -5.6950e-06,
        -5.2952e-06, -5.5317e-06, -4.5829e-06, -4.7052e-06, -4.0624e-06,
        -4.2542e-06, -4.2432e-06, -4.1813e-06, -3.5610e-06, -3.1162e-06,
        -2.8147e-06, -2.4557e-06, -2.2899e-06, -2.6685e-06, -2.3288e-06,
        -2.4319e-06, -2.8676e-06, -2.4296e-06, -2.6144e-06, -1.7675e-06,
        -1.7452e-06, -1.6275e-06, -1.6042e-06, -1.3303e-06, -1.4289e-06,
        -1.2074e-06, -1.0202e-06, -9.7020e-07, -8.5603e-07, -8.7458e-07,
        -6.9110e-07, -9.6283e-07, -8.3394e-07, -9.8441e-07, -7.9333e-07,
        -8.3520e-07, -5.1605e-07, -5.8250e-07, -5.1902e-07, -4.7480e-07,
        -5.2017e-07, -5.0341e-07, -4.9701e-07, -5.4298e-07, -4.4246e-07,
        -4.0207e-07, -4.8091e-07, -3.7359e-07, -3.5006e-07, -3.5082e-07,
        -3.1156e-07, -4.0896e-07, -3.3702e-07, -3.5667e-07, -2.9841e-07,
        -2.1333e-07, -2.0255e-07, -1.9017e-07, -1.8474e-07, -1.3893e-07,
        -2.5546e-07, -2.6200e-07, -2.4044e-07, -1.5338e-07, -1.5324e-07,
        -1.0576e-07, -9.9442e-08, -3.9414e-08, -2.2929e-07, -1.4225e-07,
        -7.3033e-08, -1.2690e-07, -6.8328e-08, -4.1665e-08]).float()
audio_tokenizer.std = torch.Tensor([3.4610e+02, 2.0329e+02, 2.1766e+02, 1.8953e+02, 1.8574e+02, 1.7331e+02,
        1.6651e+02, 1.6431e+02, 1.5596e+02, 1.5770e+02, 1.4123e+02, 1.3759e+02,
        1.2521e+02, 5.8905e-01, 1.1134e+00, 1.0196e+00, 9.3499e-01, 8.5661e-01,
        7.3274e-01, 6.8508e-01, 6.2097e-01, 5.5384e-01, 5.0694e-01, 4.6303e-01,
        4.2669e-01, 3.6475e-01, 3.2250e-01, 2.9689e-01, 2.8181e-01, 2.5559e-01,
        2.3403e-01, 1.9857e-01, 1.8121e-01, 1.6036e-01, 1.4713e-01, 1.3092e-01,
        1.2421e-01, 1.1067e-01, 1.0844e-01, 1.0081e-01, 9.2958e-02, 8.7006e-02,
        8.2182e-02, 7.5746e-02, 7.6861e-02, 7.3703e-02, 6.8297e-02, 6.3531e-02,
        6.1089e-02, 5.8714e-02, 5.4342e-02, 5.0260e-02, 4.8192e-02, 4.3264e-02,
        4.1530e-02, 3.8387e-02, 3.7283e-02, 3.3814e-02, 3.2691e-02, 3.0986e-02,
        3.0893e-02, 2.9737e-02, 2.8810e-02, 2.7760e-02, 2.7588e-02, 2.6344e-02,
        2.5541e-02, 2.5534e-02, 2.4065e-02, 2.2757e-02, 2.2210e-02, 2.2098e-02,
        2.2358e-02, 2.1017e-02, 2.0680e-02, 1.8603e-02, 1.9288e-02, 1.8966e-02,
        1.9133e-02, 1.8534e-02, 1.7927e-02, 1.6336e-02, 1.7216e-02, 1.6688e-02,
        1.5482e-02, 1.5535e-02, 1.5110e-02, 1.4006e-02, 1.4007e-02, 1.2866e-02,
        1.2434e-02, 1.1218e-02, 1.1144e-02, 1.1172e-02, 1.0681e-02, 1.0664e-02,
        9.7962e-03, 9.7915e-03, 9.7845e-03, 8.3338e-03, 8.5247e-03, 8.6266e-03,
        7.9188e-03, 7.9015e-03, 7.1705e-03, 6.4327e-03, 6.3670e-03, 6.0735e-03,
        5.9712e-03, 5.6508e-03, 5.5951e-03, 6.4811e-03, 5.1814e-03, 4.6915e-03,
        4.7504e-03, 5.2872e-03, 4.9687e-03, 4.5008e-03, 4.2823e-03, 3.7788e-03,
        3.3177e-03, 3.5372e-03, 3.4401e-03, 5.9168e-03, 3.4877e-03, 3.2641e-03,
        2.5756e-03, 1.9148e-03, 6.0942e+01, 5.7679e+01, 5.8302e+01, 5.1513e+01,
        4.9517e+01, 4.6907e+01, 4.5913e+01, 4.4911e+01, 4.3823e+01, 4.3580e+01,
        4.0062e+01, 3.9248e+01, 3.6391e+01, 1.6672e-01, 3.2986e-01, 3.1402e-01,
        2.9643e-01, 2.7041e-01, 2.3492e-01, 2.1728e-01, 1.9745e-01, 1.7298e-01,
        1.5989e-01, 1.4756e-01, 1.3580e-01, 1.1643e-01, 1.0258e-01, 9.4260e-02,
        8.9818e-02, 8.1035e-02, 7.4296e-02, 6.3058e-02, 5.7382e-02, 5.0741e-02,
        4.6545e-02, 4.1249e-02, 3.9252e-02, 3.4722e-02, 3.4560e-02, 3.1622e-02,
        2.9351e-02, 2.7455e-02, 2.5706e-02, 2.3595e-02, 2.4021e-02, 2.2987e-02,
        2.1176e-02, 1.9641e-02, 1.9064e-02, 1.8404e-02, 1.7067e-02, 1.5674e-02,
        1.5120e-02, 1.3493e-02, 1.2888e-02, 1.1756e-02, 1.1533e-02, 1.0384e-02,
        1.0110e-02, 9.5833e-03, 9.5573e-03, 9.2261e-03, 8.8668e-03, 8.5919e-03,
        8.5641e-03, 8.1607e-03, 7.8424e-03, 7.9101e-03, 7.4254e-03, 6.9654e-03,
        6.7802e-03, 6.8084e-03, 6.8868e-03, 6.4047e-03, 6.4830e-03, 5.7077e-03,
        6.1219e-03, 5.9579e-03, 6.0030e-03, 5.8392e-03, 5.7222e-03, 5.1279e-03,
        5.4923e-03, 5.2673e-03, 4.9187e-03, 5.0315e-03, 4.8391e-03, 4.4943e-03,
        4.4117e-03, 4.0835e-03, 3.9393e-03, 3.6099e-03, 3.5171e-03, 3.5700e-03,
        3.3884e-03, 3.3850e-03, 3.1730e-03, 3.1304e-03, 3.1395e-03, 2.6004e-03,
        2.6340e-03, 2.7335e-03, 2.5415e-03, 2.5127e-03, 2.2553e-03, 2.0549e-03,
        2.0522e-03, 1.9179e-03, 1.9158e-03, 1.8150e-03, 1.8109e-03, 2.0503e-03,
        1.6457e-03, 1.5092e-03, 1.4945e-03, 1.6981e-03, 1.5885e-03, 1.4388e-03,
        1.3869e-03, 1.2147e-03, 1.1003e-03, 1.1416e-03, 1.1356e-03, 1.8872e-03,
        1.1271e-03, 1.0448e-03, 8.3055e-04, 6.2078e-04, 2.6353e+01, 2.6094e+01,
        2.6333e+01, 2.2768e+01, 2.1976e+01, 2.0992e+01, 2.0526e+01, 2.0185e+01,
        1.9653e+01, 1.9564e+01, 1.7978e+01, 1.7722e+01, 1.6401e+01, 7.4204e-02,
        1.4916e-01, 1.4147e-01, 1.3468e-01, 1.2219e-01, 1.0706e-01, 9.8491e-02,
        9.0003e-02, 7.8678e-02, 7.2570e-02, 6.6921e-02, 6.1441e-02, 5.2714e-02,
        4.6419e-02, 4.2577e-02, 4.0538e-02, 3.6528e-02, 3.3425e-02, 2.8382e-02,
        2.5820e-02, 2.2848e-02, 2.0999e-02, 1.8596e-02, 1.7718e-02, 1.5656e-02,
        1.5615e-02, 1.4243e-02, 1.3235e-02, 1.2402e-02, 1.1555e-02, 1.0582e-02,
        1.0776e-02, 1.0290e-02, 9.4476e-03, 8.7682e-03, 8.5459e-03, 8.2498e-03,
        7.6656e-03, 7.0288e-03, 6.7968e-03, 6.0587e-03, 5.7761e-03, 5.2476e-03,
        5.1564e-03, 4.6445e-03, 4.5241e-03, 4.2959e-03, 4.2793e-03, 4.1341e-03,
        3.9701e-03, 3.8546e-03, 3.8368e-03, 3.6334e-03, 3.4932e-03, 3.5462e-03,
        3.3290e-03, 3.1136e-03, 3.0344e-03, 3.0523e-03, 3.0798e-03, 2.8599e-03,
        2.9140e-03, 2.5425e-03, 2.7643e-03, 2.6815e-03, 2.6966e-03, 2.6264e-03,
        2.5839e-03, 2.2999e-03, 2.4701e-03, 2.3590e-03, 2.2112e-03, 2.2771e-03,
        2.1826e-03, 2.0318e-03, 1.9751e-03, 1.8369e-03, 1.7660e-03, 1.6328e-03,
        1.5771e-03, 1.6097e-03, 1.5245e-03, 1.5191e-03, 1.4376e-03, 1.4123e-03,
        1.4145e-03, 1.1608e-03, 1.1681e-03, 1.2211e-03, 1.1449e-03, 1.1249e-03,
        1.0067e-03, 9.1852e-04, 9.2093e-04, 8.5299e-04, 8.5505e-04, 8.1404e-04,
        8.1366e-04, 9.1140e-04, 7.3335e-04, 6.7294e-04, 6.6730e-04, 7.5947e-04,
        7.1015e-04, 6.4234e-04, 6.2060e-04, 5.4270e-04, 4.9665e-04, 5.1080e-04,
        5.1578e-04, 8.4158e-04, 5.0561e-04, 4.6750e-04, 3.7118e-04, 2.7763e-04])

In [43]:
len(df_train[~df_train["path"].isna()]["path"].unique()) + len(test_df_erc[~test_df_erc["path"].isna()]["path"].unique())

36412

In [44]:
gc.collect()

9

In [45]:
from src.modeling.losses import *

In [46]:
from src.modeling.mm_contrast import *

In [47]:
from tqdm import tqdm

In [48]:
pre_train_text = False
pre_train_audio = False

## Text Pretraining

In [49]:
if pre_train_text:
    train_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(df_train)))))
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True)

    test_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(test_df_erc)))))
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1024, shuffle=False)

    MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
    text_encoder = TextEncoder(MODEL_NAME, max_len=MAX_LEN, extra_tokens=['[NAME]', '[RELIGION]', '[LAUGHTER]', '[BFR]', '[AFT]'])

    PATH_TO_SAVE = f'text_encoder_pre_trained_{MODEL_NAME}'
    !mkdir -p {PATH_TO_SAVE}
    supcon_model = AudioTextContrastive(
        text_encoder,
        audio_encoder,
        in_features_text=768,
        in_features_audio=dim_embed, 
        hidden_size=768,
        wide_proj=1024,
        proj_size=128, 
        rate=0.1,
    )

    # Grid search best temperatures
    # Try to only fine tune on evaluation datasets
    #supcon_model.load_state_dict(torch.load(f'ESTAMOS_PERTO_AMIGO_ESTOU_AQUI_4_freezed_4_layer/pytorch_model_AudioTextCLIP_epoch_9.bin')['model'])

    supcon_model.to(0)

    scaler = torch.cuda.amp.GradScaler()

    step = 0
    e = 0
    patience = 9999
    early_stop_flag = 0
    old_f1 = -float('inf')

    param_optimizer = list(supcon_model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.1
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]

    scheduler_epochs = 5
    opt = torch.optim.AdamW(optimizer_grouped_parameters, lr=2e-5, betas=(0.9, 0.98), eps=1e-8)
    scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.5, end_factor=0.9, total_iters=10, last_epoch=- 1, verbose=False)
    #scheduler = Scheduler(opt, 768, 600)

    epochs = 9999

    while e < epochs:
        supcon_model.train()
        epoch_loss = 0.0
        proj_val = []
        targets_val = []

        proj_train = []
        targets_train = []

        for i, batch_indices in enumerate(tqdm(train_loader, total=len(train_loader))):
            if i == len(train_loader)-1:
                continue
            batch = df_train.iloc[batch_indices[0]]
            batch = batch[~batch["text"].isna()].reset_index(drop=True)

            batch_lab_idx = batch[batch["label"].notna()].index
            sentences = batch["text"].tolist()

            y_text = torch.Tensor(lab_encoder.transform(batch.iloc[batch_lab_idx]["label"]))
            y_text_senti = torch.Tensor(lab_encoder_senti.transform(batch["sentiment_label"]))

            # Augment Text Context
            #for i_s, s in enumerate(sentences):
            #    if "[CTXE]" in s.split(' '):
            #        if np.random.rand() < 0.5:
            #            sentences[i_s] = sentences[i_s].split("[CTXE]")[1]

            for k, s in enumerate(sentences):
                if '[BFR]' not in s and '[AFT]' not in s:
                    continue
                p = np.random.rand()
                if p < 0.25 and '[BFR]' in s:
                    sentences[k] = sentences[k].split('[BFR]')[1]

                p = np.random.rand()
                if p < 0.25 and '[AFT]' in s:
                    sentences[k] = sentences[k].split('[AFT]')[0]

            target = y_text.long().cuda()
            target_senti = y_text_senti.long().cuda()

            x = [sentences, None, None]

            with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:

                out = supcon_model(x)

                # Multimodal loss
                out_x = out["x_text"]
                out_x_lab = out_x[batch_lab_idx]
                out_x_wide = out["x_text_wide"][batch_lab_idx]

                loss = 0.8 * sup_contrastive_loss(out_x_lab, target, temperature=0.1) + 0.2 * sup_contrastive_loss(out_x, target_senti, temperature=0.1)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)

            scaler.step(opt)
            scaler.update()
            #scheduler.step()

            opt.zero_grad(set_to_none=True)

            epoch_loss += loss.item()
            proj_train.append(np.array(out_x_wide.detach().cpu()))
            targets_train.append(np.array(target.cpu()))

            del out_x
            del out
            del out_x_wide
            gc.collect()
            torch.cuda.empty_cache()

        scheduler.step()
        proj_train = np.concatenate(proj_train, axis=0)
        targets_train = np.concatenate(targets_train, axis=0)

        clf = FaissKNeighbors(k=128)
        clf.fit(proj_train, np.array(targets_train, dtype=int))

        epoch_loss = epoch_loss/len(train_loader)
        #supcon_model.eval()
        preds = []
        targets = []
        css = 0.0
        wide_audio = []

        for i, batch_indices in enumerate(tqdm(test_loader, total=len(test_loader))):
            with torch.no_grad():

                multimodal_batch = test_df_erc.iloc[batch_indices[0]]

                sentences = [str(t['text']) for _, t in multimodal_batch.iterrows()]

                target = torch.Tensor(lab_encoder.transform(list(multimodal_batch["label"])))

                x = [sentences, None, None]
                with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
                    out = supcon_model(x)

                # Multimodal loss
                out_x_wide = out["x_text_wide"]

                wide = np.array(out_x_wide.cpu())
                pred = clf.predict(wide)
                preds.append(pred)

                assert len(wide) == len(pred)

                proj_val.append(wide)
                targets_val.append(np.array(target.cpu()))
                del out_x_wide
                gc.collect()
                torch.cuda.empty_cache()

        proj_val = np.concatenate(proj_val, axis=0)
        targets_val = np.concatenate(targets_val, axis=0)

        preds = np.array(np.concatenate(preds, axis=0))

        general_f1 = f1_score(targets_val, preds, average='weighted')
        general_acc = accuracy_score(targets_val, preds)

        meld_idx = test_df_erc[test_df_erc["source"] == "meld"].index
        iemocap_idx = test_df_erc[test_df_erc["source"] != "meld"].index

        general_f1_iemocap = f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='weighted')
        general_acc_iemocap = accuracy_score(targets_val[iemocap_idx], preds[iemocap_idx])

        general_f1_meld = f1_score(targets_val[meld_idx], preds[meld_idx], average='weighted')
        general_acc_meld = accuracy_score(targets_val[meld_idx], preds[meld_idx])

        print(f'General - KNN F1: {general_f1} Acc: {general_acc}')
        print(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
        print(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
        print(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
        print(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

        try:
            tsne = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=5).fit_transform(proj_val)

            sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=lab_encoder.inverse_transform(list(np.array(targets_val, dtype=int))) , palette='tab10')
            plt.show()

        except:
            pass

        print(f'Epoch: {e + 1} - Train Loss: {epoch_loss}')
        e += 1

        #if e == scheduler_epochs: # Unfreeze text encoder
        #    for i, (name, param) in enumerate(list(supcon_model.text_encoder.named_parameters())):
        #        param.requires_grad = True

        with open(f"{PATH_TO_SAVE}/metrics_epoch_{e}.txt", "w") as f:
            f.write(f'General - KNN F1: {general_f1} Acc: {general_acc}')
            f.write(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
            f.write(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
            f.write(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
            f.write(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
            f.write(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

        checkpoint = {"model": supcon_model.state_dict(),
                  "optimizer": opt.state_dict(),
                  "scaler": scaler.state_dict()}
        torch.save(checkpoint, f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_{e}.bin')

In [50]:
if pre_train_text:
    supcon_model = AudioTextContrastive(
        text_encoder,
        audio_encoder,
        in_features_text=768,
        in_features_audio=dim_embed, 
        hidden_size=768,
        wide_proj=1024,
        proj_size=128, 
        rate=0.2,
    )
    supcon_model.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_17.bin')['model'])
    torch.save(supcon_model.text_encoder.state_dict(), f'{PATH_TO_SAVE}/dabest_text_encoder.bin')


## Audio PreTrain

In [51]:
gc.collect()

0

In [52]:
if pre_train_audio:
    train_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(df_train[df_train["path"].notna()])))))
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=768, shuffle=True)

    test_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(test_df_erc)))))
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1024, shuffle=False)

    PATH_TO_SAVE = 'audio_encoder_pre_trained_1_layer'
    !mkdir -p {PATH_TO_SAVE}
    supcon_model = AudioTextContrastive(
        None,
        audio_encoder,
        in_features_text=384,
        in_features_audio=dim_embed, 
        hidden_size=768,
        wide_proj=1024,
        proj_size=128, 
        rate=0.1,
    )

    # Grid search best temperatures
    # Try to only fine tune on evaluation datasets

    supcon_model.to(0)

    scaler = torch.cuda.amp.GradScaler()

    step = 0
    e = 0
    patience = 9999
    early_stop_flag = 0
    old_f1 = -float('inf')

    param_optimizer = list(supcon_model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.1
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]

    scheduler_epochs = 5
    opt = torch.optim.AdamW(optimizer_grouped_parameters, lr=0, betas=(0.9, 0.98), eps=1e-8)
    #scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.5, end_factor=0.9, total_iters=10, last_epoch=- 1, verbose=False)
    scheduler = Scheduler(opt, dim_embed, 600)
    #for i in range(51 * 177):
    #    scheduler.step()
    
    #checkpoint = {"model": supcon_model.state_dict(),
    #          "optimizer": opt.state_dict(),
    #          "scaler": scaler.state_dict()}
    #supcon_model.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_51.bin')['model'])
    #opt.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_51.bin')['optimizer'])
    #scaler.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_51.bin')['scaler'])
    
    epochs = 9999

    while e < epochs:
        supcon_model.train()
        epoch_loss = 0.0
        proj_val = []
        targets_val = []

        proj_train = []
        targets_train = []

        for i, batch_indices in enumerate(tqdm(train_loader, total=len(train_loader))):
            if i == len(train_loader)-1:
                continue
            batch = df_train[df_train["path"].notna()].reset_index(drop=True).iloc[batch_indices[0]]
            
            only_audio = batch[batch["path"].notna()].reset_index(drop=True)
            only_audio_lab_idx = only_audio[only_audio["label"].notna()].index
            
            audio_paths = only_audio["path"].tolist()
            
            mfccs, att = audio_tokenizer.batch_tokenize(audio_paths)

            audio_input = {
                "features": mfccs.float().to(0),
                "attn_masks": att.float().to(0),
            }

            y_audio = torch.Tensor(lab_encoder.transform(only_audio.iloc[only_audio_lab_idx]["label"]))
            y_audio_senti = torch.Tensor(lab_encoder_senti.transform(only_audio["sentiment_label"]))

            target = y_audio.long().cuda()
            target_senti = y_audio_senti.long().cuda()

            x = [None, audio_input, None]

            with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:

                out = supcon_model(x)

                # Multimodal loss
                out_x = out["x_audio"]
                out_x_lab = out_x[only_audio_lab_idx]
                out_x_wide = out["x_audio_wide"][only_audio_lab_idx]

                loss = 0.8 * sup_contrastive_loss(out_x_lab, target, temperature=0.1) + 0.2 * sup_contrastive_loss(out_x, target_senti, temperature=0.1)

            scaler.scale(loss).backward()
            scaler.unscale_(opt)

            scaler.step(opt)
            scaler.update()
            scheduler.step()

            opt.zero_grad(set_to_none=True)

            epoch_loss += loss.item()
            proj_train.append(np.array(out_x_wide.detach().cpu()))
            targets_train.append(np.array(target.cpu()))

            del out_x
            del out
            del out_x_wide
            gc.collect()
            torch.cuda.empty_cache()
            
        proj_train = np.concatenate(proj_train, axis=0)
        targets_train = np.concatenate(targets_train, axis=0)

        clf = FaissKNeighbors(k=128)
        clf.fit(proj_train, np.array(targets_train, dtype=int))

        epoch_loss = epoch_loss/len(train_loader)
        #supcon_model.eval()
        preds = []
        targets = []

        for i, batch_indices in enumerate(tqdm(test_loader, total=len(test_loader))):
            with torch.no_grad():

                multimodal_batch = test_df_erc.iloc[batch_indices[0]]

                audio_path_mult = [str(t['path']) for _, t in multimodal_batch.iterrows()]
                mfccs_mult, att_mult = audio_tokenizer.batch_tokenize(audio_path_mult)

                audio_input = {"features": mfccs_mult.float().to(0), "attn_masks": att_mult.float().to(0)}

                target = torch.Tensor(lab_encoder.transform(list(multimodal_batch["label"])))

                x = [None, audio_input, None]
                with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
                    out = supcon_model(x)

                # Multimodal loss
                out_x_wide = out["x_audio_wide"]

                wide = np.array(out_x_wide.cpu())
                pred = clf.predict(wide)
                preds.append(pred)

                assert len(wide) == len(pred)

                proj_val.append(wide)
                targets_val.append(np.array(target.cpu()))
                del out_x_wide
                del out
                gc.collect()
                torch.cuda.empty_cache()

        proj_val = np.concatenate(proj_val, axis=0)
        targets_val = np.concatenate(targets_val, axis=0)

        preds = np.array(np.concatenate(preds, axis=0))

        general_f1 = f1_score(targets_val, preds, average='weighted')
        general_acc = accuracy_score(targets_val, preds)

        meld_idx = test_df_erc[test_df_erc["source"] == "meld"].index
        iemocap_idx = test_df_erc[test_df_erc["source"] != "meld"].index

        general_f1_iemocap = f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='weighted')
        general_acc_iemocap = accuracy_score(targets_val[iemocap_idx], preds[iemocap_idx])

        general_f1_meld = f1_score(targets_val[meld_idx], preds[meld_idx], average='weighted')
        general_acc_meld = accuracy_score(targets_val[meld_idx], preds[meld_idx])

        print(f'General - KNN F1: {general_f1} Acc: {general_acc}')
        print(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
        print(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
        print(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
        print(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

        try:
            tsne = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=5).fit_transform(proj_val)

            sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=lab_encoder.inverse_transform(list(np.array(targets_val, dtype=int))) , palette='tab10')
            plt.show()

        except:
            pass

        print(f'Epoch: {e + 1} - Train Loss: {epoch_loss}')
        e += 1

        #if e == scheduler_epochs: # Unfreeze text encoder
        #    for i, (name, param) in enumerate(list(supcon_model.text_encoder.named_parameters())):
        #        param.requires_grad = True

        with open(f"{PATH_TO_SAVE}/metrics_epoch_{e}.txt", "w") as f:
            f.write(f'General - KNN F1: {general_f1} Acc: {general_acc}')
            f.write(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
            f.write(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
            f.write(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
            f.write(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

        checkpoint = {"model": supcon_model.state_dict(),
                  "optimizer": opt.state_dict(),
                  "scaler": scaler.state_dict()}
        torch.save(checkpoint, f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_{e}.bin')

In [53]:
if pre_train_audio:
    supcon_model = AudioTextContrastive(
        None,
        audio_encoder,
        in_features_text=384,
        in_features_audio=dim_embed, 
        hidden_size=768,
        wide_proj=1024,
        proj_size=128, 
        rate=0.1,
    )
    supcon_model.load_state_dict(torch.load(f'audio_encoder_pre_trained_1_layer/pytorch_model_AudioTextCLIP_epoch_50.bin')['model'])
    torch.save(supcon_model.audio_encoder.state_dict(), f'audio_encoder_pre_trained_1_layer/dabest_text_encoder.bin')

In [54]:
dim_embed = 768
N_VECTORS = 512
MAX_LEN = 256

audio_encoder = AudioEncoderMFCCHU(
    N_VECTORS, 
    emb_dim=dim_embed, 
    n_layers=1, 
    max_length=MAX_LEN, 
    nheads=12,
    dropout=0.1
)
audio_encoder.load_state_dict(torch.load(f'audio_encoder_pre_trained_1_layer/dabest_text_encoder.bin'))

<All keys matched successfully>

In [55]:
MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'
text_encoder = TextEncoder(MODEL_NAME, max_len=128, extra_tokens=['[NAME]', '[RELIGION]', '[LAUGHTER]', '[BFR]', '[AFT]'])
text_encoder.load_state_dict(torch.load(f'text_encoder_pre_trained_{MODEL_NAME}/dabest_text_encoder.bin'))

<All keys matched successfully>

In [56]:
for param in text_encoder.parameters():
    param.requires_grad = False
for param in text_encoder.encoder.encoder.layer[11].parameters():
    param.requires_grad = True
for param in text_encoder.encoder.pooler.parameters():
    param.requires_grad = True
for param in audio_encoder.parameters():
    param.requires_grad = True

In [None]:
train_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(df_train)))))
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=1024, shuffle=True)

test_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(test_df_erc)))))
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1024, shuffle=False)

PATH_TO_SAVE = 'ESTAMOS_PERTO_AMIGO_ESTOU_AQUI_4_freezed_11'
!mkdir -p {PATH_TO_SAVE}
supcon_model = AudioTextContrastive(
    text_encoder,
    audio_encoder,
    in_features_text=768,
    in_features_audio=dim_embed, 
    hidden_size=768,
    wide_proj=1024,
    proj_size=128, 
    rate=0.1,
)

# Grid search best temperatures
# Try to only fine tune on evaluation datasets
#supcon_model.load_state_dict(torch.load(f'ESTAMOS_PERTO_AMIGO_ESTOU_AQUI_4_freezed_4_layer/pytorch_model_AudioTextCLIP_epoch_9.bin')['model'])

supcon_model.to(0)

scaler = torch.cuda.amp.GradScaler()

step = 0
e = 0
patience = 9999
early_stop_flag = 0
old_f1 = -float('inf')

param_optimizer = list(supcon_model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [{
    'params':
    [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
    'weight_decay_rate':
    0.1
}, {
    'params':
    [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
    'weight_decay_rate':
    0.0
}]

scheduler_epochs = 5
opt = torch.optim.AdamW(optimizer_grouped_parameters, lr=2e-5, betas=(0.9, 0.98), eps=1e-8)
scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=0.5, end_factor=0.9, total_iters=10, last_epoch=- 1, verbose=False)
#scheduler = Scheduler(opt, 768, 600)

epochs = 9999

while e < epochs:
    supcon_model.train()
    epoch_loss = 0.0
    proj_val = []
    targets_val = []

    proj_train = []
    targets_train = []

    for i, batch_indices in enumerate(tqdm(train_loader, total=len(train_loader))):
        if i == len(train_loader)-1:
            continue
        batch = df_train.iloc[batch_indices[0]]
        only_text = batch[batch["path"].isna()]
        sentences = only_text["text"].tolist()
        y_text = torch.Tensor(lab_encoder.transform(only_text["label"]))
        y_text_senti = torch.Tensor(lab_encoder_senti.transform(only_text["sentiment_label"]))
        
        only_audio = batch[batch["text"].isna()]
        audio_paths = only_audio["path"].tolist()

        mfccs, att = audio_tokenizer.batch_tokenize(audio_paths)

        audio_input = {
            "features": mfccs.float().to(0),
            "attn_masks": att.float().to(0),
        }

        y_audio = torch.Tensor(lab_encoder.transform(only_audio["label"]))
        y_audio_senti = torch.Tensor(lab_encoder_senti.transform(only_audio["sentiment_label"]))
        
        mult = batch[batch["text"].notna()]
        mult = mult[mult["path"].notna()].reset_index(drop=True)
        
        mult_not_na_idx = mult[mult["label"].notna()].index
        #batch_not_na_idx = batch[batch["label"].notna()].index
        #mult_na_idx = mult[mult["label"].isna()].index
        
        y_mult = torch.Tensor(lab_encoder.transform(mult.iloc[mult_not_na_idx]["label"]))
        
        y_mult_senti = torch.Tensor(lab_encoder_senti.transform(mult["sentiment_label"]))
        
        audio_path_mult = [str(t['path']) for _, t in mult.iterrows()]
        
        mfccs_mult, att_mult = audio_tokenizer.batch_tokenize(audio_path_mult)
        
        # Augment Text Context
        sentences_mult = [str(t['text']) for _, t in mult.iterrows()]
        
        multimodal = {'sentences': sentences_mult, 
                      'audio_input': {"features": mfccs_mult.float().to(0), "attn_masks": att_mult.float().to(0)}}
        
        target = torch.cat([y_text, y_audio, y_mult]).long().cuda()
        target_senti = torch.cat([y_text_senti, y_audio_senti, y_mult_senti]).long().cuda()
        
        x = [sentences, audio_input, multimodal]
        
        if len(sentences) == 0:
            x[0] = None
        if len(audio_paths) == 0:
            x[1] = None
        if len(sentences_mult) == 0:
            x[2] = None

        with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
            
            out = supcon_model(x)
            
            # Multimodal loss
            if x[-1] is not None:
                x_mult_text = out["x_mult_text"]
                x_mult_audio = out["x_mult_audio"]
                x_mult_text_norm = F.normalize(x_mult_text, dim=-1)
                x_mult_audio_norm = F.normalize(x_mult_audio, dim=-1)
                
                #weights = torch.stack([torch.norm(x_mult_audio, dim=-1), torch.norm(x_mult_text, dim=-1)], dim=0).detach()
                weights = torch.cat([torch.norm(x_mult_audio, dim=-1), torch.norm(x_mult_text, dim=-1)], dim=0).detach()
                #weights = torch.ones_like(weights)
                
                # Augument modality
                augs = random.choices(
                    population=[0, 1, 2],
                    weights=[0.8, 0.1, 0.1],
                    k=len(x_mult_text)
                )
                
                x_mult = torch.stack([F.normalize(x_mult_text + x_mult_audio, dim=-1), 
                                      x_mult_text_norm, 
                                     x_mult_audio_norm], dim=1)
                
                x_mult = x_mult[list(range(len(augs))), augs, :] 
                
                x_mult_wide = F.normalize(out["x_mult_text_wide"] + out["x_mult_audio_wide"], dim=-1)
                
                # Add weighted contrastive loss
                #x_mult_text = x_mult_text_norm.unsqueeze(dim=1)
                #x_mult_audio = x_mult_audio_norm.unsqueeze(dim=1)
                #mult = torch.cat([x_mult_text, x_mult_audio], dim=1)
                
                out_x, out_x_wide = None, None
                
                if x[0] is not None:
                    if x[1] is not None:
                        out_x = torch.cat([out["x_text"], out["x_audio"], x_mult], dim=0) #.unsqueeze(dim=1)
                        out_x_lab = torch.cat([out["x_text"], out["x_audio"], x_mult[mult_not_na_idx]], dim=0) #.unsqueeze(dim=1)
                        out_x_wide = torch.cat([out["x_text_wide"], out["x_audio_wide"], x_mult_wide[mult_not_na_idx]], dim=0)
                    else:
                        out_x = torch.cat([out["x_text"], x_mult], dim=0) #.unsqueeze(dim=1)
                        out_x_lab = torch.cat([out["x_text"], x_mult[mult_not_na_idx]], dim=0) #.unsqueeze(dim=1)
                        out_x_wide = torch.cat([out["x_text_wide"], x_mult_wide[mult_not_na_idx]], dim=0)
                elif x[1] is not None:
                    out_x = torch.cat([out["x_audio"], x_mult], dim=0) #.unsqueeze(dim=1)
                    out_x_lab = torch.cat([out["x_audio"], x_mult[mult_not_na_idx]], dim=0) #.unsqueeze(dim=1)
                    out_x_wide = torch.cat([out["x_audio_wide"], x_mult_wide[mult_not_na_idx]], dim=0)
                else:
                    out_x = x_mult.unsqueeze(dim=1)
                    out_x_lab = x_mult[mult_not_na_idx] #.unsqueeze(dim=1)
                    out_x_wide = x_mult_wide[mult_not_na_idx]
                
                # fera ta
                
                loss = 0.8 * (0.9 * sup_contrastive_loss(out_x_lab, target, temperature=0.1) + 0.1 * sup_contrastive_loss(out_x, target_senti, temperature=0.1)) \
                        + 0.2 * unsupervised_contrastive_loss(x_mult_text_norm, x_mult_audio_norm, temperature=0.1, weights=weights)
                #loss = 0.5 * (0.5 * supcon_loss(out_x_lab, labels=target) + 0.5 * supcon_loss_senti(out_x, labels=target_senti)) + 0.5 * supcon_loss_intra(mult, weights=weights) 
            else:
                if x[0] is not None:
                    if x[1] is not None:
                        out_x = torch.cat([out["x_text"], out["x_audio"]], dim=0).unsqueeze(dim=1)
                        out_x_wide = torch.cat([out["x_text_wide"], out["x_audio_wide"]], dim=0)
                    else:
                        out_x = out["x_text"]
                        out_x_wide = out["x_text_wide"]
                else:
                    if x[1] is not None:
                        out_x = out["x_audio"]
                        out_x_wide = out["x_audio_wide"]
                    else:
                        raise Exception("Nothing to work :()")
                        
                loss = 0.8 * sup_contrastive_loss(out_x_lab, target, temperature=0.1) + 0.2 * sup_contrastive_loss(out_x, target_senti, temperature=0.1) #+ 0.8 * unsupervised_contrastive_loss(x_mult_text_norm, x_mult_audio_norm, temperature=0.8, weights=None)

        scaler.scale(loss).backward()
        scaler.unscale_(opt)

        #torch.nn.utils.clip_grad_norm_(supcon_model.parameters(), 30.0)
        scaler.step(opt)
        scaler.update()
        #scheduler.step()
        
        opt.zero_grad(set_to_none=True)

        epoch_loss += loss.item()
        proj_train.append(np.array(out_x_wide.detach().cpu()))
        targets_train.append(np.array(target.cpu()))

        del out_x
        del x_mult
        del out_x_wide
        gc.collect()
        torch.cuda.empty_cache()
    scheduler.step()
    proj_train = np.concatenate(proj_train, axis=0)
    targets_train = np.concatenate(targets_train, axis=0)
    
    clf = FaissKNeighbors(k=128)
    clf.fit(proj_train, np.array(targets_train, dtype=int))

    epoch_loss = epoch_loss/len(train_loader)
    #supcon_model.eval()
    preds = []
    targets = []
    css = 0.0
    wide_audio = []
    
    for i, batch_indices in enumerate(tqdm(test_loader, total=len(test_loader))):
        with torch.no_grad():
        
            multimodal_batch = test_df_erc.iloc[batch_indices[0]]

            audio_path_mult = [str(t['path']) for _, t in multimodal_batch.iterrows()]
            mfccs_mult, att_mult = audio_tokenizer.batch_tokenize(audio_path_mult)

            sentences_mult = [str(t['text']) for _, t in multimodal_batch.iterrows()]

            multimodal = {'sentences': sentences_mult, 
                          'audio_input': {"features": mfccs_mult.float().to(0), "attn_masks": att_mult.float().to(0)}}
        
            target = torch.Tensor(lab_encoder.transform(list(multimodal_batch["label"])))

            x = [None, None, multimodal]
            with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
                out = supcon_model(x)
            
            # Multimodal loss
            out_x_wide = F.normalize(out["x_mult_text_wide"] + out["x_mult_audio_wide"], dim=-1)
            
            cs = F.cosine_similarity(F.normalize(out["x_mult_text_wide"], dim=-1), F.normalize(out["x_mult_audio_wide"], dim=-1))

            wide = np.array(out_x_wide.cpu())
            wide_audio.append(np.array(F.normalize(out["x_mult_audio_wide"], dim=-1).cpu()))
            pred = clf.predict(wide)
            preds.append(pred)

            assert len(wide) == len(pred)

            proj_val.append(wide)
            targets_val.append(np.array(target.cpu()))
            css += np.sum(np.array(cs.cpu()))
            del out_x_wide
            gc.collect()
            torch.cuda.empty_cache()

    proj_val = np.concatenate(proj_val, axis=0)
    wide_audio = np.concatenate(wide_audio, axis=0)
    targets_val = np.concatenate(targets_val, axis=0)
    
    preds = np.array(np.concatenate(preds, axis=0))
    
    css = css / len(test_df_erc)

    general_f1 = f1_score(targets_val, preds, average='weighted')
    general_acc = accuracy_score(targets_val, preds)
    
    print(f'Cosine Similarity between mods: {css}')
    
    meld_idx = test_df_erc[test_df_erc["source"] == "meld"].index
    iemocap_idx = test_df_erc[test_df_erc["source"] != "meld"].index
    
    general_f1_iemocap = f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='weighted')
    general_f1_iemocap_audio = f1_score(targets_val[iemocap_idx], clf.predict(wide_audio)[iemocap_idx], average='weighted')
    general_acc_iemocap = accuracy_score(targets_val[iemocap_idx], preds[iemocap_idx])
    
    general_f1_meld = f1_score(targets_val[meld_idx], preds[meld_idx], average='weighted')
    general_f1_meld_audio = f1_score(targets_val[meld_idx], clf.predict(wide_audio)[meld_idx], average='weighted')
    general_acc_meld = accuracy_score(targets_val[meld_idx], preds[meld_idx])
    
    print(f'General - KNN F1: {general_f1} Acc: {general_acc}')
    print(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
    print(f'Iemocap - KNN F1 - Only Audio: {general_f1_iemocap_audio}')
    print(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
    print(f'Meld - KNN F1 - Only Audio: {general_f1_meld_audio}')
    print(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
    print(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

    try:
        tsne = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=5).fit_transform(proj_val)

        sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=lab_encoder.inverse_transform(list(np.array(targets_val, dtype=int))) , palette='tab10')
        plt.show()
    
    except:
        pass
    
    print(f'Epoch: {e + 1} - Train Loss: {epoch_loss}')
    e += 1
    
    #if e == scheduler_epochs: # Unfreeze text encoder
    #    for i, (name, param) in enumerate(list(supcon_model.text_encoder.named_parameters())):
    #        param.requires_grad = True

    with open(f"{PATH_TO_SAVE}/metrics_epoch_{e}.txt", "w") as f:
        f.write(f'General - KNN F1: {general_f1} Acc: {general_acc}')
        f.write(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
        f.write(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')
        f.write(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
        
        f.write(f'Iemocap - KNN F1 - Only Audio: {general_f1_iemocap_audio}')
        f.write(f'Meld - KNN F1 - Only Audio: {general_f1_meld_audio}')
        f.write(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
        
        f.write(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")
        
    checkpoint = {"model": supcon_model.state_dict(),
              "optimizer": opt.state_dict(),
              "scaler": scaler.state_dict()}
    torch.save(checkpoint, f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_{e}.bin')

 27%|███████████▌                               | 24/89 [01:00<02:34,  2.38s/it]

In [None]:
"""
Cosine Similarity between mods: 0.6148036817116292
General - KNN F1: 0.652949920693673 Acc: 0.6535964684497533
Iemocap - KNN F1: 0.7621117618450867 Acc: 0.7558420628525383
Iemocap - KNN F1 - Only Audio: 0.48286032752145575
Meld - KNN F1: 0.6025222678926638 Acc: 0.6049808429118774
Iemocap - KNN F1 (macro): 0.6085758686292261
Meld - KNN F1 (macro): 0.41899567212117655

Epoch: 31 - Train Loss: 6.175596459077136
"""

## Eval

In [None]:
ss

In [None]:
#import pickle
#pickle.dump(kmeans, open("./transformer_1_layer_repetindo/kmeans_200_clusters_curr.pkl", 'wb'))

In [None]:
gc.collect()

In [None]:
#PATH_TO_SAVE = 'ESTAMOS_PERTO_AMIGO_ESTOU_AQUI_4_freezed_5_layer_pivoting_to_speech_training'

In [None]:
#torch.load(f'pre_test_final_2/pytorch_model_AudioTextCLIP_epoch_35.bin')['model']
#torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_1.bin')['model']

In [None]:
PATH_TO_SAVE = 'ESTAMOS_PERTO_AMIGO_ESTOU_AQUI_4_freezed_11'

supcon_model = AudioTextContrastive(
    text_encoder,
    audio_encoder,
    in_features_text=768,
    in_features_audio=dim_embed, 
    hidden_size=768,
    wide_proj=1024,
    proj_size=128, 
    rate=0.2,
).cuda()
supcon_model.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_23.bin')['model'])

In [None]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [None]:
get_n_params(supcon_model)

In [None]:
supcon_model.eval()

In [None]:
import pickle

In [None]:
supcon_model.train()
test = supcon_model([["I Hate you, i believe you are shit!", "You are my best friend, love you!"],None, None])
torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_text_wide"][1, :], dim=-1))

In [None]:
test = supcon_model([["The best man ever, keep the good work!", "you are my best friend, love you!"],None, None])
torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_text_wide"][1, :], dim=-1))

In [None]:
test = supcon_model([["I Hate you, i believe you are shit!", "Fuck you, you should not be alive"],None, None])
torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_text_wide"][1, :], dim=-1))

In [None]:
test = supcon_model([["I love you, mate!", "Fuck you, you should not be alive"],None, None])
torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_text_wide"][1, :], dim=-1))

In [None]:
supcon_model.train()
with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.sad.100.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I am very sad"],audio_input, None])
    print(torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_audio_wide"][0, :], dim=-1)))

In [None]:
torch.norm(test["x_text"][0, :])

In [None]:
torch.norm(test["x_audio"][0, :])

In [None]:
supcon_model.train()
supcon_model.training = False
with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.sad.101.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I love my girlfriend, but she died"],audio_input, None])
    print(torch.dot(test["x_text"][0, :], test["x_audio"][0, :]))

In [None]:
___

In [None]:
supcon_model.train()
supcon_model.training = False
dropout_modules = [module for module in supcon_model.modules() if isinstance(module,torch.nn.Dropout)]
[module.eval() for module in dropout_modules]

with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.sad.279.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I am happy"], audio_input, None])
    print(torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_audio_wide"][0, :], dim=-1)))

In [None]:
supcon_model.train()
supcon_model.training = False
dropout_modules = [module for module in supcon_model.modules() if isinstance(module,torch.nn.Dropout)]
[module.eval() for module in dropout_modules]

with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.sad.279.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["My dog was great, but he was also cute! Today he is dead"], audio_input, None])
    print(torch.dot(F.normalize(test["x_text"][0, :], dim=-1), F.normalize(test["x_audio"][0, :], dim=-1)))

In [None]:
supcon_model.train()
supcon_model.training = False
dropout_modules = [module for module in supcon_model.modules() if isinstance(module,torch.nn.Dropout)]
[module.eval() for module in dropout_modules]

with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.happy.50.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I had a discussion with my mother", "I love my mother"],audio_input, None])
    print(torch.dot(test["x_text_wide"][0, :], test["x_audio_wide"][0, :]))
    print(torch.dot(test["x_text_wide"][1, :], test["x_audio_wide"][0, :]))

In [None]:
with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.happy.50.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I just finished my PhD!!", "I finished my PhD, but I dont have a job"],audio_input, None])
    print(torch.dot(F.normalize(test["x_text_wide"][0, :], dim=-1), F.normalize(test["x_audio_wide"][0, :], dim=-1)))
    print(torch.dot(F.normalize(test["x_text_wide"][1, :], dim=-1), F.normalize(test["x_audio_wide"][0, :], dim=-1)))

In [None]:
supcon_model.train()
supcon_model.training = False
dropout_modules = [module for module in supcon_model.modules() if isinstance(module,torch.nn.Dropout)]
[module.eval() for module in dropout_modules]

with torch.no_grad():
    m, a = audio_tokenizer.batch_tokenize(["./audio/audio_emo/tess.woman.happy.279.wav", "./audio/audio_emo/tess.woman.sad.59.wav"])
    audio_input = {
        "features": m.to(0),
        "attn_masks": a.to(0),
    }
    test = supcon_model([["I did not pass in the final exam, i will kill myself"], audio_input, None])
    print(torch.dot(F.normalize(test["x_audio_wide"][0, :], dim=-1), F.normalize(test["x_audio_wide"][1, :], dim=-1)))
    print(torch.dot(F.normalize(test["x_audio_wide"][0, :], dim=-1), F.normalize(test["x_text_wide"][0, :], dim=-1)))
    print(torch.dot(F.normalize(test["x_audio_wide"][1, :], dim=-1), F.normalize(test["x_text_wide"][0, :], dim=-1)))

In [None]:
test

In [None]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        print(output)
        activation[name] = output #.detach()
    return hook

In [None]:
get_activation

In [None]:
supcon_model

In [None]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        print(output)
        print(output.shape)
        activation[name] = output #.detach()
    return hook

supcon_model = AudioTextContrastive(
    text_encoder,
    audio_encoder,
    in_features_text=768,
    in_features_audio=dim_embed, 
    hidden_size=768,
    wide_proj=1024,
    proj_size=128, 
    freeze_text_enc=True,
    freeze_audio_enc=False,
    rate=0.2,
).cuda()

supcon_model.load_state_dict(torch.load(f'{PATH_TO_SAVE}/pytorch_model_AudioTextCLIP_epoch_16.bin')['model'])

supcon_model.audio_proj.register_forward_hook(get_activation('audio_proj'))
output = supcon_model([["I had a discussion with my mother"],audio_input, None])
activation['audio_proj']

In [None]:
m

In [None]:
test["x_audio_wide"][0, :]

In [None]:
test["x_audio"][0, :]

In [None]:
test

In [None]:
print(df_train[df_train["label"] == "sadness"]["path"].tolist())

In [None]:
#supcon_model.load_state_dict(torch.load('./pytorch_model_AudioTextCLIPvFinal_epoch_25_only_meld.bin'))

In [None]:
#supcon_model.audio_encoder.clusterization_model = kmeans

In [None]:
gc.collect()

In [None]:
df_train_f =df_train

In [None]:
#df_dev_audio = pd.concat([df_meld_dev, test_audio], axis=0)

# Param: Select dataset for scoring

In [None]:
meld_train_idx = train_df_erc[train_df_erc["path"].apply(lambda x: True if "MELD" in x else False)].index

In [None]:
iemocap_train_idx = train_df_erc[train_df_erc["path"].apply(lambda x: False if "MELD" in x else True)].index

In [None]:
#train_audio_repeated = pd.concat([df_train_audio, df_train_audio,df_train_audio,df_train_audio,df_train_audio,df_train_audio,df_train_audio, df_train_audio,df_train_audio,df_train_audio,df_train_audio,df_train_audio], axis=0).sample(frac=1).reset_index(drop=True)
#test_audio_repeated = pd.concat([df_dev_audio, df_dev_audio,df_dev_audio,df_dev_audio,df_dev_audio,df_dev_audio], axis=0).sample(frac=1).reset_index(drop=True)
#train_iemocap = train_df_erc.iloc[iemocap_train_idx].reset_index(drop=True)
#train_iemocap = train_df_erc.iloc[meld_train_idx].reset_index(drop=True)
train_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(train_df_erc)))))
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=1024, shuffle=False)

test_ds = torch.utils.data.TensorDataset(torch.Tensor(list(range(len(test_df_erc)))))
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=1024, shuffle=False)

In [None]:
gc.collect()

In [None]:
supcon_model.train()

proj_val = []
targets_val = []

proj_train = []
targets_train = []

for i, batch_indices in enumerate(tqdm(train_loader, total=len(train_loader))):
    with torch.no_grad():
        batch = train_df_erc.iloc[batch_indices[0]]
        only_text = batch[batch["path"].isna()]
        sentences = only_text["text"].tolist()
        y_text = torch.Tensor(lab_encoder.transform(only_text["label"]))

        only_audio = batch[batch["text"].isna()]
        audio_paths = only_audio["path"].tolist()
        try:
            mfccs, att = audio_tokenizer.batch_tokenize(audio_paths)

            audio_input = {
                "features": mfccs.to(0),
                "attn_masks": att.to(0),
            }
        except:
            audio_input = None

        y_audio = torch.Tensor(lab_encoder.transform(only_audio["label"]))

        mult = batch[batch["text"].notna()]
        mult = mult[mult["path"].notna()]
        mult = mult[mult["label"].notna()]
        y_mult = torch.Tensor(lab_encoder.transform(mult["label"]))

        audio_path_mult = [str(t['path']) for _, t in mult.iterrows()]

        mfccs_mult, att_mult = audio_tokenizer.batch_tokenize(audio_path_mult)

        sentences_mult = [str(t['text']) for _, t in mult.iterrows()]

        multimodal = {'sentences': sentences_mult, 
                      'audio_input': {"features": mfccs_mult.to(0), "attn_masks": att_mult.to(0)}}

        target = torch.cat([y_text, y_audio, y_mult])

        x = [sentences, audio_input, multimodal]

        if len(sentences) == 0:
            x[0] = None
        if len(audio_paths) == 0:
            x[1] = None
        if len(sentences_mult) == 0:
            x[2] = None

        with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:

            out = supcon_model(x)

            # Multimodal loss
            x_mult_wide = F.normalize(out["x_mult_text_wide"] + out["x_mult_audio_wide"], dim=-1)
            #x_mult_wide = F.normalize(out["x_mult_audio_wide"], dim=-1)

        proj_train.append(np.array(x_mult_wide.detach().cpu()))
        targets_train.append(np.array(target.cpu()))

        del x_mult_wide
        gc.collect()
        torch.cuda.empty_cache()
proj_train = np.concatenate(proj_train, axis=0)
targets_train = np.concatenate(targets_train, axis=0)

clf = FaissKNeighbors(k=128)
clf.fit(proj_train, np.array(targets_train, dtype=int))

preds = []
targets = []
css = 0.0

for i, batch_indices in enumerate(tqdm(test_loader, total=len(test_loader))):
    with torch.no_grad():

        multimodal_batch = test_df_erc.iloc[batch_indices[0]]

        audio_path_mult = [str(t['path']) for _, t in multimodal_batch.iterrows()]
        mfccs_mult, att_mult = audio_tokenizer.batch_tokenize(audio_path_mult)

        sentences_mult = [str(t['text']) for _, t in multimodal_batch.iterrows()]

        multimodal = {'sentences': sentences_mult, 
                      'audio_input': {"features": mfccs_mult.to(0), "attn_masks": att_mult.to(0)}}

        target = torch.Tensor(lab_encoder.transform(list(multimodal_batch["label"])))

        x = [None, None, multimodal]
        with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16) as autocast, torch.backends.cuda.sdp_kernel(enable_flash=False) as disable:
            out = supcon_model(x)

            # Multimodal loss
            out_x_wide = F.normalize(out["x_mult_text_wide"] + out["x_mult_audio_wide"], dim=-1)
            #out_x_wide = F.normalize(out["x_mult_audio_wide"], dim=-1)

        cs = F.cosine_similarity(out["x_mult_text_wide"], out["x_mult_audio_wide"])

        wide = np.array(out_x_wide.cpu())
        pred = clf.predict(wide)
        preds.append(pred)

        assert len(wide) == len(pred)

        proj_val.append(wide)
        targets_val.append(np.array(target.cpu()))
        css += np.sum(np.array(cs.cpu()))
        del out_x_wide
        gc.collect()
        torch.cuda.empty_cache()

proj_val = np.concatenate(proj_val, axis=0)
targets_val = np.concatenate(targets_val, axis=0)

preds = np.array(np.concatenate(preds, axis=0))

css = css / len(test_df_erc)

general_f1 = f1_score(targets_val, preds, average='weighted')
general_acc = accuracy_score(targets_val, preds)

print(f'Cosine Similarity between mods: {css}')

meld_idx = test_df_erc[test_df_erc["source"] == "meld"].index
iemocap_idx = test_df_erc[test_df_erc["source"] != "meld"].index

general_f1_iemocap = f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='weighted')
general_acc_iemocap = accuracy_score(targets_val[iemocap_idx], preds[iemocap_idx])

general_f1_meld = f1_score(targets_val[meld_idx], preds[meld_idx], average='weighted')
general_acc_meld = accuracy_score(targets_val[meld_idx], preds[meld_idx])

print(f'General - KNN F1: {general_f1} Acc: {general_acc}')
print(f'Iemocap - KNN F1: {general_f1_iemocap} Acc: {general_acc_iemocap}')
print(f'Meld - KNN F1: {general_f1_meld} Acc: {general_acc_meld}')

print(f"Iemocap - KNN F1 (macro): {f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='macro')}")
print(f"Meld - KNN F1 (macro): {f1_score(targets_val[meld_idx], preds[meld_idx], average='macro')}")

tsne = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=5).fit_transform(proj_val)

sns.scatterplot(x=tsne[:, 0], y=tsne[:, 1], hue=lab_encoder.inverse_transform(list(np.array(targets_val, dtype=int))) , palette='tab10')
plt.show()

In [None]:
train_iemocap

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
mean_test = proj_train[meld_train_idx].mean(axis=0)
std_test = proj_train[meld_train_idx].std(axis=0)
clf = FaissKNeighbors(k=128)
clf.fit((proj_train[meld_train_idx]-mean_test)/std_test, np.array(targets_train[meld_train_idx], dtype=int))

preds = clf.predict((proj_val-mean_test)/std_test)

general_f1_meld = f1_score(targets_val[meld_idx], preds[meld_idx], average='weighted')
general_acc_meld = accuracy_score(targets_val[meld_idx], preds[meld_idx])

from sklearn.metrics import classification_report

print(classification_report(targets_val[meld_idx], preds[meld_idx], digits=4))

In [None]:
len(targets_val[iemocap_idx])

In [None]:
set(lab_encoder.inverse_transform(np.array(targets_val[iemocap_idx], dtype=int)))


In [None]:
df_iemocap_orig = pd.read_json("emotions.json").reset_index(drop=False)
df_iemocap_orig = pd.melt(df_iemocap_orig, id_vars=['index'], value_vars=['train', 'val', 'test']).dropna().drop(columns=["variable"]).rename(columns={"index":"id", "value": "orig_label"}).reset_index(drop=True)
df_iemocap_orig = df_iemocap_orig[df_iemocap_orig["orig_label"].notna() & (df_iemocap_orig["orig_label"] != "undecided")].reset_index(drop=True)
df_iemocap_orig

In [None]:
def cleaning_shit(x):
    if "MELD" in x:
        return None
    x = x.replace("val/", "")
    x = x.replace("train/", "")
    x = x.replace("test/", "")
    l = len("/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/")
    return x[l:].replace(".wav", "")


In [None]:
#train_df_erc_iemocap = train_df_erc[train_df_erc["path"].apply(lambda x: True if "IEMOCAP" in x else False)]
train_df_erc["id"] = train_df_erc["path"].apply(cleaning_shit)
train_df_erc_iemocap = train_df_erc.dropna()
train_df_erc_iemocap = train_df_erc_iemocap.merge(df_iemocap_orig, on="id", how="inner").dropna()
train_df_erc_iemocap

In [None]:
len(iemocap_train_idx)

In [None]:
test_df_erc["id"] = test_df_erc["path"].apply(lambda x: x[len('/home/vmachado/Documents/multimodal-datasets/IEMOCAP/raw-audios/test/'):].replace(".wav", "")) 

In [None]:
test_df_erc_iemocap = test_df_erc[test_df_erc["source"] == "iemocap"].reset_index(drop=True)
test_df_erc_iemocap

In [None]:
test_df_erc_iemocap = test_df_erc_iemocap.merge(df_iemocap_orig, on="id", how="inner")
test_df_erc_iemocap

In [None]:
new_lab = LabelEncoder().fit(train_df_erc_iemocap["orig_label"])

In [None]:
train_df_erc_iemocap["orig_label"].unique()

In [None]:
test_df_erc_iemocap["orig_label"].unique()

In [None]:
train_df_erc_iemocap

In [None]:
test_df_erc_iemocap

In [None]:
correct_labels_train = new_lab.transform(train_df_erc_iemocap["orig_label"])
correct_labels_test = new_lab.transform(test_df_erc_iemocap["orig_label"])

In [None]:
correct_labels_train

In [None]:
lab_encoder.classes_

In [None]:
test_df_erc[test_df_erc["source"] == "iemocap"]["label"].value_counts()

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
mean_test = proj_train.mean(axis=0)
std_test = proj_train.std(axis=0)

#clf = MLPClassifier(hidden_layer_sizes=(768,), learning_rate="invscaling", solver="sgd", max_iter=5000, validation_fraction=0.2, nesterovs_momentum=False)
clf = LogisticRegression()
clf.fit((proj_train-mean_test)/std_test, np.array(targets_train, dtype=int))

preds = clf.predict((proj_val-mean_test)/std_test)

from sklearn.metrics import classification_report

print(classification_report(targets_val[iemocap_idx], preds[iemocap_idx], digits=4))

In [None]:
general_f1_iemocap = f1_score(targets_val[iemocap_idx], preds[iemocap_idx], average='weighted')
general_acc_iemocap = accuracy_score(targets_val[iemocap_idx], preds[iemocap_idx])



In [None]:
general_f1_iemocap

In [None]:
general_f1_meld

In [None]:
from sklearn.metrics import classification_report

print(classification_report(targets_val[iemocap_idx], list(map(lambda x: x if x != 1 else 6, preds[iemocap_idx])), digits=4))

## 