In [1]:
import os
from pathlib import Path
import wget
from tqdm.auto import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision
import warnings
from matplotlib import pyplot as plt
from numpy.lib.format import open_memmap
import torchvision.transforms as transforms

warnings.filterwarnings("ignore")

In [2]:
CHECKPOINT_DIR = "cifar_checkpoints"

In [3]:
# Resnet9
class Mul(torch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight

    def forward(self, x):
        return x * self.weight


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class Residual(torch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module

    def forward(self, x):
        return x + self.module(x)


def construct_rn9(num_classes=10):
    def conv_bn(
        channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1
    ):
        return torch.nn.Sequential(
            torch.nn.Conv2d(
                channels_in,
                channels_out,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                bias=False,
            ),
            torch.nn.BatchNorm2d(channels_out),
            torch.nn.ReLU(inplace=True),
        )

    model = torch.nn.Sequential(
        conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
        conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
        Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
        conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
        torch.nn.MaxPool2d(2),
        Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
        conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
        torch.nn.AdaptiveMaxPool2d((1, 1)),
        Flatten(),
        torch.nn.Linear(128, num_classes, bias=False),
        Mul(0.2),
    )
    return model

In [4]:
def get_dataloader(
    batch_size=256, num_workers=8, split="train", shuffle=False, augment=True
):
    if augment:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.RandomAffine(0),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201)
                ),
            ]
        )
    else:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201)
                ),
            ]
        )

    is_train = split == "train"
    dataset = torchvision.datasets.CIFAR10(
        root="/tmp/cifar/", download=True, train=is_train, transform=transforms
    )

    loader = torch.utils.data.DataLoader(
        dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers
    )

    return loader

In [5]:
ckpt_files = sorted(list(Path(f"./{CHECKPOINT_DIR}").rglob("*.pt")))
ckpts = [torch.load(ckpt, map_location="cpu") for ckpt in ckpt_files]

In [6]:
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
model.load_state_dict(ckpts[-1])
model = model.eval()

In [7]:
loader = get_dataloader(split="val", augment=False)
model.eval()

with torch.no_grad():
    total_correct, total_num = 0.0, 0.0
    for ims, labs in tqdm(loader):
        ims = ims.cuda()
        labs = labs.cuda()
        with autocast():
            out = model(ims)
            total_correct += out.argmax(1).eq(labs).sum().cpu().item()
            total_num += ims.shape[0]

    print(f"Accuracy: {total_correct / total_num * 100:.1f}%")

Files already downloaded and verified


  0%|          | 0/40 [00:00<?, ?it/s]

Accuracy: 91.8%


In [8]:
batch_size = 128
loader_train = get_dataloader(batch_size=batch_size, split="train")

Files already downloaded and verified


In [9]:
from trak import TRAKer

traker = TRAKer(
    model=model,
    task="image_classification",
    proj_dim=4096,
    train_set_size=len(loader_train.dataset),
)

INFO:STORE:No existing model IDs in /home/frank/trak_results.
INFO:STORE:No existing TRAK scores in /home/frank/trak_results.


In [None]:
try:
    for model_id, ckpt in enumerate(tqdm(ckpts)):
        traker.load_checkpoint(ckpt, model_id=model_id)
        for batch in tqdm(loader_train):
            batch = [x.cuda() for x in batch]
            traker.featurize(batch=batch, num_samples=batch[0].shape[0])
    traker.finalize_features()
except Exception as e:
    pass

In [None]:
loader_targets = get_dataloader(batch_size=batch_size, split="val", augment=False)

In [None]:
for model_id, ckpt in enumerate(tqdm(ckpts)):
    traker.start_scoring_checkpoint(
        exp_name="quickstart",
        checkpoint=ckpt,
        model_id=model_id,
        num_targets=len(loader_targets.dataset),
    )
    for batch in loader_targets:
        batch = [x.cuda() for x in batch]
        traker.score(batch=batch, num_samples=batch[0].shape[0])

scores = traker.finalize_scores(exp_name="quickstart")

