In [1]:
import torch
import torchvision.datasets
import torchvision.models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
from train_utils import get_data_loaders, train, test
from PIL import Image
import pandas as pd

In [2]:
root = './distilled_dataset'

### CIFAR10 Distilled

In [3]:
class CIFAR10_Distilled(Dataset):
    def __init__(self, idx_df, transform):
        self.idx_df = idx_df
        self.transform = transform
    
    def __len__(self):
        return len(self.idx_df)

    def __getitem__(self, index):
        entry = self.idx_df.iloc[index]
        image = self.transform(Image.open(entry.image_path))
        label = entry.label
        return image, torch.tensor(label)

In [4]:
index_file = pd.read_csv('./CIFAR10_DM_index_files/ipc50_idx.csv')
index_file

Unnamed: 0,image_path,label
0,./CIFAR10_DM/ipc500/airplane/006365cd-0d76-4f8...,0
1,./CIFAR10_DM/ipc500/airplane/00b2e5ad-bd19-4ab...,0
2,./CIFAR10_DM/ipc500/airplane/01163b19-375e-4ce...,0
3,./CIFAR10_DM/ipc500/airplane/0184205a-6931-468...,0
4,./CIFAR10_DM/ipc500/airplane/0203df52-3303-497...,0
...,...,...
5115,./CIFAR10_DM/ipc500/truck/fd6ff94f-754f-40e5-b...,9
5116,./CIFAR10_DM/ipc500/truck/fdc54d46-d726-4b14-a...,9
5117,./CIFAR10_DM/ipc500/truck/fe77baa8-1be7-4192-b...,9
5118,./CIFAR10_DM/ipc500/truck/fe787924-9859-41ae-b...,9


In [5]:
image_normalize= transforms.Lambda(lambda img: 2 * img / 255. - 1) # normalize to [-1, 1]
# https://github.com/Lornatang/pytorch-alexnet-cifar100/blob/master/utils/datasets.py
# transform = transforms.Compose([transforms.Resize(64),
#                               transforms.RandomHorizontalFlip(),
#                               transforms.ToTensor(),
#                               transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

transform = transforms.Compose([transforms.Resize(64),
                                transforms.ToTensor()])
cifar10_distilled_train = CIFAR10_Distilled(index_file, transform)

cifar10_root = '../cifar10'
cifar10_test = torchvision.datasets.CIFAR10(cifar10_root,
                                            train=False,
                                            transform=transforms.Compose([
                                                transforms.Resize(64),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                              ]))

In [6]:
batch_size = 32 # cut batch size to half for ipc50
train_data_loader, test_data_loader = get_data_loaders(cifar10_distilled_train, cifar10_test, batch_size=64)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[INFO]: Computation device: {device}")

[INFO]: Computation device: cuda


In [8]:
model_name = 'alexnet'
dataset = 'cifar10_distilled'
model_path = f'./model_weight/{dataset}/{model_name}/'
model = torchvision.models.alexnet(num_classes = 100).to(device)

In [9]:
train(model, model_path, train_data_loader, test_data_loader)

  0%|          | 0/80 [00:03<?, ?it/s]


RuntimeError: Given input size: (256x1x1). Calculated output size: (256x0x0). Output size is too small

In [None]:
test(model, model_path, test_data_loader)

100%|██████████| 157/157 [00:06<00:00, 24.19it/s]

acuuracy: 0.4281





0.4281