In [27]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torchvision.utils import make_grid
from my_utils import device, set_random_seed
from my_torch_train import train_epoch, eval_model, train_model, count_model_params

import matplotlib.pyplot as plt
%matplotlib inline

In [28]:
# setting random seed
set_random_seed(42)

## Load Data

In [29]:
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
 
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor())

In [30]:
# Fitting data loaders for iterating
B_SIZE = 256

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=B_SIZE, 
                                           shuffle=True) 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=B_SIZE,
                                          shuffle=False)


In [31]:
train_dataset[0][0].shape

torch.Size([1, 28, 28])

## Define Models

The general formulas behind one LSTM-cell are as follows:

$$
\begin{align*}
        i &= \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
        f &= \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
        g &= \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
        o &= \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
        c' &= f * c + i * g \\
        h' &= o * \tanh(c') \\
\end{align*}
$$

In [32]:
class myLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.L_ii = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hi = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_if = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hf = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_ig = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hg = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_io = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_ho = nn.Linear(hidden_dim, hidden_dim).to(device)
    
    def forward(self, x, c, h):
        i = torch.sigmoid(self.L_ii(x) + self.L_hi(h))
        f = torch.sigmoid(self.L_if(x) + self.L_hf(h))
        g = torch.tanh(self.L_ig(x) + self.L_hg(h))
        o = torch.sigmoid(self.L_io(x) + self.L_ho(h))
        c_out = f * c + i*g
        h_out = o * torch.tanh(c_out)
        return c_out, h_out

In [33]:
class myLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.input_dim = input_dim
        
        self.lstm_cells=nn.ModuleList()
        self.lstm_cells.append(myLSTMCell(input_dim, hidden_dim))
        for i in range(num_layers-1):
            self.lstm_cells.append(myLSTMCell(hidden_dim, hidden_dim))
    
    def forward(self, x, hc):
        h = hc[0]
        c = hc[1]
        b_size, n_rows, n_cols = x.shape
        
        for i in range(self.num_layers):
            new_x = torch.empty((b_size, n_rows, self.hidden_dim))
            h_t = h[i]
            c_t = c[i]

            for i in range(n_rows):
                c_t, h_t = self.lstm_cells[i](x[:,i,:], c_t, h_t)
                if n_rows != num_layers-1:
                    new_x[:, i, :] = h_t
            x = new_x
        
        return h_t
        #return h_list, (h_t, c_t)

In [34]:
class SequentialClassifier(nn.Module):
    """
    Sequential classifier for images. Embedded image rows are fed to a RNN

    Args:
    -----
    input_dim: integer
        dimensionality of the rows to embed
    emb_dim: integer
        dimensionality of the vectors fed to the LSTM
    hidden_dim: integer
        dimensionality of the states in the cell
    use_own: boolean
        When true uses our LSTM-implementation rather than the one from pytorch
    init_mode: string
        intialization of the states
    num_layers: integer
        the number of LSTM layers the classifier should use
    """

    def __init__(self, input_dim, emb_dim, hidden_dim, use_own=True, init_mode='zeros',
                 num_layers=1):
        """ Module initializer """
        assert init_mode in ["zeros", "random", "learned"]
        super().__init__()
        self.hidden_dim =  hidden_dim
        self.num_layers = num_layers
        self.use_own = use_own

        self.init_mode = init_mode
        
        # for embedding rows into vector representations
        self.encoder = nn.Linear(in_features=input_dim, out_features=emb_dim).to(device)
        
        # lstm model
        if self.use_own:
            self.lstm = myLSTM(input_dim, hidden_dim, num_layers=num_layers)
        else:
            self.lstm = nn.LSTM(
                input_size=input_dim, hidden_size=hidden_dim, batch_first=True,
                num_layers=num_layers
            )
        
        # FC-classifier
        self.classifier = nn.Linear(in_features=hidden_dim, out_features=10)

        return


    def forward(self, x):
        """ Forward pass through model """

        b_size, n_channels, n_rows, n_cols = x.shape
        h, c = self.init_state(b_size=b_size, device=x.device)

        # embedding rows
        x_rowed = x.view(b_size, n_channels*n_rows, n_cols)
        embeddings = self.encoder(x_rowed)

        # classifying
        if self.use_own:
            lstm_out = self.lstm(x_rowed, (h,c))
            y = self.classifier(lstm_out)
        else:
            lstm_out, (h_out, c_out) = self.lstm(x_rowed, (h,c))
            y = self.classifier(lstm_out[:, -1, :])  # feeding only output at last layer

        return y


    def init_state(self, b_size, device):
        """ Initializing hidden and cell state """
        if(self.init_mode == "zeros"):
            h = torch.zeros(self.num_layers, b_size, self.hidden_dim)
            c = torch.zeros(self.num_layers, b_size, self.hidden_dim)
        elif(self.init_mode == "random"):
            h = torch.randn(self.num_layers, b_size, self.hidden_dim)
            c = torch.randn(self.num_layers, b_size, self.hidden_dim)
        elif(self.init_mode == "learned"):
            h = self.learned_h.repeat(1, b_size, 1)
            c = self.learned_c.repeat(1, b_size, 1)
        h = h.to(device)
        c = c.to(device)
        return h, c
    


# Training and Evaluation

Defining the models

In [35]:
torch_lstm = SequentialClassifier(input_dim=28, emb_dim=64, hidden_dim=128, use_own=False, init_mode="zeros", num_layers=2).to(device)
own_lstm = SequentialClassifier(input_dim=28, emb_dim=64, hidden_dim=128, use_own=True, init_mode="zeros", num_layers=2).to(device)

In [36]:
count_model_params(torch_lstm)

216138

In [37]:
count_model_params(own_lstm)

216138

Our model and the model using nn.LSTM have the same number of parameters.

In [None]:
def train_lstm(model, save_path=None, num_epochs=10):
    # classification loss function
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer = torch.optim.Adam(torch_lstm.parameters(), lr=3e-4)

    # Decay LR by a factor of 0.1 every 7 epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)

    return train_model(
            model=torch_lstm, optimizer=optimizer, scheduler=scheduler, criterion=criterion,
            train_loader=train_loader, valid_loader=test_loader, num_epochs=num_epochs,
            save_path=save_path
        )

## Using nn.LSTM

In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_lstm(
        model=torch_lstm, save_path='models/nnLSTM_1layer.pt'
    )

## Using myLSTM

In [None]:
train_loss, val_loss, loss_iters, valid_acc = train_lstm(
        model=own_lstm, save_path='models/ownLSTM_1layer.pt'
)