In [1]:
import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from PIL import Image
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torchvision import transforms, models, datasets
from tqdm import tqdm
import pdb

%matplotlib inline

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device')

Using cuda device


In [3]:
! nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 2080 Ti (UUID: GPU-c85147b5-9f81-f5d0-5e57-ac54e49001d2)
GPU 1: NVIDIA GeForce RTX 2080 Ti (UUID: GPU-2abcbaa3-7827-737e-ce00-5b6759c6a8dd)
GPU 2: NVIDIA GeForce RTX 2080 Ti (UUID: GPU-864b0edf-d0e1-ac20-8000-b111d5387609)


In [4]:
def set_all_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
set_all_seed(123)

In [5]:
batch_size = 1024

train_transform = transforms.Compose([
    transforms.Pad(4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=train_transform)
valid_dataset = datasets.CIFAR10(root='data', train=False, download=True, transform=test_transform)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

sixteenth_train_sampler = RandomSampler(train_dataset, num_samples=len(train_dataset)//16, replacement=True)
half_train_sampler = RandomSampler(train_dataset, num_samples=len(train_dataset)//2, replacement=True)

sixteenth_train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sixteenth_train_sampler)
half_train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=half_train_sampler)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [12]:
# HINT: Remember to change the model to 'resnet50' and the weights to weights="IMAGENET1K_V1" when needed.
model_50_nw = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=False).to(device)

# Background: The original resnet18 is designed for ImageNet dataset to predict 1000 classes.
# TODO: Change the output of the model to 10 class.
model.fc = torch.nn.Linear(2048, 10)

Using cache found in /home/lysee13/.cache/torch/hub/pytorch_vision_v0.10.0


In [None]:
model_18_nw = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False).to(device)
model.fc = torch.nn.Linear(2048, 10)

In [13]:
# TODO: Fill in the code cell according to the pytorch tutorial we gave.
def train(dataloader, model, loss_fn, optimizer):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    epoch_loss = 0
    correct = 0

    model.to(device).train()

    for X, y in tqdm(dataloader):
#         pdb.set_trace()
        X, y = X.to(device), y.to(device)
#         pdb.set_trace()
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pred = pred.argmax(dim=1, keepdim=True)
        correct += pred.eq(y.view_as(pred)).sum().item()

    avg_epoch_loss = epoch_loss / num_batches
    avg_acc = correct / size

    return avg_epoch_loss, avg_acc

def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    epoch_loss = 0
    correct = 0

    model.eval()

    with torch.no_grad():
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
#             pdb.set_trace()
            pred = model(X)

            epoch_loss += loss_fn(pred, y).item()
            pred = pred.argmax(dim=1, keepdim=True)
            correct += pred.eq(y.view_as(pred)).sum().item()

    avg_epoch_loss = epoch_loss / num_batches
    avg_acc = correct / size

    return avg_epoch_loss, avg_acc

In [14]:
def train_test_loop(train_dataloader, valid_dataloader, model):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    epochs = 100
    acc = 0
    for epoch in range(epochs):
        train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
        test_loss, test_acc = test(valid_dataloader, model, loss_fn)
        if test_acc>acc:
            acc = test_acc
            i = epoch
        print(f"Epoch {epoch + 1:2d}: Loss = {train_loss:.4f} Acc = {train_acc:.2f} Test_Loss = {test_loss:.4f} Test_Acc = {test_acc:.2f}")
    print("Done!")
    print("the acc of test:{}".format(acc))
    print("the epoch:{}".format(i))
    return acc

### small model (resnet18) train_dataloader weights=None

In [None]:
acc_resnet50_1_nw = train_test_loop(train_dataloader, valid_dataloader, model_50)

100%|██████████| 49/49 [00:25<00:00,  1.89it/s]
100%|██████████| 10/10 [00:01<00:00,  5.34it/s]


Epoch  1: Loss = 2.1208 Acc = 0.27 Test_Loss = 1.8860 Test_Acc = 0.31


100%|██████████| 49/49 [00:24<00:00,  2.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.91it/s]


Epoch  2: Loss = 1.6229 Acc = 0.41 Test_Loss = 1.4864 Test_Acc = 0.46


100%|██████████| 49/49 [00:24<00:00,  2.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.55it/s]


Epoch  3: Loss = 1.4569 Acc = 0.47 Test_Loss = 1.3159 Test_Acc = 0.52


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.52it/s]


Epoch  4: Loss = 1.3735 Acc = 0.50 Test_Loss = 2.4390 Test_Acc = 0.42


100%|██████████| 49/49 [00:24<00:00,  1.99it/s]
100%|██████████| 10/10 [00:01<00:00,  5.73it/s]


Epoch  5: Loss = 1.3419 Acc = 0.52 Test_Loss = 1.2852 Test_Acc = 0.54


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.66it/s]


