In [1]:
import torch
import torchvision.datasets
import torchvision.models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
from train_utils import get_data_loaders, train, test
from PIL import Image
import pandas as pd

In [2]:
root = './distilled_dataset'

### CIFAR10 Distilled

In [3]:
class CIFAR10_Distilled(Dataset):
    def __init__(self, idx_df, transform):
        self.idx_df = idx_df
        self.transform = transform
    
    def __len__(self):
        return len(self.idx_df)

    def __getitem__(self, index):
        entry = self.idx_df.iloc[index]
        image = self.transform(Image.open(entry.image_path))
        label = entry.label
        return image, torch.tensor(label)

In [4]:
index_file = pd.read_csv('./CIFAR10_DM_index_files/ipc500_idx.csv')
index_file

Unnamed: 0,image_path,label
0,./CIFAR10_DM/ipc500/airplane/003746e7-c915-4d7...,0
1,./CIFAR10_DM/ipc500/airplane/0090a052-69a0-461...,0
2,./CIFAR10_DM/ipc500/airplane/0143960d-3978-448...,0
3,./CIFAR10_DM/ipc500/airplane/01ee79d8-a8fe-430...,0
4,./CIFAR10_DM/ipc500/airplane/0231edfd-42fd-477...,0
...,...,...
4995,./CIFAR10_DM/ipc500/truck/fb1fbeaf-b6c7-4da4-a...,9
4996,./CIFAR10_DM/ipc500/truck/fcd9e0b0-9b62-4d41-8...,9
4997,./CIFAR10_DM/ipc500/truck/fdefdd6d-2fd9-42f1-b...,9
4998,./CIFAR10_DM/ipc500/truck/ff83d309-da9a-459a-a...,9


In [5]:
image_normalize= transforms.Lambda(lambda img: 2 * img / 255. - 1) # normalize to [-1, 1]
# https://github.com/Lornatang/pytorch-alexnet-cifar100/blob/master/utils/datasets.py
transform = transforms.Compose([
                                transforms.Resize(128),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                # image_normalize
                              ])
cifar10_distilled_train = CIFAR10_Distilled(index_file, transform)

cifar10_root = '../cifar10'
cifar10_test = torchvision.datasets.CIFAR10(cifar10_root,
                                            train=False,
                                            transform=transforms.Compose([
                                                transforms.Resize(64),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                                # image_normalize
                                              ]))

In [6]:
batch_size = 16 # cut batch size to half for ipc50
train_data_loader, test_data_loader = get_data_loaders(cifar10_distilled_train, cifar10_test, batch_size=batch_size)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[INFO]: Computation device: {device}")

[INFO]: Computation device: cuda


In [8]:
model_name = 'alexnet'
dataset = 'cifar10_distilled'
model_path = f'./model_weight/{dataset}/{model_name}/'
model = torchvision.models.alexnet(num_classes = 100).to(device)

In [9]:
train(model, model_path, train_data_loader, test_data_loader)

Epoch 1: 100%|██████████| 313/313 [00:14<00:00, 21.59it/s, loss=1.32] 


epoch: 1 done, loss: 0.10201902687549591
Acc: 0.2067


Epoch 2: 100%|██████████| 313/313 [00:09<00:00, 31.67it/s, loss=0.432]


epoch: 2 done, loss: 0.05046798661351204
Acc: 0.154


Epoch 3: 100%|██████████| 313/313 [00:10<00:00, 30.58it/s, loss=0.438] 


epoch: 3 done, loss: 0.03396567702293396
Acc: 0.1505


Epoch 4: 100%|██████████| 313/313 [00:10<00:00, 29.65it/s, loss=0.283] 


epoch: 4 done, loss: 0.025460369884967804
Acc: 0.1669


Epoch 5: 100%|██████████| 313/313 [00:11<00:00, 28.45it/s, loss=0.468] 


epoch: 5 done, loss: 0.019044693559408188
Acc: 0.2094


Epoch 6: 100%|██████████| 313/313 [00:10<00:00, 29.42it/s, loss=0.598] 


epoch: 6 done, loss: 0.016895169392228127
Acc: 0.2178


Epoch 7: 100%|██████████| 313/313 [00:10<00:00, 30.05it/s, loss=0.368]  


epoch: 7 done, loss: 0.013079565018415451
Acc: 0.2206


