In [None]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# 1. Hyper-parameters and Dataset

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_dir = '/content/drive/My Drive/PyTorch/Github_Series/02-intermediate/'

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root=data_dir,
                                           train=True,
                                           download=True,
                                           transform=transforms.ToTensor())

test_dataset = torchvision.datasets.MNIST(root=data_dir,
                                           train=False,
                                           transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

# 2. Modeling and Training

**Theory conclusion**
1. Basic units of LSTM networks are LSTM layers that have multiple LSTM cells.
2. Cells do have internal cell state, often abbreviated as "c", and cells output is what is called a "hidden state", abbreviated as "h".
3. $h^{(t)}$ is a non-linear transformation dependent on $c^{(t)}$.

**Implementation** \\
The dimensions of the parameters of each layer can be refered to [the documentation of LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html).

In [None]:
# Recurrent neural network (many-to-one)
class RNN(nn.Module):

  def __init__(self, input_size, hidden_size, num_layers, num_classes):
    super().__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    # Set initial hidden and cell states 
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

    # Forward propagate LSTM
    # out, (hn, cn) = self.lstm(x, (h0, c0))
    out, _ = self.lstm(x, (h0, c0))

    # Decode the hidden state of the last time step
    out = self.fc(out[:,-1,:])  # Shape of out: (batch_size, sequence_length, hidden_size)

    return out

In [None]:
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
  for batch_id, (images, labels) in enumerate(train_loader):
    input = images.reshape(-1, sequence_length, input_size).to(device)
    labels = labels.to(device)
    
    # Feedward
    output = model(input)
    loss = loss_fn(output, labels)

    # Backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (batch_id + 1) % 100 == 0:
      print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'
            .format(epoch+1, num_epochs, batch_id+1, total_step, loss.item()))

Epoch: [1/2], Step: [100/600], Loss: 0.5942
Epoch: [1/2], Step: [200/600], Loss: 0.2171
Epoch: [1/2], Step: [300/600], Loss: 0.1453
Epoch: [1/2], Step: [400/600], Loss: 0.1736
Epoch: [1/2], Step: [500/600], Loss: 0.1378
Epoch: [1/2], Step: [600/600], Loss: 0.1477
Epoch: [2/2], Step: [100/600], Loss: 0.0875
Epoch: [2/2], Step: [200/600], Loss: 0.0422
Epoch: [2/2], Step: [300/600], Loss: 0.1576
Epoch: [2/2], Step: [400/600], Loss: 0.0377
Epoch: [2/2], Step: [500/600], Loss: 0.0994
Epoch: [2/2], Step: [600/600], Loss: 0.1190


# 3. Test the model

In [None]:
# Test the model
model.eval()
with torch.no_grad():
  total = 0
  correct = 0
  for images, labels in test_loader:
    input = images.reshape(-1, sequence_length, input_size).to(device)
    labels = labels.to(device)
    output = model(input)
    _, pred = torch.max(output, dim=1)
    total += labels.size(0)
    correct += (pred == labels).sum()
  print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

Test Accuracy of the model on the 10000 test images: 97.97999572753906 %
