## Imports

In [35]:
import logging
import os
from typing import List, Optional

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import tvp  # noqa
from tvp.data.datamodule import MetaData
from tvp.data.datasets.registry import get_dataset
from tvp.task_vectors.task_vectors import TaskVector
from tvp.utils.io_utils import load_model_from_artifact
from tvp.utils.utils import build_callbacks
from torch.nn.utils import vector_to_parameters
from torch.nn.utils import parameters_to_vector

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

## Configuration

In [36]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="playground")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


hydra.initialize()

In [37]:
cfg = compose(config_name="task_vectors", overrides=[])

In [38]:
seed_index_everything(cfg)

cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

template_core: NNTemplateCore = NNTemplateCore(
    restore_cfg=cfg.train.get("restore", None),
)
logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

Global seed set to 1608637542


  rank_zero_warn(


## Load models

In [39]:
import copy


zeroshot_identifier = f"{cfg.nn.module.model.model_name}_pt"

zeroshot_model = load_model_from_artifact(artifact_path=f"{zeroshot_identifier}:latest", run=logger.experiment)

finetuned_id_fn = lambda dataset: f"{cfg.nn.module.model.model_name}_{dataset}_{cfg.seed_index}:latest"

finetuned_models = {
    dataset: load_model_from_artifact(artifact_path=finetuned_id_fn(dataset), run=logger.experiment)
    for dataset in cfg.task_vectors.to_apply
}

zeroshot_orig_weights = copy.deepcopy(zeroshot_model.state_dict())

[34m[1mwandb[0m: Downloading large artifact ViT-B-16_pt:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_Cars_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_CIFAR100_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_DTD_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_EuroSAT_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_GTSRB_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_MNIST_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.9


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_RESISC45_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


[34m[1mwandb[0m: Downloading large artifact ViT-B-16_SVHN_0:latest, 426.51MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Loading ViT-B-16 pre-trained weights.


## Repair

In [40]:
import torch.nn as nn

class ResetConv(nn.Module):
    def __init__(self, conv):
        super().__init__()
        self.out_channels = conv.out_channels
        self.conv = conv
        self.bn = nn.BatchNorm2d(self.out_channels)
        self.rescale = False

    def set_stats(self, goal_mean, goal_var, eps=1e-5):
        self.bn.bias.data = goal_mean
        goal_std = (goal_var + eps).sqrt()
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.conv(x)
        if self.rescale:
            x = self.bn(x)
        else:
            self.bn(x)
        return x
    
class ResetLinear(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm1d(layer.out_features)
        self.weight = layer.weight
        self.bias = layer.bias

    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.layer(x)
        return self.bn(x)
    
class ResetLayerNorm(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm1d(layer.normalized_shape[0])

    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std

    def forward(self, x):
        x = self.layer(x)
        return self.bn(x)


def replace_layers(module):
    for name, child in module.named_children():
        if isinstance(child, nn.Conv2d):
            setattr(module, name, ResetConv(child))
        # elif isinstance(child, nn.LayerNorm):
        #     setattr(module, name, ResetLayerNorm(child))
        elif isinstance(child, nn.Linear):
            setattr(module, name, ResetLinear(child))
        else:
            replace_layers(child)


def make_tracked_net(model):
    tracked_model = copy.deepcopy(model)
    replace_layers(tracked_model.model.visual)
    return tracked_model.eval()

In [41]:
# def compute_task_statistics(model, endpoint_models):
#     for m_interp, *endpoint_modules in zip(model_to_repair.modules(), *[model.modules() for model in endpoint_models]):
#         if isinstance(m_interp, (ResetConv, ResetLayerNorm, ResetLinear)):
#             mu_endpoints = torch.stack([m.bn.running_mean for m in endpoint_modules])
#             goal_mean = mu_endpoints.mean(dim=0)
#             var_endpoints = torch.stack([m.bn.running_var for m in endpoint_modules])
#             goal_var = var_endpoints.mean(dim=0)
#             m_interp.set_stats(goal_mean, goal_var)
#             m_interp.rescale = True


#### Wrap and compute stats

In [42]:
dataset = get_dataset(
    cfg.nn.data.train_dataset,
    preprocess_fn=zeroshot_model.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

Files already downloaded and verified
Files already downloaded and verified


In [43]:
def reset_bn_stats(model, epochs, loader):
    """
    Reset batchnorm stats. We use the train loader with data augmentation as this gives better results.
    """
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None  # use simple average
            m.reset_running_stats()

    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad():
            for batch in loader:
                if isinstance(batch, Dict):
                    input = batch["x"]
                else:
                    input = batch[0]
                _ = model(input.cuda())

In [44]:
dataset_name = 'CIFAR100'
finetuned_model = finetuned_models[dataset_name]
tracked_finetuned_model = make_tracked_net(finetuned_models[dataset_name])

In [45]:
finetuned_model.model.visual

VisualTransformer(
  (conv1): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
  (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (transformer): Transformer(
    (resblocks): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ln_attn): Identity()
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (ln): Identity()
          (gelu): QuickGELU()
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
  )
  (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [46]:
reset_bn_stats(tracked_finetuned_model.cuda(), 1, dataset.train_loader)

torch.Size([128, 197, 768])
torch.Size([197, 128, 768])
torch.Size([197, 128, 768])


RuntimeError: running_mean should contain 128 elements not 3072

## Task vectors

In [None]:
tracked_finetuned_model

ImageEncoder(
  (model): CLIP(
    (visual): VisualTransformer(
      (conv1): ResetConv(
        (conv): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
        (bn): BatchNorm2d(768, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
      )
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_attn): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (ln): Identity()
              (gelu): QuickGELU()
              (c_proj): Line

In [None]:
flatten = lambda model: parameters_to_vector(model.parameters()) 

zeroshot_vec = flatten(zeroshot_model)

In [None]:
task_vectors = [TaskVector.from_models(zeroshot_model, finetuned_models[dataset]) for dataset in cfg.task_vectors.to_apply]

In [None]:

def apply_task_vector(model, task_vector):
    model.load_state_dict({k: v + task_vector[k] for k, v in model.state_dict().items()})

### Aggregate task vectors

In [None]:
with torch.no_grad():
    task_vectors = torch.stack([flatten(finetuned_models[dataset]) - zeroshot_vec for dataset in cfg.task_vectors.to_apply])

### Standard task vectors

In [None]:
task_vectors_sum = torch.sum(task_vectors, dim=0)

In [None]:
alpha = 0.8

multi_task_vector = task_vectors_sum / len(task_vectors)

In [None]:
delta_model = copy.deepcopy(zeroshot_model) 
vector_to_parameters(multi_task_vector, delta_model.parameters())

In [None]:
task_equipped_model = copy.deepcopy(zeroshot_model)
apply_task_vector(task_equipped_model, delta_model.state_dict())

In [None]:
classification_head_identifier = f"{cfg.nn.module.model.model_name}_{cfg.nn.data.dataset.dataset_name}_head"
classification_head = load_model_from_artifact(
    artifact_path=f"{classification_head_identifier}:latest", run=logger.experiment
)

model = hydra.utils.instantiate(
    cfg.nn.module, encoder=task_equipped_model, classifier=classification_head, _recursive_=False
)

[34m[1mwandb[0m:   1 of 1 files downloaded.  
  rank_zero_warn(
  rank_zero_warn(


## Load dataset

In [None]:
seed_index_everything(cfg)

dataset = get_dataset(
    cfg.nn.data.train_dataset,
    preprocess_fn=model.encoder.train_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

storage_dir: str = cfg.core.storage_dir

pylogger.info("Instantiating the <Trainer>")
trainer = pl.Trainer(
    default_root_dir=storage_dir,
    plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
    logger=False,
    callbacks=callbacks,
    **cfg.train.trainer,
    
)

Global seed set to 1608637542


Files already downloaded and verified
Files already downloaded and verified


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


## Evaluation

In [None]:
# pylogger.info("Evaluating on the training set")
# trainer.test(model=model, dataloaders=dataset.train_loader)

pylogger.info("Evaluating on the test set!")
trainer.test(model=model, dataloaders=dataset.test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 40/40 [00:10<00:00,  3.76it/s]


[{'acc/test': 0.7657999992370605, 'loss/test': 1.0084812641143799}]