# Load Data

In [1]:
import torchvision
from torchvision import transforms

cifar100 = torchvision.datasets.CIFAR100("/data/cifar100/", transform=transforms.ToTensor())
device = "cuda:6"

# Partition into approximate latent classes

In [2]:
from sas.approx_latent_classes import clip_approx
from sas.subset_dataset import SASSubsetDataset
import random 

rand_labeled_examples_indices = random.sample(range(len(cifar100)), 500)
rand_labeled_examples_labels = [cifar100[i][1] for i in rand_labeled_examples_indices]

partition = clip_approx(
    img_trainset=cifar100,
    labeled_example_indices=rand_labeled_examples_indices, 
    labeled_examples_labels=rand_labeled_examples_labels,
    num_classes=100,
    device=device
)

# Load proxy model

In [3]:
from torch import nn 

class ProxyModel(nn.Module):
    def __init__(self, net, critic):
        super().__init__()
        self.net = net
        self.critic = critic
    def forward(self, x):
        return self.critic.project(self.net(x))

# Determine subset

In [4]:
import torch 

net = torch.load("ckpt/proxy-cifar100-resnet10-399-net.pt")
critic = torch.load("ckpt/proxy-cifar100-resnet10-399-critic.pt")
proxy_model = ProxyModel(net, critic)
     
subset_dataset = SASSubsetDataset(
    dataset=cifar100,
    subset_fraction=0.2,
    num_downstream_classes=100,
    device=device,
    proxy_model=proxy_model,
    approx_latent_class_partition=partition,
    verbose=True
)

Subset Selection:: 100%|██████████| 99/99 [00:02<00:00, 34.97it/s]

Subset Size: 10000
Discarded 40000 examples





# Save subset to file

In [6]:
import os
os.makedirs('subset_indices', exist_ok=True)
subset_dataset.save_to_file("subset_indices/cifar100-0.2-sas-indices.pkl")