In [1]:
import pretrained
import torch
import torchvision
import datetime
from tqdm import tqdm

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.init()

In [3]:
def load_cifar100_dataloaders_with_validation():
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    dataset_train = torchvision.datasets.CIFAR100(".data", download=True, transform=transform)
    dataset_train, dataset_validation = torch.utils.data.random_split(dataset_train, [0.8, 0.2], torch.Generator().manual_seed(123))
    dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=16)
    dataloader_validation = torch.utils.data.DataLoader(dataset_validation, batch_size=16)
    dataset_test = torchvision.datasets.CIFAR100(".data", download=True, train=False, transform=transform)
    dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=16)
    return dataloader_train, dataloader_validation, dataloader_test

In [4]:
def eval_accuracy(model, dataloader):
    with torch.no_grad():
        model.to(device)
        correct = 0
        all_so_far = 0
        for inputs, labels in tqdm(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            pred = torch.argmax(model(inputs), dim=1)

            all_so_far += labels.size().numel()
            correct += torch.sum(pred.eq(labels))
    return correct/all_so_far

In [5]:
def fit_one_cycle(epochs, max_lr, model, dataloader_train, dataloader_test=None, name=None):
    optimizer = torch.optim.Adam(model.parameters(), max_lr)
    name = "model" if name is None else name

    for epoch in range(epochs):
        torch.cuda.empty_cache()
        model.train()
        print(f"Epoch:{epoch + 1}")
        for inputs, labels in tqdm(dataloader_train):
            optimizer.zero_grad()
            inputs, labels = inputs.to(device), labels.to(device)
            loss = torch.nn.functional.cross_entropy(model(inputs), labels)
            loss.backward()
            del inputs, labels
            optimizer.step()

        model.eval()
        accuracy_train = eval_accuracy(model, dataloader_train)
        accuracy_test = eval_accuracy(model, dataloader_test)

        line = f"{datetime.datetime.now().isoformat()}_train_{accuracy_train:.2f}_test_{accuracy_test:.2f}"
        print(line)
        if epoch % 5 == 0:
            torch.save(model, f".weights/{name}/{line}")

In [6]:
dataloader_cifar100_train, dataloader_cifar100_validation, dataloader_cifar100_test = load_cifar100_dataloaders_with_validation()

Files already downloaded and verified
Files already downloaded and verified


In [7]:
pretrained_model = torch.load(".weights/full/pretrained")
pretrained_model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=100)
model_cifar100 = pretrained_model.to(device)
del pretrained_model

In [13]:
fit_one_cycle(150, 0.001, model_cifar100, dataloader_cifar100_train, dataloader_cifar100_validation, name="cifar100")

Epoch:1


100%|██████████| 2500/2500 [03:06<00:00, 13.39it/s]
100%|██████████| 2500/2500 [00:36<00:00, 68.79it/s]
100%|██████████| 625/625 [00:09<00:00, 68.72it/s]


2022-12-17T11:55:16.513814_train_0.69_test_0.31
Epoch:2


100%|██████████| 2500/2500 [02:59<00:00, 13.89it/s]
100%|██████████| 2500/2500 [00:36<00:00, 68.79it/s]
100%|██████████| 625/625 [00:09<00:00, 68.74it/s]


2022-12-17T11:59:02.061671_train_0.70_test_0.31
Epoch:3


100%|██████████| 2500/2500 [02:59<00:00, 13.92it/s]
100%|██████████| 2500/2500 [00:36<00:00, 68.88it/s]
100%|██████████| 625/625 [00:09<00:00, 69.00it/s]


2022-12-17T12:02:47.042773_train_0.67_test_0.31
Epoch:4


100%|██████████| 2500/2500 [02:59<00:00, 13.96it/s]
100%|██████████| 2500/2500 [00:36<00:00, 69.06it/s]
100%|██████████| 625/625 [00:09<00:00, 69.05it/s]


2022-12-17T12:06:31.457545_train_0.70_test_0.31
Epoch:5


100%|██████████| 2500/2500 [02:58<00:00, 13.99it/s]
100%|██████████| 2500/2500 [00:33<00:00, 74.15it/s]
100%|██████████| 625/625 [00:08<00:00, 74.22it/s]


2022-12-17T12:10:12.259336_train_0.70_test_0.31
Epoch:6


100%|██████████| 2500/2500 [02:46<00:00, 15.03it/s]
100%|██████████| 2500/2500 [00:33<00:00, 74.20it/s]
100%|██████████| 625/625 [00:08<00:00, 70.04it/s]


2022-12-17T12:13:41.260013_train_0.67_test_0.30
Epoch:7


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.28it/s]
100%|██████████| 625/625 [00:08<00:00, 70.19it/s]


2022-12-17T12:17:19.524339_train_0.70_test_0.31
Epoch:8


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.18it/s]
100%|██████████| 625/625 [00:08<00:00, 70.13it/s]


