In [None]:
import numpy as np
import torch
from aug.automold import add_rain, add_snow, add_fog, add_autumn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tt
from models import ResNet
from tqdm.notebook import tqdm
import os

In [None]:
def load_stl10(root_dir: str = "stl10_binary"):

    path_to_images = os.path.join(root_dir, "test_X.bin")
    path_to_labels = os.path.join(root_dir, "test_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        train_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        train_labels = np.fromfile(f, dtype=np.uint8) - 1
    
    path_to_images = os.path.join(root_dir, "train_X.bin")
    path_to_labels = os.path.join(root_dir, "train_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        test_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        test_labels = labels - 1

    return train_images, train_labels, test_images, test_labels

def shift(image, domain):
    if domain == "rain":
        return add_rain(image, rain_type = 'torrential')
    elif domain == "fog":
        return add_fog(image, fog_coeff=1.0)
    elif domain == "snow":
        return add_snow(image, snow_coeff=0.05)
    elif domain == "autumn":
        return add_autumn(image)
    return image

class STL10Dataset(Dataset):
    def __init__(self, images, labels, domain="base"):
        self.images = images.astype(np.float32)
        self.labels = labels.astype(np.int64)
        self.domain = domain
        stats = ((113.911194, 112.1515, 103.69485), (51.854874, 51.261967, 51.842403))
        self.tfms = tt.Compose([
            tt.ToTensor(),
            tt.Normalize(stats[0], stats[1])
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = shift(image, self.domain)
        image[np.isnan(image)] = 0

        image = self.tfms(image)

        return image, label

def evaluate(model, batch_size, num_workers, domains, device, dtype):
    
    train_images, train_labels, test_images, test_labels = load_stl10()
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    model.to(device).to(dtype)
    result = {}

    for domain in domains:
        
        test_dataset = STL10Dataset(test_images, test_labels, domain=domain)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        total_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                inputs, labels = images.to(device).to(dtype), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        result[domain] = {
            "loss": total_loss / total,
            "accuracy": correct / total
        }

    return result

In [None]:
model = ResNet.load_model("./models/resnet152_base.pth")
domains = ["base", "rain", "fog", "autumn"]
result = evaluate(
    model,
    batch_size = 256,
    num_workers = 4,
    domains = domains,
    device = "cuda:7",
    dtype = torch.float16
)
print([result[domain]["accuracy"] for domain in domains])

### Avg MTL

$$
θ_{MTL} = \sum_{i=1}^{N}(\theta_{i}-\theta_{0})
$$
$$
θ_{AVG MTL} = \theta_{0} + \alpha \theta_{MTL}
$$

In [None]:
base = ResNet.load_model("./models/resnet152_base.pth")
rain = ResNet.load_model("./models/resnet152_rain.pth")
fog = ResNet.load_model("./models/resnet152_fog.pth")
autumn = ResNet.load_model("./models/resnet152_autumn.pth")
domains = ["base", "rain", "fog", "autumn"]

tv1 = rain.sub(base)
tv2 = fog.sub(base)
tv3 = autumn.sub(base)
mtm = tv1.add(tv2).add(tv3)

In [None]:
for alpha in [0, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 1e1, 1e2, 1e3, 1e4, 1e5]:
    avg = sumv.scale(alpha).add(base)
    # avg_model = base + alpha * (sum(base - domain_adapted_model))
    print(f"evaluating at alpha: {alpha}")
    result = evaluate(
        avg,
        batch_size = 256,
        num_workers = 4,
        domains = domains,
        device = "cuda:7",
        dtype = torch.float16
    )
    print([result[domain]["accuracy"] for domain in domains])

### SLERP

$$
\theta_{\text{SLERP}}(t) = \frac{\sin((1-t)\Omega)}{\sin(\Omega)} \theta_1 + \frac{\sin(t\Omega)}{\sin(\Omega)} \theta_2
$$

In [None]:
base = ResNet.load_model("./models/resnet152_base.pth")
rain = ResNet.load_model("./models/resnet152_rain.pth")
fog = ResNet.load_model("./models/resnet152_fog.pth")
autumn = ResNet.load_model("./models/resnet152_autumn.pth")
domains = ["base", "rain", "fog", "autumn"]

for i in range(10):
    alpha = i / 10

    slep = rain.slep(fog, alpha=alpha)
    
    result = evaluate(
        slep,
        batch_size = 256,
        num_workers = 4,
        domains = domains,
        device = "cuda:7",
        dtype = torch.float16
    )
    print(f"Alpha: {alpha}    " + "  ".join([str(result[domain]["accuracy"]*100) for domain in domains]))

### TALL-masks

$$
θ̂ₜ = θ₀ + mₜ ∘ θ_{MTL} ≈ θₜ
$$
$$
mₜ^* = \min_{mₜ \in \{0,1\}^P} ||θ̂ₜ - θₜ||₁
$$

$$
= \min_{mₜ \in \{0,1\}^P} ||mₜ ∘ θ_{MTL} - τₜ||₁
$$

$$
= \mathbf{1} [|τₜ| ≥ |θ_{MTL} - τₜ|]
$$

$$
mₜ = \mathbf{1} [|τₜ| ≥ |θ_{MTL} - τₜ| · λₜ]
$$


In [None]:
domains = ["rain", "fog", "autumn"]
for i in range(10):
    alpha = 2 + i / 5
    print(f"Alpha: {alpha}")
    for domain in domains:
        model = ResNet.load_model(f"./models/resnet152_{domain}.pth")
        mask = model.create_mask(mtm, alpha=alpha)
        merged = base.apply_mask(mtm, mask)
    
        result = evaluate(
            merged,
            batch_size = 256,
            num_workers = 4,
            domains = domains,
            device = "cuda:7",
            dtype = torch.float16
        )
        print(f"{domain}: ", result[domain]["accuracy"])

In [None]:
## Benchmarks (Accuracy %)
#                 <- models ->
# domain    base   rain    fog    autumn
# base      63.32  29.04   54.14  26.77
# rainy     26.02  60.26   32.36  26.38
# foggy     51.55  39.9    62.92  24.14
# autumn    16.25  11.98   17.26  39.16

# all       29.28  30.48   31.5   40.2

## Method 1: MTL (Task Vector avg)
# Editing Models with Task arithematic
# Ref: https://arxiv.org/pdf/2212.04089
#  alpha    base    rain    fog   autumn
#    0      63.38  26.68   51.42  16.52
#  1e-4     38.80  16.24   25.82  14.14
#  1e-3     19.34  13.04   16.56  11.59
#  1e-2     12.52  10.74   13.00  11.66
#  1e-1     10.0   10.0    10.0   11.92
#  1e0      10.0   10.0    10.0   10.0
#  1e1      10.0   10.0    10.0   10.0
#  1e2      10.0   10.0    10.0   10.0
#  1e3      10.0   10.0    10.0   10.0

## Method 2: Spherical Linear Interpolation (SLERP)
# Smoothly blend two models by interpolating their weights along the
# shortest path on a high dimensional sphere
# Refs:
# 1] https://arxiv.org/pdf/1609.04468
# 2] https://arxiv.org/pdf/2403.13257
# 3] https://www.engr.colostate.edu/ECE481A2/Readings/Rotation_Animation.pdf
# Alpha: 0.19    31.2  42.199999999999996  35.9  12.26

## Method 3: Tall-masks
# Alpha: 0.6
# rain:  34.78
# fog: 42.78
# autumn: 23.89