In [1]:
import os
os.chdir('dl-model-extraction')

In [5]:
from victim.cifar100_models.models import CIFAR100Module
from utils import DEVICE
from victim import RESNET50, VGG19_BN

In [3]:
import os
import zipfile

import pytorch_lightning as pl
import requests
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR100
from tqdm import tqdm


class CIFAR100Data(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.batch_size = args['batch_size']
        self.num_workers = args['num_workers']
        
        self.mean = (0.5071, 0.4867, 0.4408)
        self.std = (0.2675, 0.2565, 0.2761)

    def train_dataloader(self):
        transform = T.Compose(
            [
                T.RandomCrop(32, padding=4),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR100(root='.', train=True, transform=transform, download=True)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
        )
        return dataloader

    def val_dataloader(self):
        transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(self.mean, self.std),
            ]
        )
        dataset = CIFAR100(root='.', train=False, transform=transform, download=True)
        dataloader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=True,
        )
        return dataloader

    def test_dataloader(self):
        return self.val_dataloader()

In [4]:
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger


def main(args):
    seed_everything(0)
    checkpoint = ModelCheckpoint(monitor="acc/val", mode="max", save_last=False)

    trainer = Trainer(
        fast_dev_run=False,
        gpus=-1,
        deterministic=True,
        weights_summary=None,
        log_every_n_steps=1,
        max_epochs=args['max_epochs'],
        checkpoint_callback=checkpoint,
        precision=32,
    )

    data = CIFAR100Data(args)
    model = CIFAR100Module(args, len(data.train_dataloader()))
    
    model.model.to(DEVICE)
#     if bool(args.pretrained):
#         state_dict = os.path.join(
#             "cifar10_models", "state_dicts", args.classifier + ".pt"
#         )
#         model.model.load_state_dict(torch.load(state_dict))
    trainer.fit(model, data.train_dataloader(), data.test_dataloader())
#     print(trainer.callback_metrics["val_acc"])
    return trainer, model, data

In [5]:
trainer, model, data = main({
            "model_name": VGG19_BN,
            "batch_size": 256,
            "max_epochs": 100,
            "num_workers": 2,
            "learning_rate": 1e-2,
            "weight_decay": 1e-2
        })

Global seed set to 0
  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [7]:
import numpy as np
li = []
for x, y in data.test_dataloader():
    y_pred = torch.argmax(model.model(x), dim=1)
    li.append(torch.sum(y_pred == y).cpu().detach().numpy())
np.sum(li)/len(data.test_data_loader())

Files already downloaded and verified


NameError: name 'np' is not defined

In [15]:
torch.save(model, "vgg19bn_cifar100.pt")

In [16]:
tmp = torch.load("resnet50_cifar100.pt")

In [21]:
for x,y in data.test_dataloader():
    y_pred = torch.argmax(tmp.model(x), dim=1)
    break

Files already downloaded and verified


In [22]:
y_pred

tensor([55, 33, 55, 51, 71, 79, 29, 75, 23,  0, 71, 75, 81, 69, 40, 43, 92, 97,
        70, 53, 70, 49, 33, 29, 35, 16, 39,  8,  8, 84, 20, 10, 41, 67, 56, 64,
        58, 35, 25, 37, 63, 73, 49, 30, 56, 22, 41, 58, 44, 17,  4,  6,  9, 57,
         2, 32, 71, 52, 42, 69, 77, 27,  4, 33, 62, 98, 43,  6, 63, 54, 66, 90,
        67, 91, 67, 32, 82, 10, 77, 61, 71, 78, 54,  6, 44, 89, 78, 85, 35, 67,
        22, 18, 27, 21, 13, 21, 50, 75, 37, 15, 26, 83, 96, 86, 43, 69, 76, 17,
        57, 59, 25, 20, 27,  0,  9, 71, 48, 43, 57, 56, 85, 45, 19, 92, 33, 20,
         3, 27, 70, 46, 46, 16,  1, 74,  3, 91, 60,  3, 52, 23,  4, 11, 52, 29,
        24, 95, 13, 39, 51, 58, 58, 77, 37, 60, 45, 66, 85, 20, 77, 80, 36,  8,
        87, 10, 98, 59, 54, 99, 51, 83,  9, 33,  4, 83, 95, 45, 24, 73, 18, 40,
        39, 66, 22, 80, 16, 28, 25, 95, 98, 83, 12,  7, 78, 13, 94, 24, 90, 42,
         7, 87,  6, 78, 68, 60,  6, 23, 44, 31, 80, 66, 72, 11, 49, 90, 97, 96,
        53, 30, 16, 81, 94, 27, 25, 77, 

In [23]:
y

tensor([49, 33, 72, 51, 71, 92, 15, 14, 23,  0, 71, 75, 81, 69, 40, 43, 92, 97,
        70, 53, 70, 49, 75, 29, 21, 16, 39,  8,  8, 70, 20, 61, 41, 93, 56, 73,
        58, 11, 25, 37, 63, 24, 49, 73, 56, 22, 41, 58, 75, 17,  4,  6,  9, 57,
         2, 32, 71, 52, 42, 69, 77, 27, 15, 65,  7, 35, 43, 82, 63, 92, 66, 90,
        67, 91, 32, 32, 82, 10, 77, 22, 71, 78, 54,  6, 29, 89, 78, 33, 11, 67,
        22, 18, 27, 21, 13, 21, 50, 75, 37, 35, 26, 83, 47, 95, 43, 69, 76, 17,
        57, 59, 25, 20, 27,  0,  9, 71,  8, 43, 57, 56, 85, 10, 19, 92, 33, 20,
        21, 50, 70, 46, 11, 16,  1, 74, 33, 91, 60, 64, 52, 23,  4, 11, 52, 37,
        24, 95, 25, 39, 51, 58, 58, 77, 18, 59, 45, 66, 58, 20, 24,  4, 36,  8,
        87, 10, 30, 47, 54, 99, 51, 83,  9, 37,  4, 83, 95, 83, 32, 73, 18, 40,
        39, 64, 22, 80, 28, 28, 40, 95, 98, 83, 12, 24, 45, 13, 94, 24, 58, 63,
         7, 87,  6, 78, 68, 60,  6, 23, 44, 31, 80, 93, 73, 98, 49, 90, 97, 59,
         2, 67, 16, 81, 94, 27, 76, 77, 