#### Importing Libraries

In [None]:
import natsort
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
import torchvision.models as models
import numpy as np
from torch.utils.data import Dataset,DataLoader
import os
from PIL import Image
import torch.optim as optim
import logging
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
from tqdm import tqdm
# from tools import wer

#### Define Model

In [None]:
class ResCRNN(nn.Module):
    def __init__(self):
        super(ResCRNN, self).__init__()
        self.batch_size = 8
        self.seq_length = 80
        self.num_classes = 59
        self.lstm_hidden_size = 512
        self.lstm_num_layers = 1
        self.attention = 0

        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1] #removing Linear(in_features=2048, out_features=1000, bias=True) layer
        self.resnet = nn.Sequential(*modules)
        self.lstm = nn.LSTM(
            input_size=resnet.fc.in_features,
            hidden_size=self.lstm_hidden_size,
            num_layers=self.lstm_num_layers,
            batch_first=True,
        )
        self.fc1 = nn.Linear(self.lstm_hidden_size, self.num_classes)
    
    def forward(self, x):
        # CNN
        cnn_embed_seq = []
        # x: (batch_size, channel, t, h, w)
#         print('x',x)
        for t in range(x.size(2)):
#             print('x.size(2): t',x.size(2),t)
            # with torch.no_grad():
            out = self.resnet(x[:, :, t, :, :])
#             print('inside loop',out.shape)
            # print(out.shape)
            out = out.view(out.size(0), -1)
#             print('inside after loop',out.shape)
            cnn_embed_seq.append(out)

        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0)
        cnn_embed_seq = cnn_embed_seq.transpose_(0, 1)
        # LSTM
        # use faster code paths
        self.lstm.flatten_parameters()
        out, (h_n, c_n) = self.lstm(cnn_embed_seq, None)
#         print('afte lstm',out.shape)
        out = self.fc1(out[:, -1, :])
#         print('output shape',out.shape)
        return out     

In [None]:
crnn = ResCRNN()

#### Preprocess Data

In [None]:
frameDir_path = '/home/kirtan/Documents/Sign_Language/data/Frames/'
checkpoint_path = '/home/kirtan/Documents/Sign_Language/'
log_path = '/home/kirtan/Documents/Sign_Language/cnnlstm_{:%Y-%m-%d_%H-%M-%S}'.format(datetime.now())
sum_path = '/home/kirtan/Documents/Sign_Language/sum_cnnlstm_{:%Y-%m-%d_%H-%M-%S}'.format(datetime.now())

In [None]:
import json
def load_json(path):
    with open(path, "r") as f:
        json_file = json.load(f)
    return json_file
label_map = load_json('./label_map_temp_include.json')
print(label_map)


In [None]:
labels_train = []
filenames_train = []
masks_train =[]
labels_test = []
filenames_test = []
masks_test =[]
labels_val = []
filenames_val = []
masks_val =[]
seq_length=80
def get_Lables(mode):
    train_label = []
    train_filename = []
    train_mask=[]
    train_split_file = f'./train_test_split/temp_include_{mode}.txt'
    train_file = open(train_split_file, 'r')
    for line in train_file:
        label = "".join([i for i in line if i.isalpha()]).lower()
        label = label[10:]
        label = label[:-6]
        line = line.split("/")
        line.pop(1)
        line.insert(1,label)
        last_word = line[-1].strip('\n')
        last_word = last_word + "_frames"
        line[-1] = last_word
        line = "/".join(line)
        line = frameDir_path + line #frame folder
        frame_names = os.listdir(line)
        frame_names = natsort.natsorted(frame_names)
        orginal_length = len(frame_names)
        padding_length = seq_length - orginal_length
        for i in range(orginal_length):
            temp_fileName= os.path.join(line, frame_names[i])
            train_filename.append(temp_fileName)
        train_filename = train_filename + [""] * padding_length
    for frame in train_filename:
        if frame != "":
            split_frame = frame.split("/")
            train_label.append(split_frame[-3])
            train_mask.append(True)  # Valid element, set mask to True
        else:
            train_label.append(4)
            train_mask.append(False)  # Padded element, set mask to False
    for i in range(len(train_label)):
        if train_label[i] != 4:
            train_label[i] = label_map[train_label[i]]
    train_file.close()
    return np.array(train_filename),np.array(train_label),np.array(train_mask)

