In [4]:
import os
import json
import random
import numpy as np

import torch 
from torch import nn 
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import copy
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from utils.analysis import action_evaluator
import cv2
import json

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE",
                        "#FFDD00",
                        "#FF7D00",
                        "#FF006D",
                        "#ADFF02",
                        "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
model_ident = "bidiretional_lstm_hrnet_nturgb_classifier"
unique_iden = "epoch10_emb1024xy"

main_dir = "D:\\FYP\\HAR-ZSL-XAI"
data_dir = os.path.join(main_dir,"data","sequence_data","midpoint_50f")
epoch_vids = os.path.join(main_dir,"epoch_vids")
models_saves = os.path.join(main_dir,"model_saves")
embeddings_save = os.path.join(main_dir,"embedding_save")
test_vids = os.path.join(main_dir,"test_vids")
class_names = os.listdir(data_dir)
train_ratio = 0.9
val_ratio = 0.1
test_ratio = 1-train_ratio - val_ratio
batch_size = 128

os.makedirs(epoch_vids,exist_ok=True)
os.makedirs(models_saves,exist_ok=True)
os.makedirs(embeddings_save,exist_ok=True)

In [7]:
with open("../data/sequence_data/metadata/midpoint_50f/class_data.json","r") as f0:
    loaded_class_data = json.load(f0)
class_names = loaded_class_data["selected_classes"]

In [8]:
#class_names += os.listdir("../data/nipun_video_dataset/PAMAP2_K10_V1")

In [9]:
config = {
    "n_epochs":100,
    "model_name":"BidirectionalLSTM",
    "model":{
        "seq_len":50,
        "input_size":12*2,
        "hidden_size":1024,
        "linear_filters":[128,256,512,1024],
        "embedding_size":1024,
        "num_classes":len(class_names),
        "num_layers":1,
        "bidirectional":True,
        "batch_size":batch_size,
        "dev":device
    }
}

In [10]:
def classname_id(class_name_list):
    id2classname = {k:v for k, v in zip(list(range(len(class_name_list))),class_name_list)}
    classname2id = {v:k for k, v in id2classname.items()}
    return id2classname, classname2id

In [11]:
id2clsname, clsname2id = classname_id(class_names)

In [12]:
train_file_list = []
val_file_list = []
test_file_list = []

file_list = [os.path.join(data_dir,x) for x in os.listdir(data_dir)]

random.shuffle(file_list)
num_list = len(file_list)

train_range = [0,int(num_list*train_ratio)]
val_range = [int(num_list*train_ratio),int(num_list*(train_ratio+val_ratio))]
test_range = [int(num_list*(train_ratio+val_ratio)),num_list-1]

train_file_list += file_list[train_range[0]:train_range[1]]
val_file_list += file_list[val_range[0]:val_range[1]]
test_file_list += file_list[test_range[0]:test_range[1]]

In [13]:
len(train_file_list),len(val_file_list),len(test_file_list)

(40320, 4480, 0)

