In [1]:
import collections
import math
import numpy as np
import pickle
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

Using cuda


In [3]:
folder_path = "C:/Users/Kieran/Documents/Master Thesis Data/Datasets/MouseDirectional"

dataframes = []

for filename in os.listdir(folder_path):
    if filename.endswith(".csv"):
        file_path = os.path.join(folder_path, filename)
        df = pd.read_csv(file_path)
        dataframes.append(df)
        
data = pd.concat(dataframes, ignore_index=True)
data = data.drop(columns = ['frame_number'])
data

Unnamed: 0,"('nose', 'x')","('nose', 'y')","('nose', 'likelihood')","('H1R', 'x')","('H1R', 'y')","('H1R', 'likelihood')","('H2R', 'x')","('H2R', 'y')","('H2R', 'likelihood')","('H1L', 'x')",...,"('tail', 'x')","('tail', 'y')","('tail', 'likelihood')","('S2', 'x')","('S2', 'y')","('S2', 'likelihood')","('S1', 'x')","('S1', 'y')","('S1', 'likelihood')",mouse_no
0,-121.281118,6.777963,0.999969,-107.543587,-9.904076,0.999760,-91.049674,-24.262784,0.996968,-114.476969,...,136.551132,15.836794,0.998689,58.663218,3.552714e-15,0.994912,-428.453907,-57.731640,0.995648,11.4
1,-118.466293,5.164439,0.999950,-104.771778,-11.426444,0.999568,-92.801407,-25.102813,0.997916,-111.668127,...,139.580639,16.430253,0.999059,62.608034,3.552714e-15,0.996069,-432.897362,-57.456968,0.993003,11.4
2,-117.116147,3.743272,0.999945,-103.840623,-12.463020,0.999638,-91.015857,-23.573421,0.999153,-108.146772,...,142.442255,12.872289,0.999419,64.691238,-7.105427e-15,0.986211,-437.910362,-49.581608,0.987800,11.4
3,-118.168901,10.876589,0.999765,-103.922190,-6.807094,0.999534,-90.473196,-18.974423,0.999468,-108.043361,...,144.586701,9.035496,0.999285,66.998246,-7.105427e-15,0.981083,-443.445989,-39.517151,0.994152,11.4
4,-129.872247,18.977417,0.999855,-118.215160,4.053176,0.999252,-101.756529,-10.320907,0.998945,-117.468658,...,142.597915,6.642514,0.999156,68.408748,-7.105427e-15,0.983985,-444.126056,-45.945764,0.995609,11.4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1102495,-67.278284,-67.951573,0.999232,-53.510518,-67.561356,0.999924,-45.792688,-60.568728,0.999492,-71.032565,...,112.339967,-39.013942,0.999563,65.652483,-1.421085e-14,0.998793,44.635680,-639.773663,0.998882,88.3
1102496,-63.502359,-71.564452,0.999680,-50.896190,-71.149798,0.999842,-44.362936,-61.712636,0.999562,-70.997425,...,112.415697,-38.871035,0.999529,65.616523,-1.065814e-14,0.998716,45.998709,-639.839548,0.998840,88.3
1102497,-69.702519,-81.962860,0.999622,-55.540755,-78.942856,0.999950,-45.310777,-68.474218,0.999656,-73.736269,...,111.575389,-37.813775,0.999215,65.260951,-7.105427e-15,0.999035,45.910568,-640.098971,0.998980,88.3
1102498,-69.750044,-81.922420,0.999570,-55.586531,-78.910630,0.999939,-45.350483,-68.447927,0.999616,-73.774792,...,111.369144,-37.768235,0.999225,65.197667,7.105427e-15,0.998865,45.539315,-640.125490,0.998603,88.3


In [4]:
filtered_data = data.filter(regex='x|y')
filtered_data

