In [20]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Grayscale, Normalize, Resize
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18


from transights.utils import FolderScanner as fs
from transights.utils import Pickler
from transights.transforms import (FileToPIL,
                             PILToNumpy,
                             FlattenArray,
                             DebugTransform,
                             ProjectTransform,
                             PyTorchOutput,
                             PyTorchEmbedding,
                             ToDevice,
                             FlattenTensor,
                             CachingTransform)

from transights.aggregator import DataLoaderAggregator

random_state = 23

In [12]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

Running on device: CUDA


In [13]:
import requests

# download pre-trained weights
response = requests.get(
    "https://unlearning-challenge.s3.eu-west-1.amazonaws.com/weights_resnet18_cifar10.pth"
)
open("weights_resnet18_cifar10.pth", "wb").write(response.content)
weights_pretrained = torch.load("weights_resnet18_cifar10.pth", map_location=DEVICE)

# load model with pre-trained weights
#model = resnet18(weights=None, num_classes=10)
model = resnet18(num_classes=10)
model.load_state_dict(weights_pretrained)

<All keys matched successfully>

In [21]:
DATA_PATH = Path(r"E:\Dropbox\git\CIFAR10")
DATA_PATH_TRAIN = Path(DATA_PATH, "train")
DATA_PATH_TEST = Path(DATA_PATH, "test")

# Create the transformation pipeline
transform_pipeline = Compose([
    ToTensor(),
    Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ToDevice(DEVICE),
    PyTorchEmbedding(model, device=DEVICE),
    ToDevice('cpu'),
    FlattenTensor(),
])

dataset = CIFAR10(root=DATA_PATH_TEST,
                  train=False,
                  transform=transform_pipeline,
                  download=True)

Files already downloaded and verified


In [22]:
dataset

Dataset CIFAR10
    Number of datapoints: 10000
    Root location: E:\Dropbox\git\CIFAR10\test
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
               ToDevice()
               PyTorchEmbedding()
               ToDevice()
               FlattenTensor()
           )

In [38]:
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=2,
                                          shuffle=False,
                                          num_workers=4)

In [39]:
mini_batches = []

# Iterate over the DataLoader and aggregate all mini batches
for batch in data_loader:
    mini_batches.append(batch)


In [50]:
len(mini_batches)

5000

In [60]:
from transights.dataset import GenericDataset

isinstance(data_loader.dataset, GenericDataset)

False

In [47]:
full = data_loader.collate_fn(mini_batches)

In [59]:
data_loader.*?

In [51]:
len(full[0])

5000

In [46]:
mini_batches[1][0][1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [41]:
mini_batches[0][1]

tensor([3, 8])

In [34]:
type(mini_batches[0])

list

In [25]:
agg = DataLoaderAggregator(data_loader)

In [26]:
test_embedding_result = agg.transform()#cache_file=test_embedding_pickle_file)

AttributeError: 'list' object has no attribute 'items'

In [None]:
test_embedding_result