In [None]:
import os
import logging
from pathlib import Path

from dotenv import dotenv_values
import wandb
import numpy as np
import torch
import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

from exrep.registry import load_data, save_data, load_model, load_tensor, get_artifact, save_tensor

if 'notebooks' in os.getcwd():
    os.chdir("../")

local_config = dotenv_values(".env")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

random_state = 42

embedding_artifact_name = "imagenet-1k-first-20-take-2000_target-embeddings_mocov3-resnet50"
image_artifact_name = "imagenet-1k-first-20-take-2000_images"
output_phase_name = "surrogate"

run = wandb.init(
    project=local_config["WANDB_PROJECT"],
    config={
        "job_type": "concept_attribution",
        "num_clusters": 20,
    },
    # reinit=True,
    # save_code=True,
)

device = "cuda:3"

[34m[1mwandb[0m: Currently logged in as: [33mnhathcmus[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [2]:
train_configs = {
    "surrogate": dict(
        output_dim=32,
    ),
    "loss": dict(
        name="KDLoss",
        gamma1=1.0,
        gamma2=1.0,
        temp_student=0.2,
        temp_teacher=1,
    ),
    "optimizer": dict(
        lr=1e-3,
        weight_decay=1e-4,
    )
}
run.config.update(train_configs)

In [3]:
%load_ext autoreload
%autoreload 2

from scripts.train_surrogate import train_surrogate_experiment, train_local_representation

In [4]:
assert device is not None, "Please provide a device to run the experiment on."

embedding_artifact_name = "imagenet-1k-first-20-take-2000_target-embeddings_mocov3-resnet50"
image_artifact_name = "imagenet-1k-first-20-take-2000_images"
output_phase_name = "surrogate"

encoding = load_tensor(
    base_name="imagenet",
    phase="local-encoding",
    identifier="agglomerative",
    file_name=f"local-encoding_{run.config.num_clusters}.pt",
    map_location=device,
    wandb_run=run,
)
embeddings = load_tensor(
    "embeddings.pt",
    artifact_name=embedding_artifact_name,
    map_location=device,
    wandb_run=run,
)
images_path = get_artifact(
    image_artifact_name,
    wandb_run=run,
).download()
images_dataset = datasets.load_from_disk(images_path)
labels_dataset = images_dataset.remove_columns(["image"])

if isinstance(embeddings, list):
    embeddings = torch.cat(embeddings, dim=0)
embeddings_dataset = datasets.Dataset.from_dict({"targets": embeddings})
encoding_dataset = datasets.Dataset.from_dict({"inputs": encoding})

xy_dataset = datasets.concatenate_datasets(
    [encoding_dataset, embeddings_dataset, labels_dataset],
    axis=1
).with_format("torch").train_test_split(0.1, shuffle=False, seed=random_state)

logger.info("Encoding shape: %s", encoding.shape)
logger.info("Embeddings shape: %s", embeddings.shape)
logger.info("Image dataset: %s", images_dataset)
logger.info("XY dataset: %s", xy_dataset)

  tensor = torch.load(file_path, map_location=map_location)
[34m[1mwandb[0m: Downloading large artifact imagenet-1k-first-20-take-2000_images:latest, 230.91MB. 3 files... 
[34m[1mwandb[0m:   3 of 3 files downloaded.  
Done. 0:0:0.7
INFO:__main__:Encoding shape: torch.Size([2000, 20])
INFO:__main__:Embeddings shape: torch.Size([2000, 2048])
INFO:__main__:Image dataset: Dataset({
    features: ['image', 'label'],
    num_rows: 2000
})
INFO:__main__:XY dataset: DatasetDict({
    train: Dataset({
        features: ['inputs', 'targets', 'label'],
        num_rows: 1800
    })
    test: Dataset({
        features: ['inputs', 'targets', 'label'],
        num_rows: 200
    })
})


In [5]:
from operator import itemgetter
from typing import Optional

def compute_baseline_loss(
    loss_config: dict,
    val_dataset: datasets.Dataset,
    keys: torch.Tensor,
    batch_size: int,
    device: Optional[str] = None,
):
    assert device is not None, "Please provide a device to run the experiment on."
    temp_teacher = loss_config["temp_teacher"]
    losses = []
    with torch.inference_mode():
        for batch in val_dataset.iter(batch_size=batch_size):
            features, targets, labels = itemgetter("inputs", "targets", "label")(batch)
            
            sim_teacher = targets.to(device) @ keys.T      # shape (B x B)
            prob_student = torch.ones_like(sim_teacher, device=device) / sim_teacher.shape[1]

            loss_batch = torch.nn.functional.kl_div(
                input=torch.log(prob_student), 
                target=torch.softmax(sim_teacher / temp_teacher, dim=-1), 
                reduction="batchmean",
            )
            losses.append(loss_batch)
    # technically mean of means is not the same as mean of all losses
    # but in this case it should be fine
    return torch.stack(losses).mean().item()

compute_baseline_loss(run.config.loss, xy_dataset["test"], embeddings, batch_size=512, device=device)

0.9918730854988098

In [6]:
class Nop:
    def nop(*args, **kw): pass
    def __getattr__(self, _): return self.nop

In [None]:
from math import ceil

from functools import partial
import shap
from tqdm.notebook import tqdm    

def test_fn(X):
    results = []
    for row in tqdm(X):
        indices = np.where(row == 0)[0]
        # print(np.where(row == 1)[0])
        
        masked_encoding = encoding.clone()
        masked_encoding[:, indices] = 0
        perturbed_encoding_dataset = datasets.Dataset.from_dict({"inputs": masked_encoding})

        perturbed_dataset = datasets.concatenate_datasets(
            [perturbed_encoding_dataset, embeddings_dataset, labels_dataset],
            axis=1
        ).with_format("torch").train_test_split(0.1, shuffle=False, seed=random_state)

        model, logs = train_local_representation(
            alpha=0,
            model_config=run.config.surrogate,
            loss_config=run.config.loss,
            optimizer_config=run.config.optimizer,
            train_dataset=perturbed_dataset["train"],
            val_dataset=perturbed_dataset["test"],
            keys=embeddings,
            groups=None,
            eval_downstream=False,
            wandb_run=Nop(),
            num_epochs=40,
            batch_size=512,
            log_every_n_steps=0,
            device=device,
        )
        best_val_loss = min(log["val_loss"] for log in logs["val"])
        # logger.info("Best validation loss: %s", best_val_loss)
        results.append(best_val_loss)
    return np.array(results)

shap_explainer = shap.KernelExplainer(test_fn, np.zeros((1, encoding.shape[1])))
shap_values = shap_explainer.shap_values(np.ones((encoding.shape[1], )), nsamples=200)
shap_values 

  0%|          | 0/1 [00:00<?, ?it/s]

[]


  0%|          | 0/1 [00:00<?, ?it/s]

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]


INFO:shap:num_full_subsets = 1
INFO:shap:remaining_weight_vector = array([0.214078  , 0.15111388, 0.12041888, 0.10275744, 0.09174772,
       0.0846902 , 0.08027925, 0.07784655, 0.07706808])
INFO:shap:num_paired_subset_sizes = 9
INFO:shap:weight_left = np.float64(0.7032951454518925)


  0%|          | 0/200 [00:00<?, ?it/s]

[0]
[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[1]
[ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[2]
[ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[3]
[ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[4]
[ 0  1  2  3  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[5]
[ 0  1  2  3  4  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
[6]
[ 0  1  2  3  4  5  7  8  9 10 11 12 13 14 15 16 17 18 19]
[7]
[ 0  1  2  3  4  5  6  8  9 10 11 12 13 14 15 16 17 18 19]
[8]
[ 0  1  2  3  4  5  6  7  9 10 11 12 13 14 15 16 17 18 19]
[9]
[ 0  1  2  3  4  5  6  7  8 10 11 12 13 14 15 16 17 18 19]
[10]
[ 0  1  2  3  4  5  6  7  8  9 11 12 13 14 15 16 17 18 19]
[11]
[ 0  1  2  3  4  5  6  7  8  9 10 12 13 14 15 16 17 18 19]
[12]
[ 0  1  2  3  4  5  6  7  8  9 10 11 13 14 15 16 17 18 19]
[13]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 14 15 16 17 18 19]
[14]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 15 16 17 18 19]
[15]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14

INFO:shap:np.sum(w_aug) = np.float64(19.999999999999996)
INFO:shap:np.sum(self.kernelWeights) = np.float64(0.9999999999999996)
INFO:shap:phi = array([-0.01262849, -0.00427697, -0.00091514,  0.        , -0.00992   ,
       -0.00571036, -0.00448445, -0.01092365, -0.00076525, -0.000886  ,
       -0.03073997, -0.03894176, -0.01954492, -0.01091697, -0.03279042,
       -0.00590338, -0.02714465, -0.00986548, -0.00555533, -0.00303933])


array([-0.01262849, -0.00427697, -0.00091514,  0.        , -0.00992   ,
       -0.00571036, -0.00448445, -0.01092365, -0.00076525, -0.000886  ,
       -0.03073997, -0.03894176, -0.01954492, -0.01091697, -0.03279042,
       -0.00590338, -0.02714465, -0.00986548, -0.00555533, -0.00303933])

In [20]:
np.round(shap_values / shap_values.sum(), 3)

array([ 0.054,  0.018,  0.004, -0.   ,  0.042,  0.024,  0.019,  0.046,
        0.003,  0.004,  0.131,  0.166,  0.083,  0.046,  0.14 ,  0.025,
        0.116,  0.042,  0.024,  0.013])