In [None]:
_scores = open_memmap("./trak_results/scores/quickstart.mmap")

In [None]:
ds_train = torchvision.datasets.CIFAR10(root="/tmp/cifar/", download=True, train=True)
ds_val = torchvision.datasets.CIFAR10(root="/tmp/cifar/", download=True, train=False)

In [None]:
IMAGE_DIR="images"
NUM_IMAGES = 10
NUM_SAMPLES = 5

indices = range(NUM_SAMPLES)

In [None]:
for i in indices:
    fig, axs = plt.subplots(ncols=(NUM_IMAGES + 1), figsize=(NUM_IMAGES * 2, 2))
    axs[0].imshow(ds_val[i][0])
    axs[0].axis("off")
    axs[0].set_title(f"Target \n Class: {ds_val[i][1]}")
    top_scorers = np.where(scores[:, i] > 0)[0]
    top_scorers = scores[:, i].argsort()[-NUM_IMAGES:][::-1]
    
    for ii, train_im_ind in enumerate(top_scorers):
        axs[ii + 1].set_title(f"{scores[train_im_ind, i]:.2f} \n Class: {ds_train[train_im_ind][1]}")
        axs[ii + 1].imshow(ds_train[train_im_ind][0])
        axs[ii + 1].axis("off")
        
fig.show()

In [None]:
for i in indices:
    fig, axs = plt.subplots(ncols=(NUM_IMAGES + 1), figsize=(NUM_IMAGES * 2, 2))
    axs[0].imshow(ds_val[i][0])
    axs[0].axis("off")
    axs[0].set_title(f"Target \n Class: {ds_val[i][1]}")
    low_scorers = np.where(scores[:, i] < 0)[0]
    low_scorers = scores[:, i].argsort()[:NUM_IMAGES][::-1]
    
    for ii, train_im_ind in enumerate(low_scorers):
        axs[ii + 1].set_title(f"{scores[train_im_ind, i]:.2f} \n Class: {ds_train[train_im_ind][1]}")
        axs[ii + 1].imshow(ds_train[train_im_ind][0])
        axs[ii + 1].axis("off")
        
fig.show()

In [None]:
os.makedirs(f"./{IMAGE_DIR}", exist_ok=True)

for i in indices:
    top_scorers = np.where(scores[:, i] > 0)[0]
    top_scorers = scores[:, i].argsort()[-NUM_IMAGES:][::-1]
    low_scorers = np.where(scores[:, i] < 0)[0]
    low_scorers = scores[:, i].argsort()[:NUM_IMAGES][::-1]

    transform = transforms.ToTensor()
    image = transform(ds_val[i][0])
    torchvision.utils.save_image(image, f"./{IMAGE_DIR}/{i}_{ds_val[i][1]}_target_.jpg")
    
    for ii, train_im_ind in enumerate(top_scorers):
        image = transform(ds_val[i][0])
        torchvision.utils.save_image(image, f"./{IMAGE_DIR}/{i}_{ds_val[i][1]}_{scores[train_im_ind, i]:.2f}_.jpg")

    for ii, train_im_ind in enumerate(low_scorers):
        image = transform(ds_val[i][0])
        torchvision.utils.save_image(image, f"./{IMAGE_DIR}/{i}_{ds_val[i][1]}_{scores[train_im_ind, i]:.2f}_.jpg")

In [None]:
num_top_scorers = 10
indices = range(50)

for i in indices:
    top_scorers = np.where(scores[:, i] > 0)[0]
    top_scorers = scores[:, i].argsort()[-num_top_scorers:][::-1]

    transform = transforms.ToTensor()
    image = transform(ds_val[i][0])
    torchvision.utils.save_image(image, f"target_{i}_{ds_val[i][1]}.jpg")
    
    for ii, train_im_ind in enumerate(top_scorers):
        image = transform(ds_val[i][0])
        torchvision.utils.save_image(image, f"{i}_{ds_val[i][1]}_{scores[train_im_ind, i]:.2f}.jpg")