In [44]:
import torch
from torch.nn import Flatten
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader

transformations = Compose([ToTensor(),
                           Normalize([0.5], [0.5]),
                           Flatten(),
                           Lambda(lambda img: torch.transpose(img,0,1))])

mnist_data = MNIST(root="sources", download=True, train=True, transform=transformations)

In [56]:
from torch import Tensor
from torch.nn import Module, RNN, Sequential, ReLU, Linear, Tanh, Softmax, CrossEntropyLoss


# Recurrent Neural network
class RecurrentNN(Module):
    def __init__(self, hidden:int, layers:int, no_linearity:str="relu"):

        super(RecurrentNN,self).__init__()
        self.hidden_size = hidden
        self.layers = layers

        self.rnn = RNN(input_size=1,
                       hidden_size=hidden,
                       num_layers=layers,
                       nonlinearity=no_linearity,
                       batch_first=True)

        self.read_out = Sequential(ReLU(), Linear(hidden, 10), Tanh(), Softmax(dim=1))
        self.criteria = CrossEntropyLoss() # Mean square error loss

        self.last_hidden = None

    def forward(self, x:Tensor, y:Tensor=None, save_state:bool=False):

        output, _ = self.rnn(x)
        y_pred = self.read_out(output[:, -1, :]) # we take the last step

        loss = None
        if y is not None:
            loss = self.criteria(y_pred, y)
        return (loss, y_pred.argmax(1)) if loss is not None else y_pred

In [57]:
loader = DataLoader(mnist_data, batch_size=10, shuffle=True)
rnn = RecurrentNN(100,1)

for i in range(0,10):
    for img_seq, label in loader:

        loss, output = rnn(img_seq,label)

torch.Size([10, 784, 1]) torch.Size([10])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])


In [36]:
# import torch
# from torch import nn
#
# rnn = nn.RNN(1, 100, 1, batch_first=True) # x , hidden-node , layer
# read_out = nn.Linear(100,10) # hidden-node , out-size
# sf = nn.Softmax(dim=1)
#
# input = torch.randn(3, 768, 1) # batch, steps, dim_x
# output, hn = rnn(input)
#
# print(output.shape, hn.shape) # batch, step, hidden // layer, batch, hidden
# output = read_out(output[:,-1,:])
# print(output.shape, hn.shape)
# print(output)
# output = sf(output)
# print(output.shape, hn.shape)
# print(output)


torch.Size([3, 768, 100]) torch.Size([1, 3, 100])
torch.Size([3, 10]) torch.Size([1, 3, 100])
tensor([[-0.0750,  0.1466, -0.0766, -0.0539, -0.0256, -0.0681,  0.0234,  0.0989,
          0.0811,  0.0197],
        [-0.1121,  0.1839,  0.0368, -0.0843,  0.0686, -0.0572,  0.0466,  0.0676,
         -0.0469,  0.0280],
        [-0.0823,  0.1506,  0.0678, -0.0850,  0.0223, -0.0453,  0.0218,  0.0859,
          0.0097,  0.0573]], grad_fn=<AddmmBackward0>)
torch.Size([3, 10]) torch.Size([1, 3, 100])
tensor([[0.0919, 0.1146, 0.0917, 0.0938, 0.0965, 0.0925, 0.1013, 0.1093, 0.1074,
         0.1010],
        [0.0879, 0.1182, 0.1020, 0.0904, 0.1053, 0.0929, 0.1030, 0.1052, 0.0938,
         0.1011],
        [0.0900, 0.1136, 0.1046, 0.0898, 0.0999, 0.0934, 0.0999, 0.1065, 0.0987,
         0.1035]], grad_fn=<SoftmaxBackward0>)