Unnamed: 0,"('nose', 'x')","('nose', 'y')","('H1R', 'x')","('H1R', 'y')","('H2R', 'x')","('H2R', 'y')","('H1L', 'x')","('H1L', 'y')","('H2L', 'x')","('H2L', 'y')",...,"('B2L', 'x')","('B2L', 'y')","('B3L', 'x')","('B3L', 'y')","('tail', 'x')","('tail', 'y')","('S2', 'x')","('S2', 'y')","('S1', 'x')","('S1', 'y')"
0,-121.281118,6.777963,-107.543587,-9.904076,-91.049674,-24.262784,-114.476969,17.757493,-95.642332,29.419203,...,68.308038,57.202524,111.302323,39.594285,136.551132,15.836794,58.663218,3.552714e-15,-428.453907,-57.731640
1,-118.466293,5.164439,-104.771778,-11.426444,-92.801407,-25.102813,-111.668127,16.282128,-95.395779,27.423142,...,71.765439,56.764052,115.032002,39.768954,139.580639,16.430253,62.608034,3.552714e-15,-432.897362,-57.456968
2,-117.116147,3.743272,-103.840623,-12.463020,-91.015857,-23.573421,-108.146772,17.867068,-95.075329,27.651654,...,75.908466,54.183708,119.430075,36.769587,142.442255,12.872289,64.691238,-7.105427e-15,-437.910362,-49.581608
3,-118.168901,10.876589,-103.922190,-6.807094,-90.473196,-18.974423,-108.043361,24.035569,-93.308834,32.225534,...,79.532210,50.372116,118.700687,31.844225,144.586701,9.035496,66.998246,-7.105427e-15,-443.445989,-39.517151
4,-129.872247,18.977417,-118.215160,4.053176,-101.756529,-10.320907,-117.468658,30.129350,-100.656114,36.639027,...,79.262964,49.639693,116.723267,29.462425,142.597915,6.642514,68.408748,-7.105427e-15,-444.126056,-45.945764
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1102495,-67.278284,-67.951573,-53.510518,-67.561356,-45.792688,-60.568728,-71.032565,-56.138913,-67.660236,-33.669516,...,107.977805,42.951319,118.008388,-8.684928,112.339967,-39.013942,65.652483,-1.421085e-14,44.635680,-639.773663
1102496,-63.502359,-71.564452,-50.896190,-71.149798,-44.362936,-61.712636,-70.997425,-59.609619,-68.610977,-36.893957,...,107.948700,43.198371,118.149390,-8.204957,112.415697,-38.871035,65.616523,-1.065814e-14,45.998709,-639.839548
1102497,-69.702519,-81.962860,-55.540755,-78.942856,-45.310777,-68.474218,-73.736269,-66.442908,-71.803595,-45.464528,...,108.147861,43.588136,118.030566,-7.391648,111.575389,-37.813775,65.260951,-7.105427e-15,45.910568,-640.098971
1102498,-69.750044,-81.922420,-55.586531,-78.910630,-45.350483,-68.447927,-73.774792,-66.400131,-71.829951,-45.422876,...,106.668777,44.425300,118.134399,-6.845362,111.369144,-37.768235,65.197667,7.105427e-15,45.539315,-640.125490


## Base model

In [5]:
dropout_prob   = 0.5
embedding_size = 32
epoch_num      = 10
hidden_size    = 16
layer_num      = 2
learning_rate  = 1e-3
seq_size       = 25 #25 frames is equal to 1 second of video, maybe use 50?
pred_window    = 1

In [6]:
class SeqDataset(Dataset):
    def __init__(self, device, seq_size, dataframe, pred_window):
        super(SeqDataset, self).__init__()
        self.device      = device
        self.seq_size    = seq_size
        self.dataframe   = dataframe
        self.pred_window = pred_window
        self.target_data = dataframe.filter(regex='x|y')

    def __len__(self):
        return len(self.dataframe) - self.seq_size - 1

    def __getitem__(self, idx):
        in_seq = torch.tensor(self.dataframe.iloc[idx:idx + self.seq_size].values, dtype=torch.float, device=self.device)
        target_seq = torch.tensor(self.target_data.iloc[idx + self.pred_window:idx + self.seq_size + self.pred_window].values, dtype=torch.float, device=self.device)
        return in_seq, target_seq

In [7]:
# Split the data into train, validation, and test using the full video sizes (so that videos are not split into different sets)
train_data, test_data = train_test_split(filtered_data, test_size= 11250*int(0.2*98), shuffle=False)

# Assuming df is your dataframe with limb coordinate changes
scaler = StandardScaler()

# Fit the scaler on the training set
scaler.fit(train_data)

# Transform the training set
train_data_scaled = scaler.transform(train_data)
train_data = pd.DataFrame(train_data_scaled.reshape(train_data.shape), columns=train_data.columns)

# When you're ready to test the model, transform the test set
test_data_scaled = scaler.transform(test_data)
test_data = pd.DataFrame(test_data_scaled.reshape(test_data.shape), columns=test_data.columns)
train_data

