In [6]:
import torchvision.transforms as T
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy as np

In [14]:
transform = T.Compose([ T.Resize((224, 224)),
                        T.ToTensor(),
                        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                       ])

trainset = torchvision.datasets.STL10(root='./data', download=True, split='train', transform=transform)
testset  = torchvision.datasets.STL10(root='./data', download=True, split='test',  transform=transform)

batchsize    = 32
train_loader = DataLoader(trainset,batch_size=batchsize,shuffle=True,drop_last=True)
test_loader  = DataLoader(testset, batch_size=256)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
weights = torchvision.models.ResNet18_Weights.DEFAULT
resnet = torchvision.models.resnet18(weights=weights)

for p in resnet.parameters():
    p.requires_grad = False

resnet.fc = nn.Linear(512,10)

lossfun = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(),lr=0.001)

In [17]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
resnet.to(device)


numepochs = 5

testAcc   = torch.zeros(numepochs)

for epochi in range(numepochs):


  resnet.train()
  batchAcc  = []
  for X,y in train_loader:
    X = X.to(device)
    y = y.to(device)

    yHat = resnet(X)
    loss = lossfun(yHat,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    batchAcc.append( torch.mean((torch.argmax(yHat,axis=1) == y).float()).item() )

  resnet.eval()
  batchAcc  = []
  for X,y in test_loader:

    X = X.to(device)
    y = y.to(device)

    with torch.no_grad():
      yHat = resnet(X)
      loss = lossfun(yHat,y)

    batchAcc.append( torch.mean((torch.argmax(yHat,axis=1) == y).float()).item() )

  testAcc[epochi]  = 100*np.mean(batchAcc)

  print(f'Finished epoch {epochi+1}/{numepochs}. Test accuracy = {testAcc[epochi]:.2f}%')

Finished epoch 1/5. Test accuracy = 94.23%
Finished epoch 2/5. Test accuracy = 94.46%
Finished epoch 3/5. Test accuracy = 94.35%
Finished epoch 4/5. Test accuracy = 94.29%
Finished epoch 5/5. Test accuracy = 94.32%
