In [None]:
import wandb
wandb.login(key="")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
sweep_config = {
    'method':'grid'
}
parameters_dict = {
    'fold':{
        'values':[ 3,8, 9, 10, 11, 12, 13, 14, 15, 16]
    }
}
sweep_config['parameters'] = parameters_dict
parameters_dict.update({
    'epochs':{
      'value': 50
    },
    'batch_size':{
        'value': 8
    },
    'learning_rate':{
        'value': 3e-5
    },
    'weight_decay':{
        'value': 9e-3
    },
    'loss':{
        'value':'contrastive'
    },
    'optimizer':{
        'value': 'AdamW'
    },
    'project_name':{
        'value':"siameseNet"
    },
    'input_size':{
        'value': (1,7*16000)
    },
    'sample_rate':{
        'value': 16000
    },
    parameters_dict.update({
    'sessionListDataset':{
        'value':[9, 10, 11, 12, 13, 14, 15, 16]},
    'dataset':{
        'value':"IEMOCAP_EMODB {}".format(pair_path)
    },
    'source_dataset_path':{
        'value':artifact_dir_iemocap
    },
    'target_dataset_path':{
        'value':artifact_dir_emodb
    },
     'pair_dataset_path':{
        'value':artifact_dir_pair
    },
    
    'source_x_name':{
         'value':'imocap'
    },
    'target_x_name':{
        'value':'emodb'
    },
    'pair_name':{
        'value':"pairs.csv"
    },
    'class_number':{
        'value': 4
    },
    'class_names':{
        'value':['neutral', 'happy', 'angry', 'sad']
    },
    'valid_labels':{
        'value':['0', '1', '2', '3']
    },
#     'p':{
#         'value': 1# in norm1 and norm2 is needed
#     },
    'distance':{
        'value': 'contrastiveLoss-mmd'
    },
    'margin':{
        'value':2
    },
    'sigma':{
        'value':2
    },
   'kernel': {
        'value':'rbf'
    },
    'freeze':{
        'value':True
    },
    'initial_weights_library':{
        'value':'distilhubert'
    },

})

In [None]:

# save_model = False
sweep_id = input('What is sweep_id? (leave out if this is first sweep) ')
if sweep_id=="":
    sweep_id = wandb.sweep(sweep_config, project=parameters_dict['project_name']['value'])
saveModel = True
saveGradient = False

In [None]:
def get_waveform(file_path, sample_rate=16000):
      wf, sr = torchaudio.load(file_path)
      resample = torchaudio.transforms.Resample(sr, sample_rate)
      waveform = resample(wf)
      return waveform
    
def tile(waveform, expected_time):
  waveform_time = waveform.shape[1]
  expected_time = expected_time
  repeat_times = (expected_time // waveform_time) + 1
  tiled_data = waveform.repeat(1, repeat_times)  
  return tiled_data[:, :expected_time]

In [None]:
def encode_labels(labels):
    emotion_label_dict={
        'neutral':0,
        'happy':1,
        'angry':2,
        'sad':3,
    }
    
        
    return [emotion_label_dict[emo] for emo in labels]

In [None]:
class SiameseClassifierNetworkDataset(Dataset):
    def __init__(self,csv_file_name,path_dataset1,path_dataset2,sessionList,x_name_source,x_name_target,
                 sessionListName,class_label_name_contrastive):
      #read csv file
      df=pd.read_csv(csv_file_name)

      self.path_dataset1=path_dataset1
      self.path_dataset2=path_dataset2
      #first and second columns of dataframe are the file names.
      self.temp_df=df[df[sessionListName].isin( sessionList)].reset_index(drop=True)
      self.speech1_path=self.temp_df[x_name_source].values
      self.speech2_path=self.temp_df[x_name_target].values


      self.label_contrastive=torch.tensor(self.temp_df[class_label_name_contrastive],dtype=torch.float)
    
        
    def __getitem__(self,index):
        speech1=tile(get_waveform(os.path.join(self.path_dataset1,self.speech1_path[index])),
                     16000*7)
        speech2=tile(get_waveform(os.path.join(self.path_dataset2,self.speech2_path[index])),
                     16000*7)

        return speech1,speech2,self.label_contrastive[index]#,self.label[index]
    
    def __len__(self):
       
        return self.temp_df.shape[0]

In [None]:
def initialization_weights(library,freezeList,freeze=False):
    siamese_net=CnnNetwork().to(device)
    if library =='distilhubert':
        !pip install s3prl
        from s3prl.hub import distilhubert
        model = distilhubert().to(device)
        state_dict=model.state_dict()
        stateDict={}
        stateDict['conv0.weight'] = state_dict['model.model.feature_extractor.conv_layers.0.0.weight']
        stateDict['conv1.weight'] = state_dict['model.model.feature_extractor.conv_layers.1.0.weight']
        stateDict['conv2.weight'] = state_dict['model.model.feature_extractor.conv_layers.2.0.weight']
        stateDict['conv3.weight'] = state_dict['model.model.feature_extractor.conv_layers.3.0.weight']
        stateDict['conv4.weight'] = state_dict['model.model.feature_extractor.conv_layers.4.0.weight']
        stateDict['conv5.weight'] = state_dict['model.model.feature_extractor.conv_layers.5.0.weight']
        stateDict['conv6.weight'] = state_dict['model.model.feature_extractor.conv_layers.6.0.weight']

        stateDict['norm0.weight'] = state_dict['model.model.feature_extractor.conv_layers.0.2.weight']
        stateDict['norm0.bias'] = state_dict['model.model.feature_extractor.conv_layers.0.2.bias']
    
        siamese_net.load_state_dict(stateDict,strict=False)
        if freeze:
            freezeTuple = tuple(freezeList)
            for name,layer in siamese_net.named_parameters():
                if name.startswith(freezeTuple):
                    layer.requires_grad = False
        
    return siamese_net

In [None]:
class CnnNetwork(nn.Module):
    def __init__(self):
          super(CnnNetwork, self).__init__()
          self.conv0 = nn.Conv1d(1 ,512 ,(10,) , stride = 5, bias=False)
          #torch.transpose(x, 1, 2)
          self.norm0 = nn.LayerNorm((512,))
          #torch.transpose(x, 1, 2)
          self.act0 = nn.GELU()
          self.conv1 = nn.Conv1d(512, 512 , 3 ,stride = 2, bias=False )
          self.norm1 = nn.LayerNorm((512,))
          self.act1 = nn.GELU()
          self.conv2 = nn.Conv1d(512, 512, 3 ,stride = 2, bias=False)
          self.norm2 = nn.LayerNorm((512,))
          self.act2 = nn.GELU()
          self.conv3 = nn.Conv1d(512,512, 3, stride = 2, bias=False)
          self.norm3 = nn.LayerNorm((512,))
          self.act3 = nn.GELU()
          self.conv4 = nn.Conv1d(512,512, 3, stride = 2, bias=False)
          self.norm4 = nn.LayerNorm((512,))
          self.act4 = nn.GELU()
          self.conv5 = nn.Conv1d(512,512, 2, stride = 2, bias=False)
          self.norm5 = nn.LayerNorm((512,))
          self.act5 = nn.GELU()
          self.conv6 = nn.Conv1d(512,512, 2, stride = 2, bias=False)
          self.norm6 = nn.LayerNorm((512,))
          self.act6 = nn.GELU()
          #self.pool4 = nn.MaxPool2d(1, stride = 2)
          #self.conv5 = nn.Conv2d(512,7, (1,1), stride = 1)
    def forward(self, x):
        
        out = self.conv0(x)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm0(out)
        out = torch.transpose(out, 1, 2)
        out = self.act0(out)
        out = self.conv1(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm1(out)
        out = torch.transpose(out, 1, 2)
        out = self.act1(out)
        #print(out.shape)
        out = self.conv2(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm2(out)
        out = torch.transpose(out, 1, 2)
        out = self.act2(out)
        #print(out.shape)
        out = self.conv3(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm3(out)
        out = torch.transpose(out, 1, 2)
        out = self.act3(out)
        #print(out.shape)
        #out = self.norm4(out)
        #print(out.shape)
        #out = self.silu4(out)
        #print(out.shape)
        #out = self.pool4(out)
        out = self.conv4(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm4(out)
        out = torch.transpose(out, 1, 2)
        out = self.act4(out)
        out = self.conv5(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm5(out)
        out = torch.transpose(out, 1, 2)
        out = self.act5(out)
        out = self.conv6(out)
        #print(out.shape)
        out = torch.transpose(out, 1, 2)
        out = self.norm6(out)
        out = torch.transpose(out, 1, 2)
        out = self.act6(out)
#         print("out shape",out.shape)
        return out


In [None]:
#create the Siamese Neural Network
class SiameseNetwork(nn.Module):

    def __init__(self,model,layers):
      super(SiameseNetwork, self).__init__()
      self.model = model
      self.layersLst = []
      self.layersTuple = tuple(layers)
      for name, layer in self.model.named_modules():
            layer.register_forward_hook(self._save_layer_output(name))

    def _save_layer_output(self,name):
            def hook(module,input,output):
                if isinstance(module,nn.Conv1d) and name.startswith(self.layersTuple):
                    print("in nn.Conv1d ", output.size())
                    self.layersLst.append(output)
                elif isinstance(module,nn.LayerNorm) and name.startswith(self.layersTuple):
                    print("in nn.LayerNorm size ", output.size())
                    self.layersLst.append(rearrange(output, "t b c ->b c t"))
            return hook  
    def forward_once(self, x):
        # This function will be called for both images
        # It's output is used to determine the similiarity
        self.layersLst=[]
        _ = self.model(x)
        return self.layersLst

    def forward(self, input1, input2):
        # In this function we pass in both images and obtain both vectors
        # which are returned

        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)

        return output1, output2

In [None]:
class MMD_loss(nn.Module):
    def __init__(self,config):
        super(MMD_loss,self).__init__()
        self.config=config
        self.kernel=config.kernel
    def gaussian_kernel(self,a, b):
        dim1_1, dim1_2 = a.shape[0], b.shape[0]
        depth = a.shape[1]
        sigma=self.config.sigma
        a = a.view(dim1_1, 1, depth)
        b = b.view(1, dim1_2, depth)

        a_core = a.expand(dim1_1, dim1_2, depth)
        b_core = b.expand(dim1_1, dim1_2, depth)
        numerator = (a_core - b_core).pow(2).mean(2)/(2.0 * sigma ** 2)
        c=torch.exp(-numerator)
#         print("torch.exp(-numerator) size",c.size())
        return c

    def polynomial_kernel(self,X, Y, c, p):
        """
        Compute the polynomial kernel between two matrices X and Y::
            K(x, y) = (<x, y> + c)^p
        """
        return ((X @ Y.transpose(0,1) + c) ** p)

    def forward(self,a, b):
        functionDict={
            'rbf':lambda a,b :self.gaussian_kernel(a,b),
            'polynomial':lambda a,b:self.polynomial_kernel(a,b,1,2)
        }
        
        XX=functionDict[self.kernel](a,a)
        YY=functionDict[self.kernel](b,b)
        XY=functionDict[self.kernel](a,b)
#         print("XX",XX.size())
#         print("YY",YY.size())
#         print("XY",XY.size())
        print("torch.mean(XX + YY - 2*XY,1)",torch.mean(XX + YY - 2*XY,1))
        return torch.mean(XX + YY - 2*XY,1)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self,config):
        super(ContrastiveLoss,self).__init__()
        self.config = config
    def forward(self,output1, output2,label=None):
        result_distance = 0
        if 'norm' in self.config.distance :
            distanceObj = nn.PairwiseDistance(p = self.config.p).to(device)
        elif self.config.distance.endswith('mmd'):
            distanceObj = MMD_loss(self.config).to(device)
        elif self.config.distance =='cosine':
            distanceObj = nn.CosineSimilarity(dim=1, eps=1e-6).to(device)
        for i,(x,y) in enumerate(zip(output1,output2)):
            distance = distanceObj(torch.mean(x,dim=2),torch.mean(y,dim=2))
            print("distance before if ",distance)
            if self.config.distance.startswith('contrastiveLoss') and label!=None:
                #0 same emotion 1 different emotion
                distance = torch.mean((1-label) * torch.pow(distance, 2) +
                                    (label) * torch.pow(torch.clamp(self.config.margin - distance, min=0.0), 2))
                
                result_distance = result_distance+distance

        return result_distance

In [None]:
def make(config):
    model = initialization_weights(config.initial_weights_library,config.freezeList,freeze=config.freeze)
    siamese_net = SiameseNetwork(model,config.distanceList)
#     if config.freeze:
#         for param in siamese_net.parameters():
#             param.requires_grad = False
#     contrastive = ContrastiveLayers(config)
#     net = Classifier(siamese_net,contrastive,config)
    net = siamese_net
    if torch.cuda.device_count()>1:
        net = nn.DataParallel(siamese_net,device_ids=[0, 1])
        print("Let's use", torch.cuda.device_count(), "GPUs!")
    net.to(device)
    criterion = ContrastiveLoss(config)#nn.BCELoss()#nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(net.parameters(),
                                  lr = config.learning_rate,
                                  weight_decay = config.weight_decay)
    return net,criterion,optimizer

In [None]:
def make_train_loader(config):
    classifier_dataset = SiameseClassifierNetworkDataset(csv_file_name=os.path.join(config.pair_dataset_path,config.pair_name),
                                        path_dataset1=config.source_dataset_path,
                                        path_dataset2=config.target_dataset_path,
                                       sessionList=[i for i in config.sessionListDataset if i!=config.fold],
                                       x_name_source=config.source_x_name,
                                       x_name_target=config.target_x_name,
                                       sessionListName='emodb_session' ,
                                      class_label_name_contrastive='pos_neg')
    train_dataloader = DataLoader(classifier_dataset, batch_size=config.batch_size, num_workers=2, shuffle=True)
    return train_dataloader

In [None]:
PATH='model.pkl'
def train(config,net,criterion,optimizer):

    train_dataloader= make_train_loader(config)
    wandb.watch(net, criterion, log="all", log_freq=1,log_graph=True)
    net.train()
    # Iterate throught the epochs
    for epoch in range(config.epochs):
        train_loss = []
        # Iterate over batches
        for i, (speech0, speech1,contrastive_label) in enumerate(train_dataloader):
            
            # Send the speech and labels to CUDA
            speech0,speech1,contrastive_label = speech0.squeeze(1).to(device),speech1.squeeze(1).to(device), contrastive_label.to(device)

            # Zero the gradients (PyTorch accumulates the gradients on subsequent backward passes. 
            #This accumulating behaviour is convenient while training RNNs or when we want to compute the gradient of the loss summed over multiple mini-batches.)
            optimizer.zero_grad()



#             with autocast():
            # Pass in the two speeches into the network and obtain two outputs and classifier output for the  source speech
#             if config.distance.startswith( 'contrastiveLoss'):
# #                 print("speech0",speech0.dtype)
# #                 print("contrastive_label in contrastiveLoss",contrastive_label)
#                 a = net(speech0, speech1,contrastive_label)
#             else:

            x1,x2 = net(speech0, speech1)

            loss = criterion(x1,x2, contrastive_label)
            # Calculate the backpropagation
            loss.backward()
            # Optimize
            train_loss.append(loss.item())
        wandb.log({
            "Train loss": np.mean(train_loss),  
          step=epoch+1)
            

        torch.save(net.state_dict(), PATH)


In [None]:
def save_model(config):
        pickle_artifact = wandb.Artifact(
            name = f'model.pkl',
            type="model",
            metadata=dict(config))
        wandb.save(PATH)

        pickle_artifact.add_file(PATH)


        wandb.log_artifact(pickle_artifact)

In [None]:
def make_model_pipeline(hyperparameters=None):
    with wandb.init(config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config
        net,criterion,optimizer=make(config) 
        # train model
        train(config,net,criterion,optimizer)
        if saveModel:
            save_model(config)

In [None]:
#Build, train and analyze the model with the pipeline

wandb.agent(sweep_id,project=parameters_dict['project_name']['value'],entity='',function=make_model_pipeline)