Unnamed: 0,"('nose', 'x')","('nose', 'y')","('H1R', 'x')","('H1R', 'y')","('H2R', 'x')","('H2R', 'y')","('H1L', 'x')","('H1L', 'y')","('H2L', 'x')","('H2L', 'y')",...,"('B2L', 'x')","('B2L', 'y')","('B3L', 'x')","('B3L', 'y')","('tail', 'x')","('tail', 'y')","('S2', 'x')","('S2', 'y')","('S1', 'x')","('S1', 'y')"
0,-0.548786,0.112236,-0.540907,0.008330,-0.529858,-0.181602,-0.747827,0.149123,-0.710679,0.254334,...,-0.044805,0.177647,-0.038578,0.514353,0.100965,0.388717,-0.312793,0.256370,-1.101565,-0.044794
1,-0.482077,0.078642,-0.465556,-0.028834,-0.582440,-0.207048,-0.668882,0.113342,-0.703076,0.193675,...,0.045100,0.159608,0.034803,0.519004,0.157796,0.404489,-0.221796,0.256370,-1.113008,-0.043997
2,-0.450080,0.049053,-0.440243,-0.054138,-0.528842,-0.160720,-0.569911,0.151780,-0.693194,0.200619,...,0.152834,0.053453,0.121334,0.439152,0.211477,0.309930,-0.173742,-0.498993,-1.125919,-0.021143
3,-0.475029,0.197571,-0.442460,0.083932,-0.512553,-0.021408,-0.567005,0.301382,-0.638720,0.339618,...,0.247065,-0.103355,0.106983,0.308025,0.251704,0.207961,-0.120525,-0.498993,-1.140175,0.008063
4,-0.752387,0.366234,-0.831011,0.349049,-0.851251,0.240722,-0.831911,0.449172,-0.865292,0.473742,...,0.240064,-0.133487,0.068078,0.244615,0.214397,0.144363,-0.087988,-0.498993,-1.141927,-0.010592
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
888745,-0.492518,0.659527,-0.417676,0.547214,-0.214989,0.320447,-0.337849,0.605044,-0.139779,0.545643,...,-0.085691,-0.155232,-0.048431,0.045442,0.201474,-0.268812,-0.016173,-0.247205,1.495610,0.112877
888746,-0.448187,0.812045,-0.659524,0.741433,-0.512479,0.470148,-0.250138,0.787986,0.014362,0.561267,...,-0.139846,-0.060972,-0.089506,0.050696,0.203059,-0.134636,-0.035160,0.508157,1.494835,0.179761
888747,-0.406025,0.836672,-0.612106,0.699109,-0.628093,0.482054,-0.134997,0.733382,0.130084,0.444807,...,-0.134693,-0.095315,-0.073009,0.159937,0.206434,0.087911,0.063541,0.004582,1.465933,0.347291
888748,-0.396969,0.830068,-0.605613,0.692534,-0.622839,0.475791,-0.118199,0.726301,0.149024,0.437928,...,-0.135024,-0.092783,-0.065171,0.163794,0.216320,0.095640,0.073733,0.130476,1.463561,0.351864


In [8]:

n_train_vids = len(train_data)/11250
train_data, val_data = train_test_split(train_data, test_size = 11250 * int(0.1 * n_train_vids), shuffle=False)

# Create SeqDataset instances for the train, validation, and test sets
train_dataset = SeqDataset(device, seq_size, train_data, pred_window)
val_dataset = SeqDataset(device, seq_size, val_data, pred_window)
test_dataset = SeqDataset(device, seq_size, test_data, pred_window)

