In [5]:
import numpy as np
np.bool = np.bool_
np.complex = np.complex128

import torch
from aug.automold import add_rain, add_snow, add_fog, add_autumn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tt
from diffusers.utils import make_image_grid
from models import ResNet
from PIL import Image
from tqdm.notebook import tqdm
import cv2
import random
import os
import torchvision
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR, StepLR

In [6]:
def load_stl10(root_dir: str = "stl10_binary"):

    path_to_images = os.path.join(root_dir, "train_X.bin")
    path_to_labels = os.path.join(root_dir, "train_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        train_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        train_labels = np.fromfile(f, dtype=np.uint8) - 1
    
    path_to_images = os.path.join(root_dir, "test_X.bin")
    path_to_labels = os.path.join(root_dir, "test_y.bin")
    
    with open(path_to_images, 'rb') as f:
        images = np.fromfile(f, dtype=np.uint8)
        images = np.reshape(images, (-1, 3, 96, 96))
        test_images = np.transpose(images, (0, 3, 2, 1))
    
    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        test_labels = labels - 1

    return train_images, train_labels, test_images, test_labels

def add_snow(image, snowflake_count=200, snowflake_radius=(1, 3), snowflake_intensity=(200, 255)):
    snowy_image = image.copy()
    height, width, _ = snowy_image.shape
    for _ in range(snowflake_count):
        x = np.random.randint(0, width)
        y = np.random.randint(0, height)
        
        radius = np.random.randint(snowflake_radius[0], snowflake_radius[1])
        intensity = np.random.randint(snowflake_intensity[0], snowflake_intensity[1])
        
        cv2.circle(snowy_image, (x, y), radius, (intensity, intensity, intensity), -1)

    return snowy_image

def shift(image, domain):
    if domain == "rain":
        return add_rain(image, rain_type = 'torrential')
    elif domain == "fog":
        return add_fog(image, fog_coeff=1.0)
    elif domain == "snow":
        return add_snow(image=image)
    elif domain == "autumn":
        return add_autumn(image)
    return image

class STL10Dataset(Dataset):
    def __init__(self, images, labels, domain="base"):
        self.images = images.astype(np.float32)
        self.labels = labels.astype(np.int64)
        self.domain = domain
        self.domains = ["rain", "fog", "snow"]
        stats = ((113.911194, 112.1515, 103.69485), (51.854874, 51.261967, 51.842403))
        self.tfms = tt.Compose([
            tt.ToTensor(),
            tt.Normalize(stats[0], stats[1])
        ])

    def __len__(self):
        if self.domain == "all":
            return len(self.domains) * len(self.images)
        return len(self.images)

    def __getitem__(self, idx):

        if self.domain == "all":
            d = idx // len(self.images)
            idx = idx % len(self.images)
            domain = self.domains[d]
        else:
            domain = self.domain
        
        image = self.images[idx]
        label = self.labels[idx]

        image = shift(image, domain)
        
        image[np.isnan(image)] = 0

        image = self.tfms(image)

        return image, label

def evaluate(model, batch_size, num_workers, domains, device, dtype):
    
    train_images, train_labels, test_images, test_labels = load_stl10()
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    model.to(device).to(dtype)
    result = {}

    for domain in domains:
        
        test_dataset = STL10Dataset(test_images, test_labels, domain=domain)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        total_loss = 0
        total = 0
        correct = 0
        with torch.no_grad():
            for images, labels in test_loader:
                inputs, labels = images.to(device).to(dtype), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        result[domain] = {
            "loss": total_loss / total,
            "accuracy": correct / total
        }

    return result

In [7]:
import torchvision
import torch

device = "cuda:7"
dtype = torch.float32

def load_model(model_name):
    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    model.load_state_dict(torch.load(f"./ckpts/{model_name}.pth"))
    model.to(device).to(dtype)
    return model

def get_empty_state():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    result_state = result.state_dict()
    for k, v in result_state.items():
        result_state[k] = torch.zeros_like(v)
    return result_state

def get_empty_model():
    result = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    result.fc = torch.nn.Linear(result.fc.in_features, 10)
    result.to(device).to(dtype)
    return result

base = load_model("resnet_50")
mtl = load_model("mtl")

for domain in ["rain", "fog", "snow"]:

    mask = torch.load(f"./ckpts/mask_{domain}.pth")
    state = get_empty_state()
    for key in state:
        state[key] = base.state_dict()[key] + mask[key].to(device).to(dtype) * mtl.state_dict()[key]
    model = get_empty_model()
    model.load_state_dict(state)
    
    result = evaluate(
        base,
        batch_size = 256,
        num_workers = 8,
        domains = [domain],
        device = device,
        dtype = dtype
    )
    print(f"Accuracy: {result[domain]['accuracy'] * 100:.4f}% on domain {domain}")

Accuracy: 68.3375% on domain rain
Accuracy: 64.8125% on domain fog
Accuracy: 56.9500% on domain snow


In [12]:
from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)

num_epochs = 30
lr = 5e-4
batch_size = 256
num_workers = 8
train_domain = "rain"
eval_domains = ["base", "rain", "fog", "snow"]
eps = 0.3

device = "cuda:7"
dtype = torch.float32

train_images, train_labels, test_images, test_labels = load_stl10()
train_dataset = STL10Dataset(train_images, train_labels, domain=train_domain)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
model.fc = torch.nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load("./ckpts/resnet50_rain.pth"))

# model = ResNet.load_model(model_name="resnet50", n_classes=10)
# # model = ResNet.load_model(model_name="./models/resnet152_base.pth", n_classes=10)
# model.train()
model.to(device).to(dtype)

criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr = lr)
optimizer = torch.optim.SGD([
            {'params': model.conv1.parameters(), 'lr': 0.001},
            {'params': model.bn1.parameters(), 'lr': 0.002},
            {'params': model.relu.parameters(), 'lr': 0.003},
            {'params': model.maxpool.parameters(), 'lr': 0.004},
            {'params': model.layer1.parameters(), 'lr': 0.005},
            {'params': model.layer2.parameters(), 'lr': 0.006},
            {'params': model.layer3.parameters(), 'lr': 0.007},
            {'params': model.layer4.parameters(), 'lr': 0.008},
            {'params': model.avgpool.parameters(), 'lr': 0.009},
            {'params': model.fc.parameters(), 'lr': 0.001}
        ], lr=0.001, momentum=0.9)
# scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * num_epochs, eta_min=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(num_epochs):

    model.train()

    pbar = tqdm(train_loader)
    for images, labels in pbar:
        
        inputs, labels = images.to(device).to(dtype), labels.to(device)

        x_fgm = fast_gradient_method(model, inputs, eps, np.inf)
        # x_pgd = projected_gradient_descent(model, inputs, eps, 0.01, 40, np.inf)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs_fgm = model(x_fgm)
        loss = criterion(outputs, labels) + criterion(outputs_fgm, labels)
        loss.backward()
        optimizer.step()
    
        pbar.set_description(f"Loss: {loss.item()}, lr: {scheduler.get_last_lr()[0]:.6f}")

    scheduler.step()
        
result = evaluate(
    model,
    batch_size = batch_size,
    num_workers = num_workers,
    domains = eval_domains,
    device = device,
    dtype = dtype
)
print([result[domain]["accuracy"] for domain in eval_domains])

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

[0.710875, 0.316125, 0.32575, 0.29075]
