In [1]:
import numpy as np
import torch
from aug.automold import add_rain, add_snow, add_fog, add_autumn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as tt
from models import ResNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR
from PIL import Image
from tqdm.notebook import tqdm

In [2]:
path_to_images = "./stl10_binary/train_X.bin"
path_to_labels = "./stl10_binary/train_y.bin"

with open(path_to_images, 'rb') as f:
    images = np.fromfile(f, dtype=np.uint8)
    images = np.reshape(images, (-1, 3, 96, 96))
    train_images = np.transpose(images, (0, 3, 2, 1))

with open(path_to_labels, 'rb') as f:
    train_labels = np.fromfile(f, dtype=np.uint8) - 1

path_to_images = "./stl10_binary/test_X.bin"
path_to_labels = "./stl10_binary/test_y.bin"

with open(path_to_images, 'rb') as f:
    images = np.fromfile(f, dtype=np.uint8)
    images = np.reshape(images, (-1, 3, 96, 96))
    test_images = np.transpose(images, (0, 3, 2, 1))

with open(path_to_labels, 'rb') as f:
    labels = np.fromfile(f, dtype=np.uint8)
    test_labels = labels - 1

def shift(image, domain):
    if domain == "rain":
        return add_rain(image, rain_type = 'torrential')
    elif domain == "fog":
        return add_fog(image, fog_coeff=1.0)
    elif domain == "snow":
        return add_snow(image, snow_coeff=0.05)
    elif domain == "autumn":
        return add_autumn(image)
    return image

class STL10Dataset(Dataset):
    def __init__(self, images, labels, domain="base", transform=None):
        self.images = images.astype(np.float32)
        self.labels = labels.astype(np.int64)
        self.transform = transform
        self.domain = domain

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        image = shift(image, self.domain)
        image[np.isnan(image)] = 0

        if self.transform:
            image = self.transform(image)

        return image, label

stats = ((113.911194, 112.1515, 103.69485), (51.854874, 51.261967, 51.842403))
train_tfms = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(stats[0], stats[1])
])

valid_tfms = tt.Compose([
    tt.ToTensor(),
    tt.Normalize(stats[0], stats[1])
])

In [None]:
num_epochs = 20
device = "cuda:0"
lr = 5e-4
batch_size = 256
dtype = torch.bfloat16
domain = "autumn"

model = ResNet.load_model(model_name="resnet152", n_classes=10)
ckpt = torch.load("./models/resnet152_base.pth")
model.load_state_dict(ckpt["state_dict"])
model.train()
model.to(device).to(dtype)

train_dataset = STL10Dataset(test_images, test_labels, domain=domain, transform=train_tfms)
test_dataset = STL10Dataset(train_images, train_labels, domain=domain, transform=valid_tfms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * num_epochs, eta_min=1e-5)
for epoch in range(num_epochs):

    pbar = tqdm(train_loader)
    for images, labels in pbar:
        
        inputs, labels = images.to(device).to(dtype), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
        pbar.set_description(f"Loss: {loss.item()}, lr: {scheduler.get_last_lr()[0]:.6f}")

        scheduler.step()
    
    total = 0
    total_loss = 0
    correct = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            inputs, labels = images.to(device).to(dtype), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    val_loss = total_loss / total
    accuracy = correct / total
    print("Accuracy: ", accuracy * 100, "Val Loss: ", val_loss)

In [3]:
from models import ResNet


num_epochs = 20
device = "cuda:7"
lr = 5e-4
batch_size = 256
dtype = torch.bfloat16
domain = "base"

model = ResNet.load_model("./models/resnet152_base.pth")
# model = ResNet.load_model(model_name="resnet152", n_classes=10)
# ckpt = torch.load("./models/resnet152_fog.pth")
# model.load_state_dict(ckpt["state_dict"])
model.eval()
model.to(device).to(dtype)
criterion = torch.nn.CrossEntropyLoss()

test_dataset = STL10Dataset(train_images, train_labels, domain=domain, transform=valid_tfms)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

total = 0
total_loss = 0
correct = 0
model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        inputs, labels = images.to(device).to(dtype), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

val_loss = total_loss / total
accuracy = correct / total
print("Accuracy: ", accuracy * 100, "Val Loss: ", val_loss)

  0%|          | 0/20 [00:00<?, ?it/s]

Accuracy:  63.32 Val Loss:  0.0062859375


In [None]:
## Benchmarks (Accuracy %)
# resnet152  base   rain    fog    autumn
# base      63.32  29.04   54.14  26.77
# rainy     26.02  60.26   32.36  26.38
# foggy     51.55  39.9    62.92  24.14
# autumn    16.25  11.98   17.26  39.16