filenames_train,labels_train,masks_train = get_Lables('train')
filenames_val,labels_val,masks_val = get_Lables('val')
filenames_test,labels_test,masks_test = get_Lables('test')

In [None]:
print('filenames_train.shape',filenames_train.shape)
print("labels_train.shape",labels_train.shape)
print("masks_train.shape",masks_train.shape)
print("filenames_test.shape",filenames_test.shape)
print("labels_test.shape",labels_test.shape)
print("masks_test.shape",masks_test.shape)
print("filenames_val.shape",filenames_val.shape)
print("labels_val.shape",labels_val.shape)
print("masks_val.shape",masks_val.shape)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels,mode):
        self.image_paths = image_paths
        self.labels = labels
        self.mode = mode
        
    def __len__(self):
        return (len(self.image_paths)//80)

    def preprocess_function(self,filename,mode):
        if filename != "":
            image = Image.open(filename)
        else:
            image = Image.fromarray(np.zeros((64, 64, 3), dtype=np.uint8))
        
        image = transforms.ToTensor()(image)
        image = image.to(torch.float32)
        image = transforms.Resize((64, 64))(image)
        
        if mode != 'Test':
            image = transforms.RandomHorizontalFlip()(image)
            image = transforms.RandomRotation(90)(image)
            image = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)(image)
            image = transforms.Normalize(mean=[0.5], std=[0.5])(image)
    
        return image
    
    def __getitem__(self, idx):
        
        # total 80 consicutive frames
        lower_value = idx * 80 
        upper_value = idx * 80 + 79

        filenames_range = self.image_paths[lower_value:upper_value+1]  #list of frames, +1 since slice excludes uppervalue 
        label_value = self.labels[lower_value] # constant value of label 


        processed_images = [self.preprocess_function(image,self.mode) for image in filenames_range]
        image_tensor = torch.stack(processed_images,dim=0)

        self.images = image_tensor.permute(1, 0, 2, 3)

        label = torch.nn.functional.one_hot(torch.tensor(label_value), num_classes=59)

        label = label.type(torch.long)

        return {'data': self.images, 'label': label}
    

In [None]:
train_dataset = CustomDataset(filenames_train, labels_train,mode="train")
test_dataset = CustomDataset(filenames_test, labels_test,mode="Test")
val_dataset = CustomDataset(filenames_val, labels_val,mode="val")

In [None]:
# print(train_dataset[72*8]['data'].shape)
# print(len(train_dataset[72*8]['label']))

In [None]:
test_dataloader = DataLoader(test_dataset,batch_size=8, shuffle=False,drop_last=True)
val_dataloader = DataLoader(val_dataset,batch_size=8, shuffle=False,drop_last=True)
train_dataloader = DataLoader(train_dataset,batch_size=8,shuffle=False,drop_last=True)
batch = next(iter(train_dataloader))

In [None]:
print(len(train_dataloader)) # 46078(Total Frames) / 640(batch_size*(seq_length)) = 72(Total Batches)
# print(batch['label'])
# print(batch['data'].shape)
# print(len(batch['data'])) # 8 number of videos = batch size , total number of batchs = len(train_dataloader) , batchsize * total number of batches = total number of samples(videos) = 72*8 = 576
# print(len(batch['label'])) # per batch 8 videos -> 8 labels each -> len(batch['label]) === 8

#### Training