In [14]:
train_file_list = train_file_list[:(len(train_file_list)//batch_size)*batch_size]
val_file_list = val_file_list[:(len(val_file_list)//batch_size)*batch_size]
test_file_list = test_file_list[:(len(test_file_list)//batch_size)*batch_size]

In [15]:
len(train_file_list),len(val_file_list),len(test_file_list)

(40320, 4480, 0)

In [16]:
class SkeletonDataset(Dataset):
    def __init__(self, file_list,class2id,transform=None,
                 target_transform=None,active_locations=[5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],file_name=False, is_2d=False):
        self.file_list = file_list
        self.transform = transform
        self.class2id = class2id
        self.target_transform = target_transform
        self.active_locations = active_locations
        self.file_name = file_name
        self.is_2d = is_2d

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        a_file = np.load(self.file_list[idx])
        action_type = self.file_list[idx].strip().split(os.path.sep)[-1].split("_cls_")[0]
        coords, vid_size = a_file["coords"],a_file["video_size"]
        coords = coords[:,self.active_locations,:]

        if self.is_2d:
            coords = coords[...,0:2]

        shape = coords.shape

        coords = torch.from_numpy(coords).float()

        coords = torch.reshape(coords, (shape[0], shape[1]*shape[2]))
        label = torch.clone(coords)

        if self.transform:
            coords = self.transform(coords)
        if self.target_transform:
            label = self.target_transform(coords)

        if self.file_name:
            return coords, label, self.class2id[action_type],a_file["video_size"],self.file_list[idx]
        return coords, label, self.class2id[action_type],a_file["video_size"]

In [17]:
train_data = SkeletonDataset(train_file_list,clsname2id,is_2d=True)
val_data = SkeletonDataset(val_file_list,clsname2id,is_2d=True)
test_data = SkeletonDataset(test_file_list,clsname2id,is_2d=True)

In [18]:
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [19]:
for x in train_dl:
    print(x[0].shape,x[1].shape,x[2])
    break

torch.Size([128, 50, 24]) torch.Size([128, 50, 24]) tensor([27, 40, 56, 36,  7, 28, 45, 51,  7, 27, 31, 31, 21, 20, 43, 63, 38, 23,
         7, 50, 11, 44, 52, 63,  3, 19, 37, 37, 54, 30, 46, 44, 55, 51,  6, 24,
         3, 40, 19, 33, 14, 54, 12, 35, 27, 35, 23, 10, 49, 62, 36, 42, 56, 17,
        44, 24, 56, 27, 39,  4, 25,  3,  9,  0, 22, 50, 32, 44, 62, 12,  9, 61,
        53, 22, 39, 48,  7,  9, 56, 46,  6, 39, 39, 47, 27, 32, 59, 18, 50, 23,
        22, 42, 36, 36, 11, 33, 23, 40, 10, 10, 10, 10, 28, 49, 43, 22, 55, 22,
        59, 36,  3, 33, 60,  3, 14, 26,  4, 15, 63, 31, 11, 46, 51, 54, 24, 23,
        47, 61])


In [20]:
class BiLSTMEncoder(nn.Module):
    def __init__(self,seq_len, input_size,num_classes, hidden_size,linear_filters,embedding_size:int, num_layers = 1,bidirectional=True,dev=device):
        super(BiLSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dev=dev
        self.num_layers = num_layers
        self.linear_filters = linear_filters
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len
        self.num_classes = num_classes

        # define LSTM layer
        self.layers = []

        # add linear layers 
        for __id,layer_out in enumerate(self.linear_filters):
            if __id == 0:
                self.layers.append(nn.Linear(self.input_size, layer_out))
            else:
                self.layers.append(nn.Linear(self.linear_filters[__id-1], layer_out))

        # add lstm layer
        self.lstm = nn.LSTM(input_size = layer_out, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=self.bidirectional,
                            batch_first=True)
        
        self.net = nn.Sequential(*self.layers)

        self.classification_header = nn.Linear(self.embedding_size,self.num_classes)

        #add embedding out
        if bidirectional:
            self.bn = nn.BatchNorm1d(self.hidden_size*4)
            self.out_linear = nn.Linear(self.hidden_size*4, self.embedding_size)
        else:
            self.bn = nn.BatchNorm1d(self.hidden_size*2)
            self.out_linear = nn.Linear(self.hidden_size*2, self.embedding_size)

        
    def forward(self, x_input):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        
        x = self.net(x_input)
        lstm_out, self.hidden = self.lstm(x)
        hidden_transformed = torch.cat(self.hidden,0)
        hidden_transformed = torch.transpose(hidden_transformed,0,1)
        hidden_transformed = torch.flatten(hidden_transformed,start_dim=1)

        #hidden_transformed = self.bn(hidden_transformed)
        hidden_transformed = self.out_linear(hidden_transformed)

        label = self.classification_header(hidden_transformed)
        
        return label, hidden_transformed

    
class BiLSTMDecoder(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size, linear_filters,embedding_size:int, num_layers = 1,bidirectional=True,dev=device):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dev = dev
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.seq_len = seq_len

        if bidirectional:
            self.input_linear = nn.Linear(self.embedding_size,4*self.hidden_size)
        else:
            self.input_linear = nn.Linear(self.embedding_size,2*self.hidden_size)

        # define LSTM layer
        self.layers = []
        # add lstm
        self.lstm = nn.LSTM(input_size = self.linear_filters[0], hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=bidirectional)

                        
        # add linear layers 
        if bidirectional:
            self.layers.append(nn.Linear(2*hidden_size,self.linear_filters[0]))
        else:
            self.layers.append(nn.Linear(hidden_size,self.linear_filters[0]))

        for __id,layer_in in enumerate(self.linear_filters):
            if __id == len(linear_filters)-1:
                self.layers.append(nn.Linear(layer_in,self.input_size))
            else:
                self.layers.append(nn.Linear(layer_in,self.linear_filters[__id+1]))

        self.net = nn.Sequential(*self.layers)

        
        

    def forward(self,encoder_hidden):
        """
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        """
        
        
        hidden_shape = encoder_hidden.shape
        encoder_hidden = self.input_linear(encoder_hidden)
        
        if self.bidirectional:
            hidden = encoder_hidden.view((-1,4,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h1,h2,c1,c2 = torch.unbind(hidden,0)
            h,c = torch.stack((h1,h2)),torch.stack((c1,c2))
            bs = h.size()[1]
        else:
            hidden = encoder_hidden.view((-1,2,self.hidden_size))
            hidden = torch.transpose(hidden,1,0)
            h,c = torch.unbind(hidden,0)
            bs = h.size()[1]
        
        dummy_input = torch.rand((bs,self.seq_len,self.hidden_size), requires_grad=True).to(self.dev)
        
        lstm_out, self.hidden = self.lstm(dummy_input,(h,c))
        x = self.net(lstm_out)
        
        return x

class BiLSTMEncDecModel(nn.Module):
    def __init__(self,seq_len, input_size, hidden_size,num_classes, linear_filters=[128,256,512],embedding_size:int=256, num_layers = 1,bidirectional=True,dev=device):
        super(BiLSTMEncDecModel, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dev = dev
        self.num_layers = num_layers
        self.linear_filters = linear_filters[::-1]
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.batch_size = batch_size
        self.seq_len = seq_len
        self.num_classes= num_classes
        
        self.encoder = BiLSTMEncoder(seq_len, input_size, num_classes,hidden_size, linear_filters,embedding_size, num_layers = 1,bidirectional=True, dev=self.dev)
        self.decoder = BiLSTMDecoder(seq_len, input_size, hidden_size, linear_filters,embedding_size, num_layers = 1,bidirectional=True, dev=self.dev)
        
    def forward(self,x):
        label,embedding = self.encoder(x)
        decoder_out = self.decoder(embedding)
        
        return decoder_out, embedding, label
        

In [21]:
encoder = BiLSTMEncoder(
    seq_len=config["model"]["seq_len"],
    input_size=config["model"]["input_size"],
    num_classes = config["model"]["num_classes"],
    hidden_size=config["model"]["hidden_size"],
    linear_filters=config["model"]["linear_filters"],
    embedding_size=config["model"]["embedding_size"],
    num_layers = config["model"]["num_layers"],
    bidirectional=config["model"]["bidirectional"],
    dev=config["model"]["dev"]).to(device)

decoder = BiLSTMDecoder(
    seq_len=config["model"]["seq_len"],
    input_size=config["model"]["input_size"],
    hidden_size=config["model"]["hidden_size"],
    linear_filters=config["model"]["linear_filters"],
    embedding_size=config["model"]["embedding_size"],
    num_layers = config["model"]["num_layers"],
    bidirectional=config["model"]["bidirectional"],
    dev=config["model"]["dev"]).to(device)

bilstm_model = BiLSTMEncDecModel(
    seq_len=config["model"]["seq_len"],
    input_size=config["model"]["input_size"],
    num_classes = config["model"]["num_classes"],
    hidden_size=config["model"]["hidden_size"],
    linear_filters=config["model"]["linear_filters"],
    embedding_size=config["model"]["embedding_size"],
    num_layers = config["model"]["num_layers"],
    bidirectional=config["model"]["bidirectional"],
    dev=config["model"]["dev"])

In [19]:
bilstm_model.to(device)

BiLSTMEncDecModel(
  (encoder): BiLSTMEncoder(
    (lstm): LSTM(1024, 1024, batch_first=True, bidirectional=True)
    (net): Sequential(
      (0): Linear(in_features=24, out_features=128, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
      (2): Linear(in_features=256, out_features=512, bias=True)
      (3): Linear(in_features=512, out_features=1024, bias=True)
    )
    (classification_header): Linear(in_features=1024, out_features=64, bias=True)
    (bn): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (out_linear): Linear(in_features=4096, out_features=1024, bias=True)
  )
  (decoder): BiLSTMDecoder(
    (input_linear): Linear(in_features=1024, out_features=4096, bias=True)
    (lstm): LSTM(1024, 1024, batch_first=True, bidirectional=True)
    (net): Sequential(
      (0): Linear(in_features=2048, out_features=1024, bias=True)
      (1): Linear(in_features=1024, out_features=512, bias=True)
      (2): Linear(in_featu

In [20]:
label, embedding = encoder(torch.randn((32,50,24)).to(device))

In [21]:
embedding.shape

torch.Size([32, 1024])

In [22]:
label.shape

torch.Size([32, 64])

In [23]:
decoder_out = decoder(embedding)

In [24]:
model_out,embedding,label = bilstm_model(torch.randn((32,50,24)).to(device))

In [25]:
model_out.shape

torch.Size([32, 50, 24])

In [26]:
model_out,embedding,label = bilstm_model(torch.randn((16,50,24)).to(device))

In [27]:
model_out.shape

torch.Size([16, 50, 24])

In [28]:
model_out,embedding,label = bilstm_model(torch.randn((5,50,24)).to(device))

In [29]:
model_out.shape

torch.Size([5, 50, 24])

In [39]:
id2clsname

{0: '100',
 1: '101',
 2: '10',
 3: '11',
 4: '12',
 5: '13',
 6: '14',
 7: '15',
 8: '16',
 9: '17',
 10: '18',
 11: '19',
 12: '1',
 13: '20',
 14: '21',
 15: '22',
 16: '23',
 17: '24',
 18: '25',
 19: '26',
 20: '27',
 21: '28',
 22: '29',
 23: '2',
 24: '30',
 25: '32',
 26: '33',
 27: '34',
 28: '35',
 29: '36',
 30: '37',
 31: '3',
 32: '41',
 33: '42',
 34: '43',
 35: '44',
 36: '45',
 37: '46',
 38: '47',
 39: '48',
 40: '49',
 41: '4',
 42: '5',
 43: '64',
 44: '66',
 45: '6',
 46: '73',
 47: '74',
 48: '75',
 49: '76',
 50: '7',
 51: '82',
 52: '84',
 53: '85',
 54: '86',
 55: '87',
 56: '88',
 57: '89',
 58: '8',
 59: '90',
 60: '91',
 61: '92',
 62: '98',
 63: '9'}

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mapping_l = [
        [15, 13], [13, 11], [11, 5],
        [12, 14], [14, 16], [12, 6],
        [3, 1],[1, 2],[1, 0],[0, 2],[2,4],
        [9, 7], [7,5], [5, 6],
        [6, 8], [8, 10],
        ]
#mapping_l = []

from dataset.SkeletonData.visualize import *
from utils.train_utils import *

In [31]:
label_map = [(k,v) for k,v in id2clsname.items()]
labelToId = {x[0]: i for i, x in enumerate(label_map)}

In [32]:
def combined_loss(pred_sequence,pred_label,true_sequence,true_label,loss_module,alpha_target=1,alpha_recon=1):
    recon_loss = alpha_recon*loss_module["reconstruction_loss"](pred_sequence,true_sequence)
    tar_loss = alpha_target*loss_module["target_loss"](pred_label,true_label)
    loss =  recon_loss + tar_loss

    #print(alpha_recon*loss_module["reconstruction_loss"](pred_sequence,true_sequence))
    #print(alpha_target*loss_module["target_loss"](pred_label,true_label))

    return loss, {
        "reconstruction_loss":recon_loss.item(),
        "target_loss":tar_loss.item()
    }



In [43]:
def train_model(__model, train_dataset, val_dataset, n_epochs):
    optimizer = torch.optim.Adam(__model.parameters(), lr=1e-3, weight_decay=0.01)
    std_loss = {
        "reconstruction_loss" :nn.L1Loss(reduction='mean').to(device),
        "target_loss" :nn.CrossEntropyLoss(reduction="mean").to(device)
    }
    train_history = dict(train=[], val=[],train_detail=[],val_detail=[])
    
    best_model_wts = copy.deepcopy(__model.state_dict())
    best_loss = 10000.0
  
    for epoch in range(1, n_epochs + 1):
        __model = __model.train()

        train_pred_class = []
        train_true_class = []
        val_pred_class = []
        val_true_class = []

        train_losses = []
        train_loss_detail = []
        for input_sequence,target_sequence,target_action,target_vid_size in tqdm(train_dataset):
            optimizer.zero_grad()
            
            input_sequence = input_sequence.to(device)
            target_sequence = target_sequence.to(device)
            target_action = target_action.to(device)
            predicted_sequence,_,predicted_label  = __model(input_sequence)
            
            loss,loss_detail = combined_loss(predicted_sequence,predicted_label, target_sequence, target_action,std_loss)
            #loss += 0.5*contrastive_loss(embed,labels=sample_label.view(-1))
            #print(contrastive_loss(embed,labels=sample_label.view(-1)))

            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            train_loss_detail.append(loss_detail)

            train_true_class.append(target_action.detach().cpu().numpy())
            train_pred_class.append(predicted_label.detach().cpu().numpy())

        train_pred_class = np.concatenate(train_pred_class)
        train_true_class = np.concatenate(train_true_class)
        train_metrics = action_evaluator(train_pred_class,train_true_class,class_names=list(clsname2id.keys()),print_report=True)

        val_losses = []
        val_loss_detail = []
        __model = __model.eval()
        with torch.no_grad():
            for input_sequence,target_sequence,target_action,target_vid_size in val_dataset:

                input_sequence = input_sequence.to(device)
                target_sequence = target_sequence.to(device)
                target_action = target_action.to(device)
                predicted_sequence,_,predicted_label  = __model(input_sequence)

                loss,loss_detail = combined_loss(predicted_sequence,predicted_label, target_sequence, target_action,std_loss)
                #loss += 0.5*contrastive_loss(embed,labels=sample_label.view(-1))
                val_losses.append(loss.item())
                val_loss_detail.append(loss_detail)

                val_true_class.append(target_action.detach().cpu().numpy())
                val_pred_class.extend(predicted_label.detach().cpu().numpy())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        val_pred_class = np.concatenate(val_pred_class)
        val_true_class = np.concatenate(val_true_class)

        val_metrics = action_evaluator(val_pred_class,val_true_class,class_names=list(clsname2id.keys()),print_report=True)

        train_history['train'].append(train_loss)
        train_history['val'].append(val_loss)

        summary_train_reconstruction_loss = np.mean([x["reconstruction_loss"] for x in train_loss_detail])
        summary_train_target_loss = np.mean([x["target_loss"] for x in train_loss_detail])
        train_history["train_detail"].append({
            "reconstruction_loss":summary_train_reconstruction_loss,
             "target_loss":summary_train_target_loss
        })

        summary_val_reconstruction_loss = np.mean([x["reconstruction_loss"] for x in val_loss_detail])
        summary_val_target_loss = np.mean([x["target_loss"] for x in val_loss_detail])
        train_history["train_detail"].append({
            "reconstruction_loss":summary_val_reconstruction_loss,
             "target_loss":summary_val_target_loss
        })



        if epoch%10 == 0:
            save_model(__model, f"temp_{model_ident}", f"{epoch}__{unique_iden}", models_saves, config)


        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(__model.state_dict())

        print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss} \n train_reconstruction_loss:- {summary_train_reconstruction_loss} train_target_loss {summary_train_target_loss} val_reconstruction_loss:- {summary_val_reconstruction_loss} val_target_loss {summary_val_target_loss}')

    __model.load_state_dict(best_model_wts)
    save_model(__model, model_ident, unique_iden, models_saves, config)
    return __model.eval(), train_history

In [44]:
model, history = train_model(
  bilstm_model, 
  train_dl, 
  val_dl, 
  n_epochs=config["n_epochs"]
)

100%|██████████| 315/315 [28:35<00:00,  5.45s/it]


TypeError: unhashable type: 'numpy.ndarray'

In [None]:
save_history(history,model_ident,unique_iden,models_saves,config)

In [None]:
ax = plt.figure().gca()

ax.plot(history['train'])
ax.plot(history['val'])
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'test'])
plt.title('Loss over training epochs')
plt.show();

In [None]:
embedding_list = {}
output_list = {}
with torch.no_grad():
    for in_sequence,tar_sequence,action,vid_size in tqdm(test_dl):
        in_sequence = in_sequence.to(device)
        tar_sequence = tar_sequence.to(device)
        seq_pred,embedding  = model(in_sequence)

        for seq,emb,action_t in zip(seq_pred.unbind(0),embedding.unbind(0),action.unbind(0)):
                try:
                    if len(embedding_list[int(action_t)])<=50:
                        embedding_list[int(action_t)].append(emb)
                        output_list[int(action_t)].append(seq)
                except KeyError:
                    embedding_list[int(action_t)] = [emb]
                    output_list[int(action_t)] = [emb]

        
            
        
        #embedding_list[]

In [None]:
import random

def draw_heatmaps(arr_list,nrows=2,ncols=2):
    ran_list = random.sample(arr_list,ncols*nrows)
    fig, ax = plt.subplots(nrows=nrows,ncols=ncols, sharex=True)
    for i in range(nrows):
        for j in range(ncols):
            #print(i*ncols+j,len(ran_list))
            ax[i,j].imshow(ran_list[i*ncols+j].detach().cpu().numpy()[np.newaxis,:], cmap="plasma", aspect="auto")

    plt.tight_layout()
    plt.show()

In [None]:
draw_heatmaps(embedding_list[11])

In [None]:
draw_heatmaps(embedding_list[22])

In [None]:
draw_heatmaps(embedding_list[16])

In [None]:
def gen_video_from_embeddings(embedding,model,save_file):
    seq_out = model.decoder(embedding.repeat(batch_size,1,1))
    gen_video(seq_out[0].detach().numpy(), save_file, 400, 400,mapping_list=mapping_l)

In [None]:
with torch.no_grad():
    for __id,(in_seq,tar_seq,action,vid_size) in tqdm(enumerate(test_dl)):
        in_seq = in_seq.to(device)
        tar_seq = tar_seq.to(device)
        seq_pred,embedding  = model(in_seq)

        for __id,(input_vid,output_vid,action) in enumerate(zip(in_seq.unbind(0),seq_pred.unbind(0),action.unbind(0))):
            os.makedirs(f"{test_vids}/{int(action)}",exist_ok=True)

            

        
            
        
        #embedding_list[]

In [None]:
gen_video_from_embeddings(embedding_list[1][10],model,"embed_video.mp4")

In [None]:
test_emb = 0.5*embedding_list[19][0]+0.5*embedding_list[16][0]
gen_video_from_embeddings(test_emb,model,"test_embed_video.mp4")