<a href="https://colab.research.google.com/github/martinpius/PYTORCH/blob/main/GRU_RECARP_WITH_MNIST_IN_PYTORCH.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f"You are on Gooogle CoLaB with Pytorch version: {torch.__version__}")
except Exception as e:
  print(f"{type(e)}: {e}\n>>>please correct {type(e)} and reload...")
  COLAB = False
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
def time_fmt(t: float = 123.784)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"{h}: {m:>02}: {s:>05.2f}"
print(f">>>time formating: please wait...\n>>>time elapse: {time_fmt()}")

Mounted at /content/drive
You are on Gooogle CoLaB with Pytorch version: 1.8.1+cu101
>>>time formating: please wait...
>>>time elapse: 0: 02: 03.00


In [2]:
#In this notebook we are going to train a biderectional rnn with GRU architecture to 
#predict mnist images(here for demo we treat the width and height of mnist images as
#deatures dimension and sequence length respectivelly): channels dimension will be shredded

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from tqdm import tqdm
import time, sys


In [4]:
#Hyperparameters for the model
batch_size = 64
EPOCHS = 15
input_size = 28
hidden_dim = 512
sequence_length = 28
learning_rate = 1e-3
num_layers = 2
num_classes = 10


In [5]:
#Model class (Biderectional rnn with gru architecture wit 2 layers) plus an output layer

In [6]:
class GRU(nn.Module):
  def __init__(self, input_size, hidden_dim, num_layers,num_classes):
    super(GRU, self).__init__()
    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    self.gru = nn.GRU(input_size,
                      hidden_dim,
                      num_layers,
                      batch_first = True, 
                      dropout = 0.25, 
                      bidirectional = True)
    self.fc = nn.Linear(hidden_dim*2, num_classes)
  
  def forward(self, x):
    #initialize the hidden state to zeros
    h0 = torch.rand(self.num_layers*2, x.size(0), self.hidden_dim).to(device = device)
    out,_ = self.gru(x,h0)
    out = out[:,-1,:]
    return self.fc(out)
  

In [7]:
#Instantiate the model class and print the output-shape for a random generated data:(batch = 64, num_classes = 10)
model = GRU(input_size, hidden_dim,num_layers,num_classes).to(device = device)
rnd_data = torch.rand(64,1,28,28)[:,-1,:].to(device = device)
print(f"the output_shape: {model(rnd_data).shape}")

the output_shape: torch.Size([64, 10])


In [8]:
#Load the data and split into batches of 64

