<a href="https://colab.research.google.com/github/grimo8805/nptel/blob/main/torch_tensorboard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from itertools import product
torch.set_printoptions(linewidth=120)
from torch.utils.tensorboard import SummaryWriter


In [None]:
#loading datasets
train_set=torchvision.datasets.FashionMNIST(root='./data',train=True,download=True,transform=transforms.ToTensor())
device=('cuda' if torch.cuda.is_available() else 'cpu')

#hyparmeters
parameters=dict(
    lr=[0.01,0.001],
    batch_size=[32,64,128],
    shuffle=[True,False]
)
param_values=[v for v in parameters.values()]


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 12360400.39it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 207566.75it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3888135.76it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5652428.53it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [None]:
#training Loop
model=CNN()
for run_id,(lr,batch_size,shuffle) in enumerate(product(*param_values)):
  print('run_id',run_id+1)
  model=CNN().to(device)
  trainloader=torch.utils.data.DataLoader(train_set,batch_size=batch_size,shuffle=shuffle)
  optimizer=optim.Adam(model.parameters(),lr=lr)
  criterion=torch.nn.CrossEntropyLoss()
  comment=f'batch_size={batch_size} lr={lr} shuffle={shuffle}'
  tb=SummaryWriter(comment=comment)
  images,labels=next(iter(trainloader))
  images, labels = images.cuda(), labels.cuda()
  grid=torchvision.utils.make_grid(images)
  tb.add_image('images',grid)
  tb.add_graph(model,images)
  for epoch in range(5):
    total_loss=0
    total_correct=0
    for images,labels in trainloader:
      images,labels=images.to(device),labels.to(device)
      preds=model(images)
      loss=criterion(preds,labels)
      total_loss+=loss.item()
      total_correct+=get_num_correct(preds,labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    tb.add_scalar("Loss",total_loss,epoch)
    tb.add_scalar("Correct",total_correct,epoch)
    tb.add_scalar('Accuracy',total_correct/len(train_set),epoch)

    print("batch size:",batch_size,"lr:",lr,"shuffle:",shuffle)
    print("epoch",epoch,"total_correct",total_correct,"loss: ",total_loss)
    print('_______________________________________________')
    tb.add_hparams(
        {"lr": lr,"batch size ": batch_size,"shuffle":shuffle},
        {
            "accuracy":total_correct/len(train_set),
            "loss": total_loss
        }
    )
tb.close()



run_id 1
batch size: 32 lr: 0.01 shuffle: True
epoch 0 total_correct 47415 loss:  1050.8824383318424
_______________________________________________
batch size: 32 lr: 0.01 shuffle: True
epoch 1 total_correct 50359 loss:  807.7478193119168
_______________________________________________
batch size: 32 lr: 0.01 shuffle: True
epoch 2 total_correct 50756 loss:  773.777516014874
_______________________________________________
batch size: 32 lr: 0.01 shuffle: True
epoch 3 total_correct 51168 loss:  756.0815527662635
_______________________________________________
batch size: 32 lr: 0.01 shuffle: True
epoch 4 total_correct 51342 loss:  740.6764306277037
_______________________________________________
run_id 2
batch size: 32 lr: 0.01 shuffle: False
epoch 0 total_correct 45927 loss:  1151.6131959706545
_______________________________________________
batch size: 32 lr: 0.01 shuffle: False
epoch 1 total_correct 49519 loss:  894.3510476350784
_______________________________________________
batch 

In [None]:
#helper function
def get_num_correct(preds,labels):
  return preds.argmax(dim=1).eq(labels).sum().item()
#CNN model
class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
    self.conv2=nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
    self.fc1=nn.Linear(12*4*4,120)
    self.fc2=nn.Linear(120,60)
    self.out=nn.Linear(60,10)
  def forward(self,x):
    x=F.relu(self.conv1(x))
    x=F.max_pool2d(x,kernel_size=2,stride=2)
    x=F.relu(self.conv2(x))
    x=F.max_pool2d(x,kernel_size=2,stride=2)
    x=torch.flatten(x,start_dim=1)
    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=self.out(x)
    return x
