## Editing Models with Task Arithematic

In [1]:
import numpy as np
np.bool = np.bool_
np.complex = np.complex128
import torch
from aug.automold import add_rain, add_snow, add_fog, add_autumn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tt
from models import ResNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR, StepLR
from PIL import Image
from tqdm.notebook import tqdm
# import imgaug.augmenters as iaa
import cv2
import random
import os
import torchvision

In [5]:
def load_stl10(root_dir: str = "stl10_binary"):

    path_to_images = os.path.join(root_dir, "train_X.bin")
    path_to_labels = os.path.join(root_dir, "train_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        train_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        train_labels = np.fromfile(f, dtype=np.uint8) - 1
    
    path_to_images = os.path.join(root_dir, "test_X.bin")
    path_to_labels = os.path.join(root_dir, "test_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        test_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        test_labels = labels - 1

    return train_images, train_labels, test_images, test_labels

def add_snow(image, snowflake_count=200, snowflake_radius=(1, 3), snowflake_intensity=(200, 255)):
    snowy_image = image.copy()
    height, width, _ = snowy_image.shape
    for _ in range(snowflake_count):
        x = np.random.randint(0, width)
        y = np.random.randint(0, height)
        
        radius = np.random.randint(snowflake_radius[0], snowflake_radius[1])
        intensity = np.random.randint(snowflake_intensity[0], snowflake_intensity[1])
        
        cv2.circle(snowy_image, (x, y), radius, (intensity, intensity, intensity), -1)

    return snowy_image

def shift(image, domain):
    if domain == "rain":
        return add_rain(image, rain_type = 'torrential')
    elif domain == "fog":
        return add_fog(image, fog_coeff=1.0)
    elif domain == "snow":
        return add_snow(image=image)
    elif domain == "autumn":
        return add_autumn(image)
    return image

class STL10Dataset(Dataset):
    def __init__(self, images, labels, domain="base"):
        self.images = images.astype(np.float32)
        self.labels = labels.astype(np.int64)
        self.domain = domain
        self.domains = ["rain", "fog", "snow"]
        stats = ((113.911194, 112.1515, 103.69485), (51.854874, 51.261967, 51.842403))
        self.tfms = tt.Compose([
            tt.ToTensor(),
            tt.Normalize(stats[0], stats[1])
        ])

    def __len__(self):
        if self.domain == "all":
            return len(self.domains) * len(self.images)
        return len(self.images)

    def __getitem__(self, idx):

        if self.domain == "all":
            d = idx // len(self.images)
            idx = idx % len(self.images)
            domain = self.domains[d]
        else:
            domain = self.domain
        
        image = self.images[idx]
        label = self.labels[idx]

        image = shift(image, domain)
        
        image[np.isnan(image)] = 0

        image = self.tfms(image)

        return image, label

def evaluate(model, batch_size, num_workers, domains, device, dtype):
    
    train_images, train_labels, test_images, test_labels = load_stl10()
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    model.to(device).to(dtype)
    result = {}

    for domain in domains:
        
        test_dataset = STL10Dataset(test_images, test_labels, domain=domain)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        total_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                inputs, labels = images.to(device).to(dtype), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        result[domain] = {
            "loss": total_loss / total,
            "accuracy": correct / total
        }

    return result

In [11]:
def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    return model

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

In [38]:
def add(model1, model2, device="cuda:7", dtype=torch.float32):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device).to(dtype)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    m2_sd = model2.state_dict()
    
    for key in m_sd:
        m_sd[key] = m1_sd[key] + m2_sd[key]
    model.load_state_dict(m_sd)
    return model

def sub(model1, model2, device="cuda:7", dtype=torch.float32):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device).to(dtype)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    m2_sd = model2.state_dict()
    
    for key in m_sd:
        m_sd[key] = m1_sd[key] - m2_sd[key]
    model.load_state_dict(m_sd)
    return model

def scale(model1, alpha=1.0, device="cuda:7", dtype=torch.float32):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device).to(dtype)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    
    for key in m_sd:
        m_sd[key] = m1_sd[key] * alpha
    model.load_state_dict(m_sd)
    return model