In [None]:
def train(model, train_loader, val_loader, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    count = 0
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    # Move model to the appropriate device
    model = model.to(device)

    for epoch in range(num_epochs):
        # Training
        model.train()
        losses= []
        orgs =[]
        outs = []
        print("Training")
        progress_bar = tqdm(train_loader, ncols=80)
        for batch_idx, data in enumerate(progress_bar):
                inputs, labels = data['data'].to(device), data['label'].to(device)
                optimizer.zero_grad()

                outputs = model(inputs)

                org = torch.argmax(labels, dim=1)
                outputs = torch.argmax(outputs, dim=1)

                orgs.append(org)
                outs.append(outputs)

                org = org.float()
                outputs = outputs.float()

                org.requires_grad = True
                outputs.requires_grad = True

                loss = criterion(outputs, org)
                temp_loss = loss
                loss.backward()
                optimizer.step()

                losses.append(loss.item())
                score = accuracy_score(org.detach().numpy(), outputs.detach().numpy())
                # Calculate precision
                precision = precision_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                # Calculate recall
                recall = recall_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                # Calculate F1-score
                f1 = f1_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                print('Training Loss: '+ str(temp_loss.item()) + 'Training Accuracy: ' + str (100.0 * score),'Training precision: '+ str(100.0 * precision)+'Training recall: '+ str(100.0 * recall)+'Training f1_score: '+str(100 * f1))
        
        training_loss = sum(losses)/len(losses)
        orgs = torch.stack(orgs, dim=0)
        outs = torch.stack(outs, dim=0)
#         print(orgs,outs)
        training_acc = accuracy_score(orgs.flatten(), outs.flatten())
        training_precision =precision_score(orgs.flatten(),outs.flatten(),average='micro')
        training_recall = recall_score(orgs.flatten(),outs.flatten(),average='micro')
        training_f1 = f1_score(orgs.flatten(),outs.flatten(),average='micro')
        print('\nAvg.Training Accuracy: ',training_acc,' Avg.Training Loss: ',training_loss,' Avg.Training Precision: ', training_precision,' Avg.Training Recall: ',training_recall," Avg.Training f1_score: ",training_f1)
        
        print("\nValidation")
        # Validation
        model.eval()
        val_losses=[]
        val_orgs =[]
        val_outs = []
        val_progress_bar = tqdm(val_loader, ncols=80)
        with torch.no_grad():
            for batch_idx, data in enumerate(val_progress_bar):
                inputs, labels = data['data'].to(device), data['label'].to(device)

                outputs = model(inputs)
                
                org = torch.argmax(labels, dim=1)
                outputs = torch.argmax(outputs, dim=1)

                val_orgs.append(org)
                val_outs.append(outputs)

                org = org.float()
                outputs = outputs.float()
                
                loss = criterion(outputs, org)
                
                val_losses.append(loss.item())
                score = accuracy_score(org.detach().numpy(), outputs.detach().numpy())
                # Calculate precision
                precision = precision_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                # Calculate recall
                recall = recall_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                # Calculate F1-score
                f1 = f1_score(org.detach().numpy(), outputs.detach().numpy(), average='micro')

                print('Val Loss: '+ str(temp_loss.item()) + 'Val Accuracy: ' + str (100.0 * score),'Val precision: '+ str(100.0 * precision)+'Val recall: '+ str(100.0 * recall)+'Val f1_score: '+str(100 * f1))
#                 progress_bar.set_postfix({'Loss': loss.item(), 'Accuracy': 100.0 * score})
        
        validation_loss = sum(val_losses)/len(val_losses)
        val_orgs = torch.stack(val_orgs, dim=0)
        val_outs = torch.stack(val_outs, dim=0)
        val_acc = accuracy_score(val_orgs.flatten(), val_outs.flatten())
        val_precision =precision_score(val_orgs.flatten(),val_outs.flatten(),average='micro')
        val_recall = recall_score(val_orgs.flatten(),val_outs.flatten(),average='micro')
        val_f1 = f1_score(val_orgs.flatten(),val_outs.flatten(),average='micro')
        print('\nAvg.Val Accuracy: ',val_acc,' Avg.Val Loss: ',validation_loss,' Avg. Val: ', val_precision,' Avg. Val: ',val_recall," Avg. Val: ",val_f1)
        

#         # Print training and validation statistics for the current epoch
#         print(f"\nEpoch {epoch+1}/{num_epochs}:")
#         print(f"\nTrain Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.2f}%")
#         print(f"\nVal Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")
        torch.save(model.state_dict(), os.path.join(checkpoint_path, "slr_convlstm_epoch{:03d}.pth".format(epoch+1)))
#         logger.info("Epoch {} Model Saved".format(epoch+1).center(60, '#'))


model = ResCRNN()

train(model, train_dataloader, val_dataloader, num_epochs=10)
