In [None]:
from Dataload import dataload

from module import transformer,video_model,resnet_o,densenet

import torch
import torch.nn as nn
from torchsummary import summary
from sklearn.metrics import confusion_matrix
from torch.utils.data import Dataset, DataLoader
import matplotlib.pylab as plt
from torchvision import  utils
import os
from constant import EMOTIPATH,EMOTIFACEPATH
import time
from tqdm.notebook import tqdm
Train_label=os.path.join(EMOTIPATH,"Train_labels.txt")
Train_video=os.path.join(EMOTIPATH,"Train")
Train_video_pt=os.path.join(EMOTIPATH,"pt","Train")
Train_face_pt=os.path.join(EMOTIFACEPATH,"pt_stacked","Train")
Val_labels=os.path.join(EMOTIPATH,"Val_labels.txt")
Val_video=os.path.join(EMOTIPATH,"Val")
Val_video_pt=os.path.join(EMOTIPATH,"pt","Valid")
Val_face_pt=os.path.join(EMOTIFACEPATH,"pt_stacked","Valid")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type(torch.cuda.FloatTensor)



frame_num=25
train_data_pt=dataload.Video_Frame_Data(Train_label,base_path_v=Train_video_pt,face_path=Train_face_pt,frame_num=frame_num,direct=True)
valid_data_pt=dataload.Video_Frame_Data(Val_labels,base_path_v=Val_video_pt,face_path=Val_face_pt,frame_num=frame_num,direct=True)
def num_correct(prediction,labels):
    correct=0
    for i,(pred_label,label) in enumerate(zip(prediction,labels)):
        if (pred_label.item()==label.item()):
            correct +=1
    return correct

def load_pretrained_model():
    frame_model=densenet.densenet121(pretrained=True)

    import torchvision.models as models
    face_model=models.resnet50()
    filepath_mdl="module/fn_affectnet50k_full.mdl"
    snapshot = torch.load(filepath_mdl)
    try:
        face_model.load_state_dict(snapshot, strict=False)
    except RuntimeError as e:
        print(e)
        
    return(frame_model,face_model)

frame_model,face_model=load_pretrained_model()



In [2]:
train_data_pt[0][0].shape

In [3]:
train_data_pt[0][1].shape

torch.Size([25, 5, 3, 64, 64])

In [4]:
import torchvision.models as models

face_pretrained=models.resnet50(pretrained=False)
data=torch.rand(1,3,64,64)
#print(face_pretrained(data))
face_pretrained.load_state_dict(torch.load("pretrained_state_dict"))

<All keys matched successfully>

In [5]:
#face_pretrained(data)

In [6]:
final_model=video_model.Video_modeller(25,face_pretrained,frame_model)
final_model=final_model.to(device)

In [7]:
for name, child in final_model.named_children():
    if not name in ['face_model','frame_model' ]:
        print(name + ' is unfrozen')
        for param in child.parameters():
            param.requires_grad = True
    else:
        print(name + ' is frozen')
        for param in child.parameters():
               param.requires_grad = False
                
def train(num_epochs,name,model,train_dataloader,valid_dataloader,optimizer,criterion):
    f = open(name+".txt",'a')
    start = time.time()
    #Triaining
    train_loss=[]
    valid_accuracy=[]
    model.train()
    for epochs in range(0,num_epochs):
        first=True
        model.train()
        correct=0
        total_samples=0
        avg_tloss=0
        print("Training Epoch: ", epochs+1,"\n")
        for i_batch, (frame_batch,face_batch,label) in tqdm(enumerate(train_dataloader)):
         
            batch_size=face_batch.size(0)
            optimizer.zero_grad()
            face_batch=face_batch.to(device)
            frame_batch=frame_batch.to(device)
            output=model(frame_batch,face_batch)
            loss=criterion(output,label.to(device))
            loss.backward()
            predicted = torch.max(output, 1)
            prediction=predicted.indices.detach().cpu()
            correct +=num_correct(prediction,label)
            total_samples+=batch_size
            accuracy=correct/(total_samples)
            optimizer.step()
            true_label=label.detach().cpu()
            avg_tloss+=loss.item()
            if first:
                first=False
                conf_mat=confusion_matrix( true_label,prediction,labels=[0,1,2])
            else:
                conf_mat+=confusion_matrix(true_label,prediction,labels=[0,1,2])
            if (i_batch+1)%40==0:
               # print(label)
                print("Batch: ",i_batch+1,"/",len(train_dataloader))
                print("Batch Recognition loss: ", loss.item())

        print(conf_mat)
        avg_tloss=avg_tloss/len(train_dataloader)
        avg_taccuracy=correct/total_samples
        print("Average_Loss: ",avg_tloss)
        print("Average_Accuracy: ",avg_taccuracy)

        torch.save(model.state_dict(),name+".pth")
        print("Validation\n")
        model.eval()   
        correct=0
        total_samples=0
        avg_vloss=0
        first=True

        for i_batch, (frame_batch,face_batch,label) in enumerate(valid_dataloader):
            batch_size=face_batch.size(0)
            face_batch=face_batch.to(device)
            face_batch=face_batch.to(device)
            frame_batch=frame_batch.to(device)
            output=model(frame_batch,face_batch)
            loss=criterion(output,label.to(device))
            avg_vloss+=loss.item()
            predicted = torch.max(output, 1)
            prediction=predicted.indices.detach().cpu()
            correct +=num_correct(prediction,label)
            total_samples+=batch_size

            true_label=label.detach().cpu()
            if first:
                first=False
                conf_mat=confusion_matrix( true_label,prediction,labels=[0,1,2])
            else:
                conf_mat+=confusion_matrix(true_label,prediction,labels=[0,1,2])
        print(conf_mat)
        avg_vloss=avg_vloss/len(valid_dataloader)
        print(avg_vloss)
        avg_vaccuracy=correct/(total_samples)
        print("Accuracy: ", avg_vaccuracy)

        data = " %f,%f,%f,%f \n" % (avg_tloss,avg_taccuracy,avg_vloss,avg_vaccuracy)
        f.write(data)



    f.close()


en1 is unfrozen
en2 is unfrozen
embedder is unfrozen
frame_model is frozen
face_model is frozen
fc1 is unfrozen
fc2 is unfrozen
fc3 is unfrozen


In [None]:
train_dataloader = DataLoader(train_data_pt, batch_size=32
                       , num_workers=0,shuffle=True)
valid_dataloader = DataLoader(valid_data_pt, batch_size=32
                   , num_workers=0)


optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, final_model.parameters()), lr=0.0005, betas=(0.5, 0.999))
num_epochs=30
criterion = nn.CrossEntropyLoss()  

model_name="pre_embedded_"+str(25)+"_faces"


print(model_name)

#train(num_epochs,model_name,final_model,train_dataloader,valid_dataloader,optimizer,criterion)


In [None]:
train_f(1)
train_f(3)
train_f(5)
train_f(7)
train_f(9)
train_f(11)
train_f(13)
train_f(15)
train_f(17)


In [None]:
valid_data1=dataload.Video_Frame_Data(Val_labels,base_path_v=Val_video,strict_num=26)

valid_dataloader1 = DataLoader(valid_data1, batch_size=32
                       , num_workers=0)

In [None]:
debug

In [None]:
train(1)