In [17]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
from types import SimpleNamespace
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import numpy as np
from torchvision.transforms import v2

from virtual_fusion_accuracy import VirtualFusionAccuracy
import torchvision.transforms as transforms
from torchvision.models import wide_resnet50_2, Wide_ResNet50_2_Weights
from torchmetrics.classification import Accuracy
from torchmetrics.functional.classification import multiclass_calibration_error

from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import (
    projected_gradient_descent,
)
from easydict import EasyDict

cudnn.benchmark = True

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Distortions
Define the augmentations and distortions that are applied during training.

In [19]:

class Distortion:
    def __init__(self, fn, lam, **kwargs):
        self.fn = fn
        self.lam = lam
        self.kwargs = kwargs


rng = np.random.RandomState(seed=42)

augMixAugmenter = v2.AugMix()

distortions = {
    'AugMix': [
        Distortion(
            lambda images:images,
            1
        ),
        Distortion(
            lambda images:augMixAugmenter(images),
            0.85 # delta
        )
    ],
}



In [20]:

data_root = "C:/datasets/cifar10-image-folder"
data_c_root = "C:/datasets/cifar-10-c"


normalize = transforms.Normalize(
    mean=[0.491, 0.482, 0.446],
    std=[0.247, 0.243, 0.261]
)


In [21]:
resize = 32

train_transform = transforms.Compose([
    transforms.Resize((resize,resize), antialias='True'),
    transforms.RandomCrop(resize, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])



## define a config object to store the hyperparameters

In [22]:
config = SimpleNamespace(
    distortion_p = 0.7,
    train_batch_size = 1024,
    train_num_epochs = 25,
    train_optimizer_lr = 0.1,
    train_lr_schedule_step_size = 25,
    train_lr_schedule_gamma = 0.1,
    evaluation_batch_size = 1024,
)



In [23]:
def LA_criterion(criterion, pred, y_a, y_b, lams):
    return torch.mean(torch.mul(criterion(pred, y_a),lams) + torch.mul(criterion(pred, y_b),(1 - lams)))


# Train

In [24]:
train_dataset = ImageFolder(
    root=os.path.join(data_root, 'train'),
    transform=train_transform,
)

In [25]:
train_dataset_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size= config.train_batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory = True,
    drop_last = True,
)


In [26]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, config):


    best_acc = 0.0
    best_model_params_path = os.path.join('checkpoints', f'LableAug_{config.train_num_epochs}_model_params.pt')

    pbar = tqdm(total=num_epochs, unit='epoch')
    pbar_batch = tqdm(total=len(train_dataset_loader), unit='batch')

    for epoch in range(num_epochs):
        step = epoch + 1

        # initialize metrics
        acc_metric = VirtualFusionAccuracy()
        acc_metric.to(device)

        # Set model to training mode
        model.train()

        train_loss = 0.0
        total = 0.0

        pbar_batch.reset()
        # Iterate over data.
        inputs_trained = False
        for inputs, targets_a in train_dataset_loader:
            for distortion_index, distortion_name in enumerate(distortions.keys()):
                for distortion in distortions[distortion_name]:

                    images = inputs
                    p = torch.rand(1)
                    if p < config.distortion_p :
                        images = distortion.fn(images.to(device), **distortion.kwargs)
                        targets_b = torch.ones_like(targets_a) * (original_num_classes + distortion_index)
                        lams = torch.ones_like(targets_a) * distortion.lam
                    else:
                        if inputs_trained:
                          continue
                        inputs_trained = True
                        targets_b = targets_a
                        lams = torch.ones_like(targets_a)


                    images = images.to(device)
                    images = normalize(images)
                    targets_a = targets_a.to(device)
                    targets_b = targets_b.to(device)
                    lams = lams.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    with torch.set_grad_enabled(True):
                        outputs = model(images)
                        loss = LA_criterion(criterion, outputs, targets_a, targets_b, lams)

                        # backward + optimize
                        loss.backward()
                        optimizer.step()

                    # statistics
                    train_loss += loss.item()
                    total += images.size(0)

                    acc = acc_metric(outputs,  targets_a, targets_b, lams)

            pbar_batch.update(1)
        scheduler.step()

        epoch_loss = train_loss / total
        epoch_acc  = acc_metric.compute()
        print(f' Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        if best_acc<epoch_acc:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': loss,
                'config': config
            }, best_model_params_path)

        pbar.update(1)

    pbar_batch.close()
    pbar.close()


    # load best model
    checkpoint = torch.load(best_model_params_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    return model

#### cosine annealing schedule

In [27]:
def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

# Fine-tuning

In [None]:


if __name__ == '__main__':
    original_num_classes = len(train_dataset.classes)
    num_classes = original_num_classes + len(distortions.keys())

    net  = wide_resnet50_2(weights= Wide_ResNet50_2_Weights.DEFAULT)

    # for param in net.parameters():
    #     param.requires_grad = False

    net.fc = nn.Linear(net.fc.in_features, num_classes)


    net.to(device)

    criterion = nn.CrossEntropyLoss(reduction='none')

    optimizer = optim.SGD(net.parameters(), lr= config.train_optimizer_lr, momentum=0.9)

    exp_lr_scheduler = lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(
            step,
            config.train_num_epochs,
            1,  # lr_lambda computes multiplicative factor
            1e-6 / config.train_optimizer_lr)
    )

    net = train_model(net, criterion, optimizer, exp_lr_scheduler, config.train_num_epochs, config)


