## Imports


In [None]:
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

import open_clip
import wandb

import hydra
import omegaconf
import pytorch_lightning as pl
import torch
from hydra import compose, initialize
from hydra.utils import instantiate
from lightning.pytorch import Callback
from omegaconf import DictConfig, ListConfig, OmegaConf
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO

# Force the execution of __init__.py if this file is executed directly.
import tvp  # noqa
from tvp.data.datasets.registry import get_dataset
from tvp.modules.encoder import ClassificationHead, ImageEncoder
from tvp.modules.projection_router import ProjectionRouter
from tvp.modules.nn_router import NNRouter
from tvp.modules.heads import get_classification_head
from tvp.modules.router import AbstractRouter
from tvp.utils.io_utils import load_model_from_disk
from tvp.utils.plots import plot_interactive_radar_chart
from tvp.utils.utils import (
    compute_task_dict, 
    apply_dict_to_model,
    build_callbacks,
    get_finetuning_accuracies,
    add_normalized_accuracy,
    compute_avg_accuracy,
    print_memory,
    get_routing_weights,
    svd_key_from_layer
)
from tvp.task_vectors.task_singular_vectors import *
import json
import os

pylogger = logging.getLogger(__name__)

torch.set_float32_matmul_precision("high")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def boilerplate(cfg):
    cfg.core.tags = enforce_tags(cfg.core.get("tags", None))

    num_tasks = len(cfg.eval_datasets)
    cfg.core.tags.append(f"n{num_tasks}")
    cfg.core.tags.append(f'{cfg.nn.module.encoder.model_name}')
    cfg.core.tags.append(f'mnist_notebook')

    template_core = NNTemplateCore(
        restore_cfg=cfg.train.get("restore", None),
    )
    logger: NNLogger = NNLogger(
        logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id
    )

    logger.upload_source()

    return logger, template_core


def get_merged_base(
    cfg,
    merging_method,
    zeroshot_encoder: ImageEncoder,
    svd_dicts: Dict[str, Any],
):

    coefficient = 1

    if merging_method == "isotropic":

        multi_task_vector = isotropic_sum(
            ref_state_dict=copy.deepcopy(zeroshot_encoder.state_dict()),
            svd_dict=svd_dicts,
        )

        model_name = cfg.nn.module.encoder.model_name

        if (
            model_name in cfg.optimal_alphas
            and len(cfg.eval_datasets) in cfg.optimal_alphas[model_name]
        ):
            coefficient = cfg.optimal_alphas[model_name][len(cfg.eval_datasets)]

    elif merging_method == "tsvm":

        multi_task_vector = sum_svd(
            ref_state_dict=copy.deepcopy(zeroshot_encoder.state_dict()),
            svd_dicts=svd_dicts,
        )
    elif merging_method == "zeroshot":
        return zeroshot_encoder
    else:
        raise NotImplementedError

    merged_encoder: ImageEncoder = copy.deepcopy(zeroshot_encoder)

    merged_encoder = apply_dict_to_model(
        multi_task_vector,
        merged_encoder,
        coefficient=coefficient,
    )

    return merged_encoder


def get_classification_heads(cfg: DictConfig):
    classification_heads = []

    for dataset_name in cfg.eval_datasets:

        classification_head = get_classification_head(
            cfg.nn.module.encoder.model_name,
            dataset_name,
            cfg.nn.data.data_path,
            cfg.misc.ckpt_path,
            cache_dir=cfg.misc.cache_dir,
            openclip_cachedir=cfg.misc.openclip_cachedir,
        )

        classification_heads.append(classification_head)

    return classification_heads



In [6]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="debug_mnist")
cfg = compose(config_name="task_vectors", overrides=["nn/benchmark=mnists"])

'hydra/launcher/submitit_slurm' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  coro.send(None)


In [7]:

seed_index_everything(cfg)

logger, template_core = boilerplate(cfg)

Global seed set to 1608637542


