In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device",device)

Using Device cuda


In [2]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size = 64, shuffle = True)
testloader = DataLoader(testset, batch_size = 64, shuffle=False)

100%|██████████| 170M/170M [00:13<00:00, 13.0MB/s]


In [3]:
resnet18 = torchvision.models.resnet18(pretrained=True)
for param in resnet18.parameters():
  param.requires_grad = False

resnet18.fc = nn.Linear(resnet18.fc.in_features,10)
resnet18.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 195MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18.parameters(), lr = 0.001)

In [5]:
!pip install wandb



In [6]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msamdanikshitij21[0m ([33msamdanikshitij21-indian-institute-of-technology-delhi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
import wandb

In [8]:
import yaml

with open("config.yaml") as f:
    config = yaml.safe_load(f)

wandb.init(project="resnet18", config=config)

[34m[1mwandb[0m: Currently logged in as: [33msamdanikshitij21[0m ([33msamdanikshitij21-indian-institute-of-technology-delhi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# wandb.init(
#     project="resnet18",
#     config = {
#         "epochs":5,
#         "batch_size":64,
#         "learning_rate":0.001,
#         "architecture":"resnet-18",
#         "freeze_base":True
#     }
# )

In [10]:
def train(model,dataloader,criterion,optimizer,epochs):
  model.train()
  for epoch in range(epochs):
    total_loss = 0
    correct_pred = 0
    total_pred = 0

    for images,labels in dataloader:
      images,labels = images.to(device),labels.to(device)

      outputs = model(images)
      optimizer.zero_grad()
      loss = criterion(outputs,labels)
      loss.backward()
      optimizer.step()

      total_loss += loss.item()*labels.size(0)
      # wandb.log({"batch_loss":loss.item()*labels.size(0),"epoch":epoch})

      _,predicted = torch.max(outputs,dim=1)
      correct_pred += (predicted==labels).sum().item()
      total_pred += labels.size(0)
      accuracy = correct_pred/total_pred
      wandb.log({"batch_loss":loss.item()*labels.size(0),"accuracy":accuracy,"epoch":epoch})

    total_epoch_loss = total_loss/total_pred
    epoch_accuracy = correct_pred/total_pred
    wandb.log({"total_epoch_loss":total_epoch_loss,"epoch":epoch})
    print(f"Epoch {epoch + 1}, Loss: {total_epoch_loss}, Accuracy: {epoch_accuracy:.4f}")

In [11]:
train(resnet18,trainloader,loss_fn,optimizer,wandb.config.epochs)

Epoch 1, Loss: 0.8304923354721069, Accuracy: 0.7316
Epoch 2, Loss: 0.6161103435134888, Accuracy: 0.7889
Epoch 3, Loss: 0.5887973208808899, Accuracy: 0.7949
Epoch 4, Loss: 0.5743599039268493, Accuracy: 0.8018
Epoch 5, Loss: 0.5654093657493592, Accuracy: 0.8038


In [12]:
def evaluate(model,dataloader,criterion):
  model.eval()
  test_loss = 0
  correct_pred = 0
  total_pred = 0

  for images,labels in dataloader:
    images,labels = images.to(device),labels.to(device)

    outputs = model(images)
    loss = criterion(outputs,labels)
    test_loss += loss.item()*labels.size(0)

    _,predicted = torch.max(outputs,dim=1)
    correct_pred += (predicted == labels).sum().item()
    total_pred += labels.size(0)

  test_loss /= total_pred
  test_acc = correct_pred/total_pred
  wandb.log({"val_loss": test_loss, "val_acc": test_acc})
  return test_loss, test_acc

In [13]:
val_loss,val_accuracy = evaluate(resnet18,testloader,loss_fn)
print('Accuracy:', val_accuracy)

Accuracy: 0.8021