# Evaluation

In [29]:
def show_performance(model, criterion, distortion_name):
    was_training = model.training
    model.eval()
    accs = []


    for severity in range(1, 6):
        if severity>1 and distortion_name=='original':
            continue
        distorted_dataset = datasets.ImageFolder(
            root=os.path.join(data_c_root, distortion_name, str(severity)),
            transform=transforms.Compose([
                transforms.Resize((resize,resize), antialias='True'),
                transforms.ToTensor(),
                normalize
            ])
        )

        distorted_dataset_loader = torch.utils.data.DataLoader(
            distorted_dataset,
            batch_size= config.evaluation_batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True,
            drop_last = True,
        )

        eval_loss = 0.0
        total = 0.0

        # initialize metric
        acc_metric = Accuracy(task="multiclass", num_classes=original_num_classes)
        acc_metric.to(device)

        for batch_idx, (data, target) in enumerate(distorted_dataset_loader):
            data = data.to(device)
            target = target.to(device)

            with torch.no_grad():
                output = model(data)
                loss = criterion(output,target)

            # create a mask for indices less than the number of actual classes
            filter_mask = torch.arange(output.size(1)) < original_num_classes

            # apply filter
            filtered_output = output[:, filter_mask]

            _, pred = torch.max(filtered_output, 1)

            # statistics
            eval_loss += torch.mean(loss).item()
            total += data.size(0)
            acc = acc_metric(pred,  target)
            err = 1 - acc


        acc =  acc_metric.compute()
        err = 1 - acc
        accs.append(acc)


    model.train(mode=was_training)
    mean_acc = torch.mean(torch.tensor(accs).detach().cpu())
    mean_err = 1 - mean_acc

    return mean_err

# ImageNet-C evaluation


In [30]:
distortions_c = [
    'gaussian_noise',
    'shot_noise', 'impulse_noise',
    'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
    'snow', 'frost', 'fog', 'brightness',
    'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
    'original',
]

error_rates = []
for distortion_name in distortions_c:
    rate = show_performance(net, criterion, distortion_name)
    error_rates.append(rate)
    print('Distortion: {:15s}  | CE (unnormalized) (%): {:.2f}'.format(distortion_name, 100 * rate))


error_rates.pop()  # to remove original for distortions_c
mCE = 100 * np.mean(error_rates)
print('mCE (unnormalized by ResNet errors) (%): {:.2f}'.format(mCE))


