In [None]:
from __future__ import print_function
from __future__ import division

import time
import copy
import os
import urllib
import gc
import datetime
import warnings
import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.optim as optim
import torch.onnx
import torchvision
from torchvision import models, transforms
import torch.nn.functional as F

import cv2


from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
class Data(Dataset):
    def __init__(self, df: pd.DataFrame, path: str, train: bool = True, preprocesar = None, context = None):
        self.preprocesar = preprocesar
        self.df = df
        self.path = path
        self.train = train
        self.context = context
        
    def __getitem__(self, index):
        im_path = os.path.join(self.path, self.df.iloc[index]['image_name'] + '.jpg')
        x = cv2.imread(im_path)
        meta = np.array(self.df.iloc[index][self.context].values, dtype=np.float32)

        if self.preprocesar:
            x = self.preprocesar(x)
            
        if self.train:
            y = self.df.iloc[index]['target']
            return (x, meta), y
        else:
            return (x, meta)
    
    def __len__(self):
        return len(self.df)

In [None]:
body, face, handl, handr, label = data #inception input dimensions of 299x299 
body_kp, face_kp, handl_kp, handr_kp = keypoints #csv #each cell is a vector (4 features) of keypoints all/face/handl/handr normalized


## I've not included yet the pose/kp generation in this network, but I can.

In [None]:
class Neural_Network(nn.Module):
    def __init__(self, pretrained, N_vocabulary: int):
        super(Red_Neuronal, self).__init__()

        self.body = pretrained
        self.face = pretrained
        self.handl = pretrained
        self.handr = pretrained


        #if 'ResNet' in str(pretrained.__class__):
        #self.pretrained.fc = nn.Linear(in_features=512, out_features=1, bias=True)

        if 'Inception3' in str(pretrained.__class__):
            self.body.AuxLogits.fc = nn.Linear(in_features = 768, out_features = 1)
            self.body.fc = nn.Linear(in_features = 2048, out_features = 512, bias = True)

            self.face.AuxLogits.fc = nn.Linear(in_features = 768, out_features = 1)
            self.face.fc = nn.Linear(in_features = 2048, out_features = 512, bias = True)
       
            self.handl.AuxLogits.fc = nn.Linear(in_features = 768, out_features = 1)
            self.handl.fc = nn.Linear(in_features = 2048, out_features = 512, bias = True)
       
            self.handr.AuxLogits.fc = nn.Linear(in_features = 768, out_features = 1)
            self.handr.fc = nn.Linear(in_features = 2048, out_features = 512, bias = True)
        
        self.body_plus_kp = nn.Sequential(nn.Linear(512 + n_body_kp, 512),
                                  nn.BatchNorm1d(512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  nn.Linear(512, 256),
                                  nn.BatchNorm1d(256),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2))
            
        self.face_plus_kp = nn.Sequential(nn.Linear(512 + n_face_kp, 512),
                                  nn.BatchNorm1d(512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  nn.Linear(512, 256),
                                  nn.BatchNorm1d(256),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2))
        
        self.handl_plus_kp = nn.Sequential(nn.Linear(512 + n_handl_kp, 512),
                                  nn.BatchNorm1d(512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  nn.Linear(512, 256),
                                  nn.BatchNorm1d(256),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2))
        
        self.handr_plus_kp = nn.Sequential(nn.Linear(512 + n_handr_kp, 512),
                                  nn.BatchNorm1d(512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  nn.Linear(512, 256),
                                  nn.BatchNorm1d(256),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2))
        
        self.encoder = nn.Sequential(nn.Linear(1024, 512),
                                  nn.BatchNorm1d(512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.25),
                                  nn.Linear(512, 128),
                                  nn.BatchNorm1d(128),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2))

        self.output = nn.Sequential(nn.Transformer(n,128),
                                  nn.Linear(128, N_vocabulary))
    def forward(self, inputs):
        """
        The inputs are a 20 frames video plus kp for each frame. 
        This is good, because when we will test this in longer or real time video,
        then we just need to feed it with chunks of 20 frames to make predictions.
        When frames in video are <20, then padding (when video is starting, for example) # NOT IMPLEMENTED YET
        """
              
        n = 20 # num of frames by video / to be defined

        #tensor to save the output of encoder for each frame
        distilled_video = torch.zeros(n, 128) # (n_frames, output_encoder)

        #videos of same length:
        body, face, handl, handr = inputs[0]

        # Following is the list of kp for each frame for each video [[kp_frame_1],[kp_frame_2],...,[kp_frame_n]]
        # where kp_frame_i is the list of kp coordinates normalized
        body_kp, face_kp, handl_kp, handr_kp = inputs[1]
        i = 0
        while i < n-1:
            #CNN forward
            body_cnn = self.body(body)
            face_cnn = self.face(face)
            handl_cnn = self.handl(handl)
            handr_cnn = self.handr(handr)

            #forward for CNN_output + kp
            body_pose = self.body_plus_kp(torch.cat((body_cnn, body_kp), dim=1))
            face_pose = self.face_plus_kp(torch.cat((face_cnn, face_kp), dim=1))
            handl_pose = self.handl_plus_kp(torch.cat((handl_cnn, handl_kp), dim=1))
            handr_pose = self.handr_plus_kp(torch.cat((handr_cnn, handr_kp), dim=1))

            #forward for encoding the concatenation of latest forwards
            distilled_video[i] = self.encoder(torch.cat(body_pose, face_pose, handl_pose, handr_pose))
            i += 1
        
        output = F.softmax(self.output(distilled_video))
        return output