In [41]:
%load_ext autoreload
%autoreload 2
%pdb
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Automatic pdb calling has been turned OFF


In [11]:
import sys
sys.path.insert(0, '../')
from infembed.embedder._core.fast_kfac_embedder import FastKFACEmbedder
import torchvision
from torch.utils.data import Subset, DataLoader, default_collate, Dataset
from torchvision.models import ResNet18_Weights, resnet18
import torch.nn as nn
from infembed.clusterer._core.sklearn_clusterer import SklearnClusterer
from infembed.clusterer._core.rule_clusterer import RuleClusterer
from sklearn.cluster import KMeans
from tqdm import tqdm
import pandas as pd
import torch
from typing import List
from infembed.utils.common import Data

### figure out device to compute embeddings on ###

In [3]:
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print('device:', DEVICE)

device: cuda:0


### define data
We will define the following:
- `eval_dataloader`: `DataLoader` for evaluation data.  This is used to compute embeddings for the evaluation data
- `eval_dataset`: `Dataset` for evaluation data.  This is used to retrieve individual examples for displaying.
- `train_dataloader`: `DataLoader` for training data.  This is needed to know how to compute embeddings for the evaluation data

In [4]:
normalize = ResNet18_Weights.IMAGENET1K_V1.transforms()

def collate_fn(examples):
    return tuple([_x.to(device=DEVICE) for _x in default_collate([(normalize(__x[0]), __x[1]) for __x in examples])])

BATCH_SIZE = 32
NUM_EVAL = 5000
eval_dataset = Subset(
    torchvision.datasets.ImageNet("../data/files/imagenet", split="val"),
    range(NUM_EVAL),
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE)

_train_data = torchvision.datasets.ImageNet("../data/files/imagenet", split="val")
NUM_TRAIN = 500
train_data = Subset(_train_data, torch.randperm(len(_train_data))[:NUM_TRAIN])
train_dataloader = DataLoader(train_data, collate_fn=collate_fn, batch_size=BATCH_SIZE)

### define model ###

In [None]:
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device=DEVICE)
# model.load_state_dict(ResNet18_Weights.IMAGENET1K_V1.get_state_dict(progress=True))
model.eval()

### define embedder ###

In [6]:
embedder = FastKFACEmbedder(
    model=model,
    layers=[
        "fc",
        # "layer4.0.conv1",
        # "layer4.0.conv2",
        # "layer4.0.downsample.0",
        # "layer4.1.conv1",
        "layer4.1.conv2",
    ],
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    sample_wise_grads_per_batch=True,
    projection_dim=50,
    projection_on_cpu=True,
    show_progress=True,
    per_layer_blocks=1,
)

### fit embedder ###

In [7]:
embedder.fit(train_dataloader)

processing `hessian_dataset` batch:   0%|                                                                                                                                                                                    | 0/16 [00:00<?, ?it/s]

processing `hessian_dataset` batch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:05<00:00,  2.94it/s]


### compute embeddings for evaluation data ###
we then package them into a `Data` instance, which contains all kinds of data that could possibly be used to do the subsequent clustering, i.e. including tabular metadata as well.

In [8]:
embeddings = embedder.predict(eval_dataloader)
data = Data(embeddings=embeddings)

Using FastKFACEmbedder to compute influence embeddings. Processing batch:   0%|          | 0/157 [00:00<?, ?it…

### define clusterer ###

In [9]:
clusterer = SklearnClusterer(sklearn_clusterer=KMeans(n_clusters=10))

### do the clustering ###

In [13]:
clusters = clusterer.fit_predict(data)

### compute metadata for evaluation data ###
this will be the ingredient needed to display the clusters.  later on, it will also be used by the rule-based clusterer.  therefore, we also add it to the running `Data` instance for easy access.

In [14]:
def _get_predictions_and_labels(_model, dataloader):
    dfs = []
    for batch in tqdm(dataloader):
        prediction_prob = (
            torch.nn.functional.softmax(_model(*batch[:-1]), dim=1)
            .detach()
            .to(device="cpu")
        )
        prediction_label = torch.argmax(prediction_prob, dim=1).to(device="cpu")
        label = batch[-1].to(
            device="cpu"
        )  # assuming batch is a tensor.  if not, can check
        dfs.append(
            pd.DataFrame(
                {
                    "prediction_label": prediction_label,
                    "label": label,
                    "prediction_prob": list(prediction_prob.numpy()),
                }
            )
        )
    df = pd.concat(dfs, axis=0)
    df.index = list(range(len(df)))
    return df

metadata = _get_predictions_and_labels(model, eval_dataloader)
data.metadata = metadata

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:45<00:00,  3.48it/s]


### define ways to display clusters ###
these will all be functions whose input is a list of list of indices in the evaluation dataset

In [15]:
from infembed.visualization._core.common import PerClusterDisplayer, DisplayAccuracy

displayers = [
    PerClusterDisplayer([
        DisplayAccuracy(prediction_col='prediction_label', label_col='label')
    ])
]

### display the clusters ###

In [16]:
for displayer in displayers:
    displayer(clusters, data)

cluster #0
accuracy: 0.18 (55/307)
cluster #1
accuracy: 0.27 (7/26)
cluster #2
accuracy: 0.84 (3710/4426)
cluster #3
accuracy: 0.16 (6/37)
cluster #4
accuracy: 0.33 (13/39)
cluster #5
accuracy: 0.20 (13/66)
cluster #6
accuracy: 0.27 (11/41)
cluster #7
accuracy: 0.00 (0/1)
cluster #8
accuracy: 0.10 (3/29)
cluster #9
accuracy: 0.32 (9/28)


### define rule clusterer ###

In [42]:
def _accuracy(data):
    return (data.metadata["prediction_label"] == data.metadata["label"]).mean()


def _size(data):
    return len(data)


rule_clusterer = RuleClusterer(
    clusterer_getter=lambda n_clusters: SklearnClusterer(KMeans(n_clusters=n_clusters)),
    cluster_rule=lambda data: _accuracy(data) < 0.2,
    stopping_rule=lambda data: _size(data) < 50,
    max_depth=5,
    branching_factor=2,
)

### do the rule clustering ###

In [43]:
rule_clusters = rule_clusterer.fit_predict(data)

### display the rule clusters

In [40]:
for displayer in displayers:
    displayer(rule_clusters, data)

cluster #0
accuracy: 0.18 (60/328)
