# NegMerge Tutorial

## 1. Import Requirements

In [3]:
import torch
import os
import json
import argparse
import sys
import timm.data.transforms
import abc

In [None]:
if 'ipykernel' in sys.modules:
    sys.argv = ['']

class MaybeToTensor:
    def __call__(self, x):
        return x
timm.data.transforms.MaybeToTensor = MaybeToTensor
device = torch.device("cpu")

## 2. Define Configuration

In [None]:
def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_location", type=str, default=os.path.expanduser("~/data"), help="The root directory for the datasets.")
    parser.add_argument("--eval-datasets", default=None, type=lambda x: x.split(","), help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT.")
    parser.add_argument("--results_db", type=str, default=None, help="Where to store the results, else does not store")
    parser.add_argument("--model", type=str, default="ViT-B-32", help="The type of model (e.g. RN50, ViT-B-32).")
    parser.add_argument("--save", type=str, default=None, help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.")
    parser.add_argument("--seed", type=int, default=None, help="Random seed.")
    parser.add_argument("--finetuning_mode", choices=["standard", "linear", "none"], help="Whether to use linearized models or not.")
    parser.add_argument("--n-eval-points", type=int, default=21, help="Number of evaluation points used to find optimal coefficient in task arithmetic.")

    parsed_args = parser.parse_args()
    parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"

    if parsed_args.load is not None and len(parsed_args.load) == 1:
        parsed_args.load = parsed_args.load[0]
        
    return parsed_args

In [None]:
args = parse_arguments()

args.data_location = "dataset/"
args.finetuning_mode = "standard"       # "linear" or "standard"
args.model = "ViT-B-32"                 # Backbone
args.results_db = "checkpoints"
args.save = os.path.join(args.results_db, args.finetuning_mode, args.model)

dataset = "Cars"                        # Forget set
control_dataset = "ImageNet"            # Retain set

with open(os.path.join("/path/to/zeroshot_accuracies.json")) as f:
    pretrained_accuracies = json.load(f)
negation_accuracies = {}

## 3. Dowload Pretrained and Fine-tuned Weights
- Download Link: https://drive.google.com/drive/u/1/folders/1m1iHi5KoTN1Fg5JqIZxtVP1ZTxgILZyi

In [None]:
pretrained_path = '/path/to/zeroshot.pt'
finetuned_paths = [
    '/path/to/clip-vit-b-32_cars_rand-m1-n1_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m1-n2_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m1-n3_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m2-n1_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m2-n2_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m2-n3_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m3-n1_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m3-n2_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m3-n3_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m4-n1_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m4-n2_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m4-n3_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m5-n1_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m5-n2_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m5-n3_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m6-n1_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m6-n2_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m6-n3_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m7-n1_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m7-n2_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m7-n3_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m8-n1_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m8-n2_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m8-n3_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m9-n1_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m9-n2_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m9-n3_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m10-n1_finetuned.pt',
    '/path/to/clip-vit-b-32_cars_rand-m10-n2_finetuned.pt', '/path/to/clip-vit-b-32_cars_rand-m10-n3_finetuned.pt'
]

## 4. Define Task Vector Class

In [8]:
class _TaskVector(abc.ABC):
    def __init__(
        self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None
    ):
        if vector is not None:
            self.vector = vector
        else:
            assert (
                pretrained_checkpoint is not None and finetuned_checkpoint is not None
            )
            with torch.no_grad():
                if isinstance(pretrained_checkpoint, dict):
                    pretrained_state_dict = pretrained_checkpoint
                else:
                    pretrained_state_dict = self._load_checkpoint(
                        pretrained_checkpoint
                    ).state_dict()

                if isinstance(finetuned_checkpoint, dict):
                    finetuned_state_dict = finetuned_checkpoint
                else:
                    finetuned_state_dict = self._load_checkpoint(
                        finetuned_checkpoint
                    ).state_dict()

                self.vector = {}
                for key in pretrained_state_dict:
                    if pretrained_state_dict[key].dtype == torch.int64:
                        continue
                    if pretrained_state_dict[key].dtype == torch.uint8:
                        continue
                    self.vector[key] = (
                        finetuned_state_dict[key] - pretrained_state_dict[key]
                    )

    @abc.abstractmethod
    def _load_checkpoint(self, checkpoint):
        """Load a checkpoint into a model."""
        raise NotImplementedError

    @abc.abstractmethod
    def _cast_to_same_type(self, other):
        raise NotImplementedError

    def __add__(self, other):
        """Add two task vectors together."""
        other = self._cast_to_same_type(other)
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                if key not in other.vector:
                    print(f"Warning, key {key} is not present in both task vectors.")
                    continue
                new_vector[key] = self.vector[key] + other.vector[key]
        return self.__class__(vector=new_vector)

    def __sub__(self, other):
        """Subtract two task vectors."""
        return self.__add__(-other)

    def __radd__(self, other):
        if other is None or isinstance(other, int):
            return self
        return self.__add__(other)

    def __neg__(self):
        """Negate a task vector."""
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                new_vector[key] = -self.vector[key]
        return self.__class__(vector=new_vector)

    def __pow__(self, power):
        """Power of a task vector."""
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                new_vector[key] = self.vector[key] ** power
        return self.__class__(vector=new_vector)

    def __mul__(self, other):
        """Multiply a task vector by a scalar."""
        with torch.no_grad():
            new_vector = {}
            for key in self.vector:
                new_vector[key] = other * self.vector[key]
        return self.__class__(vector=new_vector)

    def dot(self, other):
        """Dot product of two task vectors."""
        other = self._cast_to_same_type(other)
        with torch.no_grad():
            dot_product = 0.0
            for key in self.vector:
                if key not in other.vector:
                    print(f"Warning, key {key} is not present in both task vectors.")
                    continue
                dot_product += torch.sum(self.vector[key] * other.vector[key])
        return dot_product

    def norm(self):
        """Norm of a task vector."""
        return torch.sqrt(self.dot(self))

    def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
        """Apply a task vector to a pretrained model."""
        with torch.no_grad():
            pretrained_model = self._load_checkpoint(pretrained_checkpoint)
            new_state_dict = {}
            pretrained_state_dict = pretrained_model.state_dict()
            for key in pretrained_state_dict:
                if key not in self.vector:
                    print(
                        f"Warning: key {key} is present in the pretrained state dict but not in the task vector"  # noqa: E501
                    )
                    continue
                new_state_dict[key] = (
                    pretrained_state_dict[key] + scaling_coef * self.vector[key]
                )
        pretrained_model.load_state_dict(new_state_dict)
        return pretrained_model


class NonLinearTaskVector(_TaskVector):
    """A task vector for nonlinear models."""

    def _load_checkpoint(self, checkpoint):
        """Load a checkpoint into a model."""
        return torch.load(checkpoint, map_location="cpu")

    def apply_to_nonlinear(self, pretrained_nonlinear_checkpoint, scaling_coef=1.0):
        """Apply a task vector to a nonlinear pretrained model."""
        return self.apply_to(pretrained_nonlinear_checkpoint, scaling_coef)
    
    def _cast_to_same_type(self, other):
        return linear_to_nonlinear(other, self.vector.keys())

def linear_to_nonlinear(linear_task_vector, param_names):
    """Convert a linear task vector to a nonlinear task vector."""
    if isinstance(linear_task_vector, NonLinearTaskVector):
        return linear_task_vector
    else:
        return NonLinearTaskVector(
            vector=linear_task_vector.get_named_parameters(param_names)
        )

## 5. Merge Task Vectors

In [9]:
for idx, finetuned_path in enumerate(finetuned_paths):
    state_dict = torch.load(finetuned_path, map_location=device)
    state_dict = {k: v.to(device) for k, v in state_dict.items()}
        
    task_vector = (NonLinearTaskVector(pretrained_path, state_dict))

    if idx == 0:
        merged_vector = {k: torch.zeros_like(v) for k, v in task_vector.vector.items()}
        mask = {k: torch.zeros_like(v) for k, v in task_vector.vector.items()}

    for key in task_vector.vector.keys():
        merged_vector[key] += task_vector.vector[key]
        mask[key] += torch.sign(task_vector.vector[key])

for key in torch.load(finetuned_path).keys():
    consistency_mask = torch.abs(mask[key]) == len(finetuned_paths)
    task_vector.vector[key] = torch.where(consistency_mask, merged_vector[key] / len(finetuned_paths), torch.zeros_like(merged_vector[key]))

## 6. Evaluate

### 6.1. Find Optimal Coefficient

In [10]:
from src.eval import evaluate_task_vector, evaluate_task_vector_at_coef
from src.utils import find_optimal_coef

args.eval_datasets = [dataset + "Val"]
args.control_dataset = control_dataset + "Val"
val_metrics = evaluate_task_vector(
    -task_vector,
    pretrained_path,
    args,
)

optimal_coef = find_optimal_coef(
    val_metrics,
    metric=f"{dataset}Val:top1",
    minimize=True,
    control_metric=f"{control_dataset}Val:top1",
    control_metric_threshold=0.95 * pretrained_accuracies[control_dataset + "Val"],
)

Evaluating for scaling coefficient 0.00
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:06<00:00,  1.07it/s]


Done evaluating on CarsVal. Accuracy: 59.58%
CarsVal Top-1 accuracy: 0.5958
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


Done evaluating on ImageNetVal. Accuracy: 66.70%
ImageNetVal Top-1 accuracy: 0.6670
Evaluating for scaling coefficient 0.05
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.19it/s]


