In [1]:
from tqdm import tqdm

import torch 
import torchvision 
from torchvision import transforms, datasets, models
import numpy as np 
import matplotlib.pyplot as plt 
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
transform = transforms.Compose([
    transforms.Resize(size=(224,224)),
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5),(.5,.5,.5))
    
])

In [3]:
train_set = datasets.CIFAR10(download=True, train=True, transform=transform,root='./data')
test_set = datasets.CIFAR10(download=True, train=False, transform=transform,root='./data')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
train_loader = DataLoader(train_set, batch_size=64, num_workers=0, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, num_workers=0, shuffle=False)

In [5]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.features = nn.Linear(self.backbone.fc.in_features, 10)
        self.backbone.fc = self.features
        
    def forward(self, x):
        x = self.backbone(x)
        return x
        
        

In [6]:
def calc_accuracy(outputs, targets):
    _, pred = torch.max(outputs, 1)
    acc = torch.sum(pred==targets.data).item() / len(targets)
    return acc 
    

In [7]:
def train(model, train_loader, test_loader, num_epochs, lr):
    device = torch.device('cuda:0')
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_acc = 0.0
        # training
        model.train()
        for imgs, targets in tqdm(train_loader):
            #imgs, targets = items
            #print(imgs.shape)
            #print(targets.shape)
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = model(imgs)
            #print(outputs)
            curr_loss = criterion(outputs, targets)
            
            epoch_loss += curr_loss.item()
            epoch_acc += calc_accuracy(outputs, targets)
            
            optimizer.zero_grad()
            curr_loss.backward()
            optimizer.step()
        epoch_loss = epoch_loss / len(train_loader)
        epoch_acc = epoch_acc / len(train_loader)
        print('Epoch {}, loss: {:.4f}, acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))
        
        # testing
        model.eval()
        epoch_loss = 0.0
        epoch_acc = 0.0
        for imgs, targets in test_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = model(imgs)
            #print(outputs)
            curr_loss = criterion(outputs, targets)
            
            epoch_loss += curr_loss.item()
            epoch_acc += calc_accuracy(outputs, targets)
            
        epoch_loss = epoch_loss / len(test_loader)
        epoch_acc = epoch_acc / len(test_loader)
        print('val_loss: {:.4f}, val_acc: {:.4f}'.format(epoch_loss, epoch_acc))
        
            
    
    

In [8]:
model = MyModel()
train(model, train_loader, test_loader, 10, 1e-4)

100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 0, loss: 1.2752, acc: 0.6737


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.8544, val_acc: 0.7604


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 1, loss: 0.7863, acc: 0.7671


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.7045, val_acc: 0.7802


100%|██████████| 782/782 [02:41<00:00,  4.83it/s]


Epoch 2, loss: 0.6863, acc: 0.7837


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.6418, val_acc: 0.7906


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 3, loss: 0.6417, acc: 0.7914


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.6159, val_acc: 0.7931


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 4, loss: 0.6161, acc: 0.7951


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.5925, val_acc: 0.8009


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 5, loss: 0.5948, acc: 0.8004


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.5801, val_acc: 0.8039


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 6, loss: 0.5782, acc: 0.8048


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.5690, val_acc: 0.8061


100%|██████████| 782/782 [02:41<00:00,  4.84it/s]


Epoch 7, loss: 0.5697, acc: 0.8063


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.5640, val_acc: 0.8044


100%|██████████| 782/782 [02:41<00:00,  4.83it/s]


Epoch 8, loss: 0.5643, acc: 0.8067


  0%|          | 0/782 [00:00<?, ?it/s]

val_loss: 0.5607, val_acc: 0.8100


100%|██████████| 782/782 [02:41<00:00,  4.83it/s]


Epoch 9, loss: 0.5535, acc: 0.8097
val_loss: 0.5521, val_acc: 0.8091
