## Imports

In [14]:
import hydra
from omegaconf import DictConfig, OmegaConf
import logging
import wandb
import os

# Import boilerplate dependencies from your training framework
from nn_core.common.utils import enforce_tags, seed_index_everything
from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.model_logging import NNLogger

# Import the functions from your modified files.
from mass.clip_interpret.compute_complete_text_set import run_completeness
from mass.clip_interpret.compute_text_set_projection import run_text_features

import time
import numpy as np
import torch
import os
import einops
import tqdm
import open_clip
import wandb
import logging
from pathlib import Path

from mass.utils.io_utils import load_model_from_disk
from mass.utils.utils import compute_task_dict
from mass.task_vectors.task_singular_vectors import get_svd_dict

pylogger = logging.getLogger(__name__)

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="clip_interpret")
cfg = compose(config_name="clip_interpret")

In [17]:
def boilerplate(cfg: DictConfig):
    """
    Full boilerplate initialization mimicking your larger script.
    This sets up tags, restores any checkpoints, and initializes the advanced logger.
    """
    cfg.core.tags.append(f"clip_interpret")

    # Initialize template core for resuming experiments
    template_core = NNTemplateCore(restore_cfg=cfg.train.get("restore", None))
    # Initialize the NNLogger with advanced logging configuration
    logger = NNLogger(
        logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id
    )
    logger.upload_source()
    return logger, template_core