# Create DataLoader instances for batching
batch_size = 64  # Adjust to your desired batch size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [9]:
class Model(nn.Module):
    def __init__(self, dropout_prob, hidden_size, layer_num, input_size, output_size):
        super(Model, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.gru       = nn.GRU(input_size, hidden_size, layer_num, batch_first=True, dropout=dropout_prob)
        self.linear    = nn.Linear(hidden_size, output_size)

    def forward(self, in_sequence, hidden_state=None):
#         embedding_seq            = self.embedding(in_sequence)
        hidden_seq, hidden_state = self.gru(in_sequence, hidden_state)
        out_seq                  = self.linear(hidden_seq)
        return out_seq, hidden_state
    
    def draw(self, in_sequence, logit_temp=1.0):
        out_seq, _  = self(in_sequence)
        prob_dist   = torch.softmax(out_seq[0, -1] / logit_temp, 0)
        rand_sample = torch.multinomial(prob_dist, 1).item()                   
        return rand_sample

In [10]:
input_size = next(iter(train_loader))[0].size(-1)
print("Input size:", input_size)
output_size = next(iter(train_loader))[1].size(-1)
print("Output size:", output_size)
model = Model(dropout_prob, hidden_size, layer_num, input_size, output_size).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

Input size: 28
Output size: 28


In [11]:
min_loss = float("inf")

for epoch in range(epoch_num):
    
    model.train()
    train_loss = 0.0
    
    i = 0
    for in_seq, target_seq in tqdm(train_loader, desc=f"Epoch {epoch}/{epoch_num}"):
        
        out_seq, _  = model(in_seq)
        loss        = criterion(out_seq, target_seq)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        i += 1

    train_loss       /= len(train_loader)
    # train_perplexity  = np.exp(train_loss)
    print(f"Train loss: {train_loss:.4f}")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for in_seq, target_seq in val_loader:
            out_seq, _  = model(in_seq)
            loss        = criterion(out_seq, target_seq)
            val_loss   += loss.item()

    val_loss       /= len(val_loader)
    # val_perplexity  = np.exp(val_loss)
    print(f"Val loss: {val_loss:.4f}")

    if val_loss < min_loss:
        min_loss = val_loss
        torch.save(model.state_dict(), "../Models/model_md_norm_dropped_cols.pt")

Epoch 0/10:   7%|████▊                                                             | 919/12655 [00:12<02:37, 74.51it/s]


KeyboardInterrupt: 

In [16]:
model.load_state_dict(torch.load("../Models/model_md_norm_dropped_cols.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(28, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [17]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|██████████████████████████████████████████████████████████████████████████████| 3339/3339 [00:49<00:00, 67.27it/s]

Test loss: 0.7109





In [68]:
model.load_state_dict(torch.load("../Models/model_baseline.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(43, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [69]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|██████████████████████████████████████████████████████████████████████████████| 3340/3340 [00:36<00:00, 91.03it/s]

Test loss: 952.3952





In [74]:
model.load_state_dict(torch.load("../Models/model_baseline_dropped_cols.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(28, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [75]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|██████████████████████████████████████████████████████████████████████████████| 3340/3340 [00:40<00:00, 82.70it/s]

Test loss: 1070.2398





## Longer sequence

In [48]:
dropout_prob   = 0.5
embedding_size = 32
epoch_num      = 10
hidden_size    = 16
layer_num      = 2
learning_rate  = 1e-3
seq_size       = 25 #25 frames is equal to 1 second of video, maybe use 50?
pred_window    = 25
shift_size     = 25

In [52]:
class SeqDataset(Dataset):
    def __init__(self, device, dataframe, seq_size, pred_window, shift_size):
        super(SeqDataset, self).__init__()
        self.device = device
        self.seq_size = seq_size  # Input sequence length
        self.shift_size = 25  # Shift for the start of each new sequence
        self.pred_window = pred_window  # Additional frames for prediction
        self.dataframe = dataframe
        self.target_data = dataframe.filter(regex='x|y')

    def __len__(self):
        # The total number of sequences is adjusted for shifting the sequence start by self.shift_size
        return (len(self.dataframe) - self.seq_size - self.pred_window) // self.shift_size + 1

    def __getitem__(self, idx):
        # Calculate the actual starting index for the sequence based on the shift size
        start_idx = idx * self.shift_size
        end_idx = start_idx + self.seq_size
        target_end_idx = end_idx + self.pred_window

        # Ensure we don't exceed the bounds of the dataframe
        if target_end_idx > len(self.dataframe):
            target_end_idx = len(self.dataframe)
            end_idx = target_end_idx - self.pred_window

        in_seq = torch.tensor(self.dataframe.iloc[start_idx:end_idx].values,
                              dtype=torch.float, device=self.device)
        target_seq = torch.tensor(self.dataframe.iloc[end_idx:target_end_idx].values,
                                  dtype=torch.float, device=self.device)
        return in_seq, target_seq

In [53]:
# Split the data into train, validation, and test using the full video sizes (so that videos are not split into different sets)
train_data, test_data = train_test_split(filtered_data, test_size= 11250*int(0.2*98), shuffle=False)

# Assuming df is your dataframe with limb coordinate changes
scaler = StandardScaler()

# Fit the scaler on the training set
scaler.fit(train_data)

# Transform the training set
train_data_scaled = scaler.transform(train_data)
train_data = pd.DataFrame(train_data_scaled.reshape(train_data.shape), columns=train_data.columns)

# When you're ready to test the model, transform the test set
test_data_scaled = scaler.transform(test_data)
test_data = pd.DataFrame(test_data_scaled.reshape(test_data.shape), columns=test_data.columns)
train_data

Unnamed: 0,"('nose', 'x')","('nose', 'y')","('H1R', 'x')","('H1R', 'y')","('H2R', 'x')","('H2R', 'y')","('H1L', 'x')","('H1L', 'y')","('H2L', 'x')","('H2L', 'y')",...,"('B2L', 'x')","('B2L', 'y')","('B3L', 'x')","('B3L', 'y')","('tail', 'x')","('tail', 'y')","('S2', 'x')","('S2', 'y')","('S1', 'x')","('S1', 'y')"
0,-0.548786,0.112236,-0.540907,0.008330,-0.529858,-0.181602,-0.747827,0.149123,-0.710679,0.254334,...,-0.044805,0.177647,-0.038578,0.514353,0.100965,0.388717,-0.312793,0.256370,-1.101565,-0.044794
1,-0.482077,0.078642,-0.465556,-0.028834,-0.582440,-0.207048,-0.668882,0.113342,-0.703076,0.193675,...,0.045100,0.159608,0.034803,0.519004,0.157796,0.404489,-0.221796,0.256370,-1.113008,-0.043997
2,-0.450080,0.049053,-0.440243,-0.054138,-0.528842,-0.160720,-0.569911,0.151780,-0.693194,0.200619,...,0.152834,0.053453,0.121334,0.439152,0.211477,0.309930,-0.173742,-0.498993,-1.125919,-0.021143
3,-0.475029,0.197571,-0.442460,0.083932,-0.512553,-0.021408,-0.567005,0.301382,-0.638720,0.339618,...,0.247065,-0.103355,0.106983,0.308025,0.251704,0.207961,-0.120525,-0.498993,-1.140175,0.008063
4,-0.752387,0.366234,-0.831011,0.349049,-0.851251,0.240722,-0.831911,0.449172,-0.865292,0.473742,...,0.240064,-0.133487,0.068078,0.244615,0.214397,0.144363,-0.087988,-0.498993,-1.141927,-0.010592
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
888745,-0.492518,0.659527,-0.417676,0.547214,-0.214989,0.320447,-0.337849,0.605044,-0.139779,0.545643,...,-0.085691,-0.155232,-0.048431,0.045442,0.201474,-0.268812,-0.016173,-0.247205,1.495610,0.112877
888746,-0.448187,0.812045,-0.659524,0.741433,-0.512479,0.470148,-0.250138,0.787986,0.014362,0.561267,...,-0.139846,-0.060972,-0.089506,0.050696,0.203059,-0.134636,-0.035160,0.508157,1.494835,0.179761
888747,-0.406025,0.836672,-0.612106,0.699109,-0.628093,0.482054,-0.134997,0.733382,0.130084,0.444807,...,-0.134693,-0.095315,-0.073009,0.159937,0.206434,0.087911,0.063541,0.004582,1.465933,0.347291
888748,-0.396969,0.830068,-0.605613,0.692534,-0.622839,0.475791,-0.118199,0.726301,0.149024,0.437928,...,-0.135024,-0.092783,-0.065171,0.163794,0.216320,0.095640,0.073733,0.130476,1.463561,0.351864


In [54]:
n_train_vids = len(train_data)/11250
train_data, val_data = train_test_split(train_data, test_size = 11250 * int(0.1 * n_train_vids), shuffle=False)

# Create SeqDataset instances for the train, validation, and test sets
train_dataset = SeqDataset(device, train_data, seq_size, pred_window, shift_size)
val_dataset = SeqDataset(device, val_data, seq_size, pred_window, shift_size)
test_dataset = SeqDataset(device, test_data, seq_size, pred_window, shift_size)

# Create DataLoader instances for batching
batch_size = 64  # Adjust to your desired batch size
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [55]:
class Model(nn.Module):
    def __init__(self, dropout_prob, hidden_size, layer_num, input_size, output_size):
        super(Model, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.gru       = nn.GRU(input_size, hidden_size, layer_num, batch_first=True, dropout=dropout_prob)
        self.linear    = nn.Linear(hidden_size, output_size)

    def forward(self, in_sequence, hidden_state=None):
#         embedding_seq            = self.embedding(in_sequence)
        hidden_seq, hidden_state = self.gru(in_sequence, hidden_state)
        out_seq                  = self.linear(hidden_seq)
        return out_seq, hidden_state
    
    def draw(self, in_sequence, logit_temp=1.0):
        out_seq, _  = self(in_sequence)
        prob_dist   = torch.softmax(out_seq[0, -1] / logit_temp, 0)
        rand_sample = torch.multinomial(prob_dist, 1).item()                   
        return rand_sample

In [56]:
input_size = next(iter(train_loader))[0].size(-1)
print("Input size:", input_size)
output_size = next(iter(train_loader))[1].size(-1)
print("Output size:", output_size)
model = Model(dropout_prob, hidden_size, layer_num, input_size, output_size).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)
print(model)

Input size: 28
Output size: 28
Model(
  (gru): GRU(28, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)


In [58]:
min_loss = float("inf")
train_losses = []
val_losses = []

for epoch in range(epoch_num):
    
    model.train()
    train_loss = 0.0
    
    i = 0
    for in_seq, target_seq in tqdm(train_loader, desc=f"Epoch {epoch}/{epoch_num}"):

        out_seq, _  = model(in_seq)
        loss        = criterion(out_seq, target_seq)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        i += 1

    train_loss       /= len(train_loader)
    train_losses.append(train_loss)
    # train_perplexity  = np.exp(train_loss)
    print(f"Train loss: {train_loss:.4f}")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for in_seq, target_seq in val_loader:
            out_seq, _  = model(in_seq)
            loss        = criterion(out_seq, target_seq)
            val_loss   += loss.item()

    val_loss       /= len(val_loader)
    val_losses.append(val_loss)
    # val_perplexity  = np.exp(val_loss)
    print(f"Val loss: {val_loss:.4f}")

    if val_loss < min_loss:
        min_loss = val_loss
        torch.save(model.state_dict(), "../Models/model_md_norm_pred_25.pt")

Epoch 0/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 59.37it/s]


Train loss: 0.9146
Val loss: 1.1454


Epoch 1/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 61.75it/s]


Train loss: 0.8705
Val loss: 1.1360


Epoch 2/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 59.50it/s]


Train loss: 0.8620
Val loss: 1.1313


Epoch 3/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:07<00:00, 66.46it/s]


Train loss: 0.8560
Val loss: 1.1301


Epoch 4/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 60.80it/s]


Train loss: 0.8527
Val loss: 1.1298


Epoch 5/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 59.61it/s]


Train loss: 0.8502
Val loss: 1.1293


Epoch 6/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 62.81it/s]


Train loss: 0.8479
Val loss: 1.1296


Epoch 7/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 60.62it/s]


