# Loss Landscapes on CIFAR

We visualize loss landscapes using [filter normalization](https://arxiv.org/abs/1712.09913) to show that ***self-attentions flatten loss landscapes*** as shown in Fig 1.a and Fig C.4.a.

In [None]:
import sys

# check whether run in Colab
root = "."
if "google.colab" in sys.modules:
    print("Running in Colab.")
    !pip3 install matplotlib
    !pip3 install einops==0.4.1
    !pip3 install timm==0.5.4
    !pip3 install pandas==1.3.2
    !git clone https://github.com/xxxnell/how-do-vits-work.git
    root = "./how-do-vits-work"
    sys.path.append(root)

Running in Colab.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops==0.4.1
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm==0.5.4
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 7.5 MB/s 
Installing collected packages: timm
Successfully installed timm-0.5.4
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pandas==1.3.2
  Downloading pandas-1.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.3 MB)
[K     |████████████████████████████████| 11.3 MB 7.7 MB/s 
Installing collected packages: pandas
  Attempting

In [None]:
import os
import yaml
import copy
from pathlib import Path

import torch
from torch.utils.data import DataLoader

import models
import ops.tests as tests
import ops.datasets as datasets
import ops.loss_landscapes as lls

In [None]:
# config_path = "%s/configs/cifar10_vit.yaml" % root
config_path = "%s/configs/cifar100_vit.yaml" % root
# config_path = "%s/configs/imagenet_vit.yaml" % root

with open(config_path) as f:
    args = yaml.load(f)
    print(args)

{'dataset': {'name': 'cifar100', 'root': '../data', 'mean': [0.5071, 0.4867, 0.4408], 'std': [0.2675, 0.2565, 0.2761], 'padding': 4, 'color_jitter': 0.0, 'auto_augment': 'rand-m9-n2-mstd1.0', 're_prob': 0.0}, 'train': {'warmup_epochs': 5, 'epochs': 300, 'batch_size': 96, 'max_norm': 5, 'smoothing': 0.1, 'mixup': {'num_classes': 100, 'mixup_alpha': 1.0, 'cutmix_alpha': 0.8, 'prob': 1.0}}, 'val': {'batch_size': 256, 'n_ff': 1}, 'model': {'stem': False, 'block': {'image_size': 32, 'patch_size': 2, 'sd': 0.1}}, 'optim': {'name': 'AdamW', 'lr': 0.000125, 'weight_decay': 0.05, 'scheduler': {'name': 'CosineAnnealingLR', 'T_max': 300, 'eta_min': 0}}, 'env': {}}


In [None]:
dataset_args = copy.deepcopy(args).get("dataset")
train_args = copy.deepcopy(args).get("train")
val_args = copy.deepcopy(args).get("val")
model_args = copy.deepcopy(args).get("model")
optim_args = copy.deepcopy(args).get("optim")
env_args = copy.deepcopy(args).get("env")

In [None]:
import torch.nn as nn
dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)
dataset_name = dataset_args["name"]
num_classes = len(dataset_train.classes)

dataset_train = DataLoader(dataset_train, 
                           shuffle=True, 
                           num_workers=train_args.get("num_workers", 4), 
                           batch_size=train_args.get("batch_size", 128))
dataset_test = DataLoader(dataset_test, 
                          num_workers=val_args.get("num_workers", 4), 
                          batch_size=val_args.get("batch_size", 128))

print("Train: %s, Test: %s, Classes: %s" % (
    len(dataset_train.dataset), 
    len(dataset_test.dataset), 
    num_classes
))

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ../data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ../data/cifar-100-python.tar.gz to ../data
Files already downloaded and verified
Train: 50000, Test: 10000, Classes: 100


  cpuset_checked))


In [None]:
from timm.data import Mixup

def mixup_function(train_args):
    train_args = copy.deepcopy(train_args)
    smoothing = train_args.get("smoothing", 0.0)
    mixup_args = train_args.get("mixup", None)

    mixup_function = Mixup(
        **mixup_args,
        label_smoothing=smoothing,
    ) if mixup_args is not None else None
    return mixup_function


transform = mixup_function(train_args)

## Load Pretrained Models

The cells below provide the snippets for ResNet-50 and ViT-Ti:

In [None]:
# download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/resnet_50_cifar100_691cc9a9e4.pth.tar"
path = "checkpoints/resnet_50_cifar100_691cc9a9e4.pth.tar"
models.download(url=url, path=path)

name = "resnet_50"
uid = "691cc9a9e4"
model = models.get_model(name, num_classes=num_classes,  # timm does not provide a ResNet for CIFAR
                         stem=model_args.get("stem", False))
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

100%|██████████| 285M/285M [00:21<00:00, 13.2MiB/s]


model: resnet_50 , params: 23.7M, output: [3, 100]


<All keys matched successfully>