Epoch  6: Loss = 1.2290 Acc = 0.56 Test_Loss = 1.1780 Test_Acc = 0.58


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.71it/s]


Epoch  7: Loss = 1.1608 Acc = 0.59 Test_Loss = 1.3066 Test_Acc = 0.54


100%|██████████| 49/49 [00:25<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.54it/s]


Epoch  8: Loss = 1.1068 Acc = 0.60 Test_Loss = 1.4031 Test_Acc = 0.53


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.82it/s]


Epoch  9: Loss = 1.0856 Acc = 0.61 Test_Loss = 1.0238 Test_Acc = 0.64


100%|██████████| 49/49 [00:24<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.54it/s]


Epoch 10: Loss = 1.0104 Acc = 0.64 Test_Loss = 1.0117 Test_Acc = 0.64


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.77it/s]


Epoch 11: Loss = 0.9581 Acc = 0.66 Test_Loss = 1.0453 Test_Acc = 0.63


100%|██████████| 49/49 [00:25<00:00,  1.94it/s]
100%|██████████| 10/10 [00:01<00:00,  5.64it/s]


Epoch 12: Loss = 1.0131 Acc = 0.64 Test_Loss = 1.0407 Test_Acc = 0.64


100%|██████████| 49/49 [00:25<00:00,  1.94it/s]
100%|██████████| 10/10 [00:01<00:00,  5.47it/s]


Epoch 13: Loss = 1.0113 Acc = 0.65 Test_Loss = 2.0229 Test_Acc = 0.56


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.46it/s]


Epoch 14: Loss = 0.9635 Acc = 0.66 Test_Loss = 1.2024 Test_Acc = 0.60


100%|██████████| 49/49 [00:24<00:00,  1.99it/s]
100%|██████████| 10/10 [00:01<00:00,  5.56it/s]


Epoch 15: Loss = 0.8838 Acc = 0.69 Test_Loss = 0.8918 Test_Acc = 0.69


100%|██████████| 49/49 [00:25<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.49it/s]


Epoch 16: Loss = 0.8242 Acc = 0.71 Test_Loss = 0.9820 Test_Acc = 0.67


100%|██████████| 49/49 [00:24<00:00,  2.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.65it/s]


Epoch 17: Loss = 0.7882 Acc = 0.72 Test_Loss = 0.8734 Test_Acc = 0.69


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.74it/s]


Epoch 18: Loss = 0.7764 Acc = 0.73 Test_Loss = 0.9982 Test_Acc = 0.65


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.93it/s]


Epoch 19: Loss = 0.7422 Acc = 0.74 Test_Loss = 0.8939 Test_Acc = 0.69


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.72it/s]


Epoch 20: Loss = 0.7220 Acc = 0.75 Test_Loss = 0.8901 Test_Acc = 0.70


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.77it/s]


Epoch 21: Loss = 0.7344 Acc = 0.74 Test_Loss = 0.9224 Test_Acc = 0.69


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.83it/s]


Epoch 22: Loss = 0.7611 Acc = 0.74 Test_Loss = 0.8901 Test_Acc = 0.71


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.91it/s]


Epoch 23: Loss = 0.6862 Acc = 0.76 Test_Loss = 0.8335 Test_Acc = 0.72


100%|██████████| 49/49 [00:24<00:00,  1.99it/s]
100%|██████████| 10/10 [00:01<00:00,  5.69it/s]


Epoch 24: Loss = 0.7031 Acc = 0.75 Test_Loss = 0.8515 Test_Acc = 0.71


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.73it/s]


Epoch 25: Loss = 0.6948 Acc = 0.75 Test_Loss = 1.0583 Test_Acc = 0.68


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.25it/s]


Epoch 26: Loss = 1.1444 Acc = 0.61 Test_Loss = 5.6338 Test_Acc = 0.23


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.67it/s]


Epoch 27: Loss = 1.6344 Acc = 0.45 Test_Loss = 1.3995 Test_Acc = 0.51


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.55it/s]


Epoch 28: Loss = 1.2605 Acc = 0.55 Test_Loss = 1.1162 Test_Acc = 0.60


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.64it/s]


Epoch 29: Loss = 1.0798 Acc = 0.62 Test_Loss = 1.0201 Test_Acc = 0.64


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.66it/s]


Epoch 30: Loss = 0.9588 Acc = 0.66 Test_Loss = 0.9360 Test_Acc = 0.68


100%|██████████| 49/49 [00:25<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.80it/s]


Epoch 31: Loss = 0.8666 Acc = 0.69 Test_Loss = 0.9098 Test_Acc = 0.68


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.38it/s]


Epoch 32: Loss = 0.8216 Acc = 0.71 Test_Loss = 0.8472 Test_Acc = 0.71


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.55it/s]


Epoch 33: Loss = 0.8520 Acc = 0.71 Test_Loss = 1.1112 Test_Acc = 0.64


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.62it/s]