In [18]:
# Initialize boilerplate (advanced logging, tags, etc.)
logger, template_core = boilerplate(cfg)

  rank_zero_warn(


In [19]:
name = cfg.text_descriptions.replace(".txt", "")
output_path = os.path.join(cfg.misc.output_dir, f"{name}_{cfg.model}.npy")

if os.path.exists(output_path):
    pylogger.info(
        f"Output file already exists: {output_path}. Skipping computation."
    )
else:
    pylogger.info("Running text feature extraction...")
    run_text_features(cfg)

In [20]:
zeroshot_encoder_statedict = load_model_from_disk(cfg.misc.pretrained_checkpoint)

finetuned_name = (
    lambda name: Path(cfg.misc.ckpt_path) / f"{name}Val" / "nonlinear_finetuned.pt"
)
finetuned_models = {
    dataset: load_model_from_disk(finetuned_name(dataset))
    for dataset in cfg.task_vectors.to_apply
}

task_dicts = {}
for dataset in cfg.task_vectors.to_apply:
    task_dicts[dataset] = compute_task_dict(
        zeroshot_encoder_statedict, finetuned_models[dataset]
    )
    del finetuned_models[dataset]  # Delete one model at a time
    torch.cuda.empty_cache()

svd_dict = get_svd_dict(
    task_dicts, cfg.eval_datasets, cfg.misc.svd_path, cfg.svd_compress_factor
)

model, _, preprocess = open_clip.create_model_and_transforms(
    cfg.model, pretrained="openai", cache_dir=cfg.misc.openclip_cachedir
)
model.to(cfg.device)
all_images = set()

# Load text features from file
name = cfg.text_descriptions.replace(".txt", "")
text_features_path = os.path.join(cfg.misc.output_dir, f"{name}_{cfg.model}.npy")

with open(text_features_path, "rb") as f:
    text_features = np.load(f)
pylogger.info(f"Loaded text features from {text_features_path}")

# Load text descriptions (each line is one text)
text_file = os.path.join(cfg.misc.description_dir, f"{cfg.text_descriptions}")
with open(text_file, "r") as f:
    lines = [line.strip() for line in f.readlines()]
pylogger.info(f"Loaded text descriptions from {text_file}")

In [21]:

# Prepare a wandb Table for fancy terminal logging
results_table = wandb.Table(columns=["Task", "Top Texts"])

# Optionally also write the results to a file
output_file = os.path.join(
    cfg.misc.output_dir,
    f"{cfg.dataset}_completeness_{cfg.text_descriptions}_top_{cfg.texts_per_task}_heads_{cfg.model}.txt",
)

## TextSpan

In [22]:
@torch.no_grad()
def replace_with_iterative_removal(data, text_features, texts, iters, device):
    
    results = []
    vh = data  # in our case we already have "vectors"...
    text_features = (
        vh.T.dot(np.linalg.inv(vh.dot(vh.T)).dot(vh)).dot(text_features.T).T
    )  # Project the text to the span of W_OV
    
    data = torch.from_numpy(data).float().to(device)
    mean_data = data.mean(dim=0, keepdim=True)
    data = data - mean_data
    
    reconstruct = einops.repeat(mean_data, "A B -> (C A) B", C=data.shape[0])
    reconstruct = reconstruct.detach().cpu().numpy()
    text_features = torch.from_numpy(text_features).float().to(device)

    for i in range(iters):

        projection = data @ text_features.T
        projection_std = projection.std(axis=0).detach().cpu().numpy()
        top_n = np.argmax(projection_std)
        results.append(texts[top_n])
        text_norm = text_features[top_n] @ text_features[top_n].T
        reconstruct += (
            (
                (data @ text_features[top_n] / text_norm)[:, np.newaxis]
                * text_features[top_n][np.newaxis, :]
            )
            .detach()
            .cpu()
            .numpy()
        )
        data = data - (
            (data @ text_features[top_n] / text_norm)[:, np.newaxis]
            * text_features[top_n][np.newaxis, :]
        )
        text_features = (
            text_features
            - (text_features @ text_features[top_n] / text_norm)[:, np.newaxis]
            * text_features[top_n][np.newaxis, :]
        )
    return reconstruct, results

In [23]:
output_str = ''

for task in svd_dict.keys():
    
    pylogger.info(f"Processing Task: {task}")

    output_str += f"------------------\n"
    output_str += f"V for task {task}\n"
    output_str += f"------------------\n"

    # Retrieve SVD components for the current task
    u = svd_dict[task][cfg.layer]["u"].to(cfg.device)
    s = torch.diag_embed(svd_dict[task][cfg.layer]["s"]).to(cfg.device)
    v = svd_dict[task][cfg.layer]["v"].to(cfg.device)
    # pylogger.info(f"v shape for task {task}: {v.shape}")

    # Compute the projected matrix
    v_proj = s @ v @ model.visual.proj

    # Apply the iterative removal procedure
    reconstruct, images = replace_with_iterative_removal(
        v_proj.detach().cpu().numpy(),
        text_features,
        lines,
        cfg.texts_per_task,
        cfg.device,
    )

    all_images |= set(images)
    for text in images:
        output_str += f"{text}\n"
    results_table.add_data(task, "\n".join(images))
    pylogger.info(f"Task {task}: {images}")

    break 

print(output_str)

------------------
V for task Cars
------------------
Image of a car
Nostalgic expressions
Artwork featuring Morse code typography
Image of a delivery van
Image snapped in the Swiss chocolate factories



## Nearest-neighbor

In [24]:
def get_nearest_text_descriptions(data, text_features, texts, device='cuda'):
    results = []
    vh = data  # in our case we already have "vectors"...
    text_features = (
        vh.T.dot(np.linalg.inv(vh.dot(vh.T)).dot(vh)).dot(text_features.T).T
    )  # Project the text to the span of W_OV
    
    data = torch.from_numpy(data).float().to(device)
    
    # mean_data = data.mean(dim=0, keepdim=True)
    # data = data - mean_data
    
    text_features = torch.from_numpy(text_features).float().to(device)

    projection = data @ text_features.T

    # take the top 5 texts
    top_1 = torch.argmax(projection)
    result = texts[top_1]

    return result


In [25]:
output_str = ''

for task in svd_dict.keys():
    
    pylogger.info(f"Processing Task: {task}")

    output_str += f"------------------\n"
    output_str += f"V for task {task}\n"
    output_str += f"------------------\n"

    # Retrieve SVD components for the current task
    u = svd_dict[task][cfg.layer]["u"].to(cfg.device)
    s = torch.diag_embed(svd_dict[task][cfg.layer]["s"]).to(cfg.device)
    v = svd_dict[task][cfg.layer]["v"].to(cfg.device)
    # pylogger.info(f"v shape for task {task}: {v.shape}")

    # Compute the projected matrix
    # shall we have s here?
    # v_proj = s @ v @ model.visual.proj
    v_proj = s @ v @ model.visual.proj

    all_captions = []

    for v in v_proj:
        # Apply the iterative removal procedure
        captions = get_nearest_text_descriptions(
            v.unsqueeze(0).detach().cpu().numpy(),
            text_features,
            lines,
            cfg.device,
        )

        all_captions.append(captions)

    pylogger.info(f"Task {task}: {all_captions}")

    for text in all_captions:
        output_str += f"{text}\n"


In [26]:
output_path = Path(cfg.misc.output_dir) / f'nearest_neighbor_{cfg.model}.txt'

Path(cfg.misc.output_dir).mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
    f.write(output_str)