# Import dependencies

In [3]:
import math
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


# Hyperparameters and constants

In [21]:
data_path = "./data/"
train_percent = 0.8
sequence_length = 28
input_size = 28
num_classes = 10
num_layers = 2
hidden_size = 128
batch_size = 100
num_epochs = 10
learning_rate = 0.01

# Device configuration

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# FashionMNIST dataset

In [5]:
train_valid_dataset = torchvision.datasets.FashionMNIST(root=data_path,
                                                        train=True,
                                                        transform=torchvision.transforms.ToTensor(),
                                                        download=True)

total_size = len(train_valid_dataset)
train_size = math.ceil(total_size * train_percent)
valid_size = total_size - train_size

train_dataset, valid_dataset = random_split(train_valid_dataset, 
                                            lengths=[train_size, valid_size])

test_dataset = torchvision.datasets.FashionMNIST(root=data_path,
                                                 train=False,
                                                 transform=torchvision.transforms.ToTensor(),
                                                 download=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:32<00:00, 802951.21it/s] 


Extracting ./data/FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 133156.88it/s]


Extracting ./data/FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:05<00:00, 819231.92it/s] 


Extracting ./data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 1555975.86it/s]


Extracting ./data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST\raw



# Data Loader

In [7]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

valid_dataloader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False)

# LSTM network

In [53]:
class LSTMNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, sequence_length, num_classes=10):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.sequence_length = sequence_length
        self.LSTM = torch.nn.LSTM(input_size, self.hidden_size, 
                                  self.num_layers, batch_first=True)
        self.Linear = torch.nn.Linear(self.hidden_size, self.num_classes)

    def forward(self, x):
        h_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        
        x, __ = self.LSTM(x)
        x = x[:, -1, :]

        x = self.Linear(x)

        return x

In [54]:
model = LSTMNet(input_size, hidden_size, num_layers, sequence_length, num_classes).to(device)

# Loss and optimizer

In [55]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Load tensorboard

In [56]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [57]:
logdir = "./logs/" + datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(logdir)

# Train the model

In [58]:
for epoch in range(num_epochs):
    with tqdm(train_dataloader, unit="batch", leave=True, position=0) as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}")
        train_loss = 0
        for images, labels in train_dataloader:
            tepoch.update(1)
            images = images.reshape(-1, sequence_length, input_size).to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()/len(train_dataloader)
            tepoch.set_postfix(train_loss=loss.item())
            
        valid_loss = 0
        correct = 0
        total = 0
        valid_accuracy = 0
        with torch.no_grad():
            for images, labels in valid_dataloader:
                images = images.reshape(-1, sequence_length, input_size).to(device)
                labels = labels.to(device)

                outputs = model(images)
                predictions = torch.argmax(outputs, dim=1)
                loss = criterion(outputs, labels)

                valid_loss += loss.item()/len(valid_dataloader)
                correct += (predictions == labels).sum().item()
                total += len(labels)
            
            
            valid_accuracy = correct / total
            writer.add_scalars(main_tag="loss",
                               tag_scalar_dict={"train": loss.item()},
                               global_step=epoch+1)
            writer.add_scalars(main_tag="loss",
                               tag_scalar_dict={"valid": valid_loss},
                               global_step=epoch+1)
            writer.add_scalars(main_tag="accuracy",
                               tag_scalar_dict={"valid": valid_accuracy},
                               global_step=epoch+1)
                
            tepoch.set_postfix(train_loss=train_loss, valid_loss=valid_loss,
                               valid_accuracy=valid_accuracy)

Epoch 1: 100%|██████████| 480/480 [02:00<00:00,  3.99batch/s, train_loss=0.735, valid_accuracy=0.805, valid_loss=0.541]
Epoch 2: 100%|██████████| 480/480 [01:51<00:00,  4.30batch/s, train_loss=0.439, valid_accuracy=0.847, valid_loss=0.427]
Epoch 3:  34%|███▎      | 161/480 [00:33<01:10,  4.54batch/s, train_loss=0.289]

In [None]:
%tensorboard --logdir={logdir}

# Test the model

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        outputs = model(images)
        predicted = torch.argmax(outputs, dim=1)

        correct = (predicted == labels).sum().item()
        total = len(labels)
    
    print(f"Test accuracy: {correct/total}")