In [3]:
# Import necessary packages.
import numpy as np
import pandas as pd
import torch
import os
import torch.nn as nn
import torch.nn.utils.prune as prune
import torchvision.transforms as transforms
from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder, VisionDataset
from torchsummary import summary
from models.student import*

# This is for the progress bar.
from tqdm.auto import tqdm
import random

# "cuda" only when GPUs are available.
device = "cuda" if torch.cuda.is_available() else "cpu"

myseed = 20220013  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

dataset_root = './Food-11'
# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

test_tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])


class FoodDataset(Dataset):
    def __init__(self, path, tfm=test_tfm, files = None):
        super().__init__()
        self.path = path
        self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
        if files != None:
            self.files = files
        print(f"One {path} sample",self.files[0])
        self.transform = tfm
  
    def __len__(self):
        return len(self.files)
  
    def __getitem__(self,idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)
        #im = self.data[idx]
        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1 # test has no label
        return im,label

# Construct datasets.
# The argument "loader" tells how torchvision reads the data.
eval_set = FoodDataset(os.path.join(dataset_root, "validation"), tfm=test_tfm)
eval_loader = DataLoader(eval_set, batch_size=256, shuffle=False, num_workers=8, pin_memory=True)

def log(text):
    print(text)

criterion = nn.CrossEntropyLoss()

def test(model):
# Iterate the validation set by batches.

    model.eval()
    # These are used to record information in validation.
    eval_loss = []
    eval_accs = []
    eval_len = []
    for batch in tqdm(eval_loader):

        # A batch consists of image data and corresponding labels.
        imgs, labels = batch
        #imgs = imgs.half()

        # We don't need gradient in validation.
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(imgs.to(device))
        
        loss = criterion(logits, labels.to(device))

        # Compute the accuracy for current batch.
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().sum()

        # Record the loss and accuracy.
        batch_len = len(imgs)
        eval_loss.append(loss.item() * batch_len)
        eval_accs.append(acc)
        eval_len.append(batch_len)
        #break
    print(eval_len)
    # The average loss and accuracy for entire validation set is the average of the recorded values.
    eval_acc = sum(eval_accs) / sum(eval_len)
    eval_loss = sum(eval_loss) / sum(eval_len)

    # update logs

    log(f"[ Eval ] acc = {eval_acc:.5f}")
    log(f"[ Eval ] loss = {eval_loss:.5f}")


    log("Finish testing")
    return eval_acc.item()

pruning_ratios = [i/20 for i in range(20)]
accs = []
for ratio in pruning_ratios:
    # model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False, num_classes=11)
    model = resnet_dp_small(num_classes=11)
    # model = ResNet_simple(num_classes=11, output_whole_layers=False)
    # summary(model, (3, 224, 224), device='cpu')
    # ckpt_path = './output/pretrain/from_scratch_aug/best.ckpt'
    ckpt_path = './outputs/distill_small_10000/student_best.ckpt'
    # ckpt_path = './output/directly/simple_more_steps_wo_transform/best.ckpt'
    model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
    model.to(device)
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=ratio)
    accs.append(test(model))

df = pd.DataFrame()
df['ratio'] = pruning_ratios
df['accs'] = accs
df.to_csv('./pruning_results.csv', index=False)

One ./Food-11/validation sample ./Food-11/validation/0_0.jpg


100%|██████████| 18/18 [00:06<00:00,  2.93it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.87861
[ Eval ] loss = 0.52381
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.32it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.87387
[ Eval ] loss = 0.54669
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.42it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.85041
[ Eval ] loss = 0.65257
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.17it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.80144
[ Eval ] loss = 0.94405
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.23it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.71661
[ Eval ] loss = 1.55720
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.35it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.61823
[ Eval ] loss = 2.11371
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.28it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.32220
[ Eval ] loss = 6.46458
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.28it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.26986
[ Eval ] loss = 5.79069
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.37it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.25406
[ Eval ] loss = 8.53217
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.21it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.20171
[ Eval ] loss = 11.73199
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.13it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.17735
[ Eval ] loss = 9.98103
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.48it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.15117
[ Eval ] loss = 17.44975
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.23it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09409
[ Eval ] loss = 31.66205
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.42it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09770
[ Eval ] loss = 44.69655
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.14it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 50.69203
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.45it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 61.66537
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.20it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 41.68174
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.19it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 34.08330
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.20it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 29.82232
Finish testing


100%|██████████| 18/18 [00:05<00:00,  3.30it/s]


[256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 256, 80]
[ Eval ] acc = 0.09905
[ Eval ] loss = 35.56766
Finish testing
