<a href="https://colab.research.google.com/github/gnvikas/NoisyFER/blob/main/IEEE_SPL_Rafdb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Owner: Vikas G N, gnvikas@gmail.com

This notebook contains dataloader for Rafdb class and implementation of the work titled Instance Discrimination based Robust Training for Facial Expression Recognition under Noisy labels.


In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount= True)

Mounted at /content/drive


In [None]:
import sys
from torchvision.transforms import transforms
import torch
from PIL import Image
import pandas as pd
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.optim
import os
import torch.utils.data as data
import cv2
import random


#logfile = open('/content/drive/MyDrive/Colab_Notebooks/mtech/Project/Logs/IDN/RAFDB/log-newloss.txt','w')

#--------------------------------------------------------------------------------------------------------------------------
'''
Aum Sri Sai Ram

Resnet models

                          NOTE: only layers required are retained and fine-tuned.

'''
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=7, end2end=True):
        self.inplanes = 64
        self.end2end = end2end
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
       
        bs = x.size(0)
        f = x

        f = self.conv1(f)
        f = self.bn1(f)
        f = self.relu(f)
        f = self.maxpool(f)
        
        f = self.layer1(f)
        #print('layer1: ',f.size())
        f = self.layer2(f)
        #print('layer2: ',f.size())
        f = self.layer3(f)
        feature = f.view(bs, -1)
        #print('layer4: ',f.size())
        f = self.layer4(f)
        #print('layer4: ',f.size())
        f = self.avgpool(f)
        
        f = f.squeeze(3).squeeze(2)
        #return f
        return  F.normalize(f) #f

def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False,  **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    #if pretrained:
     #   model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
    
#-----------------------------------------------------------------------------------------------------------------

def load_base_model(model): #load pretrained MSCeleb-1M     
   checkpoint = torch.load('/content/drive/MyDrive/Colab_Notebooks/mtech/Project/ijba_res18_naive.pth.tar')
   pretrained_state_dict = checkpoint['state_dict']
   model_state_dict = model.state_dict()
   for key in pretrained_state_dict:
       if  ((key == 'module.fc.weight') | (key=='module.fc.bias') | (key=='module.feature.weight') | (key=='module.feature.bias') ) :    
           pass
       else:           
           model_state_dict[key] = pretrained_state_dict[key]

   model.load_state_dict(model_state_dict, strict = False)
   return model
   
class Classifier(nn.Module):
      def __init__(self, input_dim = 512, num_classes = 7):
          super(Classifier, self).__init__()
          self.fc = nn.Linear(input_dim, num_classes)
          
      def forward(self, x):
          out = self.fc(x)
          probs = F.softmax(out, dim=1)
          return out, probs


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)   
    

def instantiate_model(args):
    
    base_model = resnet18(pretrained=False) 
    base_model = nn.DataParallel(base_model).to(args.device)
    base_model = load_base_model(base_model)
    
    src_cl1 =  Classifier(num_classes = args.num_src_classes).to(args.device)
    src_cl2 =  Classifier(num_classes = args.num_src_classes).to(args.device)
    ins_cl =  Classifier(num_classes = args.num_ins_classes).to(args.device)
    
    criterion = nn.CrossEntropyLoss(reduction = 'none').to(args.device)
    
    criterion_kl = nn.KLDivLoss().to(args.device)
    
    optimizer = torch.optim.Adam([{'params':base_model.parameters(), 'lr': args.base_model_lr, 'weigh_decay' : args.base_model_wd},
                                 {'params':src_cl1.parameters(), 'lr': args.src_lr, 'weigh_decay' : args.other_wd},
                                 {'params':src_cl2.parameters(), 'lr': args.src_lr, 'weigh_decay' : args.other_wd},
                                 {'params':ins_cl.parameters(), 'lr': args.ins_lr, 'weigh_decay' : args.other_wd}  
                                ]#, momentum = args.momentum, nesterov = True
                               )
                                
    
    return base_model, src_cl1, src_cl2, ins_cl, criterion, criterion_kl, optimizer
    
def adjust_learning_rate(optimizer): 
  for param_group in optimizer.param_groups: 
      param_group["lr"] /= 10.
    
    
