In [None]:
import numpy as np
import pandas as pd  

from sklearn.metrics import confusion_matrix

import os
from torch.utils.data import DataLoader, Dataset

import torch,torchaudio

import torch.nn as nn
from torch import optim
import torch.nn.functional as F
#from torchmetrics import Accuracy
from einops import rearrange
from pathlib import Path
import matplotlib.pyplot as plt
from torchinfo import summary
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling

In [None]:
import os ,random
import numpy as np
os.environ['PYTHONHASHSEED'] = '42'
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
class Config:
    def __init__(self,dataset_name):
        self.dataset_name = dataset_name
        self._config = self._select_config()
    def _select_config(self):
        if self.dataset_name == 'Emodb':
            return {
             'batch_size':16,
            'dataset':self.dataset_name,
            'session': 'emodb_session',
            'class_label_name' : 'emodb_emotion',
            'dataset_path' : 'niloufarr/Dataset/emodb-dataset:v0',
            'csv_path' : 'niloufarr/Dataset/pairs-emodb_imocap:v1',
            "sessionList": [ 3, 8,  9, 10, 11, 12, 13, 14, 15, 16],
                 "x_name":'emodb',
            "loss" :768

            }
    def get_dataset(self):
        return self.dataset_name.split("_")
    def get_property(self, property_name):
        if not self._config or property_name not in self._config:
            print(f"{property_name} not found in {self.dataset_name} configuration!!!")
            return None
        return self._config[property_name]


instance_config = Config("Emodb")

In [None]:
import wandb
wandb.login(key='')
sweep_config = {
    'method':'grid'
}
parameters_dict = {
    'fold':{
        'values': [[3,8,9,10],[10,11,12,13],[13,14,15,16]]
    }
}
parameters_dict.update({
     'project_name':{
        'value':"Finetune DistilHuBERT"
    },
    'epochs':{
      'value': 35
    },
    'batch_size':{
        'value':  instance_config.get_property('batch_size')
    },
    'learning_rate':{
        'value': 0.0003#5e-3
    },
    'weight_decay':{
        'value': 9e-3
    },
    'gpu':{
        'value':'P100'
    },
    'loss':{
        'value':'focal'
    },
    'optimizer':{
        'value': 'AdamW'
    },
    'class_number':{
        'value': 4
    },
    'class_names':{
        'value':['neutral', 'happy', 'angry', 'sad']
    },
    'valid_labels':{
        'value':['0', '1', '2', '3']
    },
    'dataset':{
        'value':instance_config.get_property("dataset")
    },
    'dataset_path':{
        'value':instance_config.get_property("dataset_path")
    },
    'csv_path':{
        'value':instance_config.get_property("csv_path")
    },
    
    'model_path':{
        'value':instance_config.get_property("modelPath") if instance_config.get_property("modelPath")!=None else '-'
    },
    'asp_att_dim':{
        'value':768
    },
    'fc':{
        'value':256
    },
    'maxout_num_linear_function':{
        'value':3
    },
    'freeze':{
        'value':True
    },
    'freezeList':{
        'value':[
#             'model.model.model.feature_extractor.conv_layers',
                  'model.model.model.encoder.layers.0',
                 'model.model.model.encoder.layers.1',
    'model.model.model.post_extract_proj',
                'model.model.model.encoder.pos_conv.0',
                ]
    },

    'description':{
        'value':'whole distilhebert asp maxout layerNorm dropout'
    },
    'description2':{
        'value':f"finetune with {len(parameters_dict['fold']['values'][0])} speakers  update cnn freeze transformers"
    }
    })
sweep_config['parameters'] = parameters_dict


In [None]:
sweep_id = input('What is sweep_id? (leave out if this is first sweep) ')
if sweep_id=="":
    sweep_id = wandb.sweep(sweep_config, project=parameters_dict['project_name']['value'])
saveModel = True

In [None]:
def get_waveform(file_path, sample_rate=16000):
      wf, sr = torchaudio.load(file_path)
      resample = torchaudio.transforms.Resample(sr, sample_rate)
      waveform = resample(wf)
      if waveform.shape[0]>1:
        waveform = waveform.mean(dim=0,keepdim=True)
      return waveform
