# Quanda Quickstart Tutorial

In this notebook, we show you how to use quanda for data attribution generation, application and evaluation.

Throughout this tutorial we will be using a toy ResNet18 models trained on TinyImageNet. We will add a few "special features" to the dataset:
- We group all the cat classes into a single "cat" class, and all the dog classes into a single "dog" class.
- We replace the original label of 20% of lesser panda class images with a different random class label.
- We add 200 images of a goldfish from the ImageNet-Sketch dataset to the training set under the label "basketball", thereby inducing a backdoor attack.

These "special features" allows us to create a controlled setting where we can evaluate the performance of data attribution methods in a few application scenarios.

## Dataset Construction

We first download the dataset:

!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip tiny-imagenet-200.zip

In [1]:
import pytorch_lightning as pl
import torch
import torchvision.transforms as transforms
from nltk.corpus import wordnet as wn
from PIL import Image
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torchmetrics.functional import accuracy
from torchvision.models import resnet18

In [2]:
from quanda.utils.datasets.transformed import (
    LabelFlippingDataset,
    LabelGroupingDataset,
    SampleTransformationDataset,
)
from tutorials.utils.datasets import AnnotatedDataset, CustomDataset

In [3]:
torch.set_float32_matmul_precision("medium")

In [4]:
local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200"
goldfish_sketch_path = "/data1/datapool/sketch"
save_dir = "/home/bareeva/Projects/data_attribution_evaluation/assets"

In [5]:
n_classes = 200
batch_size = 64
num_workers = 8

rng = torch.Generator().manual_seed(42)

In [6]:
# Load the TinyImageNet dataset
regular_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)

id_dict = {}
with open(local_path + "/wnids.txt", "r") as f:
    id_dict = {line.strip(): i for i, line in enumerate(f)}
    
val_annotations = {}
with open(local_path + "/val/val_annotations.txt", "r") as f:
    val_annotations = {line.split("\t")[0]: line.split("\t")[1] for line in f}
    
train_set = CustomDataset(local_path + "/train", classes=list(id_dict.keys()), classes_to_idx=id_dict, transform=None)

holdout_set = AnnotatedDataset(
    local_path=local_path + "/val", transforms=None, id_dict=id_dict, annotation=val_annotations
)
test_set, val_set = torch.utils.data.random_split(holdout_set, [0.5, 0.5], generator=rng)

### Grouping Classes: Cat and Dog

In [7]:
# find all the classes that are in hyponym paths of "cat" and "dog"

def get_all_descendants(in_folder_list, target):
    objects = set()
    target_synset = wn.synsets(target, pos=wn.NOUN)[0]  # Get the target synset
    for folder in in_folder_list:
            synset = wn.synset_from_pos_and_offset("n", int(folder[1:]))
            if target_synset.name() in str(synset.hypernym_paths()):
                objects.add(folder)
    return objects

tiny_folders = list(id_dict.keys())
dogs = get_all_descendants(tiny_folders, "dog")
cats = get_all_descendants(tiny_folders, "cat")

In [8]:
# create class-to-group mapping for the dataset
no_cat_dogs_ids = [id_dict[k] for k in id_dict if k not in dogs.union(cats)]

class_to_group = {k: i for i, k in enumerate(no_cat_dogs_ids)}
class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs})
class_to_group.update({id_dict[k]: len(class_to_group) for k in cats})

new_n_classes = len(class_to_group) + 2

In [9]:
# create name to class label mapping
def folder_to_name(folder):
    return wn.synset_from_pos_and_offset("n", int(folder[1:])).lemmas()[0].name()

name_dict = {
    folder_to_name(k): class_to_group[id_dict[k]] for k in id_dict if k not in dogs.union(cats)
}

In [10]:
print("Class label of basketball: ", name_dict["basketball"])
print("Class label of lesser panda: ", name_dict["lesser_panda"])

Class label of basketball:  5
Class label of lesser panda:  41


### Loading Backdoor Samples of Sketch Goldfish

