# Import libraries

In [9]:
import os
import argparse
import random 
from util import Random
from configs import SupportedDatasets
from sas.approx_latent_classes import clip_approx
from sas.subset_dataset import SASSubsetDataset
import torch
from torch import nn 
import torchvision
from torchvision import transforms
from transformers import ViTMAEForPreTraining
from evaluate.lbfgs import encode_train_set, train_clf, test_clf
from trainer import Trainer
from PIL import Image
from data_proc.augmentation import ColourDistortion
from collections import namedtuple
from data_proc.dataset import *
from resnet import *

# Load data

In [10]:
Datasets = namedtuple('Datasets', 'trainset testset clftrainset num_classes stem')

def get_datasets(dataset: str, augment_clf_train=False, add_indices_to_data=False, num_positive=2):
    CACHED_MEAN_STD = {
        'cifar100': ((0.5071, 0.4865, 0.4409), (0.2009, 0.1984, 0.2023)),
    }

    PATHS = {
        'cifar100': '/data/cifar100/',
    }

    root = PATHS[dataset]

    # Data
    img_size = 224

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(img_size, interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(),
        ColourDistortion(s=0.5),
        transforms.ToTensor(),
        transforms.Normalize(*CACHED_MEAN_STD[dataset]),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(*CACHED_MEAN_STD[dataset]),
    ])

    transform_clftrain = transform_test
    trainset = testset = clftrainset = num_classes = stem = None
    
    if dataset == 'cifar100':
        dset = torchvision.datasets.CIFAR100
        trainset = CIFAR100Augment(root=root, train=True, download=True, transform=transform_train, n_augmentations=num_positive)
    clftrainset = dset(root=root, train=True, download=True, transform=transform_clftrain)
    testset = dset(root=root, train=False, download=True, transform=transform_test)
    num_classes = 100
    stem = StemCIFAR

    return Datasets(trainset=trainset, testset=testset, clftrainset=clftrainset, num_classes=num_classes, stem=stem)

In [11]:
datasets = get_datasets('cifar100')
clftrainloader = torch.utils.data.DataLoader(
    dataset=datasets.clftrainset,
    batch_size=512, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)
testloader = torch.utils.data.DataLoader(
    dataset=datasets.testset,
    batch_size=512, 
    shuffle=False, 
    num_workers=4,
    pin_memory=True,
)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


# Load model

In [12]:
model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')

In [20]:
class ProxyModel(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.encoder = net.vit
        self.representation_dim = 768
    def forward(self, inputs):
        return self.encoder(inputs).last_hidden_state[:,0,:].squeeze(1)

In [21]:
net = ProxyModel(model).to('cuda:0')

# Linear probing accuracy

In [22]:
trainer = Trainer(
    device='cuda:0',
    distributed=False,
    rank=0,
    world_size=1,
    net=net,
    critic=None,
    trainloader=None,
    clftrainloader=clftrainloader,
    testloader=testloader,
    num_classes=datasets.num_classes,
    optimizer=None,
)

In [23]:
test_acc = trainer.test()
print(test_acc)

Encoded 97/98: █████████████████████████████████| 98/98 [00:43<00:00,  2.27it/s]



L2 Regularization weight: 1e-05


Loss: 1.282 | Train Acc: 73.552% : ███████████| 100/100 [00:19<00:00,  5.02it/s]
Loss: 1.235 | Test Acc: 65.970% : ██████████████| 20/20 [00:09<00:00,  2.06it/s]

65.97