def tile(waveform, expected_time):
  waveform_time = waveform.shape[1]
  expected_time = expected_time
  repeat_times = (expected_time // waveform_time) + 1
  tiled_data = waveform.repeat(1, repeat_times)  
  return tiled_data[:, :expected_time]

class CaseInSensitiveDict(dict):
    def __getitem__(self,key):
        for k in self.keys():
            if key.lower()== k.lower():
                return super().__getitem__(k)
        raise  KeyError(key)

In [None]:
def encode_labels(labels):
    emotion_label_dict={
        'neutral':0,
        'happy':1,
        'angry':2,
        'sad':3,
        'anger':2,
    }
    
    case_insensitive_dict = CaseInSensitiveDict(emotion_label_dict)
    return [case_insensitive_dict[emo] for emo in labels]

def decode_labels(y):
    emotion_mapping={
        0:'neutral',
        1:'happy',
        2:'angry',
        3:'sad'
       
    }
    
    
    return [emotion_mapping[emo] for emo in y]

In [None]:
class ClassifierDataset(Dataset):
    def __init__(self,csv_file_name,path_dataset,sessionList,x_name,sessionListName,class_label_name):
        
        df=pd.read_csv(csv_file_name)
        self.path_dataset=path_dataset
        print("sessionList",sessionList)
        self.temp_df=df[df[sessionListName].isin( sessionList)].reset_index(drop=True)
        print("self.temp_df size",self.temp_df.shape)
        self.speech_path=self.temp_df[x_name].values
        self.label=encode_labels(self.temp_df[class_label_name].values)
        
        
    def __getitem__(self,index):
        
        speech=tile(get_waveform(os.path.join(self.path_dataset,self.speech_path[index])),
                     16000*7)
        return speech, torch.tensor(self.label[index],dtype=torch.long)
    
    def __len__(self):
        print("self.temp_df.shape[0]",self.temp_df.shape[0])
        return self.temp_df.shape[0]

In [None]:
class ASP(nn.Module):

    def __init__(self, num_emdb, attn_dim=None,split=True):
        super().__init__()
        self.channels = num_emdb
        self.split = split
        if not attn_dim:
          attn_dim = num_emdb
        self.asp = AttentiveStatisticsPooling(channels= self.channels, attention_channels=attn_dim, global_context=True)

    def forward(self, x: torch.Tensor):

        x = rearrange(x, "b l c -> b c l")
        x = self.asp(x)
#         x = x.squeeze()
        x=torch.squeeze(x,2)
        if self.split :
            x, _ = torch.split(x, self.channels, dim=1)
                
        return x

class Maxout(nn.Module):
    def __init__(self, in_features, out_features, num_linear_function=2):
        super(Maxout, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_linear_function = num_linear_function
        self.fc = nn.Linear(in_features, out_features * num_linear_function)

    def forward(self, x):
        x = self.fc(x)
        x, _ = torch.max(x.view(-1, self.out_features, self.num_linear_function), dim=2)
        return x


In [None]:
class Classifier(nn.Module):
    def __init__(self,model,config):
        super(Classifier,self).__init__()
        self.model= model
        
        self.asp=ASP(num_emdb = instance_config.get_property('loss'), attn_dim=config.asp_att_dim)
        self.maxout = Maxout(in_features= instance_config.get_property('loss'), out_features= config.fc,num_linear_function=config.maxout_num_linear_function)
        self.maxout_norm = nn.LayerNorm(normalized_shape = config.fc)
        self.dropout = nn.Dropout(config.proj_drop)
        self.fc3=nn.Linear(config.fc,config.class_number)#768,4
                
    def forward(self,x):
        x = self.model(x)["paper"]
        x = self.asp(x)
        x = self.maxout(x)
        x = self.maxout_norm(x)
        x = self.dropout(x)
        e = x.clone()
        x=self.fc3(x)
        return x,e

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha = 0.8, gamma = 2, class_number =4):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_number = class_number
    def forward(self, inputs, targets,*args):
        targets = torch.nn.functional.one_hot(targets ,num_classes=self.class_number)
        #comment out if your model contains a sigmoid or equivalent activation layer
        sig = nn.Sigmoid()
        inputs = sig(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        inputs=inputs.to(torch.float32)
        targets = targets.view(-1)
        targets=targets.to(torch.float32)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = self.alpha * (1-BCE_EXP)**self.gamma * BCE
                       
        return focal_loss

In [None]:
def get_optimizer(net,config):
    
    if config.optimizer=='AdamW':
        optimizer =torch.optim.AdamW(net.parameters(),
                                  lr = config.learning_rate,
                                  weight_decay = config.weight_decay)
    if config.optimizer == 'SGD':
        optimizer =torch.optim.SGD(lr=config.learning_rate
                                  ,weight_decay = config.weight_decay)
    return optimizer

In [None]:
def get_classifier(config):
   net = initialization_weights(config.initialize_weight)
   classifierModel = Classifier(net,config).to(device)
        
   if config.freeze:
       freezeTuple = tuple(config.freezeList)
       for name,params in classifierModel.named_parameters(): 
            if name.startswith(freezeTuple):
                params.requires_grad = False
            print(name,params.requires_grad)
   return classifierModel

In [None]:
def make(config):
    if torch.cuda.device_count()>1:
        net = nn.DataParallel(get_classifier(config),device_ids=[0, 1])
        print("Let's use", torch.cuda.device_count(), "GPUs!")
    else:
        net = get_classifier(config)
    net.to(device)
    criterion = get_loss(config)
    optimizer = get_optimizer(net,config)
    return net,criterion,optimizer

    file_name = list(filter(lambda f:f.endswith(".csv"),os.listdir(artifact_dir_csv)))[0]
    classifier_dataset = ClassifierDataset(csv_file_name=os.path.join(artifact_dir_csv,file_name),
                                       path_dataset=artifact_dir,
                                       sessionList=config.fold,
                                       x_name=instance_config.get_property("x_name"),
                                       sessionListName= instance_config.get_property("session"),
                                       class_label_name=instance_config.get_property("class_label_name"))
    train_dataloader = DataLoader(classifier_dataset, batch_size=config.batch_size, num_workers=2, shuffle=True)
    return train_dataloader


In [None]:
def make_test_dataloader(dataset_path,csv_directory,sessionName,sessionList,batch_size,x_name='emodb'):
#     directory_path = artifact_dir_source_csv if x_name in artifact_dir_source_csv else artifact_dir_target_csv
    file_name = list(filter(lambda f:f.endswith(".csv"),os.listdir(csv_directory)))[0]
    classifier_dataset = ClassifierDataset(csv_file_name=os.path.join(csv_directory, file_name),
                                        path_dataset=dataset_path,#config.target_dataset_path,
                                       sessionList= sessionList,
                                       x_name=x_name,#'emodb',
                                       sessionListName=sessionName,#x_name+"_session",
                                       class_label_name=instance_config.get_property('class_label_name'))#x_name+'_emotion')#'emodb_emotion')


    test_dataloader = DataLoader(classifier_dataset, num_workers=2, batch_size=batch_size, shuffle=False)
#     print("test_dataloader len",len(classifier_dataset))
    return test_dataloader

In [None]:
def get_sessionList(property_name:str,fold:list):
    '''Flatten session list.
        Args:
            property_name (str): session list name (source or target).
            fold (list): session should be removed from session list.
        Returns
            list : flattened sessionList
    '''
    sessionSet = set(instance_config.get_property(property_name))
    sessionList = list(sessionSet.difference(set(fold)))
    return sessionList

In [None]:
from sklearn.metrics import balanced_accuracy_score,accuracy_score
import copy
def test(config,datasetName,test_dataloader,*net):
    y_pred = []
    y_true = []
    flag=True
    if len(net)==0:
       net = load_state_dict(config) 
    else:
        net=net[0]
        flag=False
    
    for i, (speech, label) in enumerate(test_dataloader):
      speech = speech.squeeze(1)
      inputs, labels = speech.to(device), label.to(device)
      outputs,embeddings = net(inputs)
      _, predicted = torch.max(outputs.data, 1)
      y_true += labels.tolist()
      y_pred += predicted.tolist()
          
    WA=balanced_accuracy_score(y_true, y_pred)
    UA=accuracy_score(y_true, y_pred)
    wandb.log({
          "{} Weighted Accuracy".format(datasetName): WA,
          "{} Unweighted Accuracy".format(datasetName): UA
          })
    
    if flag:
        label = datasetName if datasetName.isupper() else datasetName[0].upper() + datasetName[1:]
        log_confusion_matrix(y_pred,y_true,label)

       
    return WA,UA


In [None]:
def save_model(config):
        pickle_artifact = wandb.Artifact(
            name = f'model.pkl',
            type="model",
            metadata=dict(config))
        wandb.save(PATH)

        pickle_artifact.add_file(PATH)


        wandb.log_artifact(pickle_artifact)

In [None]:
PATH='model.pkl'
def train(config,net,criterion,optimizer):
    best_accuracy=0
    best_VUA = 0
    accuracy=0
    accuracyArr = np.array([])
    train_dataloader= make_train_loader(config)
    wandb.watch(net, criterion, log="all", log_freq=1,log_graph=False)
    net.train()
    print("validation")
    test_dataloader1 =make_test_dataloader(artifact_dir,csv_directory=artifact_dir_csv
                                           ,sessionName=instance_config.get_property("session")
                                            ,sessionList= get_sessionList("sessionList",config.fold)
                                           ,batch_size = config.batch_size
                                           ,x_name=instance_config.get_property("x_name"))
    # Iterate throught the epochs
    for epoch in range(config.epochs):
        train_loss = []
        train_acc  = []
        # Iterate over batches
        for i, (speech,label) in enumerate(train_dataloader):
            
            # Send the speech and labels to CUDA
            print("speech", speech.shape)
            speech, label = speech.to(device), label.to(device)
            speech = speech.squeeze(1)
            # Zero the gradients (PyTorch accumulates the gradients on subsequent backward passes. 
            #This accumulating behaviour is convenient while training RNNs or when we want to compute the gradient of the loss summed over multiple mini-batches.)
            optimizer.zero_grad()




            # Pass in the two speeches into the network and obtain two outputs and classifier output for the  source speech
            output,embedings = net(speech)
            _,predicted=torch.max(output.data,1)
            loss = get_loss_param(criterion,output,embedings,label)

            # Calculate the backpropagation
            loss.backward()
            # Optimize
            optimizer.step()
            correct_predictions = torch.sum(predicted == label)
            number_of_predictions = torch.numel(predicted) #int. Returns the total number of elements in the input tensor.
            accuracy = correct_predictions/number_of_predictions
            train_loss.append(loss.item())
            train_acc.append(accuracy.item())
            print("loss " ,loss.item())
            print("accuracy ",accuracy.item())
            print("******************")
        trainAccuracy=np.mean(train_acc)
        print("epoch {} accuracy {}".format(epoch,trainAccuracy))
        wandb.log({
            "Train loss": np.mean(train_loss), 
            "Train Accuracy": trainAccuracy,
            "epoch":epoch+1
            })
        print("&&&&&&&&&&&&&&&&&&&&&&&&&\n validation\n$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        
        _,VUA = test(config,"Validation "+instance_config.get_property("x_name"),test_dataloader1,net)
        
        print(f"epoch {epoch} vua",VUA)
        print("&&&&&&&&&&&&&&&&&&&&&&&&&\n validation\n$$$$$$$$$$$$$$$$$$$$$$$$$$$")
        
        if trainAccuracy > best_accuracy:
            best_accuracy =trainAccuracy
            if trainAccuracy != 1:
                best_VUA = VUA
            print("epoch {} best accuracy {}".format(epoch,best_accuracy))
            if torch.cuda.device_count()>1:
                torch.save(net.module.state_dict(),PATH)

            else:  
                torch.save(net.state_dict(), PATH)
        if trainAccuracy == 1 and VUA > best_VUA:
            print("best vua",VUA)
            best_VUA = VUA
            if torch.cuda.device_count()>1:
                torch.save(net.module.state_dict(),PATH)

            else:  
                torch.save(net.state_dict(), PATH)
            
        if(epoch > 16):
            check , accuracyArr = isEarlyStopping(accuracyArr,trainAccuracy)
            
            if(check):
                print("break in epoch:", epoch)
                break
        else: 
            accuracyArr = np.append(accuracyArr,trainAccuracy)
        print("---------------------------")

In [None]:
def make_model_pipeline(hyperparameters=None):
    with wandb.init(config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config
        
        net,criterion,optimizer=make(config) 
        # train model
        train(config,net,criterion,optimizer)
        if saveModel:
            save_model(config)
        print("____________________________________")
        print("***************************")
        #Test on target dataset
        test_dataloader =make_test_dataloader(artifact_dir,csv_directory=artifact_dir_csv,
                                              sessionName=instance_config.get_property("session")
                                               ,sessionList = get_sessionList("sessionList",config.fold)
                                              ,batch_size = config.batch_size
                                              ,x_name=instance_config.get_property("x_name"))
      
        try:
            print(summary(net,input_size=(config.batch_size, config.input_size[0], config.input_size[1])))
        except:
            print("something is wrong with torchinfo")

In [None]:
wandb.agent(sweep_id,project=parameters_dict['project_name']['value'],entity='',function=make_model_pipeline)