In [9]:
train_data = datasets.MNIST(root = '/train_data', train = True, transform = transforms.ToTensor(), download = True)
test_data = datasets.MNIST(root = 'test_data', train = False, transform = transforms.ToTensor(), download = True)
train_loader = DataLoader(dataset = train_data, shuffle = True, batch_size = batch_size)
test_loader = DataLoader(dataset = test_data, shuffle = True, batch_size = batch_size)
x_batch, y_batch = next(iter(train_loader))
print(f"x_batch_shape: {x_batch.shape}\ty_batch_shape: {y_batch.shape}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /train_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting /train_data/MNIST/raw/train-images-idx3-ubyte.gz to /train_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /train_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting /train_data/MNIST/raw/train-labels-idx1-ubyte.gz to /train_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /train_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting /train_data/MNIST/raw/t10k-images-idx3-ubyte.gz to /train_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /train_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting /train_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /train_data/MNIST/raw

Processing...
Done!
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting test_data/MNIST/raw/train-images-idx3-ubyte.gz to test_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting test_data/MNIST/raw/train-labels-idx1-ubyte.gz to test_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting test_data/MNIST/raw/t10k-images-idx3-ubyte.gz to test_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting test_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to test_data/MNIST/raw

Processing...
Done!
x_batch_shape: torch.Size([64, 1, 28, 28])	y_batch_shape: torch.Size([64])


In [10]:
#Optimizer and loss objects:
optimizer = optim.Adam(params = model.parameters(), lr = learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [11]:
#Training loop from scratch:
tic = time.time()
for epoch in range(EPOCHS):
  print(f"\n>>>training starts at epoch: {epoch + 1}\n>>>please wait while model is training...\ntraining on progress: KEEP YOUR SCREEN ACTIVE...")
  for idx, (data, target) in enumerate(tqdm(train_loader)):
    #Assign data to GPU
    data = data[:, -1, :].to(device = device)
    target = target.to(device = device)
    #Forward Pass
    preds = model(data)
    train_loss = loss_fn(preds, target)
    #backward pass:
    optimizer.zero_grad()
    train_loss.backward()
    #Gradient descent using adam optimizer
    optimizer.step()

#Model performance evaluation(both training and validation data)
def model_validation(loader, model):
  if loader.dataset.train:
    print(f"\n>>>Checking performance over the train dataset\n>>>Please wait while checking...\n>>>MAINTAIN YOUR SCREEN ACTIVITY!!!")
  else:
    print(f"\n>>>Checking performance over the validation data\n>>>Please wait while checking...\n>>>MAINTAIN SREEN ACTIVITY")
  
  num_examplea = 0
  num_correct = 0
  model.eval()
  #No need to recomput the gradients:
  with torch.no_grad():
    for x, y in loader:
      x = x[:,-1,:].to(device = device)
      y = y.to(device = device)
      preds = model(x)
      _,predictions = preds.max(1) #Grab maximum proba among 10 classes
      num_correct+=(predictions == y).sum()#Total of all corect classification
      num_examplea+= predictions.size(0) #Total examples in a batch (64 for our case)
  model.train()
  return float(num_correct/num_examplea)
toc = time.time()
print(f"\n>>>Accuracy for the training data: {model_validation(train_loader, model):.2f}")
print(f"\n>>>Accuracy over the test data: {model_validation(test_loader, model):.2f}")
print(f"\n>>>time elapse for training and validation: {time_fmt(toc - tic)}")

  0%|          | 0/938 [00:00<?, ?it/s]


>>>training starts at epoch: 1
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:25<00:00, 36.98it/s]
  1%|          | 5/938 [00:00<00:23, 40.37it/s]


>>>training starts at epoch: 2
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:25<00:00, 36.27it/s]
  1%|          | 5/938 [00:00<00:23, 39.63it/s]


>>>training starts at epoch: 3
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 35.51it/s]
  0%|          | 4/938 [00:00<00:23, 39.90it/s]


>>>training starts at epoch: 4
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 34.75it/s]
  0%|          | 4/938 [00:00<00:23, 39.44it/s]


>>>training starts at epoch: 5
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.44it/s]
  0%|          | 4/938 [00:00<00:23, 38.94it/s]


>>>training starts at epoch: 6
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.67it/s]
  0%|          | 4/938 [00:00<00:23, 39.01it/s]


>>>training starts at epoch: 7
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.64it/s]
  0%|          | 4/938 [00:00<00:23, 39.70it/s]


>>>training starts at epoch: 8
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.67it/s]
  0%|          | 4/938 [00:00<00:23, 39.18it/s]


>>>training starts at epoch: 9
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.65it/s]
  0%|          | 4/938 [00:00<00:23, 39.56it/s]


>>>training starts at epoch: 10
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.66it/s]
  0%|          | 4/938 [00:00<00:23, 39.37it/s]


>>>training starts at epoch: 11
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 34.76it/s]
  1%|          | 5/938 [00:00<00:24, 38.79it/s]


>>>training starts at epoch: 12
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 34.75it/s]
  0%|          | 4/938 [00:00<00:23, 38.99it/s]


>>>training starts at epoch: 13
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 34.75it/s]
  0%|          | 4/938 [00:00<00:23, 39.36it/s]


>>>training starts at epoch: 14
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:27<00:00, 34.74it/s]
  1%|          | 5/938 [00:00<00:23, 39.40it/s]


>>>training starts at epoch: 15
>>>please wait while model is training...
training on progress: KEEP YOUR SCREEN ACTIVE...


100%|██████████| 938/938 [00:26<00:00, 34.76it/s]



>>>Checking performance over the train dataset
>>>Please wait while checking...
>>>MAINTAIN YOUR SCREEN ACTIVITY!!!

>>>Accuracy for the training data: 1.00

>>>Checking performance over the validation data
>>>Please wait while checking...
>>>MAINTAIN SREEN ACTIVITY

>>>Accuracy over the test data: 0.99

>>>time elapse for training and validation: 0: 06: 42.00
