In [1]:
import sys
import logging
import os


sys.path.append("..")
%load_ext autoreload
%autoreload 2


logger = logging.getLogger()
logger.setLevel(logging.INFO)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.95'

In [2]:
import pickle
from tqdm import tqdm

import jax
from jax import random
import jax.numpy as jnp
import numpy as np

from configs import finetune
from data import input_pipeline
import models
import utils
from utils import restore_checkpoint
from permutations import *
from constants import *
from pruning import apply_mask
from configs.eval_merge import get_config
import merging

2024-10-11 14:28:14.205614: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-11 14:28:14.218716: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-11 14:28:14.218736: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-11 14:28:15.554539: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skippi

In [3]:
from trainer import update_batch_stats
import functools
from flax.training import common_utils
from trainer import compute_metrics

def get_metrics(metrics):
    return common_utils.stack_forest(metrics)


@functools.partial(jax.jit, static_argnums=(0,))
def eval_step(apply_fn, params, batch):
    logits = apply_fn(params, batch['image'], deterministic=True)
    metrics = compute_metrics(logits, batch['label'])
    return metrics


def compute_loss_and_accuracy(apply_fn, params, dataset, nbatches=None):
    eval_iter = input_pipeline.prefetch(dataset, 10, None)
    eval_metrics = []
    ix = 0
    for eval_batch in eval_iter:
        metrics = eval_step(apply_fn, params, eval_batch)
        eval_metrics.append(metrics)
        ix+=1
        if nbatches is not None:
            if ix >= nbatches:
                break

    eval_metrics = get_metrics(eval_metrics)
    summary = {
            f'eval_{k}': v
            for k, v in jax.tree_util.tree_map(
            lambda x: x.mean(), eval_metrics
            ).items()
    }

    return summary['eval_loss'], summary['eval_accuracy']



norm_acc = lambda acc, expert_acc: acc / expert_acc

In [4]:
model_name = 'VGG16'
ntasks = 8
crop = 96
tasks = TASKS_A[:ntasks//2] + TASKS_B[:ntasks//2]
datasets = ".".join(tasks)
config = get_config(f"{model_name},{datasets},average-merging")
config.width_multiplier = 2

config.task_a_init = VGG16X2_NONLOCAL_TASK_A_INIT
config.task_b_init = VGG16X2_NONLOCAL_TASK_B_INIT


for dataset in TASKS_A[:ntasks//2]:
  config[dataset].pp.crop = crop
  config[dataset].model_dir = VGG16X2_A[dataset]

for dataset in TASKS_B[:ntasks//2]:
  config[dataset].pp.crop = crop
  config[dataset].model_dir = VGG16X2_B[dataset]

In [5]:
datasets = config.datasets
dataset_info_ls = [input_pipeline.get_dataset_info(dataset, config[dataset].pp['train']) for dataset in config.datasets]
num_classes = [dataset_info['num_classes'] for dataset_info in dataset_info_ls] 
num_train_examples = [dataset_info['num_examples'] for dataset_info in dataset_info_ls]

ds_train_ls, ds_test_ls = input_pipeline.get_datasets_for_mtl(config, datasets, batch_size=128)

model_ls = []
model_tracker_ls = []
model_repaired_ls = []
param_dirs = []

for i, nclass in enumerate(num_classes): 
  model_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier)]
  model_tracker_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier, tracker=True)]
  model_repaired_ls += [models.create_model(model_cls=getattr(models, config.model), num_classes=nclass, half_precision=False, projection_dim=512, width_multiplier=config.width_multiplier, repaired=True)]


INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0


INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2
INFO:absl:For 'cifar10/3.0.2': fields info.[splits, supervised_keys] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: eurosat/rgb
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/oxford_iiit_pet/3.2.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/resisc45/3.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar100/3.0.2
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cassava/0.1.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/svhn_cropped/3.0.0
INFO:absl:For 'svhn_cropped/3.0.0': fields info.[description, citation, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:Reading dat

In [6]:
## Load expert parameters
expert_params_ls = []

for ix, dataset in enumerate(datasets):
  logging.info(f"Loading expert models for {dataset} from {config[dataset].model_dir}.")
  raw_state = restore_checkpoint(config[dataset].model_dir)
  expert_params = {'params': raw_state['params']}
  expert_params_ls += [expert_params]


INFO:root:Loading expert models for colorectal_histology from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/colorectal_histology/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/.
INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/colorectal_histology/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/checkpoint_8000
INFO:absl:Restoring item from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/colorectal_histology/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/checkpoint_8000.
INFO:absl:Finished restoring checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/colorectal_histology/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/checkpoint_8000.
INFO:root:Loading expert models for cifar10 from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/cifar10/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_

In [7]:
## Load init parameters
raw_state_init_a = restore_checkpoint(config.task_a_init)
init_params_a = {'params': raw_state_init_a['params']}
zeros_classifier = utils.tree_zeros_like(init_params_a['params']['classifier'])
init_params_a['params'].pop('classifier')

raw_state_init_b = restore_checkpoint(config.task_b_init)
init_params_b = {'params': raw_state_init_b['params']}
init_params_b['params'].pop('classifier')

INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0
INFO:absl:Restoring item from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0.


INFO:absl:Finished restoring checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_13/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0.
INFO:absl:Restoring orbax checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_15/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0
INFO:absl:Restoring item from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_15/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0.
INFO:absl:Finished restoring checkpoint from /home/ekansh/repos/share/vision-models/experiments/VGG16x2/ILSVRC_15/cassava/sgd/trained_classifier/steps_8000_500/cosine_1e-2/seed_0/init/checkpoint_0.


{'Dense_0': {'kernel': array([[-0.00410706,  0.01246298, -0.01316439, -0.02068997,  0.00854768],
         [ 0.00552348, -0.0216361 ,  0.00534177,  0.02177045, -0.01058934],
         [ 0.00906   ,  0.00191369, -0.00198545, -0.00249587, -0.00434745],
         ...,
         [ 0.00075656, -0.00995404,  0.00848958, -0.01992465,  0.00662864],
         [-0.00346388,  0.00515312, -0.0046915 ,  0.00792189, -0.01034135],
         [ 0.00386668, -0.01339846,  0.01770931, -0.00042364, -0.00461111]],
        dtype=float32)}}

In [8]:
ps = model_ls[0].permutation_spec()
rng = random.PRNGKey(0)

In [9]:
init_params_a['params']['classifier'] = zeros_classifier
init_params_b['params']['classifier'] = zeros_classifier

perm = weight_matching(rng, ps, init_params_a['params'], init_params_b['params'])
init_params_b['params'] = unfreeze(apply_permutation(ps, perm, init_params_b['params']))

INFO:root:0/P/encoder/Conv_1: 34.825828552246094
INFO:root:0/P/encoder/Conv_6: 77.56377410888672
INFO:root:0/P/encoder/Conv_3: 52.38648223876953
INFO:root:0/P/encoder/Conv_12: 232.63577270507812
INFO:root:0/P/encoder/Conv_0: 168.3642578125
INFO:root:0/P/encoder/Conv_2: 39.931060791015625
INFO:root:0/P/encoder/Conv_4: 56.01622009277344
INFO:root:0/P/encoder/Conv_5: 64.74514770507812
INFO:root:0/P/encoder/Dense_1: 291.359375
INFO:root:0/P/encoder/Conv_8: 106.21080017089844
INFO:root:0/P/encoder/Dense_0: 261.15869140625
INFO:root:0/P/encoder/Conv_10: 137.3310546875
INFO:root:0/P/encoder/Conv_9: 93.14303588867188
INFO:root:0/P/encoder/Conv_11: 94.973876953125
INFO:root:0/P/encoder/Conv_7: 71.75433349609375
INFO:root:0/P/visual_projection: 0.0
INFO:root:1/P/encoder/Conv_8: 22.456878662109375
INFO:root:1/P/encoder/Conv_11: 0.0
INFO:root:1/P/encoder/Conv_9: 3.85589599609375
INFO:root:1/P/encoder/Conv_7: 8.4820556640625
INFO:root:1/P/encoder/Conv_1: 114.88726806640625
INFO:root:1/P/encoder/Den

In [10]:
# permute the models from tasks b
for ix in range(int(len(datasets)/2), len(datasets)):
  expert_params_ls[ix]['params'] = unfreeze(apply_permutation(ps, perm, expert_params_ls[ix]['params']))
  

In [11]:
init_params = utils.lerp(0.5, init_params_a, init_params_b)  


In [12]:
init_params['params'].pop('classifier')
classifiers_ls = [expert_params['params'].pop('classifier') for expert_params in expert_params_ls]


In [13]:
## Expert accuracies:
expert_loss = {}
expert_accs = {}

for ix, dataset in enumerate(datasets):
  expert_params_ls[ix]['params']['classifier'] = classifiers_ls[ix]
  loss, accuracy = compute_loss_and_accuracy(model_ls[ix].apply, expert_params_ls[ix], ds_test_ls[ix])
  expert_loss[dataset] = loss
  expert_accs[dataset] = accuracy
  expert_params_ls[ix]['params'].pop('classifier')
  print(f"{dataset} expert accuracy = {accuracy}")
  print(f"{dataset} expert loss = {loss}")


2024-10-11 14:29:57.081525: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


colorectal_histology expert accuracy = 0.96875
colorectal_histology expert loss = 0.1508624404668808


2024-10-11 14:30:04.268895: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar10 expert accuracy = 0.9690504670143127
cifar10 expert loss = 0.12621986865997314


2024-10-11 14:30:06.598099: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


eurosat expert accuracy = 0.984747052192688
eurosat expert loss = 0.07891904562711716


2024-10-11 14:30:09.367479: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


oxford_iiit_pet expert accuracy = 0.8653125166893005
oxford_iiit_pet expert loss = 0.8615166544914246


2024-10-11 14:30:11.097361: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


resisc45 expert accuracy = 0.931640625
resisc45 expert loss = 0.305586576461792


2024-10-11 14:30:16.824880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar100 expert accuracy = 0.8235176205635071
cifar100 expert loss = 0.8736114501953125


2024-10-11 14:30:18.426340: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cassava expert accuracy = 0.8504464030265808
cassava expert loss = 0.9811003804206848


2024-10-11 14:30:32.258150: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


svhn_cropped expert accuracy = 0.9509313702583313
svhn_cropped expert loss = 0.17739953100681305


In [14]:
tall_mask = False
task_vectors = [utils.tree_subtract(params['params'], init_params['params']) for params  in expert_params_ls]  

## Task-arithmetic model-merging: 
# print("TA")
# lam = 0.4
# t_MTL = merging.compute_task_arithmetic_vector(task_vectors, lam)
# merged_params  = {'params': utils.tree_add(t_MTL, init_params['params'])}

## TIES model-merging: 
# print("TIES")
# lam=0.8
# t_MTL = merging.compute_ties_vector(task_vectors, lam)
# merged_params  = {'params': utils.tree_add(t_MTL, init_params['params'])}

## TA+TALL model-merging: 
print("TA+TALL")
lam=0.5
t_MTL = merging.compute_ties_vector(task_vectors, lam)
t_masks_ls = [merging.compute_tall_mask(task_vector, t_MTL) for task_vector in task_vectors]
tall_mask = True

## TIES+TALL model-merging: 
# print("TIES")
# lam=0.8
# t_MTL = merging.compute_ties_vector(task_vectors, lam)
# t_masks_ls = [merging.compute_tall_mask(task_vector, t_MTL) for task_vector in task_vectors]
# tall_mask = True


TA+TALL


In [15]:
# Evaluation

merged_accuracies = {}
merged_normalized_accuracies = {}
for ix, dataset in enumerate(datasets):
  if tall_mask:
    merged_params = {'params' : utils.tree_add(init_params['params'], apply_mask(t_MTL, t_masks_ls[ix]))}
  merged_params['params']['classifier'] = classifiers_ls[ix]
  loss, accuracy = compute_loss_and_accuracy(model_ls[ix].apply, merged_params, ds_test_ls[ix])
  merged_accuracies[dataset] = accuracy
  merged_normalized_accuracies[dataset] = norm_acc(accuracy, expert_accs[dataset])
  merged_params['params'].pop('classifier')
  print(f'{dataset}')
  print(f"merged accuracy = {accuracy}")
  print(f"normalized accuracy = {norm_acc(accuracy, expert_accs[dataset])}")


2024-10-11 14:31:05.045035: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


colorectal_histology
merged accuracy = 0.5390625
normalized accuracy = 0.5564516186714172


2024-10-11 14:31:09.691860: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar10
merged accuracy = 0.7926682829856873
normalized accuracy = 0.8179845213890076


2024-10-11 14:31:11.121097: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


eurosat
merged accuracy = 0.5279017686843872
normalized accuracy = 0.5360785722732544


2024-10-11 14:31:12.860343: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


oxford_iiit_pet
merged accuracy = 0.37031251192092896
normalized accuracy = 0.4279523491859436


2024-10-11 14:31:14.508441: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


resisc45
merged accuracy = 0.3167317807674408
normalized accuracy = 0.33997204899787903


2024-10-11 14:31:19.664327: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cifar100
merged accuracy = 0.4151642620563507
normalized accuracy = 0.504135251045227


2024-10-11 14:31:20.665976: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


cassava
merged accuracy = 0.5245535969734192
normalized accuracy = 0.616797924041748


2024-10-11 14:31:34.061338: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


svhn_cropped
merged accuracy = 0.7521551847457886
normalized accuracy = 0.7909668684005737


In [16]:
## Compute output statistics for expert parameters
expert_batch_stats_ls = []
rng = random.PRNGKey(0)
for ix, dataset in enumerate(datasets):
  expert_params_ls[ix]['params']['classifier'] = classifiers_ls[ix]
  params_tracker = models.load_from_source(model_tracker_ls[ix].initialization(rng, (1, crop, crop, 3)), {'params': expert_params_ls[ix]['params']})
  params_tracker = update_batch_stats(model_tracker_ls[ix].apply, params_tracker, ds_train_ls[ix], nbatches=500)
  expert_batch_stats_ls += [params_tracker['batch_stats']]
  expert_params_ls[ix]['params'].pop('classifier')


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
2024-10-11 14:33:38.172032: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 246706 bytes after encountering the first element of size 246706 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
2024-10-11 14:33:40.155014: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 246706 bytes after encountering the first element of size 246706 bytes.This already causes the autotune ram 

In [17]:
tact_accuracies = {}
tact_normalized_accuracies = {}

for ix, dataset in enumerate(datasets):
  print(f"Computing TACT accuracy for {dataset}")
  if tall_mask:
    merged_params = {'params' : utils.tree_add(init_params['params'], apply_mask(t_MTL, t_masks_ls[ix]))}
  merged_params_tact = models.load_from_source(model_repaired_ls[ix].initialization(rng, (1, crop, crop, 3)), merged_params)
  merged_params_tact = models.set_batch_norm_params_from_batch_stats(merged_params_tact, expert_batch_stats_ls[ix])
  merged_params_tact['params']['classifier'] = classifiers_ls[ix]
  merged_params_tact = update_batch_stats(model_repaired_ls[ix].apply, merged_params_tact,  ds_train_ls[ix], nbatches=500)
  loss, accuracy = compute_loss_and_accuracy(model_repaired_ls[ix].apply, merged_params_tact, ds_test_ls[ix])
  tact_accuracies[dataset] = accuracy
  tact_normalized_accuracies[dataset] = norm_acc(accuracy, expert_accs[dataset])
  print(f"accuracy = {tact_accuracies[dataset]}")
  print(f"normalized accuracy = {tact_normalized_accuracies[dataset]}")

  

Computing TACT accuracy for colorectal_histology


2024-10-11 14:37:39.344610: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.8697916865348816
normalized accuracy = 0.8978495001792908
Computing TACT accuracy for cifar10


2024-10-11 14:38:24.851613: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9259815812110901
normalized accuracy = 0.9555555582046509
Computing TACT accuracy for eurosat


2024-10-11 14:39:06.483472: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.9226190447807312
normalized accuracy = 0.9369096755981445
Computing TACT accuracy for oxford_iiit_pet


Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature

accuracy = 0.7215625047683716
normalized accuracy = 0.8338750600814819
Computing TACT accuracy for resisc45


2024-10-11 14:40:31.816796: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.744140625
normalized accuracy = 0.7987421154975891
Computing TACT accuracy for cifar100


2024-10-11 14:41:18.211518: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.6824919581413269
normalized accuracy = 0.8287521004676819
Computing TACT accuracy for cassava


2024-10-11 14:42:00.114473: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.7287946343421936
normalized accuracy = 0.8569554090499878
Computing TACT accuracy for svhn_cropped


2024-10-11 14:42:54.336880: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


accuracy = 0.8906635046005249
normalized accuracy = 0.9366222620010376


In [18]:
## Compute REPAIR batch_stats:
ds_merged = input_pipeline.get_unlabelled_datasets(config, datasets, 'train[:1000]', 128, repeats=None)

repair_expert_batch_stats_ls = []
for ix, dataset in enumerate(datasets):
  expert_params_ls[ix]['params']['classifier'] = classifiers_ls[ix]
  params_tracker = models.load_from_source(model_tracker_ls[ix].initialization(rng, (1, crop, crop, 3)), {'params': expert_params_ls[ix]['params']})
  params_tracker = update_batch_stats(model_tracker_ls[ix].apply, params_tracker, ds_merged, nbatches=500)
  repair_expert_batch_stats_ls += [params_tracker['batch_stats']]


merged_batch_stats = utils.merge_batch_stats(repair_expert_batch_stats_ls)

INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0
INFO:absl:Reusing dataset colorectal_histology (/mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0)
INFO:absl:Constructing tf.data.Dataset colorectal_histology for split train[:1000], from /mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/colorectal_histology/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2
INFO:absl:For 'cifar10/3.0.2': fields info.[splits, supervised_keys] differ on disk and in the code. Keeping the one from code.
INFO:absl:Reusing dataset cifar10 (/mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2)


INFO:absl:Constructing tf.data.Dataset cifar10 for split train[:1000], from /mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/cifar10/3.0.2
INFO:absl:For 'cifar10/3.0.2': fields info.[splits, supervised_keys] differ on disk and in the code. Keeping the one from code.
INFO:absl:No config specified, defaulting to config: eurosat/rgb
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0
INFO:absl:Reusing dataset eurosat (/mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0)
INFO:absl:Constructing tf.data.Dataset eurosat for split train[:1000], from /mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0
INFO:absl:No config specified, defaulting to config: eurosat/rgb
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/eurosat/rgb/2.0.0
INFO:absl:Load dataset info from /mnt/MintStorage/data/tensorflow_datasets/oxford_iiit_pet/3.2.0
INFO:a

In [19]:
repair_accuracies = {}
repair_normalized_accuracies = {}

if not tall_mask:
  # Update the batch stats of the shared merged params
  merged_params_repair = models.load_from_source(model_repaired_ls[ix].initialization(rng, (1, crop, crop, 3)), merged_params)
  merged_params_repair = models.set_batch_norm_params_from_batch_stats(merged_params_repair, merged_batch_stats)
  merged_params_repair = update_batch_stats(model_repaired_ls[ix].apply, merged_params_repair,  ds_merged, nbatches=500)

for ix, dataset in enumerate(datasets):
  print(f"Computing REPAIR accuracy for {dataset}")
  if tall_mask:
    merged_params = {'params' : utils.tree_add(init_params['params'], apply_mask(t_MTL, t_masks_ls[ix]))}
    merged_params_repair = models.load_from_source(model_repaired_ls[ix].initialization(rng, (1, crop, crop, 3)), merged_params)
    merged_params_repair = models.set_batch_norm_params_from_batch_stats(merged_params_repair, merged_batch_stats)
    merged_params_repair = update_batch_stats(model_repaired_ls[ix].apply, merged_params_repair,  ds_merged, nbatches=500)

  merged_params_repair['params']['classifier'] = classifiers_ls[ix]
  loss, accuracy = compute_loss_and_accuracy(model_repaired_ls[ix].apply, merged_params_repair, ds_test_ls[ix])
  repair_accuracies[dataset] = accuracy
  repair_normalized_accuracies[dataset] = norm_acc(accuracy, expert_accs[dataset])
  print(f"accuracy = {repair_accuracies[dataset]}")
  print(f"normalized accuracy = {repair_normalized_accuracies[dataset]}")


Computing REPAIR accuracy for colorectal_histology


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:48:27.019972: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 999050 bytes after encountering the first element of size 999050 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
2024-10-11 14:48:28.231778: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 67582 bytes after encountering the first element of size 67582 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
2024-10-11 14:48:29.985328: W t

accuracy = 0.140625
normalized accuracy = 0.14516128599643707
Computing REPAIR accuracy for cifar10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:49:20.188804: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 666060 bytes after encountering the first element of size 666060 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:49:20.792821: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 byt

accuracy = 0.14443108439445496
normalized accuracy = 0.14904391765594482
Computing REPAIR accuracy for eurosat


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:50:06.684314: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
2024-10-11 14:50:07.828330: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 999050 bytes after encountering the first element of size 999050 byt

accuracy = 0.1134672611951828
normalized accuracy = 0.11522477865219116
Computing REPAIR accuracy for oxford_iiit_pet


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:50:47.303692: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 12334 bytes after encountering the first element of size 12334 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
2024-10-11 14:50:47.916139: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 666060 bytes after encountering the first element of size 666060 bytes.This already causes the autotune ram budget to be exceeded

accuracy = 0.07750000059604645
normalized accuracy = 0.08956301957368851
Computing REPAIR accuracy for resisc45


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:51:30.029431: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 666060 bytes after encountering the first element of size 666060 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:51:30.634192: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 byt

accuracy = 0.02701822854578495
normalized accuracy = 0.029000697657465935
Computing REPAIR accuracy for cifar100


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:52:07.338762: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 666060 bytes after encountering the first element of size 666060 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:52:07.969744: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 bytes.This already causes the autotune ram budget to be exceed

accuracy = 0.027944711968302727
normalized accuracy = 0.03393334895372391
Computing REPAIR accuracy for cassava


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:52:48.852738: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
2024-10-11 14:52:49.998918: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 999050 bytes after encountering the first element of size 999050 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram 

accuracy = 0.2522321343421936
normalized accuracy = 0.29658791422843933
Computing REPAIR accuracy for svhn_cropped


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:53:39.125223: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 666060 bytes after encountering the first element of size 666060 bytes.This already causes the autotune ram budget to be exceeded. To stay within the ram budget, either increase the ram budget or reduce element size
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
2024-10-11 14:53:39.730326: W tensorflow/core/kernels/data/prefetch_autotuner.cc:52] Prefetch autotuner tried to allocate 196665 bytes after encountering the first element of size 196665 byt