Epoch 8: 100%|██████████| 313/313 [00:10<00:00, 30.77it/s, loss=0.0491] 


epoch: 8 done, loss: 0.010571521706879139
Acc: 0.1954


Epoch 9: 100%|██████████| 313/313 [00:10<00:00, 29.84it/s, loss=0.293]  


epoch: 9 done, loss: 0.009646412916481495
Acc: 0.1593


Epoch 10: 100%|██████████| 313/313 [00:09<00:00, 32.68it/s, loss=0.514]  


epoch: 10 done, loss: 0.008470257744193077
Acc: 0.1844


Epoch 11: 100%|██████████| 313/313 [00:09<00:00, 32.88it/s, loss=0.454]  


epoch: 11 done, loss: 0.008482540026307106
Acc: 0.1753


Epoch 12: 100%|██████████| 313/313 [00:09<00:00, 32.77it/s, loss=0.626]   


epoch: 12 done, loss: 0.006959889084100723
Acc: 0.2252


Epoch 13: 100%|██████████| 313/313 [00:09<00:00, 32.39it/s, loss=0.0185]  


epoch: 13 done, loss: 0.006563671864569187
Acc: 0.2285


Epoch 14: 100%|██████████| 313/313 [00:09<00:00, 31.95it/s, loss=0.328]   


epoch: 14 done, loss: 0.005614642519503832
Acc: 0.2204


Epoch 15: 100%|██████████| 313/313 [00:09<00:00, 32.72it/s, loss=0.0032]  


epoch: 15 done, loss: 0.004760872106999159
Acc: 0.246


Epoch 16: 100%|██████████| 313/313 [00:09<00:00, 32.95it/s, loss=0.096]   


epoch: 16 done, loss: 0.0062405019998550415
Acc: 0.1804


Epoch 17: 100%|██████████| 313/313 [00:09<00:00, 32.73it/s, loss=0.0927]  


epoch: 17 done, loss: 0.005639783106744289
Acc: 0.1769


Epoch 18: 100%|██████████| 313/313 [00:09<00:00, 33.05it/s, loss=0.00536] 


epoch: 18 done, loss: 0.0037234134506434202
Acc: 0.1957


Epoch 19: 100%|██████████| 313/313 [00:09<00:00, 32.91it/s, loss=0.0789]  


epoch: 19 done, loss: 0.004370077978819609
Acc: 0.2008


Epoch 20: 100%|██████████| 313/313 [00:09<00:00, 32.94it/s, loss=0.00945] 


epoch: 20 done, loss: 0.0032704207114875317
Acc: 0.2324


Epoch 21: 100%|██████████| 313/313 [00:09<00:00, 33.04it/s, loss=0.101]   


epoch: 21 done, loss: 0.0032084062695503235
Acc: 0.2076


Epoch 22: 100%|██████████| 313/313 [01:10<00:00,  4.46it/s, loss=0.000341]


epoch: 22 done, loss: 0.004407044034451246
Acc: 0.2031


Epoch 23: 100%|██████████| 313/313 [00:33<00:00,  9.48it/s, loss=0.000105]


epoch: 23 done, loss: 0.002087516477331519
Acc: 0.2119


Epoch 24: 100%|██████████| 313/313 [00:09<00:00, 33.63it/s, loss=0.00576] 


epoch: 24 done, loss: 0.003119047963991761
Acc: 0.1996


Epoch 25: 100%|██████████| 313/313 [00:09<00:00, 33.20it/s, loss=0.00912] 


epoch: 25 done, loss: 0.002700412878766656
Acc: 0.2556


Epoch 26: 100%|██████████| 313/313 [00:09<00:00, 32.99it/s, loss=0.233]   


epoch: 26 done, loss: 0.0035763492342084646
Acc: 0.2242


Epoch 27: 100%|██████████| 313/313 [00:09<00:00, 32.19it/s, loss=0.00318] 


epoch: 27 done, loss: 0.003744780318811536
Acc: 0.2095


Epoch 28: 100%|██████████| 313/313 [00:09<00:00, 32.81it/s, loss=0.00484] 


epoch: 28 done, loss: 0.0028393855318427086
Acc: 0.2177


Epoch 29: 100%|██████████| 313/313 [00:09<00:00, 32.94it/s, loss=0.0134]  


epoch: 29 done, loss: 0.003349428065121174
Acc: 0.1996