Train loss: 0.8461
Val loss: 1.1295


Epoch 8/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 59.75it/s]


Train loss: 0.8448
Val loss: 1.1298


Epoch 9/10: 100%|████████████████████████████████████████████████████████████████████| 506/506 [00:08<00:00, 60.36it/s]


Train loss: 0.8434
Val loss: 1.1299


In [59]:
# model.load_state_dict(torch.load("../Models/model_mc_norm_pred_25.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(28, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [60]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|████████████████████████████████████████████████████████████████████████████████| 133/133 [00:01<00:00, 75.71it/s]

Test loss: 0.6253





In [38]:
model.load_state_dict(torch.load("../Models/model_pred_25.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(43, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [39]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|███████████████████████████████████████████████████████████████████████████████| 133/133 [00:01<00:00, 102.39it/s]

Test loss: 11223.6390





In [12]:
# model.load_state_dict(torch.load("../Models/model_cc_pred_25.pt"))
model.eval

<bound method Module.eval of Model(
  (gru): GRU(43, 16, num_layers=2, batch_first=True, dropout=0.5)
  (linear): Linear(in_features=16, out_features=28, bias=True)
)>

In [13]:
test_loss = 0.0

for in_seq, target_seq in tqdm(test_loader):
    out_seq, _ = model(in_seq)
    loss = criterion(out_seq, target_seq)
    test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test loss: {test_loss:.4f}")

100%|████████████████████████████████████████████████████████████████████████████████| 133/133 [00:04<00:00, 30.12it/s]

Test loss: 120.8619