Distortion: gaussian_noise   | CE (unnormalized) (%): 21.82
Distortion: shot_noise       | CE (unnormalized) (%): 19.24
Distortion: impulse_noise    | CE (unnormalized) (%): 29.33
Distortion: defocus_blur     | CE (unnormalized) (%): 12.24
Distortion: glass_blur       | CE (unnormalized) (%): 29.82
Distortion: motion_blur      | CE (unnormalized) (%): 17.21
Distortion: zoom_blur        | CE (unnormalized) (%): 15.45
Distortion: snow             | CE (unnormalized) (%): 17.72
Distortion: frost            | CE (unnormalized) (%): 16.85
Distortion: fog              | CE (unnormalized) (%): 17.25
Distortion: brightness       | CE (unnormalized) (%): 9.84
Distortion: contrast         | CE (unnormalized) (%): 21.67
Distortion: elastic_transform  | CE (unnormalized) (%): 14.60
Distortion: pixelate         | CE (unnormalized) (%): 17.08
Distortion: jpeg_compression  | CE (unnormalized) (%): 15.19
Distortion: original         | CE (unnormalized) (%): 8.57
mCE (unnormalized by ResNet errors) (%)

In [31]:

test_dataset = datasets.ImageFolder(
    root=os.path.join(data_root, 'test'),
    transform=transforms.Compose([
        transforms.Resize((resize,resize), antialias='True'),
        transforms.ToTensor(),
        normalize ]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=config.evaluation_batch_size, shuffle=False, num_workers=2, drop_last=True)

eps_values = [0.03, 0.3]
error_reports = {}

net.eval()

for eps in eps_values:
    report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)

    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        x_fgm = fast_gradient_method(net, x, eps, np.inf)
        x_pgd = projected_gradient_descent(net, x, eps, 0.01, 40, np.inf)


        filter_mask = torch.arange(num_classes) < original_num_classes

        _, y_pred = net(x)[:, filter_mask].max(1)  # model prediction on clean examples
        _, y_pred_fgm = net(x_fgm)[:, filter_mask].max(1)  # model prediction on FGM adversarial examples
        _, y_pred_pgd = net(x_pgd)[:, filter_mask].max(1)

        report.nb_test += y.size(0)
        report.correct_fgm += y_pred_fgm.eq(y).sum().item()
        report.correct_pgd += y_pred_pgd.eq(y).sum().item()

    fgm_acc = report.correct_fgm / report.nb_test
    pgd_acc = report.correct_pgd / report.nb_test

    fgm_error = 1 - fgm_acc
    pgd_error = 1 - pgd_acc

    error_reports[eps] = {'FGM': fgm_error, 'PGD': pgd_error}

for eps, errors in error_reports.items():
    print(f"Epsilon: {eps}")
    print("Error on FGM adversarial examples (%): {:.3f}".format(errors['FGM'] * 100.0))
    print("Error on PGD adversarial examples (%): {:.3f}".format(errors['PGD'] * 100.0))


Epsilon: 0.03
Error on FGM adversarial examples (%): 27.138
Error on PGD adversarial examples (%): 56.999
Epsilon: 0.3
Error on FGM adversarial examples (%): 46.571
Error on PGD adversarial examples (%): 82.151


In [32]:

ps = [] ; ls=[]
accuracies= []
acc_metric = Accuracy(task="multiclass", num_classes=original_num_classes)
acc_metric.to(device)

with torch.no_grad():
    for (inputs, labels) in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)


        outputs = net(inputs)

        filter_mask = torch.arange(outputs.size(1)) < original_num_classes
        filtered_output = outputs[:, filter_mask]
        _, pred = torch.max(filtered_output, 1)
        acc = acc_metric(pred,  labels)

        ps.append(filtered_output)
        ls.append(labels)
    acc =  acc_metric.compute()
    err = 1 - acc
    accuracies.append(acc)

mean_acc = torch.mean(torch.tensor(accuracies).detach().cpu())
mean_err = 1 - mean_acc

ps = torch.cat(ps, dim=0)
ls = torch.cat(ls, dim=0)

ECE = multiclass_calibration_error(ps, ls, num_classes=original_num_classes, n_bins=100, norm='l1')*100
RMS = multiclass_calibration_error(ps, ls, num_classes=original_num_classes, n_bins=100, norm='l2')*100


print(f'the expected calibration error is: {ECE}')
print(f'the root mean squared calibration error is: {RMS}')


confidences=0.961664617061615   accuracies=0.9142795205116272
confidences=0.961664617061615   accuracies=0.9142795205116272
the expected calibration error is: 4.8696489334106445
the root mean squared calibration error is: 8.90273666381836
