In [1]:
%load_ext autoreload
%autoreload 2
%pdb
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

Automatic pdb calling has been turned ON


In [2]:
import sys
# sys.path.insert(0, '../') # for some reason this stopped working
sys.path.insert(0, '/home/ubuntu/Documents/infembed')
from infembed.embedder._core.fast_kfac_embedder import FastKFACEmbedder
import torchvision
from torch.utils.data import Subset, DataLoader, default_collate, Dataset
from torchvision.models import ResNet18_Weights, resnet18
import torch.nn as nn
from infembed.clusterer._core.sklearn_clusterer import SklearnClusterer
from infembed.clusterer._core.rule_clusterer import RuleClusterer
from sklearn.cluster import KMeans
from tqdm import tqdm
import pandas as pd
import torch
torch.multiprocessing.set_start_method('spawn')
from typing import List
from infembed.utils.common import Data
from datetime import datetime
from clusterer._core.faiss_clusterer import FAISSClusterer

INFO:faiss.loader:Loading faiss with AVX2 support.
INFO:faiss.loader:Successfully loaded faiss with AVX2 support.


### figure out device to compute embeddings on ###

In [3]:
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print('device:', DEVICE)

device: cuda:0


### define data
We will define the following:
- `eval_dataloader`: `DataLoader` for evaluation data.  This is used to compute embeddings for the evaluation data when calling `predict`
- `eval_dataset`: `Dataset` for evaluation data.  This is used to retrieve individual examples for displaying.
- `train_dataloader`: `DataLoader` for training data.  This is needed to know how to compute embeddings for the evaluation data when calling `fit`

In [12]:
eval_dataset = torchvision.datasets.ImageNet("/home/ubuntu/Documents/infembed/files/imagenet", split="val")

normalize = ResNet18_Weights.IMAGENET1K_V1.transforms()


class ImagenetCollateFn:
    def __init__(self, device):
        self.device = device

    def __call__(self, examples):
        return tuple(
            [
                _x.to(device=self.device)
                for _x in default_collate(
                    [(normalize(__x[0]), __x[1]) for __x in examples]
                )
            ]
        )


collate_fn = ImagenetCollateFn(DEVICE)

BATCH_SIZE = 32
# NUM_WORKERS = 10
NUM_WORKERS = 0

eval_dataloader = DataLoader(eval_dataset, collate_fn=collate_fn, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

_train_data = torchvision.datasets.ImageNet(
    "/home/ubuntu/Documents/infembed/files/imagenet", split="val"
)
NUM_TRAIN = 5000
train_data = Subset(_train_data, torch.randperm(len(_train_data))[:NUM_TRAIN])
train_dataloader = DataLoader(train_data, collate_fn=collate_fn, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

### define model ###

In [13]:
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device=DEVICE)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### define embedder ###

In [14]:
embedder = FastKFACEmbedder(
    model=model,
    layers=[
        "fc",
        # "layer4.0.conv1",
        # "layer4.0.conv2",
        # "layer4.0.downsample.0",
        # "layer4.1.conv1",
        # "layer4.1.conv2",
    ],
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    sample_wise_grads_per_batch=True,
    projection_dim=100,
    projection_on_cpu=True,
    show_progress=True,
    per_layer_blocks=1,
)

### fit embedder ###

In [16]:
start_time = datetime.now()
embedder.fit(train_dataloader)
print(
    f"fit the embedder in {(datetime.now() - start_time).total_seconds() / 60.0} minutes"
)

INFO:root:compute training data statistics
processing `hessian_dataset` batch: 100%|██████████| 157/157 [00:42<00:00,  3.69it/s]
INFO:root:compute factors, first pass to get eigenvalue threshold
INFO:root:compute factors
INFO:root:compute factors for layer Linear(in_features=512, out_features=1000, bias=True)
INFO:root:compute factors, second pass to get eigenvalue threshold
INFO:root:compute factors
INFO:root:compute factors for layer Linear(in_features=512, out_features=1000, bias=True)


fit the embedder in 1.47524515 minutes


### compute embeddings for evaluation data ###
we then package them into a `Data` instance, which contains all kinds of data that could possibly be used to do the subsequent clustering, i.e. including tabular metadata as well.

In [17]:
start_time = datetime.now()
embeddings = embedder.predict(eval_dataloader)
data = Data(embeddings=embeddings)
print(
    f"computed the embeddings in {(datetime.now() - start_time).total_seconds() / 60.0} minutes"
)

Using FastKFACEmbedder to compute embeddings. Processing batch:   0%|          | 0/1563 [00:00<?, ?it/s]

INFO:root:compute embeddings


computed the embeddings in 7.39325845 minutes


### compute metadata for evaluation data ###
this will be the ingredient needed to display the clusters.  later on, it will also be used by the rule-based clusterer.

In [None]:
def _get_predictions_and_labels(_model, dataloader):
    dfs = []
    for batch in tqdm(dataloader):
        prediction_prob = (
            torch.nn.functional.softmax(_model(*batch[:-1]), dim=1)
            .detach()
            .to(device="cpu")
        )
        prediction_label = torch.argmax(prediction_prob, dim=1).to(device="cpu")
        label = batch[-1].to(
            device="cpu"
        )  # assuming batch is a tensor.  if not, can check
        dfs.append(
            pd.DataFrame(
                {
                    "prediction_label": prediction_label,
                    "label": label,
                    "prediction_prob": list(prediction_prob.numpy()),
                }
            )
        )
    df = pd.concat(dfs, axis=0)
    df.index = list(range(len(df)))
    return df

metadata = _get_predictions_and_labels(model, eval_dataloader)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:45<00:00,  3.48it/s]


