# Training and Testing of LSTM Model

In [None]:
import os
import glob
import wandb
import pickle
import warnings
import torch as T
from tqdm import tqdm
from torch.nn import MSELoss
from itertools import compress
from utilities.util import colate_fn as colate
from model import encoder, generator, smoothing
from utilities.util import TrTstSplit, GetInputOutputSplit

### Initialize the device and data to be used 

In [None]:
warnings.filterwarnings('ignore')
device = T.device("cuda:0" if T.cuda.is_available() else "cpu")

pkls = ['*_graph.pkl'] #insert the name of the processed pkl files
losses = 'mse+cosine'

### Initialize the loss function to be used 

In [None]:
def pccloss(y, y_hat, lens): #Compute Pearson correlation coefficient
    pearson = []
    for n in range(lens.shape[0]):
        length = lens[n]
        # Compute mean of true and predicted values
        mean_true = T.mean(y_hat[n,:length,:],dim=0)
        mean_pred = T.mean(y[n,:length,:],dim=0)

        # Compute Pearson correlation coefficient
        numerator = T.sum((y_hat[n,:length,:] - mean_true) * (y[n,:length,:] - mean_pred),dim=0)
        denominator = T.sqrt(T.sum((y_hat[n,:length,:] - mean_true)**2,dim=0)) * T.sqrt(T.sum((y[n,:length,:] - mean_pred)**2,dim=0))
        correlation_coefficient = numerator / denominator

        # Clip correlation coefficient to prevent NaNs in gradients
        correlation_coefficient = T.clamp(correlation_coefficient, -1.0, 1.0)

        # Convert correlation coefficient to correlation loss (1 - r)
        correlation_loss = 1.0 - correlation_coefficient
        pearson.append(T.mean(correlation_loss))
    return T.mean(T.tensor(pearson))

def lossmse(y, y_hat, lens):    #Compute mean squared error
    mse = T.div(T.sum(T.square(y-y_hat), dim=[1,2]), lens)
    return T.mean(mse)

### Initialize the models to be used 

In [None]:
#create the lstm and gnn models 
LSTM_model = T.nn.Sequential(encoder.LSTM(512, 256, 2),
                             generator.LSTM(256, 2, 3),
                             smoothing.LearnableGaussianSmoothing(5)).to(device)
GCN_model = T.nn.Sequential(encoder.GCN(512, 256),
                            generator.GCN(256, 3),
                            smoothing.LearnableGaussianSmoothing(5)).to(device)
models = {'LSTM':LSTM_model,'graph':GCN_model}

### Training and Testing Loop 

In [None]:
import sys
dirs = glob.glob('processed_data/'+pkls[0]) #insert the name of the processed pkl files instead of pkls[0]
input = []
output = []
for cv in range(len(dirs)):
    train = []
    test = []
    for i in range(len(dirs)):
        with open(dirs[i], 'rb') as f:
            data = pickle.load(f)
        if 'Session2' in dirs[i]:   #Remove the columns that are not processed properly
            for c in [41, 62, 69, 69, 73, 80, 100, 122, 150, 154, 374]:
                del data[c]
        if i == cv:
            test = data
        else:
            train.extend(data)
    model = models['LSTM']
    optimizer = T.optim.Adam(model.parameters(), lr=0.01)       #Optimizer and Scheduler 
    scheduler = T.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1,
                        patience=50, threshold=0.01, threshold_mode='rel',
                        verbose=True)

    dataloader = T.utils.data.DataLoader(train, batch_size=64, num_workers=4, collate_fn=colate)    #Training Dataloader
    dataloaderT = T.utils.data.DataLoader(test, batch_size=1)   #Testing Dataloader
    Alpha = T.sigmoid(T.nn.Parameter(data=T.Tensor(1), requires_grad=True))
    epoch_loss = 0
    epoch_loss_pcc = 0
    epoch=0
    epochs=tqdm(total = 500)
    while epoch<500 and optimizer.param_groups[0]['lr']>0.000001:       #Training loop with 500 epochs limit and learning rate limit
        epochs.__iter__()
        model.train()
        running_loss = 0.0
        running_loss_pcc = 0.0
        for X, _, hea, lens, _, _ in dataloader:
            X = X.squeeze()
            X = X.to(device)
            hea = T.squeeze(hea).to(device)
            optimizer.zero_grad()

            with T.set_grad_enabled(True):
                outputs = model(X)
                loss = lossmse(outputs, hea, T.tensor(lens).to(device))
                loss2 = pccloss(hea, outputs, T.tensor(lens).to(device))
                (Alpha*loss + (T.Tensor(1)-Alpha)*loss2).backward()
                optimizer.step()

            # statistics
            running_loss += loss.item()
            running_loss_pcc += loss2.item()
        epoch_loss = running_loss / len(dataloader)
        epoch_loss_pcc = running_loss_pcc / len(dataloader)
        scheduler.step(epoch_loss + epoch_loss_pcc)
        epoch+=1
        epochs.update(1)
        epochs.set_postfix(mse_loss=epoch_loss, pcc_loss=epoch_loss_pcc)        #Print the iteration loss values
        
    
    model.eval()
    for audio, _, pose, _, _ in dataloaderT:    #Testing loop
        audio = T.squeeze(audio)
        audio = audio.unsqueeze(0)
        audio = audio.to(device)
        pose = pose.to(device)
        outputs = model(audio)
        input.append(pose.detach().cpu().numpy())
        output.append(outputs.detach().cpu().numpy())
dump = {'name':'LSTM','input':input,'output':output}    #Save the testing outputs
with open('results/'+'LSTM_learn.pkl', 'wb') as f:
    pickle.dump(dump, f, protocol=pickle.HIGHEST_PROTOCOL)
    f.close()
