In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, random_split
from torch import optim

  from .autonotebook import tqdm as notebook_tqdm


## Download the dataset and load the dataset

In [3]:
dataset = FashionMNIST(root='data/', download=True,transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
train =  DataLoader(train, batch_size=64, shuffle=True,drop_last=True)
val = DataLoader(val, batch_size=64, shuffle=True,drop_last=True)

In [4]:
class SimpleRNN(nn.Module):
    def __init__(self,input_size,hidden_size,output_size,num_layers = 1) -> None:
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size,hidden_size,num_layers,batch_first = False)
        self.fc = nn.Linear(hidden_size,output_size)
    def forward(self,x):
        h, _ = self.rnn(x) # x = (seq_len,batch_size,input_size)
        x = self.fc(h)
        return x.softmax(dim = 1) # softmax for classification


## Define the RRN 
* Input_model = 28*28
* Sequence lenght = 1 ( just one channel )
* Hidden_size = 64
* output_size = 10


In [None]:
model = SimpleRNN(28*28,64,10)

## Checking that the RNN works by running the training loop


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=0.001)


loss_list = np.array([])
model.train()
for epoch in range(1_000):
    loss_list = np.array([])
    print(f"Epoch {epoch} de 10")
    for i, batch_idex in enumerate(train):
        optimizer.zero_grad()
        x, y = batch_idex
        x = x.permute(1,0,2,3)
        x = x.reshape(1,64,28*28)
        y = F.one_hot(y, num_classes=10).view(10,64)
        y = y.float()
        y = y.reshape(10,1,64)
        y = y.permute(1,2,0)
        pred = model(x)
        loss = criterion(pred,y)
        loss_list = np.append(loss_list,loss.item())
        loss.backward()
        optimizer.step()
    print(f"Loss: {loss_list.mean()}")

## Generating the PyTorch Lightning module

In [10]:
class LitSimpleRNN(pl.LightningModule):
    def __init__(self,input_size,hidden_size,output_size,num_layers = 1):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.RNN = SimpleRNN(input_size=input_size,hidden_size=hidden_size,output_size=output_size,num_layers=num_layers)
    
    def hot_encode(self,y, num_classes=10):
        """
        One hot encode an int
        """
        y = F.one_hot(y, num_classes=10).view(10,64)
        return y.float()

    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.permute(1,0,2,3)
        x = x.reshape(1,64,28*28)
        y = self.hot_encode(y,10)
        # y = F.one_hot(y, num_classes=10).view(10,64)
        # y = y.float()
        y = y.reshape(10,1,64)
        y = y.permute(1,2,0)
        pred = self.RNN(x)
        loss = nn.CrossEntropyLoss()(pred,y)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss
    def validation_step(self, batch, batch_idx):
        pass
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [11]:
RNN = LitSimpleRNN(28*28,64,10)
trainer = pl.Trainer(limit_predict_batches=100,max_epochs=1)
trainer.fit(model = RNN,train_dataloaders=train)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /Users/jaime/repos/nlp_models/notebooks/lightning_logs

  | Name | Type      | Params
-----------------------------------
0 | RNN  | SimpleRNN | 55.1 K
-----------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0: 100%|██████████| 859/859 [00:12<00:00, 69.23it/s, loss=26.6, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 859/859 [00:12<00:00, 69.19it/s, loss=26.6, v_num=0]
