In [1]:
# -*- coding: utf-8 -*
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
import os
import multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from tqdm import tqdm
from tensors_dataset_path import TensorDatasetPath
from tensors_dataset_img import TensorDatasetImg
import random
import sys
from utils import *
from models import *
from data_transform import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Setup reprouducible environment

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)

setup_seed(20)

In [2]:
params = read_config()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [3]:
model_name = params["model"]
model_set = {
    "resnets": ResNetS(nclasses=10),
    "vgg_face": VGG_16(),
    "gtsrb": gtsrb(),
    "resnet50": models.resnet50(),
}
model_name = params["model"]
model_set = {
    "resnets": ResNetS(nclasses=10),
    "vgg_face": VGG_16(),
    "gtsrb": gtsrb(),
    "resnet50": models.resnet50(),
}
print("model_name: ", model_name)
model = model_set[model_name]

ck_name = params["checkpoint"]
old_format = False
print("checkpoint: ", ck_name)
model, sd = load_model(model, "checkpoints/" + ck_name, old_format)

model_name:  resnets
checkpoint:  resnets_clean


In [6]:
if torch.cuda.is_available():
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
model.to(device)

for name, value in model.named_parameters():
    if name == "layer4.0.conv1.weight":
        break
    value.requires_grad = False

model.eval()

print('model loaded')

model loaded


In [7]:
# Load training dataset

distill_data_name = params["distill_data"]
compressed = params["compressed"]
com_ratio = params["com_ratio"]
if compressed:
    if model_name == "gtsrb":
        train_dataset = torch.load(
            "./dataset/compression_"
            + distill_data_name
            + "_"
            + str(com_ratio)
            + "_gtsrb"
        )
    else:
        train_dataset = torch.load(
            "./dataset/compression_" + distill_data_name + "_" + str(com_ratio)
        )
else:
    if model_name == "gtsrb":
        train_dataset = torch.load("./dataset/distill_" + distill_data_name + "_gtsrb")
    else:
        train_dataset = torch.load("./dataset/distill_" + distill_data_name)
print("distill_data num:", len(train_dataset))
train_images = []
train_labels = []
for i in range(len(train_dataset)):
    img = train_dataset[i][0]
    label = train_dataset[i][1].cpu()
    train_images.append(img)
    train_labels.append(label)
train_images = np.array(train_images)
train_labels = np.array(train_labels)

# train_images = np.load('train_images.npy', allow_pickle = True)
# train_labels = np.load('train_images.npy', allow_pickle = True)
print("load train data finished")

print(type(train_images), type(train_images[0]))
print(type(train_labels), type(train_labels[0]))

distill_data num: 20000
load train data finished
<class 'numpy.ndarray'> <class 'PIL.Image.Image'>
<class 'numpy.ndarray'> <class 'torch.Tensor'>


  train_images = np.array(train_images)
  train_images = np.array(train_images)
  train_labels = np.array(train_labels)
  train_labels = np.array(train_labels)


In [None]:
dataset_name = params["data"]

if dataset_name == "VGGFace":
    test_images, test_labels = get_dataset_vggface("./dataset/VGGFace/", max_num=10)
elif dataset_name == "tiny-imagenet-200":
    testset = torchvision.datasets.ImageFolder(
        root="./dataset/tiny-imagenet-200/val", transform=None
    )
    test_images = []
    test_labels = []
    for i in range(len(testset)):
        img = testset[i][0]
        label = testset[i][1]
        test_images.append(img)
        test_labels.append(label)
    test_images = np.array(test_images)
    test_labels = np.array(test_labels)
elif dataset_name == "cifar10":
    _dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
    test_images = [_dataset[i][0] for i in range(len(_dataset))]
    test_labels = _dataset.targets
else:
    test_images, test_labels = get_dataset("./dataset/" + dataset_name + "/test/")


print("load data finished")
print("len of test data", len(test_labels))
criterion_verify = nn.CrossEntropyLoss()