In [1]:
import torch
import torchvision.datasets
import torchvision.models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
from train_utils import get_data_loaders, train, test
from PIL import Image
import pandas as pd

In [2]:
root = './distilled_dataset'

### CIFAR10 Distilled

In [3]:
class CIFAR10_Distilled(Dataset):
    def __init__(self, idx_df, transform):
        self.idx_df = idx_df
        self.transform = transform
    
    def __len__(self):
        return len(self.idx_df)

    def __getitem__(self, index):
        entry = self.idx_df.iloc[index]
        image = self.transform(Image.open(entry.image_path))
        label = entry.label
        return image, torch.tensor(label)

In [4]:
index_file = pd.read_csv('./CIFAR10_DM_index_files/ipc50_idx.csv')
index_file

Unnamed: 0,image_path,label
0,./CIFAR10_DM/ipc500/airplane/006365cd-0d76-4f8...,0
1,./CIFAR10_DM/ipc500/airplane/00b2e5ad-bd19-4ab...,0
2,./CIFAR10_DM/ipc500/airplane/01163b19-375e-4ce...,0
3,./CIFAR10_DM/ipc500/airplane/0184205a-6931-468...,0
4,./CIFAR10_DM/ipc500/airplane/0203df52-3303-497...,0
...,...,...
5115,./CIFAR10_DM/ipc500/truck/fd6ff94f-754f-40e5-b...,9
5116,./CIFAR10_DM/ipc500/truck/fdc54d46-d726-4b14-a...,9
5117,./CIFAR10_DM/ipc500/truck/fe77baa8-1be7-4192-b...,9
5118,./CIFAR10_DM/ipc500/truck/fe787924-9859-41ae-b...,9


