<a href="https://colab.research.google.com/github/jyjoon001/EEE4178/blob/main/RNN_prototype.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 10
num_classes = 10
batch_size = 50
num_epochs = 3
learning_rate = 0.001

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
train_data = torchvision.datasets.MNIST(root='./datasets',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)
test_data = torchvision.datasets.MNIST(root='./datasets',
                                        train=False,
                                        transform=transforms.ToTensor(),
                                        download=True)

torch.Size([50, 1, 28, 28])


In [None]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=batch_size,
                                          shuffle=False)

In [None]:
# cf) Check dataloader shape
image, label = next(iter(test_loader))
print(image.size()) # [Batch, Channel, Height, Width]

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class RNN(nn.Module):
  def __init__(self, intput_size, hidden_size, num_layers, num_classes):
    super(RNN, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, num_classes)

  def forward(self, x):
    # set initial hidden states and cell states
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # torch.size([2, 50, 128])
    c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # torch.size([2, 50, 128])

    #Forward propagate LSTM
    out, _  = self.lstm(x, (h0, c0)) # output: tensor [batch_size, seq_length, hidden_size]

    #Decode the hidden state of the last time step
    out = self.fc(out[:,-1,:])

    return out

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
####### Train #######
total_step = len(train_loader)
for epoch in range(num_epochs):
  for i, (image, label) in enumerate(train_loader):
    image = image.reshape(-1, sequence_length, input_size).to(device)
    label = label.to(device)

    # Forward
    output = model(image)
    loss = criterion(output, label)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
      print("Epoch [{}/{}], Step[{}/{}], Loss:{:.4f}".format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [1/3], Step[100/1200], Loss:2.1105
Epoch [1/3], Step[200/1200], Loss:1.7216
Epoch [1/3], Step[300/1200], Loss:1.2655
Epoch [1/3], Step[400/1200], Loss:1.5544
Epoch [1/3], Step[500/1200], Loss:1.4535
Epoch [1/3], Step[600/1200], Loss:1.5671
Epoch [1/3], Step[700/1200], Loss:1.5063
Epoch [1/3], Step[800/1200], Loss:1.5551
Epoch [1/3], Step[900/1200], Loss:1.5996
Epoch [1/3], Step[1000/1200], Loss:1.2172
Epoch [1/3], Step[1100/1200], Loss:1.2952
Epoch [1/3], Step[1200/1200], Loss:2.3168
Epoch [2/3], Step[100/1200], Loss:2.2919
Epoch [2/3], Step[200/1200], Loss:2.3120
Epoch [2/3], Step[300/1200], Loss:2.2868
Epoch [2/3], Step[400/1200], Loss:2.2999
Epoch [2/3], Step[500/1200], Loss:2.3644
Epoch [2/3], Step[600/1200], Loss:2.2960
Epoch [2/3], Step[700/1200], Loss:2.2842
Epoch [2/3], Step[800/1200], Loss:2.2931
Epoch [2/3], Step[900/1200], Loss:2.2944
Epoch [2/3], Step[1000/1200], Loss:2.3085
Epoch [2/3], Step[1100/1200], Loss:2.2841
Epoch [2/3], Step[1200/1200], Loss:2.3150
Epoch [3/3