Epoch 30: 100%|██████████| 313/313 [00:09<00:00, 32.65it/s, loss=0.00158] 


epoch: 30 done, loss: 0.0018552930559962988
Acc: 0.2022


Epoch 31: 100%|██████████| 313/313 [00:09<00:00, 32.30it/s, loss=0.000274]


epoch: 31 done, loss: 0.0023828211706131697
Acc: 0.1856


Epoch 32: 100%|██████████| 313/313 [00:09<00:00, 32.65it/s, loss=0.0645]  


epoch: 32 done, loss: 0.003002104815095663
Acc: 0.2051


Epoch 33: 100%|██████████| 313/313 [00:09<00:00, 33.02it/s, loss=0.0306]  


epoch: 33 done, loss: 0.0022541384678333998
Acc: 0.2423


Epoch 34: 100%|██████████| 313/313 [00:09<00:00, 32.85it/s, loss=0.000305]


epoch: 34 done, loss: 0.0017798353219404817
Acc: 0.1918


Epoch 35: 100%|██████████| 313/313 [00:09<00:00, 32.59it/s, loss=0.128]   


epoch: 35 done, loss: 0.0027244838420301676
Acc: 0.2169


Epoch 36: 100%|██████████| 313/313 [00:09<00:00, 32.78it/s, loss=0.000197]


epoch: 36 done, loss: 0.0017954159993678331
Acc: 0.1829


Epoch 37: 100%|██████████| 313/313 [00:09<00:00, 32.56it/s, loss=0.000673]


epoch: 37 done, loss: 0.0030311846639961004
Acc: 0.1886


Epoch 38: 100%|██████████| 313/313 [00:09<00:00, 32.68it/s, loss=0.274]   


epoch: 38 done, loss: 0.0018027722835540771
Acc: 0.2139


Epoch 39: 100%|██████████| 313/313 [00:09<00:00, 32.76it/s, loss=0.000256]


epoch: 39 done, loss: 0.0012667904375120997
Acc: 0.18


Epoch 40: 100%|██████████| 313/313 [00:09<00:00, 32.62it/s, loss=0]       


epoch: 40 done, loss: 4.029354749945924e-05
Acc: 0.1831


Epoch 41: 100%|██████████| 313/313 [00:09<00:00, 32.47it/s, loss=0.00908] 


epoch: 41 done, loss: 0.0028867723885923624
Acc: 0.2183


Epoch 42: 100%|██████████| 313/313 [00:09<00:00, 32.64it/s, loss=0.000269]


epoch: 42 done, loss: 0.002813790924847126
Acc: 0.1951


Epoch 43: 100%|██████████| 313/313 [00:09<00:00, 32.92it/s, loss=0.0112]  


epoch: 43 done, loss: 0.0015082507161423564
Acc: 0.2272


Epoch 44: 100%|██████████| 313/313 [00:09<00:00, 32.88it/s, loss=3.77e-6] 


epoch: 44 done, loss: 0.002661849604919553
Acc: 0.1887


Epoch 45: 100%|██████████| 313/313 [00:09<00:00, 33.51it/s, loss=0.000461]


epoch: 45 done, loss: 0.002508089877665043
Acc: 0.1603


Epoch 46: 100%|██████████| 313/313 [00:09<00:00, 32.70it/s, loss=4.67e-5] 


epoch: 46 done, loss: 0.000960583274718374
Acc: 0.1948


Epoch 47: 100%|██████████| 313/313 [00:09<00:00, 32.44it/s, loss=0.013]   


epoch: 47 done, loss: 0.00022123662347439677
Acc: 0.2145


Epoch 48: 100%|██████████| 313/313 [00:09<00:00, 33.14it/s, loss=0.0166]  


epoch: 48 done, loss: 0.0034660238306969404
Acc: 0.1723


Epoch 49: 100%|██████████| 313/313 [00:09<00:00, 33.18it/s, loss=0.00756] 


epoch: 49 done, loss: 0.0014043899718672037
Acc: 0.2179


Epoch 50: 100%|██████████| 313/313 [00:09<00:00, 33.18it/s, loss=8.16e-5] 


epoch: 50 done, loss: 0.0018898183479905128
Acc: 0.2196
Time taken: 786.951261 seconds


In [10]:
test(model, model_path, test_data_loader)

0.2556