Done evaluating on CarsVal. Accuracy: 58.11%
CarsVal Top-1 accuracy: 0.5811
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.72it/s]


Done evaluating on ImageNetVal. Accuracy: 66.66%
ImageNetVal Top-1 accuracy: 0.6666
Evaluating for scaling coefficient 0.10
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.38it/s]


Done evaluating on CarsVal. Accuracy: 56.27%
CarsVal Top-1 accuracy: 0.5627
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


Done evaluating on ImageNetVal. Accuracy: 66.50%
ImageNetVal Top-1 accuracy: 0.6650
Evaluating for scaling coefficient 0.15
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.37it/s]


Done evaluating on CarsVal. Accuracy: 54.05%
CarsVal Top-1 accuracy: 0.5405
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


Done evaluating on ImageNetVal. Accuracy: 66.66%
ImageNetVal Top-1 accuracy: 0.6666
Evaluating for scaling coefficient 0.20
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.27it/s]


Done evaluating on CarsVal. Accuracy: 52.09%
CarsVal Top-1 accuracy: 0.5209
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.70it/s]


Done evaluating on ImageNetVal. Accuracy: 66.50%
ImageNetVal Top-1 accuracy: 0.6650
Evaluating for scaling coefficient 0.25
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.35it/s]


Done evaluating on CarsVal. Accuracy: 50.12%
CarsVal Top-1 accuracy: 0.5012
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