In [None]:
import copy
import timm
import torch
import torch.nn as nn

# download and load a pretrained model for CIFAR-100
url = "https://github.com/xxxnell/how-do-vits-work-storage/releases/download/v0.1/vit_ti_cifar100_9857b21357.pth.tar"
path = "checkpoints/vit_ti_cifar100_9857b21357.pth.tar"
models.download(url=url, path=path)

model = timm.models.vision_transformer.VisionTransformer(
    num_classes=num_classes, img_size=32, patch_size=2,  # for CIFAR
    embed_dim=192, depth=12, num_heads=3, qkv_bias=False,  # for ViT-Ti 
)
model.name = "vit_ti"
uid = "9857b21357"
models.stats(model)
map_location = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["state_dict"])

100%|██████████| 65.0M/65.0M [00:03<00:00, 20.0MiB/s]


model: vit_ti , params: 5.4M


<All keys matched successfully>

## Investigate the Loss Landscape

We measure not only NLL but also "L2" on "augmented training datasets".

In [None]:
scale = 1e-0
n = 21
gpu = torch.cuda.is_available()

metrics_grid = lls.get_loss_landscape(
    model, 1, dataset_train, transform=transform,
    kws=["pos_embed", "relative_position"],
    x_min=-1.0 * scale, x_max=1.0 * scale, n_x=n, y_min=-1.0 * scale, y_max=1.0 * scale, n_y=n, gpu=gpu,
)
leaderboard_path = os.path.join("leaderboard", "logs", dataset_name, model.name)
Path(leaderboard_path).mkdir(parents=True, exist_ok=True)
metrics_dir = os.path.join(leaderboard_path, "%s_%s_%s_x%s_losslandscape.csv" % (dataset_name, model.name, uid, int(1 / scale)))
metrics_list = [[*grid, *metrics] for grid, metrics in metrics_grid.items()]
tests.save_metrics(metrics_dir, metrics_list)

Grid:  [-1. -1.], 

  cpuset_checked))


NLL: 5.4112, Cutoffs: 0.0 %, 90.0 %, Accs: 1.128 %, 0.000 %, Uncs: 0.000 %, 100.000 %, IoUs: 0.105 %, 0.000 %, Freqs: 100.000 %, 0.000 %, Top-5: 5.696 %, Brier: 1.026, ECE: 12.815 %, ECE±: 12.815 %
Grid:  [-0.9 -1. ], 

  cpuset_checked))


## Plot the Loss Landscape

For the given loss landscape raw data, we visualize the loss (NLL + L2) landscape.

In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm


# load losslandscape raw data of ResNet-50 or ViT-Ti
names = ["x", "y", "l1", "l2", "NLL", "Cutoff1", "Cutoff2", "Acc", "Acc-90", "Unc", "Unc-90", "IoU", "IoU-90", "Freq", "Freq-90", "Top-5", "Brier", "ECE", "ECSE"]
# path = "%s/resources/results/cifar100_resnet_dnn_50_losslandscape.csv" % root  # for ResNet-50
path = "%s/resources/results/cifar100_vit_ti_losslandscape.csv" % root  # for ViT-Ti
data = pd.read_csv(path, names=names)
data["loss"] = data["NLL"] + optim_args["weight_decay"] * data["l2"]  # NLL + l2

# prepare data
p = int(math.sqrt(len(data)))
shape = [p, p]
xs = data["x"].to_numpy().reshape(shape) 
ys = data["y"].to_numpy().reshape(shape)
zs = data["loss"].to_numpy().reshape(shape)

zs = zs - zs[np.isfinite(zs)].min()
zs[zs > 42] = np.nan

norm = plt.Normalize(zs[np.isfinite(zs)].min(), zs[np.isfinite(zs)].max())  # normalize to [0,1]
colors = cm.plasma(norm(zs))
rcount, ccount, _ = colors.shape

fig = plt.figure(figsize=(4.2, 4), dpi=120)
ax = fig.gca(projection="3d")
ax.view_init(elev=15, azim=15)  # angle

# make the panes transparent
ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# make the grid lines transparent
ax.xaxis._axinfo["grid"]['color'] =  (1,1,1,0)
ax.yaxis._axinfo["grid"]['color'] =  (1,1,1,0)
ax.zaxis._axinfo["grid"]['color'] =  (1,1,1,0)

surf = ax.plot_surface(
    xs, ys, zs, 
    rcount=rcount, ccount=ccount,
    facecolors=colors, shade=False,
)
surf.set_facecolor((0,0,0,0))

# remove white spaces
adjust_lim = 0.8
ax.set_xlim(-1 * adjust_lim, 1 * adjust_lim)
ax.set_ylim(-1 * adjust_lim, 1 * adjust_lim)
ax.set_zlim(10, 32)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax.axis('off')

plt.show()