# Load proxy model

In [29]:
from transformers import AutoImageProcessor, ViTMAEForPreTraining
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained('facebook/vit-mae-base')
model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')

In [59]:
from torch import nn 

class ProxyModel(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.encoder = net.vit
    def forward(self, inputs):
        return self.encoder(inputs).last_hidden_state[:,0,:].squeeze(1)

# Load Data

In [49]:
import torchvision
from torchvision import transforms

transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.ToTensor(),
        ])

# cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=processor)
cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=transform_train)
device = "cuda:0"



# Partition into approximate latent classes

In [56]:
from sas.approx_latent_classes import clip_approx
from sas.subset_dataset import SASSubsetDataset
import random 

rand_labeled_examples_indices = random.sample(range(len(cifar100)), 500)
rand_labeled_examples_labels = [cifar100[i][1] for i in rand_labeled_examples_indices]

partition = clip_approx(
    img_trainset=cifar100,
    labeled_example_indices=rand_labeled_examples_indices, 
    labeled_examples_labels=rand_labeled_examples_labels,
    num_classes=100,
    device=device
)

# Determine subset

In [60]:
proxy_model = ProxyModel(model)
     
subset_dataset = SASSubsetDataset(
    dataset=cifar100,
    subset_fraction=0.2,
    num_downstream_classes=100,
    device=device,
    proxy_model=proxy_model,
    approx_latent_class_partition=partition,
    verbose=True
)

Subset Selection:: 100%|██████████| 100/100 [00:03<00:00, 28.60it/s]

Subset Size: 10000
Discarded 40000 examples





# Save subset to file

In [6]:
import os
os.makedirs('subset_indices', exist_ok=True)
subset_dataset.save_to_file("subset_indices/cifar100-0.2-sas-indices-vit-mae-base-imagenet1k.pkl")