accuracy = 0.17268319427967072
normalized accuracy = 0.18159374594688416


In [20]:
for ix, dataset in enumerate(datasets):
  print(f"Dataset: {dataset}")
  print(f"\tExpert accuracy = {expert_accs[dataset]}; Merged model accuracy = {merged_accuracies[dataset]}; Merged + REPAIR  accuracy = {repair_accuracies[dataset]}; Merged + TACT accuracy = {tact_accuracies[dataset]}")
  print(f"")


Dataset: colorectal_histology
	Expert accuracy = 0.96875; Merged model accuracy = 0.5390625; Merged + REPAIR  accuracy = 0.140625; Merged + TACT accuracy = 0.8697916865348816

Dataset: cifar10
	Expert accuracy = 0.9690504670143127; Merged model accuracy = 0.7926682829856873; Merged + REPAIR  accuracy = 0.14443108439445496; Merged + TACT accuracy = 0.9259815812110901

Dataset: eurosat
	Expert accuracy = 0.984747052192688; Merged model accuracy = 0.5279017686843872; Merged + REPAIR  accuracy = 0.1134672611951828; Merged + TACT accuracy = 0.9226190447807312

Dataset: oxford_iiit_pet
	Expert accuracy = 0.8653125166893005; Merged model accuracy = 0.37031251192092896; Merged + REPAIR  accuracy = 0.07750000059604645; Merged + TACT accuracy = 0.7215625047683716

Dataset: resisc45
	Expert accuracy = 0.931640625; Merged model accuracy = 0.3167317807674408; Merged + REPAIR  accuracy = 0.02701822854578495; Merged + TACT accuracy = 0.744140625

Dataset: cifar100
	Expert accuracy = 0.823517620563507