def slerp(model1, model2, t=0.5):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    m2_sd = model2.state_dict()
    
    for key in m_sd:
        w1 = m1_sd[key]
        w2 = m2_sd[key]
        norm_w1 = w1 / torch.norm(w1.to(torch.float32))
        norm_w2 = w2 / torch.norm(w2.to(torch.float32))
        dot_product = torch.clamp(torch.sum(norm_w1 * norm_w2), -1.0, 1.0)
        theta = torch.acos(dot_product)

        if theta.item() == 0.0:
            m_sd[key] = w1
        else:
            m_sd[key] = (
                torch.sin((1 - t) * theta) / torch.sin(theta) * w1 +
                torch.sin(t * theta) / torch.sin(theta) * w2
            )
    model.load_state_dict(m_sd)
    return model

In [39]:
tv1 = sub(rain, base)
tv2 = sub(fog, base)
tv3 = sub(snow, base)
avg = add(add(tv1, tv2), tv3)

for alpha in [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3]:
    model = add(base, scale(avg, alpha))
    
    result = evaluate(
        model,
        batch_size = 256,
        num_workers = 8,
        domains = ["base", "rain", "fog", "snow"],
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"Alpha: {alpha} ", [result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

Alpha: 0  [0.924375, 0.34, 0.33625, 0.37]
Alpha: 1e-06  [0.924375, 0.33775, 0.340375, 0.3675]
Alpha: 1e-05  [0.924375, 0.33375, 0.335125, 0.371625]
Alpha: 0.0001  [0.924375, 0.3425, 0.33525, 0.368]
Alpha: 0.001  [0.923875, 0.336125, 0.332, 0.372375]
Alpha: 0.01  [0.9255, 0.345375, 0.341375, 0.382875]
Alpha: 0.1  [0.933875, 0.41625, 0.417375, 0.504]
Alpha: 1  [0.1, 0.1, 0.1, 0.1]
Alpha: 10.0  [0.1, 0.1, 0.1, 0.1]
Alpha: 100.0  [0.1, 0.1, 0.1, 0.1]
Alpha: 1000.0  [0.1, 0.1, 0.1, 0.1]


## SLERP

In [38]:
for t in range(11):
    print(f"t = {t/10}")
    model1 = slerp(rain, fog, t/10)
    model2 = slerp(fog, snow, t/10)
    model3 = slerp(snow, rain, t/10)

    result = evaluate(
        model1,
        batch_size = 256,
        num_workers = 8,
        domains = ["base", "rain", "fog", "snow"],
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"SLERP rain->fog ", [result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

    result = evaluate(
        model2,
        batch_size = 256,
        num_workers = 8,
        domains = ["base", "rain", "fog", "snow"],
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"SLERP fog->snow ", [result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

    result = evaluate(
        model3,
        batch_size = 256,
        num_workers = 8,
        domains = ["base", "rain", "fog", "snow"],
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"SLERP snow->rain ", [result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

t = 0.0
SLERP rain->fog  [0.652875, 0.829625, 0.6965, 0.112]
t = 0.1
SLERP rain->fog  [0.588625, 0.53025, 0.5085, 0.1145]
t = 0.2
SLERP rain->fog  [0.509125, 0.36125, 0.3845, 0.10925]
t = 0.3
SLERP rain->fog  [0.455125, 0.273875, 0.3045, 0.110125]
t = 0.4
SLERP rain->fog  [0.4145, 0.23825, 0.27325, 0.11225]
t = 0.5
SLERP rain->fog  [0.38675, 0.209875, 0.253375, 0.110375]
t = 0.6
SLERP rain->fog  [0.361625, 0.1955, 0.24625, 0.110125]
t = 0.7
SLERP rain->fog  [0.347375, 0.1975, 0.244625, 0.10925]
t = 0.8
SLERP rain->fog  [0.34475, 0.236375, 0.27325, 0.10975]
t = 0.9
SLERP rain->fog  [0.359625, 0.420875, 0.440125, 0.109125]
t = 1.0
SLERP rain->fog  [0.336, 0.588125, 0.834, 0.10925]


In [36]:
def get_layers(model):
    return [
        model.conv1,
        model.bn1,
        model.relu,
        model.maxpool,
        model.layer1,
        model.layer2,
        model.layer3,
        model.layer4,
        model.avgpool,
        model.fc
    ]

# test_dataset = STL10Dataset(test_images, test_labels, domain=domain)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

## TIES Merging

In [7]:
device = "cuda:7"
dtype = torch.float32

def get_empty_state(device, dtype):
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    for k, v in result_state.items():
        result_state[k] = torch.zeros_like(v)
    return result_state

def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    return model.to(device).to(dtype)

def get_mask(tensor1, tensor2):
    mask = ((tensor1 > 0) & (tensor2 > 0)) | ((tensor1 < 0) & (tensor2 < 0)) | ((tensor1 == 0) & (tensor2 == 0))
    return mask

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

km = 0.2
alpha = 0.1
init = base
ftms = [rain, fog, snow]

# Step 1: create task vectors and trim redundant parameters
tvs = []
for m in ftms:
    
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    
    tv = sub(m, init, device, dtype)
    tv_state = tv.state_dict()
    vals, indices = torch.sort(torch.abs(torch.concat([v.flatten() for k, v in tv_state.items()], dim=0)), descending=True)
    k_min = vals[int(vals.shape[0] * km)]
    for k, v in tv_state.items():
        result_state[k][torch.abs(v) < k_min] = 0

    result.load_state_dict(result_state) 
    tvs.append(result)

# Step 2: Elect Final Signs
gamma = add(add(tvs[0], tvs[1], device, dtype), tvs[2], device, dtype)
gamma_state = gamma.state_dict()

# Step 3: Disjoint Merge
result_state = get_empty_state(device, dtype)
Ap = get_empty_state(device, dtype)

for tv in tvs:
    tv_state = tv.state_dict()
    for k in result_state:
        mask = get_mask(tv_state[k], gamma_state[k])
        result_state[k] += tv_state[k] * mask
        Ap[k] += mask

for k, v in result_state.items():
    result_state[k] = v / Ap[k]

resultm = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
resultm.fc = torch.nn.Linear(result.fc.in_features, 10)
resultm.to(device).to(dtype)
resultm.load_state_dict(result_state)

for alpha in [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e2, 1e2, 1e3, 1e4]:

    model = add(base, scale(resultm, alpha, device, dtype), device, dtype)
    
    result = evaluate(
        model,
        batch_size = 256,
        num_workers = 8,
        domains = ["base", "rain", "fog", "snow"],
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"Alpha: {alpha}", [result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

Alpha: 0 [0.924375, 0.339, 0.3325, 0.3725]
Alpha: 1e-06 [0.924375, 0.338625, 0.334, 0.373]
Alpha: 1e-05 [0.924375, 0.332125, 0.337125, 0.37125]
Alpha: 0.0001 [0.924375, 0.338625, 0.331875, 0.369125]
Alpha: 0.001 [0.924625, 0.341, 0.333625, 0.374]
Alpha: 0.01 [0.9265, 0.34275, 0.345625, 0.382]
Alpha: 0.1 [0.8855, 0.29375, 0.32475, 0.317125]
Alpha: 1.0 [0.1, 0.1, 0.1, 0.1]
Alpha: 100.0 [0.1, 0.1, 0.1, 0.1]
Alpha: 100.0 [0.1, 0.1, 0.1, 0.1]
Alpha: 1000.0 [0.1, 0.1, 0.1, 0.1]
Alpha: 10000.0 [0.1, 0.1, 0.1, 0.1]


In [None]:
## Benchmarks
## domain           base    rain      fog     snow
## resnet50_base   92.175   33.47   33.48    36.96
## resnet50_rain   46.775  82.8375  64.025    13.5
## resnet50_fog    33.63   56.987   84.0125   10.7
## resnet50_snow   47.89   26.92    23.612    85.03

## Layer Merging

In [61]:
def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    if domain == "base":
        model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    else:
        model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}_l.pth"))
    return model.to(device).to(dtype)

def get_empty_state():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    for k, v in result_state.items():
        result_state[k] = torch.zeros_like(v)
    return result_state

def get_empty_model():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    return result

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

state = base.state_dict()
for key in base_state:
    if "layer2" in key:
        state[key] = rain.state_dict()[key]
    if "layer3" in key:
        state[key] = fog.state_dict()[key]
    if "layer4" in key:
        state[key] = snow.state_dict()[key]
    # if "layer1" in key:
    #     state[key] = snow.state_dict()[key]

model = get_empty_model()
model.load_state_dict(state)
result = evaluate(
    model,
    batch_size = 256,
    num_workers = 8,
    domains = ["base", "rain", "fog", "snow"],
    device = "cuda:7",
    dtype = torch.float16
)
print([result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

[0.907375, 0.68825, 0.6665, 0.56975]


## TALL Masks

In [34]:
device = "cuda:7"
dtype = torch.float32

def get_empty_state():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    for k, v in result_state.items():
        result_state[k] = torch.zeros_like(v)
    return result_state

def get_empty_model():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    return result

def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    return model.to(device).to(dtype)

def add(model1, model2):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device).to(dtype)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    m2_sd = model2.state_dict()
    
    for key in m_sd:
        m_sd[key] = m1_sd[key] + m2_sd[key]
    model.load_state_dict(m_sd)
    return model

def sub(model1, model2):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.to(device).to(dtype)
    m_sd = model.state_dict()
    m1_sd = model1.state_dict()
    m2_sd = model2.state_dict()
    for key in m_sd:
        m_sd[key] = m1_sd[key] - m2_sd[key]
    model.load_state_dict(m_sd)
    return model

def evaluate(model, batch_size, num_workers, domains, device, dtype):
    
    train_images, train_labels, test_images, test_labels = load_stl10()
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    model.to(device).to(dtype)
    result = {}

    for domain in domains:
        
        test_dataset = STL10Dataset(test_images, test_labels, domain=domain)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        total_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                inputs, labels = images.to(device).to(dtype), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        result[domain] = {
            "loss": total_loss / total,
            "accuracy": correct / total
        }

    return result

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

## Step 1: Generate the mtl as mentioned in https://arxiv.org/pdf/2212.04089
mtl = add(add(sub(rain, base), sub(fog, base)), sub(snow, base))
mtl_state = mtl.state_dict()

## Step 2: Generate the mask for mtl
## m∗t = 1{|τt| ≥ |τMTL − τt|}
## Derivation of mask in Appendix B of https://arxiv.org/pdf/2405.07813
mask = get_empty_state()
init = base.state_dict()
domains = ["base", "rain", "fog", "snow"]
for alpha in [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]:
    accs = []
    for domain in domains:
        ftm = load_model(domain).state_dict()
        for key in mask:
            diff = torch.abs(mtl_state[key] - (ftm[key] - init[key]))
            mask[key] = (torch.abs(ftm[key] - init[key]) >= (diff * alpha)).float()
        
        ## Step 3: Merge the model
        model = get_empty_model()
        model_state = model.state_dict()
        for key in model_state:
            model_state[key] = init[key] + mask[key] * mtl_state[key]
        model.load_state_dict(model_state)
        
        ## Step 4: Standard evalatuion protocol
        result = evaluate(
            model,
            batch_size = 256,
            num_workers = 8,
            domains = ["rain"],
            device = device,
            dtype = dtype
        )
        accs.append(result[domain]["accuracy"])
    print(f"Alpha: {alpha} ", accs)


Alpha: 0.0 [92.4375, 34.225, 33.225, 37.3375]
Alpha: 0.05 [95.0125, 55.800000000000004, 56.637499999999996, 55.225]
Alpha: 0.1 [90.8125, 68.5125, 64.2375, 57.2625]
Alpha: 0.15 [89.6625, 67.5625, 63.337500000000006, 56.8875]
Alpha: 0.2 [88.725, 66.625, 61.625, 55.800000000000004]
Alpha: 0.25 [94.2625, 42.125, 41.9, 45.787499999999994]
Alpha: 0.3 [93.5875, 37.3, 37.125, 42.375]
    


In [49]:
device = "cuda:7"
dtype = torch.float32

def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    return model.to(device).to(dtype)

def get_empty_state():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    for k, v in result_state.items():
        result_state[k] = torch.zeros_like(v)
    return result_state

def get_empty_model():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    return result

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

alpha = 1
state = base.state_dict()
for key in state:
    if "layer2" in key:
        state[key] = state[key] + (rain.state_dict()[key] - state[key]) * alpha
    if "layer3" in key:
        state[key] = state[key] + (fog.state_dict()[key] - state[key]) * alpha
    if "layer4" in key:
        state[key] = state[key] + (snow.state_dict()[key] - state[key]) * alpha
    # if "layer1" in key:
    #     state[key] = snow.state_dict()[key]

model = get_empty_model()
model.load_state_dict(state)

# for alpha in [0, 0.5, 1, 1.05, 1.1, 0.2, 0.1]:
#     state = base.state_dict()
#     for key in state:
#         if "layer2" in key:
#             state[key] = state[key] + (rain.state_dict()[key] - state[key]) * alpha
#         if "layer3" in key:
#             state[key] = state[key] + (fog.state_dict()[key] - state[key]) * alpha
#         if "layer4" in key:
#             state[key] = state[key] + (snow.state_dict()[key] - state[key]) * alpha
#         # if "layer1" in key:
#         #     state[key] = snow.state_dict()[key]
    
#     model = get_empty_model()
#     model.load_state_dict(state)
#     result = evaluate(
#         model,
#         batch_size = 256,
#         num_workers = 8,
#         domains = ["base", "rain", "fog", "snow"],
#         device = "cuda:7",
#         dtype = torch.float16
#     )
#     print(f"Alpha: {alpha / 10 }", [result[domain]["accuracy"] * 100 for domain in ["base", "rain", "fog", "snow"]])

<All keys matched successfully>

In [52]:
# def get_empty_model():
#     result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
#     result.fc = torch.nn.Linear(result.fc.in_features, 10)
#     result.to(device).to(dtype)
#     return result

# model = get_empty_model()
# model.load_state_dict(torch.load("./ckpts/resnet_50.pth"))

result = evaluate(
    model,
    batch_size = 256,
    num_workers = 8,
    domains = ["base", "rain", "fog", "snow"],
    device = "cuda:7",
    dtype = torch.float16
)
result

{'base': {'loss': 0.001204498291015625, 'accuracy': 0.908125},
 'rain': {'loss': 0.0037325439453125, 'accuracy': 0.68475},
 'fog': {'loss': 0.00418292236328125, 'accuracy': 0.649875},
 'snow': {'loss': 0.005035400390625, 'accuracy': 0.56475}}

In [53]:
torch.save(model.state_dict(), "./ckpts/resnet_50.pth")

In [None]:
# Alpha: 0.0 [92.4375, 34.225, 33.225, 37.3375]
# Alpha: 0.05 [95.0125, 55.800000000000004, 56.637499999999996, 55.225]
# Alpha: 0.1 (Best) [90.8125, 68.5125, 64.2375, 57.2625]
# Alpha: 0.15 [89.6625, 67.5625, 63.337500000000006, 56.8875]
# Alpha: 0.2 [88.725, 66.625, 61.625, 55.800000000000004]
# Alpha: 0.25 [94.2625, 42.125, 41.9, 45.787499999999994]
# Alpha: 0.3 [93.5875, 37.3, 37.125, 42.375]

## Neuron Activations

In [None]:
import torch
import torch.nn as nn

activations_base = {}
activations_rain = {}
activations_fog = {}
activations_snow = {}

def get_activation(name, activations_dict):
    def hook(model, input, output):
        activations_dict[name] = output.detach().cpu()
    return hook

def register_hooks(model, activations_dict):
    for name, layer in model.named_modules():
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            layer.register_forward_hook(get_activation(name, activations_dict))

def load_model(domain):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/resnet50_{domain}.pth"))
    return model

base = load_model("base")
rain = load_model("rain")
fog = load_model("fog")
snow = load_model("snow")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

models = {'base': base, 'rain': rain, 'fog': fog, 'snow': snow}
for model in models.values():
    model.to(device)
    model.eval()

register_hooks(base, activations_base)
register_hooks(rain, activations_rain)
register_hooks(fog, activations_fog)
register_hooks(snow, activations_snow)

batch_size = 256
num_workers = 8
train_images, train_labels, test_images, test_labels = load_stl10()

def process_data(model, activations_dict, domain):
    train_dataset = STL10Dataset(train_images, train_labels, domain=domain)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    with torch.no_grad():
        for images, _ in tqdm(train_loader):
            images = images.to(device)
            model(images)

process_data(base, activations_base, "base")
process_data(rain, activations_rain, "rain")
process_data(fog, activations_fog, "fog")
process_data(snow, activations_snow, "snow")

In [None]:
def compute_importance(activations_dict):
    importance_scores = {}
    for name, activation in activations_dict.items():
        # Flatten the activations (batch_size, num_neurons)
        activation_flat = activation.view(activation.size(0), -1)
        # Compute mean absolute activation per neuron
        mean_activation = activation_flat.abs().mean(dim=0)
        importance_scores[name] = mean_activation.cpu()
    return importance_scores

# Compute importance scores for each model
importance_base = compute_importance(activations_base)
importance_rain = compute_importance(activations_rain)
importance_fog = compute_importance(activations_fog)
importance_snow = compute_importance(activations_snow)

In [None]:
def rank_neurons(importance_scores):
    ranked_neurons = {}
    for name, scores in importance_scores.items():
        # Get indices of neurons sorted by importance in descending order
        sorted_indices = torch.argsort(scores, descending=True)
        ranked_neurons[name] = sorted_indices
    return ranked_neurons

# Get ranked neurons for each model
ranked_base = rank_neurons(importance_base)
ranked_rain = rank_neurons(importance_rain)
ranked_fog = rank_neurons(importance_fog)
ranked_snow = rank_neurons(importance_snow)

In [None]:
def normalize_importance(importance_scores):
    normalized_scores = {}
    for name, scores in importance_scores.items():
        norm_scores = (scores - scores.min()) / (scores.max() - scores.min())
        normalized_scores[name] = norm_scores
    return normalized_scores

# Normalize importance scores
importance_rain_normalized = normalize_importance(importance_rain)
importance_fog_normalized = normalize_importance(importance_fog)
importance_snow_normalized = normalize_importance(importance_snow)

In [None]:
import pandas as pd

# Assuming importance scores are stored as dictionaries: {layer_name: tensor of importance scores}

def create_importance_dataframe(layers, importance_scores_dicts):
    data = {}
    for domain, importance_scores in importance_scores_dicts.items():
        for layer_name in layers:
            scores = importance_scores[layer_name].numpy()
            for idx, score in enumerate(scores):
                key = f"{layer_name}_{idx}"
                if key not in data:
                    data[key] = {}
                data[key][domain] = score
    df = pd.DataFrame.from_dict(data, orient='index')
    return df

# Collect importance scores from all domains
importance_scores_dicts = {
    'rain': importance_rain,
    'fog': importance_fog,
    'snow': importance_snow
}

# Get the list of layers
layers = importance_rain.keys()

# Create the DataFrame
importance_df = create_importance_dataframe(layers, importance_scores_dicts)

In [None]:
# Normalize importance scores per domain
def normalize_importance_scores(df):
    for domain in ['rain', 'fog', 'snow']:
        max_score = df[domain].max()
        min_score = df[domain].min()
        df[domain] = (df[domain] - min_score) / (max_score - min_score)
    return df

# Normalize the DataFrame
importance_df = normalize_importance_scores(importance_df)


In [None]:
# Set importance threshold (e.g., top 20% neurons are considered important)
importance_threshold = 0.8

def determine_important_neurons(df, threshold):
    important_neurons = {}
    for domain in ['rain', 'fog', 'snow']:
        domain_scores = df[domain].dropna()
        threshold_value = domain_scores.quantile(threshold)
        important = domain_scores[domain_scores >= threshold_value].index.tolist()
        important_neurons[domain] = important
    return important_neurons

# Get important neurons for each domain
important_neurons = determine_important_neurons(importance_df, importance_threshold)


In [None]:
# Identify shared and domain-specific neurons
shared_neurons = set(important_neurons['rain']) & set(important_neurons['fog']) & set(important_neurons['snow'])
domain_specific_neurons = {
    'rain': set(important_neurons['rain']) - shared_neurons,
    'fog': set(important_neurons['fog']) - shared_neurons,
    'snow': set(important_neurons['snow']) - shared_neurons
}


In [None]:
print(f"Number of shared neurons: {len(shared_neurons)}")
for domain in ['rain', 'fog', 'snow']:
    print(f"Number of domain-specific neurons in {domain}: {len(domain_specific_neurons[domain])}")


In [None]:
merged_state_dict = {}

for name, param in base.state_dict().items():
    # Get parameter tensors from each model
    param_rain = rain.state_dict()[name]
    param_fog = fog.state_dict()[name]
    param_snow = snow.state_dict()[name]
    
    # Check if the parameter is a weight tensor
    if 'weight' in name and len(param.shape) >= 1:
        # Process per neuron/filter
        merged_param = param.clone()
        num_neurons = param.shape[0]
        for i in range(num_neurons):
            # Get importance scores
            imp_rain = importance_rain.get(name, torch.zeros(num_neurons))[i]
            imp_fog = importance_fog.get(name, torch.zeros(num_neurons))[i]
            imp_snow = importance_snow.get(name, torch.zeros(num_neurons))[i]
            
            # Normalize importance scores (if not already normalized)
            total_imp = imp_rain + imp_fog + imp_snow + 1e-8  # Add epsilon to avoid division by zero
            imp_rain /= total_imp
            imp_fog /= total_imp
            imp_snow /= total_imp
            
            # Merge weights based on importance
            merged_param[i] = (imp_rain * param_rain[i] +
                               imp_fog * param_fog[i] +
                               imp_snow * param_snow[i])
        merged_state_dict[name] = merged_param
    else:
        # For biases and other parameters, average them
        merged_state_dict[name] = (param_rain + param_fog + param_snow) / 3


In [None]:
# Create a new ResNet50 model instance
unified_model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
unified_model.fc = torch.nn.Linear(model.fc.in_features, 10)

# Load the merged weights
unified_model.load_state_dict(merged_state_dict)

# Set to evaluation mode
unified_model.eval()
print("Done")

In [None]:
result = evaluate(
    unified_model,
    batch_size = 256,
    num_workers = 8,
    domains = ["base", "rain", "fog", "snow"],
    device = "cuda:7",
    dtype = torch.float16
)
print([result[domain]["accuracy"] for domain in ["base", "rain", "fog", "snow"]])

In [None]:
# Assuming we have a unified importance score mask
unified_mask = create_unified_mask(important_neurons)

def masked_forward_pass(model, x, mask):
    activations = {}
    def hook(name):
        def hook_fn(module, input, output):
            if name in mask:
                output = output * mask[name]
            return output
        return hook_fn
    
    # Register hooks
    handles = []
    for name, layer in model.named_modules():
        if name in mask:
            handle = layer.register_forward_hook(hook(name))
            handles.append(handle)
    
    # Forward pass
    output = model(x)
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    return output

# Use the modified forward pass during inference
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = masked_forward_pass(base_model, images, unified_mask)
        # Compute accuracy or other metrics