In [11]:
backdoor_transforms = transforms.Compose(
    [transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)

goldfish_dataset = CustomDataset(
    goldfish_sketch_path, classes=["n02510455"], classes_to_idx={"n02510455": 5}, transform=backdoor_transforms
)
goldfish_set, goldfish_val, _ = torch.utils.data.random_split(
    goldfish_dataset, [200, 20, len(goldfish_dataset) - 220], generator=rng
)

### Adding a Shortcut: Yellow Square

In [12]:
def add_yellow_square(img):
    square_size = (3, 3)  # Size of the square
    yellow_square = Image.new("RGB", square_size, (255, 255, 0))  # Create a yellow square
    img.paste(yellow_square, (10, 10))  # Paste it onto the image at the specified position
    return img

### Combining All the Special Features

In [13]:
def flipped_group_dataset(
    train_set,
    n_classes,
    new_n_classes,
    regular_transforms,
    seed,
    class_to_group,
    label_flip_class,
    shortcut_class,
    shortcut_fn,
    p_shortcut,
    p_flipping,
    backdoor_dataset,
):
    group_dataset = LabelGroupingDataset(
        dataset=train_set,
        n_classes=n_classes,
        dataset_transform=None,
        class_to_group=class_to_group,
        seed=seed,
    )
    flipped = LabelFlippingDataset(
        dataset=group_dataset,
        n_classes=new_n_classes,
        dataset_transform=None,
        p=p_flipping,
        cls_idx=label_flip_class,
        seed=seed,
    )

    sc_dataset = SampleTransformationDataset(
        dataset=flipped,
        n_classes=new_n_classes,
        dataset_transform=regular_transforms,
        p=p_shortcut,
        cls_idx=shortcut_class,
        seed=seed,
        sample_fn=shortcut_fn,
    )

    return torch.utils.data.ConcatDataset([backdoor_dataset, sc_dataset])

In [14]:
train_set = flipped_group_dataset(
    train_set,
    n_classes,
    new_n_classes,
    regular_transforms,
    seed=42,
    class_to_group=class_to_group,
    label_flip_class=41,  # flip lesser goldfish
    shortcut_class=162,  # shortcut pomegranate
    shortcut_fn=add_yellow_square,
    p_shortcut=0.2,
    p_flipping=0.2,
    backdoor_dataset=goldfish_set,
)  # sketchy goldfish(20) is basketball(5)

val_set = flipped_group_dataset(
    val_set,
    n_classes,
    new_n_classes,
    regular_transforms,
    seed=42,
    class_to_group=class_to_group,
    label_flip_class=41,  # flip lesser goldfish
    shortcut_class=162,  # shortcut pomegranate
    shortcut_fn=add_yellow_square,
    p_shortcut=0.2,
    p_flipping=0.0,
    backdoor_dataset=goldfish_val,
)  # sketchy goldfish(20) is basketball(5)

test_set = LabelGroupingDataset(
    dataset=test_set,
    n_classes=n_classes,
    dataset_transform=regular_transforms,
    class_to_group=class_to_group,
)

### Creating DataLoaders

In [15]:
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)


## Model and Training Set-Up

In [16]:
# Load ResNet18 model
model = resnet18(pretrained=False, num_classes=n_classes)

model.to("cuda:0")
model.train()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### Training

In [17]:
# Lightning Module
class LitModel(pl.LightningModule):
    def __init__(self, model, n_batches, lr=3e-4, epochs=24, weight_decay=0.01, num_labels=64):
        super(LitModel, self).__init__()
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.weight_decay = weight_decay
        self.n_batches = n_batches
        self.criterion = CrossEntropyLoss()
        self.num_labels = num_labels

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        ims, labs = batch
        ims = ims.to(self.device)
        labs = labs.to(self.device)
        out = self.model(ims)
        loss = self.criterion(out, labs)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"test_acc": acc, "test_loss": loss}
        self.log_dict(metrics)
        return metrics

    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.criterion(y_hat, y)
        acc = accuracy(y_hat, y, task="multiclass", num_classes=self.num_labels)
        return loss, acc

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return [optimizer]

In [18]:
n_epochs = 200

checkpoint_callback = ModelCheckpoint(
    dirpath="/home/bareeva/Projects/data_attribution_evaluation/assets/",
    filename="tiny_imagenet_resnet18_epoch_{epoch:02d}",
    every_n_epochs=10,
    save_top_k=-1,
)

In [19]:
# initialize the trainer
trainer = Trainer(
    callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=10)],
    devices=1,
    accelerator="gpu",
    max_epochs=n_epochs,
    enable_progress_bar=True,
    precision=16,
)

/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Train the model
lit_model = LitModel(model=model, n_batches=len(train_dataloader), num_labels=n_classes, epochs=n_epochs)
trainer.fit(lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /home/bareeva/Projects/data_attribution_evaluation/assets/ exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 11.3 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.3 M    Trainable params
0         Non-trainable params
11.3 M    Total params
45.116    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
torch.save(
    lit_model.model.state_dict(), save_dir + "/tiny_imagenet_resnet18.pth"
)
trainer.save_checkpoint(save_dir + "/tiny_imagenet_resnet18.ckpt")

### Testing

In [None]:
trainer.test(dataloaders=test_dataloader, ckpt_path="last")