Done evaluating on ImageNetVal. Accuracy: 66.32%
ImageNetVal Top-1 accuracy: 0.6632
Evaluating for scaling coefficient 0.30
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.27it/s]


Done evaluating on CarsVal. Accuracy: 47.54%
CarsVal Top-1 accuracy: 0.4754
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.76it/s]


Done evaluating on ImageNetVal. Accuracy: 66.14%
ImageNetVal Top-1 accuracy: 0.6614
Evaluating for scaling coefficient 0.35
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.19it/s]


Done evaluating on CarsVal. Accuracy: 45.82%
CarsVal Top-1 accuracy: 0.4582
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


Done evaluating on ImageNetVal. Accuracy: 66.24%
ImageNetVal Top-1 accuracy: 0.6624
Evaluating for scaling coefficient 0.40
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]


Done evaluating on CarsVal. Accuracy: 43.98%
CarsVal Top-1 accuracy: 0.4398
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.85it/s]


Done evaluating on ImageNetVal. Accuracy: 66.18%
ImageNetVal Top-1 accuracy: 0.6618
Evaluating for scaling coefficient 0.45
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:04<00:00,  1.43it/s]


Done evaluating on CarsVal. Accuracy: 42.26%
CarsVal Top-1 accuracy: 0.4226
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.77it/s]


Done evaluating on ImageNetVal. Accuracy: 66.02%
ImageNetVal Top-1 accuracy: 0.6602
Evaluating for scaling coefficient 0.50
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.38it/s]


Done evaluating on CarsVal. Accuracy: 40.54%
CarsVal Top-1 accuracy: 0.4054
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.71it/s]


Done evaluating on ImageNetVal. Accuracy: 65.70%
ImageNetVal Top-1 accuracy: 0.6570
Evaluating for scaling coefficient 0.55
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]


Done evaluating on CarsVal. Accuracy: 38.94%
CarsVal Top-1 accuracy: 0.3894
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.74it/s]


Done evaluating on ImageNetVal. Accuracy: 65.32%
ImageNetVal Top-1 accuracy: 0.6532
Evaluating for scaling coefficient 0.60
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.22it/s]


Done evaluating on CarsVal. Accuracy: 37.59%
CarsVal Top-1 accuracy: 0.3759
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


Done evaluating on ImageNetVal. Accuracy: 65.32%
ImageNetVal Top-1 accuracy: 0.6532
Evaluating for scaling coefficient 0.65
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]


Done evaluating on CarsVal. Accuracy: 36.24%
CarsVal Top-1 accuracy: 0.3624
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.71it/s]


Done evaluating on ImageNetVal. Accuracy: 65.12%
ImageNetVal Top-1 accuracy: 0.6512
Evaluating for scaling coefficient 0.70
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]


Done evaluating on CarsVal. Accuracy: 34.52%
CarsVal Top-1 accuracy: 0.3452
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.79it/s]


Done evaluating on ImageNetVal. Accuracy: 64.90%
ImageNetVal Top-1 accuracy: 0.6490
Evaluating for scaling coefficient 0.75
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.27it/s]


Done evaluating on CarsVal. Accuracy: 33.05%
CarsVal Top-1 accuracy: 0.3305
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.76it/s]


