In [3]:
import os 
import random
import numpy as np

import torch 
from torch import nn 
from torch import optim
import torch.nn.functional as F

In [4]:
class LSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(LSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # define LSTM layer
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers)

    def forward(self, x_input):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        lstm_out, self.hidden = self.lstm(x_input.view(x_input.shape[0], x_input.shape[1], self.input_size))
        return lstm_out, self.hidden     
    
    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))

In [5]:
class LSTMDecoder(nn.Module):    
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(LSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers)
        self.linear = nn.Linear(hidden_size, input_size)           

    def forward(self, x_input, encoder_hidden_states):
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        '''
        lstm_out, self.hidden = self.lstm(x_input.unsqueeze(0), encoder_hidden_states)
        output = self.linear(lstm_out.squeeze(0))     
        return output, self.hidden

In [19]:
class LSTMAE(nn.Module):

    def __init__(self, input_size, seq_len, hidden_size, batch_size, ae_type='recursive', teacher_forcing_ratio=0.5):
        super(LSTMAE, self).__init__()
        self.input_size = input_size
        self.seq_len = seq_len
        self.hidden_size = hidden_size 
        self.bs = batch_size
        self.ae_type = ae_type # ['recursive', 'teacher_forcing', 'mixed_teacher_forcing']
        self.teacher_forcing_ratio = teacher_forcing_ratio

        self.encoder = LSTMEncoder(input_size = input_size, hidden_size = hidden_size)
        self.decoder = LSTMDecoder(input_size = input_size, hidden_size = hidden_size)

        self.encoder_hidden = self.encoder.init_hidden(self.bs)

    def forward(self, x):
        # encoding 
        encoder_output, self.encoder_hidden = self.encoder(x)
        # decoding 
        decoder_input = torch.rand((self.bs, self.input_size), requires_grad=True)#self.encoder_hidden[0].squeeze()
        print(decoder_input.shape)
        decoder_hidden = self.encoder_hidden
        # outputs tensor
        outputs = torch.zeros(self.seq_len, self.bs, self.input_size)

        if self.ae_type == 'recursive':
            for t in range(self.seq_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[t] = decoder_output
                decoder_input = decoder_output
        
        elif self.ae_type == 'teacher_forcing':
            if random.random() < self.teacher_forcing_ratio:
                for t in range(self.seq_len):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    outputs[t] = decoder_output
                    decoder_input = x[t, :, :]

            else:
                for t in range(self.seq_len):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    outputs[t] = decoder_output
                    decoder_input = decoder_output

        elif self.ae_type == 'mixed_teacher_forcing':
            for t in range(self.seq_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[t] = decoder_output

                if random.random() < self.teacher_forcing_ratio:
                    decoder_input = x[t, :, :]
                else:
                    decoder_input = decoder_output

        return outputs

In [20]:
model = LSTMAE(input_size=42, seq_len=120, hidden_size=128, batch_size=32, ae_type='recursive')
sample_input = torch.rand((120, 32, 42))
sample_output = model(sample_input)

torch.Size([32, 42])


---

In [None]:
from tqdm import tqdm 


In [None]:
def train_step(model, dataloader, optimizer, loss_module, device, phase='train'):
    model = model.train()

    epoch_loss = 0 
    total_samples = 0 

    with tqdm(dataloader, unit='batch') as  tepoch:
        for batch in tepoch:
            X, y = batch 

            X = X.float().to(device)
            y = y.float().to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                seq_output = model(X)

            rec_loss = loss_module['recons'](seq_output, y)

            if phase == 'train':
                rec_loss.backward()
                optimizer.step()

            metrics = {'loss': rec_loss.item()}

            with torch.no_grad():
                total_samples += len(y)
                epoch_loss += rec_loss.item() 
            tepoch.set_postfix(metrics)

    avg_loss = epoch_loss / total_samples  # average loss per sample for whole epoch
    return metrics   
            