Epoch 34: Loss = 0.8356 Acc = 0.71 Test_Loss = 0.9923 Test_Acc = 0.66


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.46it/s]


Epoch 35: Loss = 0.7542 Acc = 0.73 Test_Loss = 0.8113 Test_Acc = 0.72


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.59it/s]


Epoch 36: Loss = 0.7051 Acc = 0.75 Test_Loss = 0.7432 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.74it/s]


Epoch 37: Loss = 0.6684 Acc = 0.76 Test_Loss = 0.7940 Test_Acc = 0.73


100%|██████████| 49/49 [00:24<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.28it/s]


Epoch 38: Loss = 0.6331 Acc = 0.77 Test_Loss = 0.7294 Test_Acc = 0.75


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.60it/s]


Epoch 39: Loss = 0.6125 Acc = 0.78 Test_Loss = 0.7735 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  1.99it/s]
100%|██████████| 10/10 [00:01<00:00,  5.85it/s]


Epoch 40: Loss = 0.5926 Acc = 0.79 Test_Loss = 0.7663 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.56it/s]


Epoch 41: Loss = 0.5664 Acc = 0.80 Test_Loss = 0.6886 Test_Acc = 0.77


100%|██████████| 49/49 [00:24<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.77it/s]


Epoch 42: Loss = 0.5469 Acc = 0.80 Test_Loss = 0.7267 Test_Acc = 0.75


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.80it/s]


Epoch 43: Loss = 0.5372 Acc = 0.81 Test_Loss = 0.7419 Test_Acc = 0.75


100%|██████████| 49/49 [00:24<00:00,  2.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.75it/s]


Epoch 44: Loss = 0.5164 Acc = 0.81 Test_Loss = 0.7576 Test_Acc = 0.75


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.57it/s]


Epoch 45: Loss = 0.5024 Acc = 0.82 Test_Loss = 0.7796 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  2.00it/s]
100%|██████████| 10/10 [00:01<00:00,  5.60it/s]


Epoch 46: Loss = 0.4917 Acc = 0.82 Test_Loss = 0.7153 Test_Acc = 0.77


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.55it/s]


Epoch 47: Loss = 0.4732 Acc = 0.83 Test_Loss = 0.7946 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  1.99it/s]
100%|██████████| 10/10 [00:01<00:00,  5.61it/s]


Epoch 48: Loss = 0.4686 Acc = 0.83 Test_Loss = 0.6640 Test_Acc = 0.78


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.79it/s]


Epoch 49: Loss = 0.4500 Acc = 0.84 Test_Loss = 0.7730 Test_Acc = 0.74


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.29it/s]


Epoch 50: Loss = 0.4427 Acc = 0.84 Test_Loss = 0.7202 Test_Acc = 0.77


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.49it/s]


Epoch 51: Loss = 0.4231 Acc = 0.85 Test_Loss = 0.6797 Test_Acc = 0.78


100%|██████████| 49/49 [00:25<00:00,  1.96it/s]
100%|██████████| 10/10 [00:01<00:00,  5.74it/s]


Epoch 52: Loss = 0.4167 Acc = 0.85 Test_Loss = 0.6429 Test_Acc = 0.79


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.54it/s]


Epoch 53: Loss = 0.4004 Acc = 0.86 Test_Loss = 0.7193 Test_Acc = 0.77


100%|██████████| 49/49 [00:24<00:00,  2.01it/s]
100%|██████████| 10/10 [00:01<00:00,  5.57it/s]


Epoch 54: Loss = 0.3935 Acc = 0.86 Test_Loss = 0.7688 Test_Acc = 0.76


100%|██████████| 49/49 [00:24<00:00,  2.02it/s]
100%|██████████| 10/10 [00:01<00:00,  5.58it/s]


Epoch 55: Loss = 0.3907 Acc = 0.86 Test_Loss = 0.9817 Test_Acc = 0.71


100%|██████████| 49/49 [00:24<00:00,  1.98it/s]
100%|██████████| 10/10 [00:01<00:00,  5.78it/s]


Epoch 56: Loss = 0.3781 Acc = 0.87 Test_Loss = 0.6790 Test_Acc = 0.79


100%|██████████| 49/49 [00:25<00:00,  1.95it/s]
100%|██████████| 10/10 [00:01<00:00,  5.46it/s]


Epoch 57: Loss = 0.3631 Acc = 0.87 Test_Loss = 0.8125 Test_Acc = 0.76


100%|██████████| 49/49 [00:24<00:00,  1.97it/s]
100%|██████████| 10/10 [00:01<00:00,  5.62it/s]


Epoch 58: Loss = 0.3616 Acc = 0.87 Test_Loss = 0.6732 Test_Acc = 0.79


 24%|██▍       | 12/49 [00:06<00:18,  1.95it/s]

In [None]:
acc_resnet18_1_nw = train_test_loop(train_dataloader, valid_dataloader, model_18)