add in column with human-readable prediction and label names to metadata, and look at metadata

In [None]:
import json
import pandas as pd

class_index_to_name = pd.read_csv(
    open("/home/ubuntu/Documents/infembed/files/imagenet/imagenet_classes.txt", "r"),
#     sep=" ",
    index_col=None,
    header=None,
    
)
class_index_to_name.columns = ['name']
class_index_to_name.index = list(range(len(class_index_to_name)))

def rename(index):
    return class_index_to_name.loc[index]['name']

metadata['prediction_label_name'] = metadata['prediction_label'].apply(rename)
metadata['label_name'] = metadata['label'].apply(rename)
metadata

### define clusterer ###

In [None]:
if False:
    clusterer = SklearnClusterer(sklearn_clusterer=KMeans(n_clusters=25))
if True:
    clusterer = FAISSClusterer(k=25, spherical=True)

### do the clustering ###

In [None]:
clusters = clusterer.fit_predict(data)

### define ways to display clusters ###
these will all be functions whose input is a list of list of indices in the evaluation dataset

In [None]:
from infembed.visualization._core.common import PerClusterDisplayer, DisplayAccuracy

displayers = [
    PerClusterDisplayer([
        DisplayAccuracy(prediction_col='prediction_label', label_col='label')
    ])
]

### display the clusters ###

In [None]:
for displayer in displayers:
    displayer(clusters, data)

cluster #0
accuracy: 0.18 (55/307)
cluster #1
accuracy: 0.27 (7/26)
cluster #2
accuracy: 0.84 (3710/4426)
cluster #3
accuracy: 0.16 (6/37)
cluster #4
accuracy: 0.33 (13/39)
cluster #5
accuracy: 0.20 (13/66)
cluster #6
accuracy: 0.27 (11/41)
cluster #7
accuracy: 0.00 (0/1)
cluster #8
accuracy: 0.10 (3/29)
cluster #9
accuracy: 0.32 (9/28)


### define rule clusterer ###

In [None]:
def _accuracy(data):
    return (data.metadata["prediction_label"] == data.metadata["label"]).mean()


def _size(data):
    return len(data)


rule_clusterer = RuleClusterer(
    clusterer_getter=lambda n_clusters: FAISSClusterer(k=n_clusters, spherical=True),
    cluster_rule=lambda data: _accuracy(data) < 0.2,
    stopping_rule=lambda data: _size(data) < 50,
    max_depth=7,
    branching_factor=3,
)

### do the rule clustering ###

In [None]:
rule_clusters = rule_clusterer.fit_predict(data)

### define ways to display clusters

In [None]:
from infembed.visualization._core.common import (
    DisplayMetadata,
    DisplayPIL,
    DisplayPredictionAndLabels,
    DisplaySingleExamples,
    PerClusterDisplayer,
    DisplayAccuracy,
)

rule_displayers = [
    PerClusterDisplayer(
        [
            DisplayAccuracy(prediction_col="prediction_label", label_col="label"),
            DisplayPredictionAndLabels(
                prediction_col="prediction_label_name", label_col="label_name"
            ),
            # DisplaySingleExamples(
            #     [
            #         DisplayMetadata(["label_name", "prediction_label_name"]),
            #         DisplayPIL(),
            #     ],
            #     limit=3,
            # ),
        ]
    )
]

### display the rule clusters

In [None]:
for displayer in rule_displayers:
    displayer(rule_clusters, data)

cluster #0
accuracy: 0.18 (60/328)