In [5]:
image_normalize= transforms.Lambda(lambda img: 2 * img / 255. - 1) # normalize to [-1, 1]
# https://github.com/Lornatang/pytorch-alexnet-cifar100/blob/master/utils/datasets.py
transform = transforms.Compose([transforms.Resize(64),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# transform = transforms.Compose([transforms.Resize(64),
#                                 transforms.ToTensor()])
cifar10_distilled_train = CIFAR10_Distilled(index_file, transform)

cifar10_root = '../cifar10'
cifar10_test = torchvision.datasets.CIFAR10(cifar10_root,
                                            train=False,
                                            transform=transforms.Compose([
                                                transforms.Resize(64),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                              ]))

In [6]:
batch_size = 32 # cut batch size to half for ipc50
train_data_loader, test_data_loader = get_data_loaders(cifar10_distilled_train, cifar10_test, batch_size=64)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[INFO]: Computation device: {device}")

[INFO]: Computation device: cuda


In [8]:
model_name = 'alexnet'
dataset = 'cifar10_distilled'
model_path = f'./model_weight/{dataset}/{model_name}/'
model = torchvision.models.alexnet(num_classes = 100).to(device)

In [9]:
train(model, model_path, train_data_loader, test_data_loader)

Epoch 1: 100%|██████████| 80/80 [00:09<00:00,  8.77it/s, loss=2.19]


epoch: 1 done, loss: 0.039202455431222916
Acc: 0.1026


Epoch 2: 100%|██████████| 80/80 [00:04<00:00, 19.96it/s, loss=1.65]


epoch: 2 done, loss: 0.027243098244071007
Acc: 0.1633


Epoch 3: 100%|██████████| 80/80 [00:03<00:00, 20.32it/s, loss=0.416]


epoch: 3 done, loss: 0.01299170684069395
Acc: 0.1579


Epoch 4: 100%|██████████| 80/80 [00:03<00:00, 20.16it/s, loss=0.17]  


epoch: 4 done, loss: 0.00397164560854435
Acc: 0.1583


Epoch 5: 100%|██████████| 80/80 [00:03<00:00, 20.21it/s, loss=0.00169]


epoch: 5 done, loss: 0.0003112895938102156
Acc: 0.1671


Epoch 6: 100%|██████████| 80/80 [00:03<00:00, 20.18it/s, loss=0.000845]


epoch: 6 done, loss: 2.2419873857870698e-05
Acc: 0.1683


Epoch 7: 100%|██████████| 80/80 [00:03<00:00, 20.27it/s, loss=0.000477]


epoch: 7 done, loss: 9.201466127706226e-06
Acc: 0.1712


Epoch 8: 100%|██████████| 80/80 [00:03<00:00, 20.25it/s, loss=0.000327]


epoch: 8 done, loss: 6.102608494984452e-06
Acc: 0.1731


Epoch 9: 100%|██████████| 80/80 [00:03<00:00, 20.48it/s, loss=0.000318]


epoch: 9 done, loss: 3.9972214835870545e-06
Acc: 0.1716


Epoch 10: 100%|██████████| 80/80 [00:03<00:00, 20.46it/s, loss=0.000122]


epoch: 10 done, loss: 3.265399300289573e-06
Acc: 0.1717


Epoch 11: 100%|██████████| 80/80 [00:04<00:00, 16.26it/s, loss=0.000174]


epoch: 11 done, loss: 2.044890607066918e-06
Acc: 0.1705


Epoch 12: 100%|██████████| 80/80 [00:04<00:00, 19.53it/s, loss=8.28e-5] 


epoch: 12 done, loss: 1.7463529502492747e-06
Acc: 0.1713


Epoch 13: 100%|██████████| 80/80 [00:04<00:00, 19.86it/s, loss=6.25e-5] 


epoch: 13 done, loss: 1.9707406408997485e-06
Acc: 0.1714


Epoch 14: 100%|██████████| 80/80 [00:04<00:00, 19.45it/s, loss=5.23e-5] 


epoch: 14 done, loss: 1.1478645092211082e-06
Acc: 0.1712


Epoch 15: 100%|██████████| 80/80 [00:04<00:00, 19.42it/s, loss=6.21e-5] 


epoch: 15 done, loss: 9.075327511709474e-07
Acc: 0.174


Epoch 16: 100%|██████████| 80/80 [00:04<00:00, 19.36it/s, loss=1.98e-5] 


epoch: 16 done, loss: 6.708248179165821e-07
Acc: 0.1726


Epoch 17: 100%|██████████| 80/80 [00:04<00:00, 19.50it/s, loss=0.000213]


epoch: 17 done, loss: 6.228781330719357e-07
Acc: 0.1716


Epoch 18: 100%|██████████| 80/80 [00:04<00:00, 19.72it/s, loss=0.000132]


epoch: 18 done, loss: 6.441732125495037e-07
Acc: 0.1701


Epoch 19: 100%|██████████| 80/80 [00:04<00:00, 19.62it/s, loss=0.000112]


epoch: 19 done, loss: 5.32817693965626e-07
Acc: 0.1705


Epoch 20: 100%|██████████| 80/80 [00:04<00:00, 19.19it/s, loss=7.36e-6] 


epoch: 20 done, loss: 3.5501565776030475e-07
Acc: 0.1712


Epoch 21: 100%|██████████| 80/80 [00:04<00:00, 19.94it/s, loss=3.21e-5]


epoch: 21 done, loss: 3.347905135342444e-07
Acc: 0.1709


Epoch 22: 100%|██████████| 80/80 [00:03<00:00, 20.53it/s, loss=9.02e-6]


epoch: 22 done, loss: 3.377209623067756e-07
Acc: 0.1726


Epoch 23: 100%|██████████| 80/80 [00:04<00:00, 19.63it/s, loss=6.71e-6] 


epoch: 23 done, loss: 3.7091234617037117e-07
Acc: 0.1705


Epoch 24: 100%|██████████| 80/80 [00:04<00:00, 18.36it/s, loss=1.01e-5]


epoch: 24 done, loss: 2.1131937444351934e-07
Acc: 0.1717


Epoch 25: 100%|██████████| 80/80 [00:04<00:00, 17.03it/s, loss=1.32e-5]


epoch: 25 done, loss: 2.100009908190259e-07
Acc: 0.1728


Epoch 26: 100%|██████████| 80/80 [00:04<00:00, 18.65it/s, loss=1.35e-5]


epoch: 26 done, loss: 2.2557303225312353e-07
Acc: 0.172


Epoch 27: 100%|██████████| 80/80 [00:04<00:00, 18.62it/s, loss=4.76e-6]


epoch: 27 done, loss: 1.9636955528312683e-07
Acc: 0.1714


Epoch 28: 100%|██████████| 80/80 [00:04<00:00, 18.35it/s, loss=8.82e-6]


epoch: 28 done, loss: 1.454241527198974e-07
Acc: 0.1723


Epoch 29: 100%|██████████| 80/80 [00:04<00:00, 17.05it/s, loss=1.98e-5]


epoch: 29 done, loss: 1.7036398958225618e-07
Acc: 0.1715


Epoch 30: 100%|██████████| 80/80 [00:04<00:00, 17.09it/s, loss=6.43e-6]


epoch: 30 done, loss: 1.2379915403926134e-07
Acc: 0.1718


Epoch 31: 100%|██████████| 80/80 [00:04<00:00, 18.27it/s, loss=7.62e-6]


epoch: 31 done, loss: 1.4771136136459972e-07
Acc: 0.1716


Epoch 32: 100%|██████████| 80/80 [00:04<00:00, 17.55it/s, loss=5.2e-6] 


epoch: 32 done, loss: 9.700088554609465e-08
Acc: 0.1711


Epoch 33: 100%|██████████| 80/80 [00:04<00:00, 18.04it/s, loss=5.27e-6]


epoch: 33 done, loss: 1.2932990500758024e-07
Acc: 0.1707


Epoch 34: 100%|██████████| 80/80 [00:04<00:00, 18.80it/s, loss=7.13e-6]


epoch: 34 done, loss: 9.901599895556501e-08
Acc: 0.1711


Epoch 35: 100%|██████████| 80/80 [00:04<00:00, 18.68it/s, loss=1.47e-6]


epoch: 35 done, loss: 9.01362682270701e-08
Acc: 0.1728


Epoch 36: 100%|██████████| 80/80 [00:04<00:00, 17.44it/s, loss=2.72e-6]


epoch: 36 done, loss: 6.944787145357623e-08
Acc: 0.1703


Epoch 37: 100%|██████████| 80/80 [00:04<00:00, 19.13it/s, loss=1.15e-6]


epoch: 37 done, loss: 7.222194398082138e-08
Acc: 0.1717


Epoch 38: 100%|██████████| 80/80 [00:04<00:00, 18.23it/s, loss=2.33e-6]


epoch: 38 done, loss: 1.174826707028842e-07
Acc: 0.1702


Epoch 39: 100%|██████████| 80/80 [00:04<00:00, 18.43it/s, loss=2.82e-6] 


epoch: 39 done, loss: 1.1879082961741005e-07
Acc: 0.1722


Epoch 40: 100%|██████████| 80/80 [00:04<00:00, 18.59it/s, loss=1.96e-6]


epoch: 40 done, loss: 7.384173272839689e-08
Acc: 0.1733


Epoch 41: 100%|██████████| 80/80 [00:04<00:00, 18.72it/s, loss=2.37e-6]


epoch: 41 done, loss: 6.737784730148633e-08
Acc: 0.1722


Epoch 42: 100%|██████████| 80/80 [00:04<00:00, 18.33it/s, loss=3.94e-6]


epoch: 42 done, loss: 6.059417501091957e-08
Acc: 0.1722


Epoch 43: 100%|██████████| 80/80 [00:04<00:00, 18.55it/s, loss=1.38e-6]


epoch: 43 done, loss: 4.366985351111907e-08
Acc: 0.1717


Epoch 44: 100%|██████████| 80/80 [00:04<00:00, 17.37it/s, loss=2.67e-6]


epoch: 44 done, loss: 4.4870272830621616e-08
Acc: 0.1729


Epoch 45: 100%|██████████| 80/80 [00:04<00:00, 17.93it/s, loss=2.29e-6]


epoch: 45 done, loss: 3.889899957698617e-08
Acc: 0.1706


Epoch 46: 100%|██████████| 80/80 [00:04<00:00, 18.69it/s, loss=2.25e-6]


epoch: 46 done, loss: 5.228087118780422e-08
Acc: 0.171


Epoch 47: 100%|██████████| 80/80 [00:04<00:00, 18.36it/s, loss=1.97e-6]


epoch: 47 done, loss: 3.156197792009152e-08
Acc: 0.1691


Epoch 48: 100%|██████████| 80/80 [00:04<00:00, 18.35it/s, loss=1.5e-6] 


epoch: 48 done, loss: 4.582392065799468e-08
Acc: 0.1709


Epoch 49: 100%|██████████| 80/80 [00:04<00:00, 18.32it/s, loss=1.16e-6]


epoch: 49 done, loss: 3.873773479767806e-08
Acc: 0.1697


Epoch 50: 100%|██████████| 80/80 [00:04<00:00, 18.32it/s, loss=4.03e-6]


epoch: 50 done, loss: 5.302634420445429e-08
Acc: 0.1721
Time taken: 410.695034 seconds


In [10]:
test(model, model_path, test_data_loader)

0.1731