[34m[1mwandb[0m: Currently logged in as: [33mzirilli-1967394[0m ([33mgladia[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
# upperbound accuracies, used for logging the normalized accuracy
finetuned_accuracies = get_finetuning_accuracies(cfg.misc.finetuned_accuracy_path)

In [9]:
# only has vision encoder, no text transformer
zeroshot_encoder_statedict = load_model_from_disk(cfg.misc.pretrained_checkpoint)

zeroshot_encoder: ImageEncoder = instantiate(
    cfg.nn.module.encoder
)  # the second pass backbone

zeroshot_encoder.load_state_dict(zeroshot_encoder_statedict, strict=False)

<All keys matched successfully>

In [10]:
finetuned_name = (
    lambda name: Path(cfg.misc.ckpt_path) / f"{name}Val" / "nonlinear_finetuned.pt"
)
finetuned_models = {
    dataset: load_model_from_disk(finetuned_name(dataset))
    for dataset in cfg.task_vectors.to_apply
}

num_tasks = len(cfg.eval_datasets)

pylogger.info(f"Number of tasks: {len(cfg.eval_datasets)}")
pylogger.info(f"Finetuned models: {list(finetuned_models.keys())}")

### Visualize some images from MNIST and EMNIST

In [11]:
mnist = get_dataset(
    'MNIST',
    preprocess_fn=zeroshot_encoder.val_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

In [10]:
emnist = get_dataset(
    'EMNIST',
    preprocess_fn=zeroshot_encoder.val_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

In [13]:
kmnist = get_dataset(
    'KMNIST',
    preprocess_fn=zeroshot_encoder.val_preprocess,
    location=cfg.nn.data.data_path,
    batch_size=cfg.nn.data.batch_size.train,
)

In [12]:
## compute statistics

mnist_val = mnist.test_loader.dataset
emnist_val = emnist.test_loader.dataset

In [14]:
import matplotlib.pyplot as plt
import torch
import numpy as np

def compute_dataset_statistics(dataset, batch_size=256):
    """
    Computes the per-channel mean and standard deviation for a given dataset.
    Note:
      - If the dataset is already normalized/augmented by transforms,
        the computed stats will reflect the *transformed* version.
      - This function assumes images have shape [C, H, W].
    """
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    n_samples = 0
    total_mean = 0.0
    total_std = 0.0
    
    for images, _ in loader:
        # Flatten spatial dimensions: (B, C, H, W) -> (B, C*H*W)
        # so that mean(1) is per-sample, then we average across batch
        # If you want per-channel stats, you can adjust accordingly.
        images = images.view(images.size(0), -1)
        batch_mean = images.mean(dim=1)  # shape [B]
        batch_std = images.std(dim=1)    # shape [B]
        
        total_mean += batch_mean.sum().item()
        total_std += batch_std.sum().item()
        n_samples += images.size(0)

    dataset_mean = total_mean / n_samples
    dataset_std = total_std / n_samples
    return dataset_mean, dataset_std


mnist_mean, mnist_std = compute_dataset_statistics(mnist_val)
emnist_mean, emnist_std = compute_dataset_statistics(emnist_val)

print(f"MNIST (val) mean: {mnist_mean:.4f}, std: {mnist_std:.4f}")
print(f"EMNIST (val) mean: {emnist_mean:.4f}, std: {emnist_std:.4f}")

MNIST (val) mean: -1.1779, std: 1.1046
EMNIST (val) mean: -1.0263, std: 1.2029


In [15]:

def visualize_random_samples(dataset, n_samples=8):
    """
    Shows a grid of n_samples random images (and their labels) from a dataset.
    Assumes each dataset item is (image, label).
    """
    indices = np.random.choice(len(dataset), size=n_samples, replace=False)
    
    fig, axes = plt.subplots(1, n_samples, figsize=(2 * n_samples, 2))
    for i, idx in enumerate(indices):
        image, label = dataset[idx]  # image shape [C, H, W]

        # Convert tensor -> NumPy and move channels last: [C, H, W] -> [H, W, C]
        image_np = image.permute(1, 2, 0).cpu().numpy()
        
        # For grayscale images (single channel), you can just squeeze:
        # image_np = np.squeeze(image_np, axis=-1)  # if shape is [H, W, 1]
        
        axes[i].imshow(image_np, cmap="gray")
        axes[i].set_title(f"Label: {label}")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()


print("\nVisualizing random MNIST validation samples:")
visualize_random_samples(mnist_val, n_samples=8)

print("Visualizing random EMNIST validation samples:")
visualize_random_samples(emnist_val, n_samples=8)

### Test model finetuned on {MNIST,EMNIST} on {EMNIST, MNIST} dataset

In [9]:
def evaluate(model, dataset_name, preprocess_fn):

    dataset = get_dataset(
        dataset_name,
        preprocess_fn=preprocess_fn,
        location=cfg.nn.data.data_path,
        batch_size=cfg.nn.data.batch_size.train,
    )

    trainer = pl.Trainer(
        **cfg.train.trainer,
    )

    pylogger.error("For now evaluation supported only on val-set")

    pylogger.info(f"Evaluating on the {dataset_name} test set!")
    test_results = trainer.test(model=model, dataloaders=dataset.test_loader)

    return test_results[0]["acc/test"]


In [10]:
model_dataset_combinations = { 'MNIST': finetuned_models['MNIST'], 'EMNIST': finetuned_models['MNIST'], 'KMNIST': finetuned_models['KMNIST'] }

In [11]:
from tvp.pl_module.image_classifier import ImageClassifier

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


In [14]:

for dataset_name, model_dict in model_dataset_combinations.items():

    encoder = instantiate(
        cfg.nn.module.encoder
    )

    model = ImageClassifier(
        encoder=encoder,
        x_key='x',
        y_key='y',
        classifier=get_classification_head(
            cfg.nn.module.encoder.model_name,
            'KMNIST',
            cfg.nn.data.data_path,
            cfg.misc.ckpt_path,
            cache_dir=cfg.misc.cache_dir,
            openclip_cachedir=cfg.misc.openclip_cachedir,
        ),
    )

    encoder.load_state_dict(model_dict, strict=False)
    
    evaluate(model, dataset_name, zeroshot_encoder.val_preprocess)
    

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")


Loading classification head from ../checkpoints/ViT-B-32/head_KMNIST.pt


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 313/313 [00:09<00:00, 34.31it/s]


Loading classification head from ../checkpoints/ViT-B-32/head_KMNIST.pt


  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 1250/1250 [00:59<00:00, 20.97it/s]


Loading classification head from ../checkpoints/ViT-B-32/head_KMNIST.pt


INFO: GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 313/313 [00:09<00:00, 33.43it/s]


## Merging

In [12]:
task_dicts = {}
for dataset in cfg.task_vectors.to_apply:
    task_dicts[dataset] = compute_task_dict(
        zeroshot_encoder_statedict, finetuned_models[dataset]
    )

In [13]:
task_dicts.keys()

dict_keys(['KMNIST', 'MNIST', 'EMNIST'])

In [14]:
from tvp.task_vectors.aggregator import TaskSingularVectorAggregator, IsotropicAggregator
from tvp.task_vectors.task_singular_vectors import get_svd_dict

aggregator = TaskSingularVectorAggregator(
    zeroshot_model=zeroshot_encoder.cuda(), 
)

In [16]:
first_round = {"MNIST": task_dicts["MNIST"], "EMNIST": task_dicts["EMNIST"]}

In [19]:
svds_first = get_svd_dict(first_round, list(first_round.keys()), svd_path='./')
merged = aggregator.aggregate(svds_first, coefficients=None)

Computing and compressing SVD:   0%|          | 0/2 [00:00<?, ?it/s]

Computing and compressing SVD: 100%|██████████| 2/2 [00:05<00:00,  2.75s/it]


Summing SVD: 100%|██████████| 158/158 [00:05<00:00, 29.48it/s]


In [20]:
mnist_emnist = compute_task_dict(
        zeroshot_encoder_statedict, merged.state_dict()
)

In [21]:
second_round = {"KMNIST": task_dicts["KMNIST"]}

In [22]:
second_round["merged"] = mnist_emnist

In [23]:
svds = get_svd_dict(second_round, list(second_round.keys()), svd_path='./')
final = aggregator.aggregate(svds, coefficients=None)

Computing and compressing SVD: 100%|██████████| 2/2 [00:05<00:00,  2.62s/it]


Summing SVD: 100%|██████████| 158/158 [00:05<00:00, 31.06it/s]


In [24]:
model = ImageClassifier(
        encoder=final,
        x_key='x',
        y_key='y',
        classifier=get_classification_head(
            cfg.nn.module.encoder.model_name,
            'KMNIST',
            cfg.nn.data.data_path,
            cfg.misc.ckpt_path,
            cache_dir=cfg.misc.cache_dir,
            openclip_cachedir=cfg.misc.openclip_cachedir,
        ),
    )



evaluate(model, 'KMNIST', zeroshot_encoder.val_preprocess)

  rank_zero_warn(
  rank_zero_warn(
INFO: GPU available: True (cuda), used: True


Loading classification head from ../checkpoints/ViT-B-32/head_KMNIST.pt


INFO: TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 313/313 [00:09<00:00, 32.93it/s]


0.8604000210762024