def train(args, train_dataset, test_dataset, logfile):    
    model, src_cl1, src_cl2, ins_cl, criterion, criterion_kl, optimizer = instantiate_model(args)
    src_cl1.train()
    src_cl2.train()
    ins_cl.train()
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = args.batch_size, drop_last = True, 
    							num_workers = args.num_workers, shuffle = True)
    
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = args.batch_size, num_workers = args.num_workers, shuffle = False)
    
    best_acc = 0.0
    count = 0.0
    
    for epoch in range(0, args.epochs):
      
      train_acc = count / len(train_dataset)
      #print(f'epoch no: {epoch}, train_acc:{train_acc}',file=logfile)
      #print(f'epoch no: {epoch}, train_acc:{train_acc}')
      
      count = 0.0
      if len(train_dataset.clean_data) == len(train_dataset):
          print("setting back to phase1",file=logfile)
          train_dataset.set_phase(1)
      
      #if epoch == 25 or epoch == 40 or epoch == 50: #fplus
      if epoch == 20 or epoch == 28 or epoch==36:   #rafdb
        adjust_learning_rate(optimizer)
          
      for i, (data1, data2, label1, label2, idx1, idx2, is_labeled1, is_labeled2) in enumerate(train_loader): #training
          correct_cls_1, correct_cls_2, correct_ins  = 0., 0., 0.
                    
          data1 = data1.to(args.device)
          label1 =  label1.to(args.device)
          feat1 = model(data1)
          idx1 = idx1.to(args.device)
          out_src_cl1_1, probs_src_cl1_1 = src_cl1(feat1)
          out_src_cl2_1, probs_src_cl2_1 = src_cl2(feat1)
          out_ins_cl_1, probs_ins_cl_1 = ins_cl(feat1)
          
          
          
          if train_dataset.phase == 1 or epoch < args.warmup_epochs:
             probs1, preds1 = torch.max(probs_src_cl1_1, dim = 1)
             probs2, preds2 = torch.max(probs_src_cl2_1, dim = 1)
             indices1 = ((preds1 == label1) & (preds2 == label1) & (probs1 > args.probs_threshold_warmup) & 
             							    (probs2 > args.probs_threshold_warmup)
                     	)
             

             loss1_per_sample = criterion(out_src_cl1_1, label1)
             src_loss1 =  torch.mean(loss1_per_sample) # torch.mean(loss1_per_sample[indices1]) #  
             
             loss2_per_sample = criterion(out_src_cl2_1, label1)
             src_loss2 = torch.mean(loss2_per_sample) # torch.mean(loss2_per_sample[indices1]) # 
             
             ins_loss = 0                  
             kl_loss = 0
             
             count += (preds1 == label1).cpu().sum().item()
             
             loss = src_loss1 + src_loss2 
             
             if epoch == args.warmup_epochs - 1:                
                correct_indices = indices1  
                train_dataset.set_clean_data(idx1[correct_indices].detach().cpu().tolist(), 
                                             label1[correct_indices].detach().cpu().tolist())
                #print(f'Length of clean dataset :{len(train_dataset.clean_data)}',file=logfile)
                #print(f'Length of clean dataset :{len(train_dataset.clean_data)}')
          else:
             data2 = data2.to(args.device)
             label2 =  label2.to(args.device)
             feat2 = model(data2)
             idx2 = idx2.to(args.device)
             
             # out_src_cl1 is for cl1 - out_src_cl2 is for cl2. out_src_cl{i}_1 is for clean data, out_src_cl{i}_2 is for messy data
             # out_ins_cl_1 - is clean data into ins / out_ins_cl_2 is messy data into ins
             out_ins_cl_2, probs_ins_cl_2 = ins_cl(feat2)
             out_src_cl1_2, probs_src_cl1_2  = src_cl1(feat2) #messy out1
             out_src_cl2_2, probs_src_cl2_2 = src_cl2(feat2)  #messy out2

             src_loss1_per_sample = criterion(out_src_cl1_1, label1) 
             src_loss1 = torch.mean(src_loss1_per_sample)
             
             src_loss2_per_sample = criterion(out_src_cl2_1, label1) 
             src_loss2 = torch.mean(src_loss2_per_sample)
             
             ins_loss_per_sample_1 = criterion(out_ins_cl_1, idx1)  # from clean
             ins_loss_1 = torch.mean(ins_loss_per_sample_1)  
             ins_loss_per_sample_2 = criterion(out_ins_cl_2, idx2)  # from messy
             ins_loss_2 = torch.mean(ins_loss_per_sample_2)
             ins_loss = ins_loss_1 + ins_loss_2
             
             kl_loss1 = criterion_kl(torch.log(probs_src_cl1_2), probs_src_cl2_2)  # from messy out from 1 || 2
             kl_loss2 = criterion_kl(torch.log(probs_src_cl2_2), probs_src_cl1_2)  # from messy out from 2 || 1           
             kl_loss = kl_loss1 + kl_loss2
             src_loss = src_loss1 + src_loss2
             
             if epoch>15:
               a,b,c = .2,.5,.3
             else:
               a,b,c = .3,.3,.4
             loss = a*src_loss +  b*kl_loss + c*ins_loss 
          
          optimizer.zero_grad()   
          loss.backward()
          optimizer.step()
          #src_loss2 = 0
          print(f"Epoch/batch {epoch}/{i}\tsrc_loss1:{src_loss1:.3f}\tsrc_loss2:{src_loss2:.3f}\tins_loss:{ins_loss:.3f}\tkl_loss:{kl_loss:.3f}",file=logfile) 
          
      if epoch == args.warmup_epochs - 1:
          train_dataset.set_phase(phase=2)   
          
      #Perform testing
      if epoch % args.test_freq == 0:
        model.eval()
        src_cl1.eval()
        src_cl2.eval()
                
        correct_cls_1 = 0.
        correct_cls_2 = 0.
        for i, (data, label) in enumerate(test_loader): 
              data = data.to(args.device)
              label =  label.to(args.device)
                                   
              with torch.no_grad():
                 feat = model(data)            
                 _, probs_src_cl1 = src_cl1(feat)
                 _, probs_src_cl2 = src_cl2(feat)
                                               
                 probs1, preds_cls_1 = torch.max(probs_src_cl1, dim = 1)
                 probs2, preds_cls_2 = torch.max(probs_src_cl2, dim = 1)  
                 correct_cls_1 += (preds_cls_1 == label).cpu().sum().item()
                 correct_cls_2 += (preds_cls_2 == label).cpu().sum().item()          
            
        acc_cls_1 = correct_cls_1/len(test_dataset)
        
        acc_cls_2 = correct_cls_2/len(test_dataset)
        
        print(f"Test: Epoch {epoch}\tsrc_cls1_acc:{acc_cls_1:.4f}\tsrc_cls2_acc:{acc_cls_2:.4f}")
        print(f"Test: Epoch {epoch}\tsrc_cls1_acc:{acc_cls_1:.4f}\tsrc_cls2_acc:{acc_cls_2:.4f}",file=logfile)
        
        if best_acc < acc_cls_1 or best_acc < acc_cls_2:
           best_acc = max(acc_cls_1, acc_cls_2)
           print(f'best_acc: {best_acc}')
           """torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'src_cl1_state_dict' : src_cl1.state_dict(),
            'src_cl2_state_dict' : src_cl2.state_dict(),            
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, '/content/drive/MyDrive/Colab_Notebooks/mtech/Project/Logs/RAFDB/log_IDN_kl_run3_0.pth')"""
    
    
    print(f"\n\n \t Best Test: Best_acc:{best_acc:.4f}. Sairam",file=logfile)    
    print(f"\n\n \t Best Test: Best_acc:{best_acc:.4f}. Sairam")    
    
    return model, src_cl1, src_cl2, ins_cl, train_dataset 

