In [37]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torchvision.utils import make_grid
from my_utils import device
from my_torch_train import train_epoch, eval_model, train_model, count_model_params

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# setting random seed
set_random_seed(42)

## Load Data

In [38]:
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
 
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor())

In [39]:
# Fitting data loaders for iterating
B_SIZE = 256

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=B_SIZE, 
                                           shuffle=True) 
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=B_SIZE,
                                          shuffle=False)


In [40]:
train_dataset[0][0].shape

torch.Size([1, 28, 28])

## Define Models

The general formulas behind one LSTM-cell are as follows:

$$
\begin{align*}
        i &= \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
        f &= \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
        g &= \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
        o &= \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
        c' &= f * c + i * g \\
        h' &= o * \tanh(c') \\
\end{align*}
$$

In [41]:
class myLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.L_ii = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hi = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_if = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hf = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_ig = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_hg = nn.Linear(hidden_dim, hidden_dim).to(device)
        self.L_io = nn.Linear(input_dim, hidden_dim).to(device)
        self.L_ho = nn.Linear(hidden_dim, hidden_dim).to(device)
    
    def forward(self, x, c, h):
        i = torch.sigmoid(self.L_ii(x) + self.L_hi(h))
        f = torch.sigmoid(self.L_if(x) + self.L_hf(h))
        g = torch.tanh(self.L_ig(x) + self.L_hg(h))
        o = torch.sigmoid(self.L_io(x) + self.L_ho(h))
        c_out = f * c + i*g
        h_out = o * torch.tanh(c_out)
        return c_out, h_out

In [42]:
class myLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm_cell = myLSTMCell(input_dim, hidden_dim)
        self.hidden_dim = hidden_dim
    
    def forward(self, x, hc):
        h_t = hc[0][0] # < those 2 lines have to be changed when adding num_layers
        c_t = hc[1][0] # <
        b_size, n_rows, n_cols = x.shape
        #h_list = torch.empty(b_size, n_rows, self.hidden_dim).to(device)
        
        for i in range(n_rows):
            c_t, h_t = self.lstm_cell(x[:,i,:], c_t, h_t)
            #h_list[:, i, :] = h_t
        
        return h_t
        #return h_list, (h_t, c_t)

In [43]:
class SequentialClassifier(nn.Module):
    """
    Sequential classifier for images. Embedded image rows are fed to a RNN

    Args:
    -----
    input_dim: integer
        dimensionality of the rows to embed
    emb_dim: integer
        dimensionality of the vectors fed to the LSTM
    hidden_dim: integer
        dimensionality of the states in the cell
    use_own: boolean
        When true uses our LSTM-implementation rather than the one from pytorch
    mode: string
        intialization of the states
    rnn: nn.Module
        the recurrent neural network to use
    """

    def __init__(self, input_dim, emb_dim, hidden_dim, use_own=True, init_mode='zeros'):
        """ Module initializer """
        assert init_mode in ["zeros", "random", "learned"]
        super().__init__()
        self.hidden_dim =  hidden_dim
        self.num_layers = 1 #num_layers
        self.use_own = use_own

        self.init_mode = init_mode
        
        # for embedding rows into vector representations
        self.encoder = nn.Linear(in_features=input_dim, out_features=emb_dim).to(device)
        
        # lstm model
        if self.use_own:
            self.lstm = myLSTM(input_dim, hidden_dim)
        else:
            self.lstm = nn.LSTM(
                input_size=input_dim, hidden_size=hidden_dim, batch_first=True
            ).to(device)
        
        # FC-classifier
        self.classifier = nn.Linear(in_features=hidden_dim, out_features=10).to(device)

        return


    def forward(self, x):
        """ Forward pass through model """

        b_size, n_channels, n_rows, n_cols = x.shape
        h, c = self.init_state(b_size=b_size, device=x.device)

        # embedding rows
        x_rowed = x.view(b_size, n_channels*n_rows, n_cols)
        embeddings = self.encoder(x_rowed)

        # classifying
        if self.use_own:
            lstm_out = self.lstm(x_rowed, (h,c))
            y = self.classifier(lstm_out)
        else:
            lstm_out, (h_out, c_out) = self.lstm(x_rowed, (h,c))
            y = self.classifier(lstm_out[:, -1, :])  # feeding only output at last layer

        return y


    def init_state(self, b_size, device):
        """ Initializing hidden and cell state """
        if(self.init_mode == "zeros"):
            h = torch.zeros(self.num_layers, b_size, self.hidden_dim)
            c = torch.zeros(self.num_layers, b_size, self.hidden_dim)
        elif(self.init_mode == "random"):
            h = torch.randn(self.num_layers, b_size, self.hidden_dim)
            c = torch.randn(self.num_layers, b_size, self.hidden_dim)
        elif(self.init_mode == "learned"):
            h = self.learned_h.repeat(1, b_size, 1)
            c = self.learned_c.repeat(1, b_size, 1)
        h = h.to(device)
        c = c.to(device)
        return h, c
    


# Training and Evaluation

Defining the models

In [49]:
torch_lstm = SequentialClassifier(input_dim=28, emb_dim=64, hidden_dim=128, use_own=False, init_mode="zeros").to(device)
own_lstm = SequentialClassifier(input_dim=28, emb_dim=64, hidden_dim=128, use_own=True, init_mode="zeros").to(device)

In [45]:
count_model_params(torch_lstm)

84042

In [46]:
count_model_params(own_lstm)

84042

Our model and the model using nn.LSTM have the same number of parameters.

## Using nn.LSTM

In [47]:
# classification loss function
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(torch_lstm.parameters(), lr=3e-4)

# Decay LR by a factor of 0.1 every 7 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)

train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=torch_lstm, optimizer=optimizer, scheduler=scheduler, criterion=criterion,
        train_loader=train_loader, valid_loader=test_loader, num_epochs=10,
        save_path='model/nnLSTM.pt'
    )

Epoch 1 Iter 21: loss 2.26094. :   9%|▉         | 21/235 [00:01<00:12, 17.15it/s]


KeyboardInterrupt: 

## Using myLSTM

In [48]:
# classification loss function
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(own_lstm.parameters(), lr=3e-4)

# Decay LR by a factor of 0.1 every 7 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)

train_loss, val_loss, loss_iters, valid_acc = train_model(
        model=own_lstm, optimizer=optimizer, scheduler=scheduler, criterion=criterion,
        train_loader=train_loader, valid_loader=test_loader, num_epochs=10,
        save_path='model/ownLSTM.pt'
    )

Epoch 1 Iter 235: loss 0.65930. : 100%|██████████| 235/235 [00:15<00:00, 15.49it/s]


Epoch 1/10
    Train loss: 1.28809
    Valid loss: 2.3075
    Accuracy: 6.660000000000001%




Epoch 2 Iter 235: loss 0.46997. : 100%|██████████| 235/235 [00:15<00:00, 15.63it/s]


Epoch 2/10
    Train loss: 0.62402
    Valid loss: 0.72967
    Accuracy: 73.45%




Epoch 3 Iter 235: loss 0.57167. : 100%|██████████| 235/235 [00:14<00:00, 16.04it/s]


Epoch 3/10
    Train loss: 0.5232
    Valid loss: 0.57275
    Accuracy: 78.96%




Epoch 4 Iter 100: loss 0.46484. :  43%|████▎     | 100/235 [00:06<00:08, 16.05it/s]


KeyboardInterrupt: 