Done evaluating on ImageNetVal. Accuracy: 64.74%
ImageNetVal Top-1 accuracy: 0.6474
Evaluating for scaling coefficient 0.80
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]


Done evaluating on CarsVal. Accuracy: 30.22%
CarsVal Top-1 accuracy: 0.3022
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.77it/s]


Done evaluating on ImageNetVal. Accuracy: 64.16%
ImageNetVal Top-1 accuracy: 0.6416
Evaluating for scaling coefficient 0.85
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.39it/s]


Done evaluating on CarsVal. Accuracy: 28.62%
CarsVal Top-1 accuracy: 0.2862
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.79it/s]


Done evaluating on ImageNetVal. Accuracy: 63.72%
ImageNetVal Top-1 accuracy: 0.6372
Evaluating for scaling coefficient 0.90
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.35it/s]


Done evaluating on CarsVal. Accuracy: 25.55%
CarsVal Top-1 accuracy: 0.2555
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.78it/s]


Done evaluating on ImageNetVal. Accuracy: 63.54%
ImageNetVal Top-1 accuracy: 0.6354
Evaluating for scaling coefficient 0.95
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.19it/s]


Done evaluating on CarsVal. Accuracy: 24.20%
CarsVal Top-1 accuracy: 0.2420
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:14<00:00,  2.75it/s]


Done evaluating on ImageNetVal. Accuracy: 63.20%
ImageNetVal Top-1 accuracy: 0.6320
Evaluating for scaling coefficient 1.00
Evaluating on CarsVal
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 7/7 [00:05<00:00,  1.23it/s]


Done evaluating on CarsVal. Accuracy: 22.60%
CarsVal Top-1 accuracy: 0.2260
Evaluating on ImageNetVal
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 40/40 [00:13<00:00,  2.86it/s]

Done evaluating on ImageNetVal. Accuracy: 62.74%
ImageNetVal Top-1 accuracy: 0.6274
Control metric fell below 0.6332699999999999 threshold
Control metric fell below 0.6332699999999999 threshold





### 6.2. Evaluate on the test set with the optimal coefficient.

In [11]:
args.eval_datasets = [dataset]
args.control_dataset = control_dataset
test_metrics = evaluate_task_vector_at_coef(
    -task_vector,
    pretrained_path,
    args,
    optimal_coef,
)

print("=" * 100)
print(f"Test accuracy: {test_metrics[f'{dataset}:top1']}")

negation_accuracies[dataset] = {
    "test": test_metrics[f"{dataset}:top1"],
    "test_control": test_metrics[f"{control_dataset}:top1"],
    "val": val_metrics,
}

print(negation_accuracies[dataset])

Evaluating on Cars
Classification head for ViT-B-32 on CarsVal exists at checkpoints/standard/ViT-B-32/head_CarsVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_CarsVal.pt


100%|██████████| 63/63 [00:33<00:00,  1.91it/s]


Done evaluating on Cars. Accuracy: 27.40%
Cars Top-1 accuracy: 0.2740
Evaluating on ImageNet
Classification head for ViT-B-32 on ImageNetVal exists at checkpoints/standard/ViT-B-32/head_ImageNetVal.pt
Loading classification head from checkpoints/standard/ViT-B-32/head_ImageNetVal.pt


100%|██████████| 391/391 [03:30<00:00,  1.86it/s]

Done evaluating on ImageNet. Accuracy: 60.38%
ImageNet Top-1 accuracy: 0.6038
Test accuracy: 0.27397089914189776
{'test': 0.27397089914189776, 'test_control': 0.60378, 'val': {0.0: {'CarsVal:top1': 0.5958230958230958, 'ImageNetVal:top1': 0.667}, 0.05: {'CarsVal:top1': 0.581081081081081, 'ImageNetVal:top1': 0.6666}, 0.1: {'CarsVal:top1': 0.5626535626535627, 'ImageNetVal:top1': 0.665}, 0.15000000000000002: {'CarsVal:top1': 0.5405405405405406, 'ImageNetVal:top1': 0.6666}, 0.2: {'CarsVal:top1': 0.5208845208845209, 'ImageNetVal:top1': 0.665}, 0.25: {'CarsVal:top1': 0.5012285012285013, 'ImageNetVal:top1': 0.6632}, 0.30000000000000004: {'CarsVal:top1': 0.47542997542997545, 'ImageNetVal:top1': 0.6614}, 0.35000000000000003: {'CarsVal:top1': 0.4582309582309582, 'ImageNetVal:top1': 0.6624}, 0.4: {'CarsVal:top1': 0.4398034398034398, 'ImageNetVal:top1': 0.6618}, 0.45: {'CarsVal:top1': 0.4226044226044226, 'ImageNetVal:top1': 0.6602}, 0.5: {'CarsVal:top1': 0.40540540540540543, 'ImageNetVal:top1': 0.6