#--------------------------------------------------------------------------------------------------------------------------


In [None]:

### Entire rafdb_datset file pasted below

class RafDataSet(data.Dataset):
    def __init__(self, raf_path, noise_file, phase, noise = True, partition = 'train', transform = None, num_classes = 7):
        self.phase = phase
        
        self.transform = transform
        self.raf_path = raf_path
        self.clean_data = dict()
        self.phase = 1 #pretraining 
        self.num_classes = num_classes
        self.partition = partition
        
        NAME_COLUMN = 0
        LABEL_COLUMN = 1
        df_train_clean = pd.read_csv(os.path.join(self.raf_path, 'RAFDB/train_label.txt'), sep=' ', header=None)
        df_train_noisy = pd.read_csv(os.path.join(self.raf_path, noise_file), sep=' ', header=None)
        
        df_test = pd.read_csv(os.path.join(self.raf_path, 'RAFDB/test_label.txt'), sep=' ', header=None)
        if partition == 'train':
            dataset_train_noisy = df_train_noisy[df_train_noisy[NAME_COLUMN].str.startswith('train')]
            dataset_train_clean = df_train_clean[df_train_clean[NAME_COLUMN].str.startswith('train')]
            self.clean_label = dataset_train_clean.iloc[:, LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
            self.noisy_label = dataset_train_noisy.iloc[:, LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
            if noise:
              self.label = self.noisy_label  # if noise file used
            else:
              self.label = self.clean_label
            file_names = dataset_train_noisy.iloc[:, NAME_COLUMN].values
            #self.pseudo_probs1 = [0]*self.label.shape[0]
            #self.pseudo_probs2 = [0]*self.label.shape[0]
            self.noise_or_not = (self.noisy_label == self.clean_label) #By DG
        else:             
            dataset = df_test[df_test[NAME_COLUMN].str.startswith('test')]
            self.label = dataset.iloc[:, LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral            
            file_names = dataset.iloc[:, NAME_COLUMN].values
        
        new_label = [] 
        
        for label in self.label:
            new_label.append(self.change_emotion_label_same_as_affectnet(label))
            
        self.label = new_label
        self.pseudo_labels = []  
        
        self.file_paths = []
        # use raf aligned images for training/testing
        for f in file_names:
            f = f.split(".")[0]
            f = f +"_aligned.jpg"
            working_directory = self.raf_path + 'RAFDB/aligned'
            path = os.path.join(working_directory, f)
            self.file_paths.append(path)
        
        self.pseudo_probs1 = torch.zeros((len(self.label), self.num_classes))
        self.pseudo_probs2 = torch.zeros((len(self.label), self.num_classes))
        
        
    def set_clean_data(self, indices, pseudo_labels):  # To be called after warmup period
        self.clean_data.update(zip(indices, pseudo_labels))
        
    def set_probs(self, indices, probs1, probs2):
        indices = indices.tolist()
        for i in range(len(indices)):
          self.pseudo_probs1[indices[i]] = probs1[i]
          self.pseudo_probs2[indices[i]] = probs2[i]
        
    def set_phase(self, phase):
        self.phase = phase
        
        
    def change_emotion_label_same_as_affectnet(self, emo_to_return):
        """
        Parse labels to make them compatible with AffectNet.  
        #https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/blob/master/model/utils/udata.py
        """

        if emo_to_return == 0:
            emo_to_return = 3
        elif emo_to_return == 1:
            emo_to_return = 4
        elif emo_to_return == 2:
            emo_to_return = 5
        elif emo_to_return == 3:
            emo_to_return = 1
        elif emo_to_return == 4:
            emo_to_return = 2
        elif emo_to_return == 5:
            emo_to_return = 6
        elif emo_to_return == 6:
            emo_to_return = 0

        return emo_to_return   
         
    def __len__(self):                   
           return len(self.file_paths)
        
    def __getitem__(self, idx):
        if self.partition == 'train': 
          if self.phase == 1: #warm-up
             label = self.label[idx]
             path = self.file_paths[idx]
             labeled = True   
             image = cv2.imread(path)
             image = image[:, :, ::-1] # BGR to RGB
        
             if self.transform is not None:
                image =  self.transform(image)
            
             label = torch.tensor(label, dtype = torch.int64) 
             idx = torch.tensor(idx, dtype = torch.int64)  
             return image, image, label, label, idx, idx, labeled, labeled   
                
          else:       #pseudo-labeling   
             if idx in self.clean_data:
               idx1 = idx
               label1 = self.clean_data[idx1]                     
               path1 = self.file_paths[idx1]
               labeled1 = True
             else:
               idx1 = random.choice(list(self.clean_data.keys()))
               label1 = self.clean_data[idx1]                     
               path1 = self.file_paths[idx1]
               labeled1 = True
               
             assigned_indices = set(self.clean_data.keys())
             unassigned_indices = list(set(range(len(self))) - assigned_indices)
             idx2 = random.choice(unassigned_indices)
             label2 = self.label[idx2]
             path2 = self.file_paths[idx2]
             labeled2 = False     
             
             image1 = cv2.imread(path1)
             image2 = cv2.imread(path2)
             image1 = image1[:, :, ::-1] # BGR to RGB
             image2 = image2[:, :, ::-1] # BGR to RGB
        
             if self.transform is not None:
                image1 =  self.transform(image1)
                image2 =  self.transform(image2)
            
             label1 = torch.tensor(label1, dtype = torch.int64) 
             idx1 = torch.tensor(idx1, dtype = torch.int64)  
             label2 = torch.tensor(label2, dtype = torch.int64) 
             idx2 = torch.tensor(idx2, dtype = torch.int64)
             
             return image1, image2, label1, label2, idx1, idx2, labeled1, labeled2  
             
        else:     
             label = self.label[idx]
             path = self.file_paths[idx]
             
             image = cv2.imread(path)
             
             if self.transform is not None:
                image =  self.transform(image)
                
             label = torch.tensor(label, dtype = torch.int64) 
                       
             return image, label     


#--------------------------------------------------------------------------------------------------------------------------

parser = argparse.ArgumentParser()

parser.add_argument('--base_model_lr', type=float, default=0.001)
parser.add_argument('--src_lr', type=float, default=0.01)
parser.add_argument('--ins_lr', type=float, default=0.01)

parser.add_argument('--raf_path', type=str, default='/content/drive/MyDrive/Colab_Notebooks/mtech/Project/', help='Raf-DB dataset path.')   # Set path
    
parser.add_argument('--pretrained', type=str, default='/content/drive/MyDrive/Colab_Notebooks/mtech/Project/ijba_res18_naive.pth.tar',
                        help='Pretrained weights')                  # Set path of pretrained model

parser.add_argument('--resume', type=str, default='', help='Use FEC trained models')                     
                        
parser.add_argument('--noise_file', type=str, help='train_label.txt, 0.3noise_train.txt', default='/content/drive/MyDrive/Colab_Notebooks/mtech/Project/noise files/0.3noise_train.txt')  # How? and Set path

parser.add_argument('--noise', type=bool, default=True)

parser.add_argument('--epochs', type=int, default=47)

parser.add_argument('--num_src_classes', type=int, default=7)

parser.add_argument('--print_freq', type=int, default=30)

parser.add_argument('--test_freq', type=int, default=1)

parser.add_argument('--num_workers', type=int, default=4, help='how many subprocesses to use for data loading')

parser.add_argument('--batch_size', type=int, default=128, help='batch_size')

parser.add_argument('--warmup_epochs', type=int, default=8, help='Warmup epochs.')

parser.add_argument('--base_model_wd', type=float, default=1e-6)

parser.add_argument('--other_wd', type=float, default=1e-4)

parser.add_argument('--momentum', type=float, default=0.9)

parser.add_argument('--probs_threshold_warmup', type=float, default=0.02)

parser.add_argument('--probs_threshold', type=float, default=0.94)

args = parser.parse_args(" ".split())


if torch.cuda.is_available():
   args.device = 'cuda'
else:
   args.device = 'cpu' 

def main(args):

  file_str = f'/content/drive/MyDrive/Colab_Notebooks/mtech/Project/Logs/IDN/RAFDB/filtering/0noise_allfiltered_{1}'
  logfile = open(file_str,'w')
  print(f"Warmup epochs: {args.warmup_epochs} - Noise: {args.noise},Noise Rate: {args.noise_file}")
  print(f"Warmup epochs: {args.warmup_epochs} - Noise: {args.noise},Noise Rate: {args.noise_file}",logfile)
  train_transform = transforms.Compose([
          transforms.ToPILImage(),
          transforms.RandomHorizontalFlip(p=0.5), transforms.RandomApply([transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25),
                  transforms.RandomAffine(degrees=0, translate=(.1, .1), scale=(1.0, 1.25),resample=Image.BILINEAR)],p=0.5), 
                
          transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

  test_transform = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])
                                   
                                   
  train_dataset = RafDataSet(raf_path=args.raf_path, noise_file = args.noise_file, phase = 1, noise = args.noise, partition = 'train', transform = train_transform, num_classes = args.num_src_classes)
  test_dataset = RafDataSet(raf_path=args.raf_path, noise_file = args.noise_file, phase = 1, noise = args.noise, partition = 'test', transform = test_transform, num_classes =  args.num_src_classes)
  args.num_ins_classes = len(train_dataset) 
  model, src_cl1, src_cl2, ins_cl, train_dataset = train(args, train_dataset, test_dataset, logfile)
                                                       
    
if __name__=='__main__':
   main(args)                                           


