In [1]:
import sys
import logging
import os


sys.path.append("..")
%load_ext autoreload
%autoreload 2


logger = logging.getLogger()
logger.setLevel(logging.INFO)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.95'

In [2]:
import pickle
from tqdm import tqdm

import jax
from jax import random
import jax.numpy as jnp
import numpy as np

from configs import finetune
from data import input_pipeline
import models
import utils
from utils import restore_checkpoint
from permutations import *
from constants import *
from pruning import apply_mask

from configs.eval_merge import get_config
import merging

2024-10-11 14:59:03.067388: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-11 14:59:03.080694: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-11 14:59:03.080711: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-11 14:59:04.365362: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skippi

In [3]:
from trainer import update_batch_stats
import functools
from flax.training import common_utils
from trainer import compute_metrics

def get_metrics(metrics):
    return common_utils.stack_forest(metrics)


@functools.partial(jax.jit, static_argnums=(0,))
def eval_step(apply_fn, params, batch):
    logits = apply_fn(params, batch['image'], deterministic=True)
    metrics = compute_metrics(logits, batch['label'])
    return metrics


def compute_loss_and_accuracy(apply_fn, params, dataset, nbatches=None):
    eval_iter = input_pipeline.prefetch(dataset, 10, None)
    eval_metrics = []
    ix = 0
    for eval_batch in eval_iter:
        metrics = eval_step(apply_fn, params, eval_batch)
        eval_metrics.append(metrics)
        ix+=1
        if nbatches is not None:
            if ix >= nbatches:
                break

    eval_metrics = get_metrics(eval_metrics)
    summary = {
            f'eval_{k}': v
            for k, v in jax.tree_util.tree_map(
            lambda x: x.mean(), eval_metrics
            ).items()
    }

    return summary['eval_loss'], summary['eval_accuracy']



norm_acc = lambda acc, expert_acc: acc / expert_acc

In [4]:
model_name = 'ViTB16'
ntasks = 8
tasks = TASKS_A[:ntasks//2] + TASKS_B[:ntasks//2]
datasets = ".".join(tasks)
config = get_config(f"{model_name},{datasets},average-merging")

In [5]:
if model_name == 'VGG16':
  config.width_multiplier = 2
  for dataset in DATASETS:
    config[dataset].pp.crop = crop = 96
    config[dataset].model_dir = VGG16X2_A[dataset]
  config.local_init = VGG16X2_NONLOCAL_TASK_A_INIT
elif model_name == 'ViTB16':
  for dataset in DATASETS:
    config[dataset].pp.crop = crop = 224
    config[dataset].model_dir = VITB16[dataset]
  config.local_init = VITB16_INIT
elif model_name == 'ViTmaeB16':
  for dataset in DATASETS:
    config[dataset].pp.crop = crop = 224
    config[dataset].model_dir = VITMAEB16[dataset]
  config.local_init = VITMAEB16_INIT


In [6]:
datasets = config.datasets
dataset_info_ls = [input_pipeline.get_dataset_info(dataset, config[dataset].pp['train']) for dataset in config.datasets]
num_classes = [dataset_info['num_classes'] for dataset_info in dataset_info_ls] 
num_train_examples = [dataset_info['num_examples'] for dataset_info in dataset_info_ls]

ds_train_ls, ds_test_ls = input_pipeline.get_datasets_for_mtl(config, datasets, batch_size=128)

model_ls = []
model_tracker_ls = []
model_repaired_ls = []
param_dirs = []

for i, nclass in enumerate(num_classes): 
  model_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier)]
  model_tracker_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier, tracker=True)]
  model_repaired_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier, repaired=True)]


INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2
INFO:absl:For 'cifar10/3.0.2': fields info.[splits, supervised_keys] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: eurosat/rgb
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/oxford_iiit_pet/3.2.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/resisc45/3.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar100/3.0.2
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cassava/0.1.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/svhn_cropped/3.0.0
INFO:absl:For 'svhn_cropped/3.0.0': fields info.[description, c

In [7]:
## Load expert parameters
expert_params_ls = []

for ix, dataset in enumerate(datasets):
  logging.info(f"Loading expert models for {dataset} from {config[dataset].model_dir}.")
  raw_state = restore_checkpoint(config[dataset].model_dir)
  expert_params = {'params': raw_state['params']}
  expert_params_ls += [expert_params]


INFO:root:Loading expert models for colorectal_histology from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/colorectal_histology/sgd/steps_10000_500/cosine_1e-3/seed_0.
INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/colorectal_histology/sgd/steps_10000_500/cosine_1e-3/seed_0/checkpoint_10000
INFO:absl:Restoring item from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/colorectal_histology/sgd/steps_10000_500/cosine_1e-3/seed_0/checkpoint_10000.
INFO:absl:Finished restoring checkpoint from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/colorectal_histology/sgd/steps_10000_500/cosine_1e-3/seed_0/checkpoint_10000.
INFO:root:Loading expert models for cifar10 from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/cifar10/sgd/steps_10000_500/cosine_1e-3/seed_0.
INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-m

In [8]:
## Load init parameters
raw_state_init = restore_checkpoint(config.local_init)
init_params = {'params': raw_state_init['params']}


INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/cassava/sgd/steps_10000_500/cosine_1e-3/seed_0/init/checkpoint_0
INFO:absl:Restoring item from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/cassava/sgd/steps_10000_500/cosine_1e-3/seed_0/init/checkpoint_0.


INFO:absl:Finished restoring checkpoint from /home/ekansh/repos/share/vision-models/experiments/ViTB16/hugging_face/cassava/sgd/steps_10000_500/cosine_1e-3/seed_0/init/checkpoint_0.


In [9]:
init_params['params'].pop('classifier')
classifiers_ls = [expert_params['params'].pop('classifier') for expert_params in expert_params_ls]

In [10]:
## Expert accuracies:
expert_loss = {}
expert_accs = {}

for ix, dataset in enumerate(datasets):
  expert_params_ls[ix]['params']['classifier'] = classifiers_ls[ix]
  loss, accuracy = compute_loss_and_accuracy(model_ls[ix].apply, expert_params_ls[ix], ds_test_ls[ix])
  expert_loss[dataset] = loss
  expert_accs[dataset] = accuracy
  expert_params_ls[ix]['params'].pop('classifier')
  print(f"{dataset} expert accuracy = {accuracy}")
  print(f"{dataset} expert loss = {loss}")


2024-10-11 14:59:20.276972: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


colorectal_histology expert accuracy = 0.9713541865348816
colorectal_histology expert loss = 0.11431100964546204


2024-10-11 14:59:39.292576: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar10 expert accuracy = 0.9886819124221802
cifar10 expert loss = 0.03868929296731949


2024-10-11 14:59:45.405243: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


eurosat expert accuracy = 0.9866071343421936
eurosat expert loss = 0.03653038293123245


2024-10-11 14:59:52.686324: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


oxford_iiit_pet expert accuracy = 0.9359375238418579
oxford_iiit_pet expert loss = 0.23676452040672302


2024-10-11 14:59:59.561591: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


resisc45 expert accuracy = 0.9518229365348816
resisc45 expert loss = 0.15293549001216888


2024-10-11 15:00:18.668679: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar100 expert accuracy = 0.9278846383094788
cifar100 expert loss = 0.24418126046657562


2024-10-11 15:00:23.428058: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cassava expert accuracy = 0.87890625
cassava expert loss = 0.5140537619590759


2024-10-11 15:01:10.803897: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


svhn_cropped expert accuracy = 0.951662540435791
svhn_cropped expert loss = 0.17149339616298676


In [11]:
tall_mask = False
task_vectors = [utils.tree_subtract(params['params'], init_params['params']) for params  in expert_params_ls]  

## Task-arithmetic model-merging: 
# print("TA")
# lam = 0.4
# t_MTL = merging.compute_task_arithmetic_vector(task_vectors, lam)
# merged_params  = {'params': utils.tree_add(t_MTL, init_params['params'])}

## TIES model-merging: 
# print("TIES")
# lam=0.8
# t_MTL = merging.compute_ties_vector(task_vectors, lam)
# merged_params  = {'params': utils.tree_add(t_MTL, init_params['params'])}

## TA+TALL model-merging: 
print("TA+TALL")
lam=0.5
t_MTL = merging.compute_ties_vector(task_vectors, lam)
t_masks_ls = [merging.compute_tall_mask(task_vector, t_MTL) for task_vector in task_vectors]
tall_mask = True

## TIES+TALL model-merging: 
# print("TIES")
# lam=0.8
# t_MTL = merging.compute_ties_vector(task_vectors, lam)
# t_masks_ls = [merging.compute_tall_mask(task_vector, t_MTL) for task_vector in task_vectors]
# tall_mask = True



TA+TALL


In [12]:
# Evaluation

merged_accuracies = {}
merged_normalized_accuracies = {}
for ix, dataset in enumerate(datasets):
  if tall_mask:
    merged_params = {'params' : utils.tree_add(init_params['params'], apply_mask(t_MTL, t_masks_ls[ix]))}
  merged_params['params']['classifier'] = classifiers_ls[ix]
  loss, accuracy = compute_loss_and_accuracy(model_ls[ix].apply, merged_params, ds_test_ls[ix])
  merged_accuracies[dataset] = accuracy
  merged_normalized_accuracies[dataset] = norm_acc(accuracy, expert_accs[dataset])
  merged_params['params'].pop('classifier')
  print(f'{dataset}')
  print(f"merged accuracy = {accuracy}")
  print(f"normalized accuracy = {norm_acc(accuracy, expert_accs[dataset])}")


2024-10-11 15:01:25.520331: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


colorectal_histology
merged accuracy = 0.8854166865348816
normalized accuracy = 0.9115281701087952


2024-10-11 15:01:40.270881: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar10
merged accuracy = 0.9820713400840759
normalized accuracy = 0.993313729763031


2024-10-11 15:01:44.836987: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


eurosat
merged accuracy = 0.9583333134651184
normalized accuracy = 0.9713423848152161


2024-10-11 15:01:50.264191: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


oxford_iiit_pet
merged accuracy = 0.9290624856948853
normalized accuracy = 0.9926543831825256


2024-10-11 15:01:55.492136: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


resisc45
merged accuracy = 0.8541666865348816
normalized accuracy = 0.8974007964134216


2024-10-11 15:02:11.940054: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar100
merged accuracy = 0.8890224099159241
normalized accuracy = 0.9581173658370972


2024-10-11 15:02:15.136629: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cassava
merged accuracy = 0.81640625
normalized accuracy = 0.9288889169692993


2024-10-11 15:02:57.801059: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


svhn_cropped
merged accuracy = 0.5591133236885071
normalized accuracy = 0.587512195110321


In [13]:
## Compute output statistics for expert parameters
expert_batch_stats_ls = []
rng = random.PRNGKey(0)
for ix, dataset in enumerate(datasets):
  expert_params_ls[ix]['params']['classifier'] = classifiers_ls[ix]
  params_tracker = models.load_from_source(model_tracker_ls[ix].initialization(rng, (1, crop, crop, 3)), {'params': expert_params_ls[ix]['params']})
  params_tracker = update_batch_stats(model_tracker_ls[ix].apply, params_tracker, ds_train_ls[ix], nbatches=500)
  expert_batch_stats_ls += [params_tracker['batch_stats']]
  expert_params_ls[ix]['params'].pop('classifier')


2024-10-11 15:03:04.891045: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 770744320 bytes after encountering the first element of size 77074432 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt J

In [14]:
tact_accuracies = {}
tact_normalized_accuracies = {}

for ix, dataset in enumerate(datasets):
  print(f"Computing TACT accuracy for {dataset}")
  if tall_mask:
    merged_params = {'params' : utils.tree_add(init_params['params'], apply_mask(t_MTL, t_masks_ls[ix]))}
  merged_params_tact = models.load_from_source(model_repaired_ls[ix].initialization(rng, (1, crop, crop, 3)), merged_params)
  merged_params_tact = models.set_batch_norm_params_from_batch_stats(merged_params_tact, expert_batch_stats_ls[ix])
  merged_params_tact['params']['classifier'] = classifiers_ls[ix]
  merged_params_tact = update_batch_stats(model_repaired_ls[ix].apply, merged_params_tact,  ds_train_ls[ix], nbatches=500)
  loss, accuracy = compute_loss_and_accuracy(model_repaired_ls[ix].apply, merged_params_tact, ds_test_ls[ix])
  tact_accuracies[dataset] = accuracy
  tact_normalized_accuracies[dataset] = norm_acc(accuracy, expert_accs[dataset])
  print(f"accuracy = {tact_accuracies[dataset]}")
  print(f"normalized accuracy = {tact_normalized_accuracies[dataset]}")


Computing TACT accuracy for colorectal_histology


2024-10-11 15:20:16.345234: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9375
normalized accuracy = 0.9651474356651306
Computing TACT accuracy for cifar10


2024-10-11 15:22:28.260773: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9850761294364929
normalized accuracy = 0.9963529109954834
Computing TACT accuracy for eurosat


2024-10-11 15:24:29.941816: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9713541865348816
normalized accuracy = 0.9845399856567383
Computing TACT accuracy for oxford_iiit_pet


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 15:25:20.690608: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 329372 bytes after encountering the first element of size 329372 bytes.This already causes the autotune ram budget to be e

accuracy = 0.9318749904632568
normalized accuracy = 0.9956594109535217
Computing TACT accuracy for resisc45


2024-10-11 15:28:41.094999: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9091796875
normalized accuracy = 0.955198347568512
Computing TACT accuracy for cifar100


2024-10-11 15:30:55.071808: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9099559187889099
normalized accuracy = 0.9806778430938721
Computing TACT accuracy for cassava


2024-10-11 15:33:02.920560: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.8381696343421936
normalized accuracy = 0.9536507725715637
Computing TACT accuracy for svhn_cropped


2024-10-11 15:35:43.824877: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.8748460412025452
normalized accuracy = 0.9192817807197571


In [15]:
for ix, dataset in enumerate(datasets):
  print(f"Dataset: {dataset}")
  print(f"\tExpert accuracy = {expert_accs[dataset]}; Merged model accuracy = {merged_accuracies[dataset]}; Merged + TACT accuracy = {tact_accuracies[dataset]}")
  print(f"")

Dataset: colorectal_histology
	Expert accuracy = 0.9713541865348816; Merged model accuracy = 0.8854166865348816; Merged + TACT accuracy = 0.9375

Dataset: cifar10
	Expert accuracy = 0.9886819124221802; Merged model accuracy = 0.9820713400840759; Merged + TACT accuracy = 0.9850761294364929

Dataset: eurosat
	Expert accuracy = 0.9866071343421936; Merged model accuracy = 0.9583333134651184; Merged + TACT accuracy = 0.9713541865348816

Dataset: oxford_iiit_pet
	Expert accuracy = 0.9359375238418579; Merged model accuracy = 0.9290624856948853; Merged + TACT accuracy = 0.9318749904632568

Dataset: resisc45
	Expert accuracy = 0.9518229365348816; Merged model accuracy = 0.8541666865348816; Merged + TACT accuracy = 0.9091796875

Dataset: cifar100
	Expert accuracy = 0.9278846383094788; Merged model accuracy = 0.8890224099159241; Merged + TACT accuracy = 0.9099559187889099

Dataset: cassava
	Expert accuracy = 0.87890625; Merged model accuracy = 0.81640625; Merged + TACT accuracy = 0.83816963434219