2022-12-17T12:20:57.631127_train_0.69_test_0.30
Epoch:9


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.08it/s]
100%|██████████| 625/625 [00:08<00:00, 70.09it/s]


2022-12-17T12:24:35.743547_train_0.69_test_0.30
Epoch:10


100%|██████████| 2500/2500 [02:53<00:00, 14.38it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.10it/s]
100%|██████████| 625/625 [00:08<00:00, 70.11it/s]


2022-12-17T12:28:14.151430_train_0.70_test_0.30
Epoch:11


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.11it/s]
100%|██████████| 625/625 [00:08<00:00, 70.08it/s]


2022-12-17T12:31:52.386623_train_0.70_test_0.31
Epoch:12


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.13it/s]
100%|██████████| 625/625 [00:08<00:00, 70.12it/s]


2022-12-17T12:35:30.806335_train_0.69_test_0.31
Epoch:13


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.09it/s]
100%|██████████| 625/625 [00:08<00:00, 70.06it/s]


2022-12-17T12:39:08.928281_train_0.70_test_0.31
Epoch:14


100%|██████████| 2500/2500 [02:53<00:00, 14.38it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.11it/s]
100%|██████████| 625/625 [00:08<00:00, 70.07it/s]


2022-12-17T12:42:47.333782_train_0.71_test_0.31
Epoch:15


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.21it/s]
100%|██████████| 625/625 [00:08<00:00, 70.17it/s]


2022-12-17T12:46:25.470740_train_0.72_test_0.31
Epoch:16


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.21it/s]
100%|██████████| 625/625 [00:08<00:00, 70.26it/s]


2022-12-17T12:50:03.548913_train_0.71_test_0.31
Epoch:17


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.21it/s]
100%|██████████| 625/625 [00:08<00:00, 70.30it/s]


2022-12-17T12:53:41.663205_train_0.71_test_0.31
Epoch:18


100%|██████████| 2500/2500 [02:53<00:00, 14.39it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.09it/s]
100%|██████████| 625/625 [00:08<00:00, 70.08it/s]


2022-12-17T12:57:20.058956_train_0.71_test_0.31
Epoch:19


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.10it/s]
100%|██████████| 625/625 [00:08<00:00, 70.09it/s]


2022-12-17T13:00:58.304208_train_0.70_test_0.31
Epoch:20


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.10it/s]
100%|██████████| 625/625 [00:08<00:00, 70.10it/s]


2022-12-17T13:04:36.500495_train_0.70_test_0.31
Epoch:21


100%|██████████| 2500/2500 [02:53<00:00, 14.41it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.08it/s]
100%|██████████| 625/625 [00:08<00:00, 70.11it/s]


2022-12-17T13:08:14.628563_train_0.70_test_0.31
Epoch:22


100%|██████████| 2500/2500 [02:53<00:00, 14.38it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.08it/s]
100%|██████████| 625/625 [00:08<00:00, 70.12it/s]


2022-12-17T13:11:53.198520_train_0.71_test_0.31
Epoch:23


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.10it/s]
100%|██████████| 625/625 [00:08<00:00, 70.08it/s]


2022-12-17T13:15:31.457425_train_0.70_test_0.30
Epoch:24


100%|██████████| 2500/2500 [02:53<00:00, 14.40it/s]
100%|██████████| 2500/2500 [00:35<00:00, 70.11it/s]
100%|██████████| 625/625 [00:08<00:00, 70.07it/s]


2022-12-17T13:19:09.653664_train_0.72_test_0.31
Epoch:25


 82%|████████▏ | 2046/2500 [02:23<00:31, 14.28it/s]


KeyboardInterrupt: 

In [9]:
eval_accuracy(model_cifar100, dataloader_cifar100_test)

100%|██████████| 625/625 [00:09<00:00, 68.05it/s]


tensor(0.3103, device='cuda:0')

In [11]:
sums = torch.zeros(100)
for inputs, labels in tqdm(dataloader_cifar100_validation):
    sums[labels] += 1

100%|██████████| 625/625 [00:02<00:00, 265.36it/s]


In [12]:
sums

tensor([100.,  86., 112.,  95.,  81.,  86., 102., 103., 104.,  86.,  97.,  90.,
         81., 105.,  69.,  94.,  95.,  94.,  93.,  98.,  89.,  97., 101.,  95.,
         78.,  91.,  91.,  98.,  87.,  92., 101.,  91.,  87.,  85.,  95.,  81.,
        100.,  91.,  90.,  89.,  86.,  90., 102.,  95.,  93.,  75., 100.,  79.,
        103.,  92.,  99.,  84., 101.,  69.,  85.,  93.,  99.,  76.,  86.,  90.,
        104., 105.,  86.,  89.,  99.,  88.,  90.,  95.,  92.,  97.,  94.,  88.,
         87.,  90.,  91.,  83.,  92.,  92.,  95., 101.,  84.,  95.,  86.,  87.,
         91.,  90., 102., 111.,  93.,  90.,  98.,  96.,  86.,  89.,  97.,  98.,
        100., 101., 101.,  93.])