In [2]:
import torch, torchvision
import torchvision.transforms as T
import tqdm

# Load CIFAR100 test dataset
transform = T.Compose(
    [
        T.ToTensor(),
        T.Normalize(mean=[0.507, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    ]
)
test_dataset = torchvision.datasets.CIFAR100(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

# Load CIFAR100 pretrained model
model = torch.hub.load(
    "chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data


Using cache found in /Users/ayb/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [3]:


def check_accuracy(loader, model):
    """Check top 5 accuracy of a model on a given dataset"""
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in tqdm.tqdm(loader):
            scores = model(x)
            _, preds = scores.topk(5, 1, True, True)
            preds = preds.t()
            correct = preds.eq(y.view(1, -1).expand_as(preds))
            num_correct += correct[:5].reshape(-1).float().sum(0, keepdim=True)
            num_samples += x.shape[0]

        acc = float(num_correct) / num_samples
        return acc


# Check top 5 accuracy of the model on the test dataset
acc = check_accuracy(test_loader, model)
print(acc)


100%|██████████| 79/79 [01:06<00:00,  1.19it/s]

0.9194





In [None]:
# Prune the model and make experiments according to the assignment pdf
import torch.nn.utils.prune as prune

##FILL HERE
