## Imports

In [1]:
import logging
import os
from typing import List, Optional

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import tvp  # noqa
from tvp.data.datamodule import MetaData
from tvp.data.datasets.registry import get_dataset
from tvp.task_vectors.task_vectors import TaskVector
from tvp.utils.io_utils import load_model_from_artifact
from tvp.utils.utils import build_callbacks
from torch.nn.utils import vector_to_parameters
from torch.nn.utils import parameters_to_vector

import numpy as np

from collections import Counter

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

  from .autonotebook import tqdm as notebook_tqdm
The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


## Configuration

In [2]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="playground")

hydra.initialize()

In [3]:
cfg = compose(config_name="task_vectors", overrides=[])

In [18]:
seed_index_everything(cfg)

cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

template_core: NNTemplateCore = NNTemplateCore(
    restore_cfg=cfg.train.get("restore", None),
)
logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

Seed set to 1608637542


[34m[1mwandb[0m: Currently logged in as: [33mdansolombrino[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Load models

In [None]:
import copy


zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

## Repair

In [4]:
import torch.nn as nn

class ResetConv(nn.Module):
    def __init__(self, conv):
        super().__init__()
        self.out_channels = conv.out_channels
        self.conv = conv
        self.bn = nn.BatchNorm2d(self.out_channels)
        self.rescale = False

    def set_stats(self, goal_mean, goal_var, eps=1e-5):
        self.bn.bias.data = goal_mean
        goal_std = (goal_var + eps).sqrt()
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.conv(x)
        if self.rescale:
            x = self.bn(x)
        else:
            self.bn(x)
        return x
    
class ResetLinear(nn.Module):
    def __init__(self, layer):
        # print(f"\n\n[ResetLinear] layer: {layer}\n\n")
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm1d(layer.out_features)
        self.weight = layer.weight
        self.bias = layer.bias

    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std

    def forward(self, x):
        # x.shape: [L, N, C] (in torch.nn.BatchNorm1d doc notation)
        x = self.layer(x)

        # match current shape to shape required by BatchNorm1d
        x = x.permute(1, 2, 0)
        # x.shape: [N, C, L] (in torch.nn.BatchNorm1d doc notation)
        
        x = self.bn(x)

        # match current shape to shape required by Linear
        x = x.permute(2, 0, 1)
        # x.shape: [L, N, C] (in torch.nn.BatchNorm1d doc notation)

        return x
    
class ResetLayerNorm(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm1d(layer.normalized_shape[0])

    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.layer(x)
        return self.bn(x)


def replace_layers(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, ResetConv(child))
        # elif isinstance(child, nn.LayerNorm):
        #     setattr(module, name, ResetLayerNorm(child))
        elif isinstance(child, nn.Linear):
            setattr(module, name, ResetLinear(child))
        else:
            replace_layers(child)


def make_tracked_net(model):
    tracked_model = copy.deepcopy(model)
    replace_layers(tracked_model.model.visual)
    return tracked_model.eval()

In [None]:
# def compute_task_statistics(model, endpoint_models):
#     for m_interp, *endpoint_modules in zip(model_to_repair.modules(), *[model.modules() for model in endpoint_models]):
#         if isinstance(m_interp, (ResetConv, ResetLayerNorm, ResetLinear)):
#             mu_endpoints = torch.stack([m.bn.running_mean for m in endpoint_modules])
#             goal_mean = mu_endpoints.mean(dim=0)
#             var_endpoints = torch.stack([m.bn.running_var for m in endpoint_modules])
#             goal_var = var_endpoints.mean(dim=0)
#             m_interp.set_stats(goal_mean, goal_var)
#             m_interp.rescale = True


#### Wrap and compute stats

In [None]:
dataset = get_dataset(
    cfg.nn.data.train_dataset,
    preprocess_fn=zeroshot_model.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

In [5]:
from tqdm import tqdm

def reset_bn_stats(model, epochs, loader):
    """
    Reset batchnorm stats. We use the train loader with data augmentation as this gives better results.
    """
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None  # use simple average
            m.reset_running_stats()

    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad():
            for batch in tqdm(list(loader)):
                if isinstance(batch, Dict):
                    input = batch["x"]
                else:
                    input = batch[0]
                    # print(f"[reset_bn_stats] input.shape before model(): {input.shape}")
                out = model(input.cuda())
                # print(f"[reset_bn_stats] out.shape after model()   : {out.shape}")

In [None]:
# dataset_name = 'CIFAR100'
dataset_name = cfg.nn.data.train_dataset.replace('Val', '')
finetuned_model = finetuned_models[dataset_name]
tracked_finetuned_model = make_tracked_net(finetuned_models[dataset_name])

In [None]:
finetuned_model.model.visual

In [None]:
reset_bn_stats(tracked_finetuned_model.cuda(), 1, dataset.train_loader)

## Task vectors

In [None]:
tracked_finetuned_model

In [None]:
flatten = lambda model: parameters_to_vector(model.parameters()) 

zeroshot_vec = flatten(zeroshot_model)

In [None]:
task_vectors = [TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply]

In [6]:
def apply_task_vector(model, task_vector):
    model.load_state_dict({k: v + task_vector[k] for k, v in model.state_dict().items()})

### Aggregate task vectors

In [None]:
with torch.no_grad():
    task_vectors = torch.stack([flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply])

### Standard task vectors

In [None]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

In [None]:
alpha = 0.8

multi_task_vector = task_vectors_sum / len(task_vectors)

In [None]:
delta_model = copy.deepcopy(zeroshot_model) 
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [None]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

In [None]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

## Load dataset

In [None]:
seed_index_everything(cfg)

dataset = get_dataset(
    cfg.nn.data.train_dataset,
    preprocess_fn=model.encoder.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    logger=False,
    callbacks=callbacks,
    **cfg.train.trainer,
    
)

## Evaluation

In [None]:
# pylogger.info("Evaluating on the training set")
# trainer.test(model=model, dataloaders=dataset.train_loader)

pylogger.info("Evaluating on the test set!")
trainer.test(model=model, dataloaders=dataset.test_loader)

Keep this accuracy in mind, as it will be used as baseline to compare all upcoming experiments against

# Experiment

## Dataset stats

Compute and export mean and std for every dataset, if not already present on disk.<br>
Otherwhise, load mean and std for each dataset

### Define utility methods

In [7]:
# Define a function to get the full path of a module
def get_module_path(module, parent_name=""):
    if parent_name:
        return f"{parent_name}.{module._get_name()}"
    return module._get_name()


# Define a recursive function to collect batch norm stats from 'Reset' layers
def collect_bn_stats(model, parent_name=""):
    bn_stats = {}
    for name, module in model.named_children():
        # Construct the full path of the current module
        module_path = f"{parent_name}.{name}" if parent_name else name
        
        # Check for batch normalization layers within Reset layers
        if isinstance(module, ResetConv) or isinstance(module, ResetLinear):
            for sub_name, sub_module in module.named_children():
                if isinstance(sub_module, nn.BatchNorm2d) or isinstance(sub_module, nn.BatchNorm1d):
                    bn_stats[f"{module_path}.{sub_name}"] = {
                        "running_mean": sub_module.running_mean.clone().detach(),
                        "running_var": sub_module.running_var.clone().detach(),
                        "num_batches_tracked": sub_module.num_batches_tracked.clone().detach()
                    }
        # Recursively check submodules
        bn_stats.update(collect_bn_stats(module, module_path))
    
    return bn_stats

### Define objects

In [8]:
dataset_stats = {} # S in board pictures

# if "./dataset_stats.pth" is on disk, do not compute stats!
populate_stats = False

### Load models

In [9]:
if populate_stats: 

    # TODO once done, try without it, as all methods use copy.deepcopy, so hygene should be guaranteed

    import copy

    zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

    zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

    finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

    finetuned_models = {
        dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
        for dataset in cfg.task_vectors.to_apply
    }

    zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

### Stats computation

In [10]:
if populate_stats:

    for dataset_name in cfg.task_vectors.to_apply:

        print(f"[dataset_name]: {dataset_name}\n")
        
        # begin load the dataset

        dataset = get_dataset(
            dataset_name=dataset_name,
            preprocess_fn=zeroshot_model.train_preprocess,
            location=cfg.nn.data.data_path,
            batch_size=cfg.nn.data.batch_size.train,
        )

        #   end load the dataset

        # begin compute and apply dataset task vector

        flatten = lambda model: parameters_to_vector(model.parameters()) 

        zeroshot_vec = flatten(zeroshot_model)

        task_vector = TaskVector.from_models(zeroshot_model, finetuned_models[dataset_name])

        task_vector = flatten(finetuned_models[dataset_name]) - zeroshot_vec

        # no need to do task_vector_sum-related stuff, as for this step we apply just one task vector

        delta_model = copy.deepcopy(zeroshot_model) 
        vector_to_parameters(task_vector, delta_model.parameters())

        task_equipped_model = copy.deepcopy(zeroshot_model)
        apply_task_vector(task_equipped_model, delta_model.state_dict())

        #   end compute and apply dataset task vector

        # begin compute layer-wise statistics

        tracked_task_equipped_model = make_tracked_net(task_equipped_model)

        reset_bn_stats(tracked_task_equipped_model.cuda(), 1, dataset.train_loader)

        #   end compute layer-wise statistics

        # begin store layer-wise statistics

        dataset_stats[dataset_name] = collect_bn_stats(tracked_task_equipped_model.model.visual)

        #   end store layer-wise statistics  

    torch.save(dataset_stats, './dataset_stats.pth')

else:

    print(f"Loading dataset stats from './dataset_stats.pth'...")
    
    dataset_stats = torch.load('./dataset_stats.pth')
    



Loading dataset stats from './dataset_stats.pth'...


## Dataset anchors

### Define objects

In [11]:
n_anchor_batches = 1

anchors = {}

populate_anchors = False

### Load models

In [12]:
if populate_anchors: 

    # TODO once done, try without it, as all methods use copy.deepcopy, so hygene should be guaranteed

    import copy

    zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

    zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

    finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

    finetuned_models = {
        dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
        for dataset in cfg.task_vectors.to_apply
    }

    zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

### Anchors computations

In [13]:
if populate_anchors:

    for dataset_name in cfg.task_vectors.to_apply:

        print(f"[dataset_name]: {dataset_name}\n")

        dataset = get_dataset(
            dataset_name=dataset_name,
            preprocess_fn=zeroshot_model.train_preprocess,
            location=cfg.nn.data.data_path,
            batch_size=cfg.nn.data.batch_size.train,
        )

        model: nn.Module = copy.deepcopy(finetuned_models[dataset_name])
        model.eval()

        with torch.no_grad():

            for id, train_batch in enumerate (dataset.train_loader):
                    
                    if id == n_anchor_batches:                    
                        break
            
                    if isinstance(train_batch, Dict):
                        input: torch.Tensor = train_batch["x"]
                    else:
                        input: torch.Tensor = train_batch[0]

                    anchors[dataset_name] = model.cuda()(input.cuda())

    torch.save(anchors, './anchors.pth')
    
else:
     
    print(f"Loading anchors from './anchors.pth'...")
    
    anchors = torch.load('./anchors.pth')

Loading anchors from './anchors.pth'...


In [14]:
anchors_tensor = torch.stack([anchors[dataset_name] for dataset_name in anchors.keys()])

In [15]:
anchors_tensor.shape

torch.Size([8, 128, 512])

In [16]:
list(enumerate((anchors.keys())))

[(0, 'Cars'),
 (1, 'CIFAR100'),
 (2, 'DTD'),
 (3, 'EuroSAT'),
 (4, 'GTSRB'),
 (5, 'MNIST'),
 (6, 'RESISC45'),
 (7, 'SVHN')]

## Evaluation (naive)

### Load models

In [19]:
# TODO once done, try without it, as all methods use copy.deepcopy, so hygene should be guaranteed

import copy

zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

[34m[1mwandb[0m: Downloading large artifact ViT-B-16_pt:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_Cars_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.1


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_CIFAR100_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.1


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_DTD_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_EuroSAT_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_GTSRB_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_MNIST_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.1


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_RESISC45_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.1


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_SVHN_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


Loading ViT-B-16 pre-trained weights.


### Load test dataset

In [21]:
dataset_name = "GTSRB"

In [22]:
# seed_index_everything(cfg)

dataset = get_dataset(
    dataset_name=dataset_name,
    preprocess_fn=zeroshot_model.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    logger=False,
    callbacks=callbacks,
    **cfg.train.trainer,
    
)

INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


### Apply all task vectors

In [23]:
flatten = lambda model: parameters_to_vector(model.parameters()) 

zeroshot_vec = flatten(zeroshot_model)

In [24]:
task_vectors = [TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply]

In [25]:
with torch.no_grad():
    task_vectors = torch.stack([flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply])

In [26]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

In [27]:
alpha = 0.8

multi_task_vector = task_vectors_sum / len(task_vectors)

In [28]:
delta_model = copy.deepcopy(zeroshot_model) 
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [29]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

In [30]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
/mnt/KS_2TB/PARA/Resources/miniconda3/envs/tvp/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/mnt/KS_2TB/PARA/Resources/miniconda3/envs/tvp/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'classifier' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classifier'])`.


### Evaluate (naive)

In [31]:
pylogger.info("Evaluating on the test set!")
trainer.test(model=model, dataloaders=dataset.test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0:   0%|          | 0/99 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


Testing DataLoader 0: 100%|██████████| 99/99 [00:16<00:00,  5.85it/s]


[{'acc/test': 0.6593032479286194, 'loss/test': 1.093709945678711}]

## Evaluation (our method)

### Load models

In [32]:
# TODO once done, try without it, as all methods use copy.deepcopy, so hygene should be guaranteed

import copy

zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

[34m[1mwandb[0m: Downloading large artifact ViT-B-16_pt:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_Cars_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_CIFAR100_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_DTD_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_EuroSAT_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_GTSRB_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_MNIST_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_RESISC45_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_SVHN_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


###  Load test dataset

In [33]:
# seed_index_everything(cfg)

dataset = get_dataset(
    dataset_name=dataset_name,
    preprocess_fn=zeroshot_model.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    logger=False,
    callbacks=callbacks,
    **cfg.train.trainer,
    
)

INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


### Apply all task vectors

In [34]:
flatten = lambda model: parameters_to_vector(model.parameters()) 

zeroshot_vec = flatten(zeroshot_model)

In [35]:
task_vectors = [TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply]

In [36]:
with torch.no_grad():
    task_vectors = torch.stack([flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply])

In [37]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

In [38]:
alpha = 0.8

multi_task_vector = task_vectors_sum / len(task_vectors)

In [39]:
delta_model = copy.deepcopy(zeroshot_model) 
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [40]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

### Prepare model for REPAIR

Keep this here, as the classification head may have layers that may be detected as REPAIRable, even tho we don't want to repair them

In [41]:
tracked_task_equipped_model = make_tracked_net(task_equipped_model)

### Evaluate (our method)

In [42]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=tracked_task_equipped_model, classifier=classification_head, _recursive_=False
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
/mnt/KS_2TB/PARA/Resources/miniconda3/envs/tvp/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['encoder'])`.
/mnt/KS_2TB/PARA/Resources/miniconda3/envs/tvp/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'classifier' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classifier'])`.


In [43]:
def majority_voting(similarity: torch.Tensor):
    # similarity.shape: [B, M]
    
    # Get the index of the highest similarity for each element in the batch
    max_similarity_indices = torch.argmax(similarity, dim=1)  # shape: [B]

    # Perform majority voting
    majority_vote = Counter(max_similarity_indices.cpu().numpy()).most_common(1)[0][0]
    
    return majority_vote

In [44]:
def centroid_based_similarity(batch: torch.Tensor, anchors_tensor: torch.Tensor):
    # batch.shape: [B, D]
    # anchors_tensor.shape: [M, K, D]

    M, K, D = anchors_tensor.shape
    B = batch.shape[0]

    # Compute centroids of each anchor set
    centroids = anchors_tensor.mean(dim=1)  # shape: [M, D]
    
    # Compute cosine similarity
    # Normalize batch and centroids
    batch_norm = F.normalize(batch, p=2, dim=1)  # shape: [B, D]
    centroids_norm = F.normalize(centroids, p=2, dim=1)  # shape: [M, D]

    # Compute similarity
    similarity = torch.mm(batch_norm, centroids_norm.t())  # shape: [B, M]

    return similarity

In [45]:
batch_wise_majority_votes = []

from tqdm import tqdm

# Iterate through the test data loader and compute point-to-set distances
for id, test_batch in tqdm(list(enumerate(dataset.test_loader))):
    
    with torch.no_grad():
        
        if isinstance(test_batch, dict):
            input: torch.Tensor = test_batch["x"]
        else:
            input: torch.Tensor = test_batch[0]

        emb = finetuned_models[dataset_name].cuda()(input.cuda())  # [B, D]
        # print(f"[emb]: {emb.shape}")

        similarities = centroid_based_similarity(emb, anchors_tensor)
        # print(f"[similarities]: {similarities.shape}")

        batch_wise_majority_votes.append(majority_voting(similarities))
        # print(f"[majority_vote]: {batch_wise_majority_votes[-1]}")

batch_wise_majority_votes = np.asarray(batch_wise_majority_votes)
most_similar_dataset_id = Counter(batch_wise_majority_votes).most_common(1)[0][0]        

100%|██████████| 99/99 [00:14<00:00,  6.95it/s]


In [46]:
most_similar_dataset_id

4

In [47]:
most_similar_dataset = list(anchors.keys())[most_similar_dataset_id]
most_similar_dataset

'GTSRB'

In [48]:
most_similar_dataset = get_dataset(
    dataset_name=most_similar_dataset,
    preprocess_fn=model.encoder.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

In [56]:
model.encoder.model.visual.conv1.bn.running_mean[:10] # should be all zeroes

tensor([-0.0329, -0.2029, -0.0214, -0.0393, -0.0250, -0.0440, -0.0327, -0.0504,
        -0.0188, -0.0048])

In [50]:
reset_bn_stats(model.cuda(), 1, most_similar_dataset.train_loader)

100%|██████████| 209/209 [00:47<00:00,  4.38it/s]


In [55]:
model.encoder.model.visual.conv1.bn.running_mean[:10] # should NOT be all zeroes

tensor([-0.0329, -0.2029, -0.0214, -0.0393, -0.0250, -0.0440, -0.0327, -0.0504,
        -0.0188, -0.0048])

In [53]:
pylogger.info("Evaluating on the test set (stats set via reset_bn_stats()!)")
trainer.test(model=model, dataloaders=dataset.test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 99/99 [00:24<00:00,  4.10it/s]


[{'acc/test': 0.019952494651079178, 'loss/test': 4.691280841827393}]

In [57]:
model.encoder.model.visual.conv1.bn.running_mean = dataset_stats[dataset_name]["conv1.bn"]["running_mean"]
model.encoder.model.visual.conv1.bn.running_var = dataset_stats[dataset_name]["conv1.bn"]["running_var"]

for i in range(12):
    model.encoder.model.visual.transformer.resblocks[i].attn.out_proj.bn.running_mean = dataset_stats[dataset_name][f"transformer.resblocks.{i}.attn.out_proj.bn"]["running_mean"]
    model.encoder.model.visual.transformer.resblocks[i].mlp.c_fc.bn.running_mean = dataset_stats[dataset_name][f"transformer.resblocks.{i}.mlp.c_fc.bn"]["running_mean"]
    model.encoder.model.visual.transformer.resblocks[i].mlp.c_proj.bn.running_mean = dataset_stats[dataset_name][f"transformer.resblocks.{i}.mlp.c_proj.bn"]["running_mean"]
    
    model.encoder.model.visual.transformer.resblocks[i].attn.out_proj.bn.running_var = dataset_stats[dataset_name][f"transformer.resblocks.{i}.attn.out_proj.bn"]["running_var"]
    model.encoder.model.visual.transformer.resblocks[i].mlp.c_fc.bn.running_var = dataset_stats[dataset_name][f"transformer.resblocks.{i}.mlp.c_fc.bn"]["running_var"]
    model.encoder.model.visual.transformer.resblocks[i].mlp.c_proj.bn.running_var = dataset_stats[dataset_name][f"transformer.resblocks.{i}.mlp.c_proj.bn"]["running_var"]


In [58]:
pylogger.info("Evaluating on the test set! (stats set via dataset_stats)")
trainer.test(model=model, dataloaders=dataset.test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 99/99 [00:24<00:00,  4.12it/s]


[{'acc/test': 0.02351543866097927, 'loss/test': 4.6658124923706055}]