In [None]:
# Author: Laura Kulowski

import numpy as np
import random
import os, errno
import sys
from tqdm import trange

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [None]:

#encoder class
class lstm_encoder(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers = 1):

        super(lstm_encoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

    def forward(self, x_input):

        lstm_out, self.hidden = self.lstm(x_input.view(x_input.shape[0], x_input.shape[1], self.input_size))

        return lstm_out, self.hidden
    
    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))


# decoder class
class lstm_decoder(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(lstm_decoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, input_size)

    def forward(self, x_input, hidden_state):

        lstm_out, self.hidden = self.lstm(x_input, hidden_state)
        output = self.fc(lstm_out.squeeze(0))

        return output, self.hidden
    

#full class
class lstm_seq2seq(nn.Module):

    def __init__(self, input_size, hidden_size):
        super(lstm_seq2seq, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.encoder = lstm_encoder(input_size, hidden_size)
        self.decoder = lstm_decoder(input_size, hidden_size)

