<a href="https://colab.research.google.com/github/katelyn98/CorruptionRobustness/blob/main/EvaluatingCorruptionsImageNetC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install and import necessary libraries

In [None]:
!pip3 install timm
!pip install pytorch_pretrained_vit

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from google.colab import drive
import os
from skimage import io
from tqdm.notebook import tqdm
from PIL import Image
from __future__ import print_function
from __future__ import division
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import requests, zipfile, io
import subprocess
from pytorch_pretrained_vit import ViT
import timm
from pprint import pprint

## Download ImageNet-C
Choose the class from ImageNet-c that you want to evaluate. Note: You should only do one class at a time or Google Colab will run out of space. 

Downloading ImageNet-C will take at least 10 minutes depending on the class you are downloading. 

**Note**: If you do not have Colab Pro, you may not be able to download some classes because of their size. Refer to the [original download site](https://zenodo.org/record/2235448#.YKqCy5NKjLY) for the size of each class.

Another note: you can double click on the title of the form in google colab to see the code that is running the form. 

In [None]:
#@title ImageNet-C Class to Evaluate

dropdown = 'blur' #@param ["blur", "weather", "noise", "extra", "digital"]

if dropdown == 'blur':
  subprocess.run('wget "https://zenodo.org/record/2235448/files/blur.tar?download=1"; mv /content/blur.tar?download=1 /content/blur.tar; tar -xf blur.tar; rm -rf blur.tar; done', shell=True)
elif dropdown == 'weather':
  subprocess.run('wget "https://zenodo.org/record/2235448/files/weather.tar?download=1"; mv /content/weather.tar?download=1 /content/weather.tar; tar -xf weather.tar; rm -rf weather.tar; done', shell=True)
elif dropdown == 'noise':
  subprocess.run('wget "https://zenodo.org/record/2235448/files/noise.tar?download=1"; mv /content/noise.tar?download=1 /content/noise.tar; tar -xf noise.tar; rm -rf noise.tar; done', shell=True)
elif dropdown == 'extra':
  subprocess.run('wget "https://zenodo.org/record/2235448/files/extra.tar?download=1"; mv /content/extra.tar?download=1 /content/extra.tar; tar -xf extra.tar; rm -rf extra.tar; done', shell=True)
elif dropdown == 'digital':
  subprocess.run('wget "https://zenodo.org/record/2235448/files/digital.tar?download=1"; mv /content/digital.tar?download=1 /content/digital.tar; tar -xf digital.tar; rm -rf digital.tar; done', shell=True)


## Class to evaluate function
Note: ImageNet-C has 1000 classes. You can modify this in the case of using CIFAR-10-C or CIFAR-100-C

Note: Some models require differnt input size. You can modify the image size here to be the size your model requires.

Below of the subclasses available for each corruption class. In the subclass spot, you must type the subclass exactly as shown below. 

**Blur**: 'defocus_blur', 'glass_blur', 'zoom_blur', 'motion_blur'

**Weather**: 'frost', 'fog', 'snow', 'brightness'

**Noise**: 'gaussian_noise', 'shot_noise', 'impulse_noise'

**Digital**:  'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'

**Extra**: 'speckle_noise', 'spatter', 'gaussian_blur', 'saturate'

In [None]:
#@title Choose subclass & create dataloader
# # Number of classes in the dataset
num_classes = 1000 #@param {type:"integer"}
batch_size = 100 #@param {type:"integer"}
# # Batch size for training (change depending on how much memory you have)

subclass = 'motion_blur' #@param {type:"string"}

if dropdown == 'blur':
  if subclass != 'motion_blur' and subclass != 'zoom_blur' and subclass != 'glass_blur' and subclass != 'defocus_blur':
    print("Please choose a subclass from the following:")
    print("['motion_blur', 'zoom_blur', 'defocus_blur', 'glass_blur']")
elif dropdown == 'weather':
  if subclass != 'frost' and subclass != 'snow' and subclass != 'fog' and subclass != 'brightness':
    print("Please choose a subclass from the following:")
    print("['frost', 'snow', 'fog', 'brightness']")
elif dropdown == 'noise':
  if subclass != 'gaussian_noise' and subclass != 'shot_noise' and subclass != 'impulse_noise':
    print("Please choose a subclass from the following:")
    print("['gaussian_noise', 'shot_noise', 'impulse_noise']")
elif dropdown == 'extra':
  if subclass != 'speckle_noise' and subclass != 'spatter' and subclass != 'gaussian_blur' and subclass != 'saturate':
    print("Please choose a subclass from the following:")
    print("['speckle_noise', 'spatter', 'gaussian_blur', 'saturate']")
elif dropdown == 'digital':
  if subclass != 'contrast' and subclass != 'elastic_transform' and subclass != 'pixelate' and subclass != 'jpeg_compression':
    print("Please choose a subclass from the following:")
    print("['contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']")


#change this to your directory of ImageNet-C images 
#for the specific class you are testing
data_dir = "/content/"+subclass
severity =  2#@param {type:"integer"}

image_size = 254 #@param {type:"integer"}
center_crop = 224 #@param {type:"integer"}

key = str(severity)
data_transforms = {
    key: transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(center_crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

    ]),
}

# Create dataset
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [key]}
# Create dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in [key]}



## Evaluation Function

In [None]:
def evaluation(model, dataloaders, key):
  #switch model to evaluation mode
  model.eval()

  with torch.no_grad():
    right, running_total = 0, 0
    
    for inputs, labels in dataloaders[key]:
      images = inputs.to('cuda:0')
      ground_truth = labels.to('cuda:0')

      output_probs = model(images) #pass images through model to get probability for each class (1 x 10 dim)
      _, pred_label = torch.max(output_probs.data, 1) #returning index (also label id) for col with highest prediction in the 1 x 10 vector output_preds

      running_total += ground_truth.size(0) #keep track of the number of labels
      right += (pred_label == ground_truth).sum().item() #keeping track number of correct predictions

  acc = (right / running_total)  
  print('Accuracy is ' + str(acc))

# Evaluation on ImageNet-C 

In [None]:
#@title Choose your model to evaluate

model_choice = "ViT" #@param ["ViT", "CaiT", "DeiT", "Swin-T", "MLP-Mixer", "ResNet50", "AlexNet", "VGG", "GoogLeNet"]

exact_model = "ViT_L16" #@param ["ViT_B16", "ViT_L16", "DeiT_B16", "DeiT_B16_Distilled", "DeiT_S16", "DeiT_S16_Distilled", "DeiT_T16", "DeiT_T16_Distilled", "CaiT_S24", "CaiT_XXS24", "Swin-T_B", "Swin-T_L", "Swin-T_S", "Swin-T_T", "mixer_b16", "mixer_l16", "ResNet50", "AlexNet", "VGG16", "GoogLeNet"]

if exact_model == 'ViT_B16':
  if model_choice != "ViT":
    print("please set model_choice to 'ViT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_base_patch16_224', pretrained=True)
      print("Using model vit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_base_patch16_384', pretrained=True)
      print("Using model vit_base_patch16_384")
    else:
      print("image size wrong")

elif exact_model == 'ViT_L16':
  if model_choice != "ViT":
    print("please set model_choice to 'ViT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_large_patch16_224', pretrained=True)
      print("Using model vit_large_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_large_patch16_384', pretrained=True)
      print("Using model vit_large_patch16_384")
    else:
      print("image size wrong")

elif exact_model == 'DeiT_B16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_patch16_224', pretrained=True)
      print("Using model vit_deit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_patch16_384', pretrained=True)
      print("Using model vit_deit_base_patch16_384")
    else:
      print("image size wrong")
elif exact_model == "DeiT_B16_Distilled":
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_384")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_S16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_patch16_224', pretrained = True)
      print("Using model vit_deit_small_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_S16_Distilled':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained = True)
      print("Using model vit_deit_small_distilled_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_T16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_T16_Distilled':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_distilled_patch16_224")
    else:
      print("image size wrong")

elif exact_model == 'CaiT_S24':
  if model_choice != "CaiT":
    print("please set model_choice to 'CaiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('cait_s24_224', pretrained=True)
      print("Using model cait_s24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_s24_384', pretrained=True)
      print("Using model cait_s24_384")
    else:
      print("image size wrong")

elif exact_model == 'CaiT_XXS24':
  if model_choice != "CaiT":
    print("please set model_choice to 'CaiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('cait_xxs24_224', pretrained=True)
      print("Using model cait_xxs24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_xxs24_384', pretrained=True)
      print("Using model cait_xxs24_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_B':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
      print("Using model swin_base_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_base_patch4_window7_384', pretrained=True)
      print("Using model swin_base_patch4_window7_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_L':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
      print("Using model swin_large_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_large_patch4_window7_384', pretrained=True)
      print("Using model swin_large_patch4_window7_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_S':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)
      print("Using model swin_small_patch4_window7_224")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_T':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
      print("Using model swin_tiny_patch4_window7_224")
    else:
      print("image size wrong")

elif exact_model == 'mixer_b16':
  if model_choice != "MLP-Mixer":
    print("please set model_choice to 'MLP-Mixer'")
  else:
    if center_crop == 224:
      model = timm.create_model('mixer_b16_224', pretrained=True)
      print("Using model mixer_b16_224")
    else:
      print("image size wrong")

elif exact_model == 'mixer_l16':
  if model_choice != "MLP-Mixer":
    print("please set model_choice to 'MLP-Mixer'")
  else:
    if center_crop == 224:
      model = timm.create_model('mixer_l16_224', pretrained=True)
      print("Using model mixer_l16_224")
    else:
      print("image size wrong")

elif exact_model == 'AlexNet':
  if model_choice != "AlexNet":
    print("please set model_choice to 'AlexNet'")
  else:
    if center_crop == 224:
      model = models.alexnet(pretrained=True)
      print("Using model AlexNet")
    else:
      print("image size wrong")

elif exact_model == 'ResNet50':
  if model_choice != "ResNet50":
    print("please set model_choice to 'ResNet50'")
  else:
    if center_crop == 224:
      model = models.resnet50(pretrained=True)
      print("Using model ResNet50")
    else:
      print("image size wrong")

elif exact_model == 'GoogLeNet':
  if model_choice != "GoogLeNet":
    print("please set model_choice to 'GoogLeNet'")
  else:
    if center_crop == 224:
      model = models.googlenet(pretrained=True)
      print("Using model GoogLeNet")
    else:
      print("image size wrong")

elif exact_model == 'VGG16':
  if model_choice != "VGG16":
    print("please set model_choice to 'VGG16'")
  else:
    if center_crop == 224:
      model = models.vgg16(pretrained=True)
      print("Using model VGG16")
    else:
      print("image size wrong")

model = model.to('cuda:0')


 **Evaluate**

In [None]:
print("SEVERITY of " + str(severity))
print("CLASS: " + str(dropdown) + " - " + str(subclass))
evaluation(model, dataloaders_dict, key)

## Run all models at once

### Model Choice function

In [None]:

def model_choice(exact_model_usr):

  exact_model = exact_model_usr

  if exact_model == 'ViT_B16':
    if center_crop == 224:
      model = timm.create_model('vit_base_patch16_224', pretrained=True)
      print("Using model vit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_base_patch16_384', pretrained=True)
      print("Using model vit_base_patch16_384")
    else:
      print("image size wrong")

  elif exact_model == 'ViT_L16':
    if center_crop == 224:
      model = timm.create_model('vit_large_patch16_224', pretrained=True)
      print("Using model vit_large_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_large_patch16_384', pretrained=True)
      print("Using model vit_large_patch16_384")
    else:
      print("image size wrong")

  elif exact_model == 'DeiT_B16':
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_patch16_224', pretrained=True)
      print("Using model vit_deit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_patch16_384', pretrained=True)
      print("Using model vit_deit_base_patch16_384")
    else:
      print("image size wrong")
  elif exact_model == "DeiT_B16_Distilled":
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_384")
    else:
      print("image size wrong")
  elif exact_model == 'DeiT_S16':
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_patch16_224', pretrained = True)
      print("Using model vit_deit_small_patch16_224")
    else:
      print("image size wrong")
  elif exact_model == 'DeiT_S16_Distilled':
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained = True)
      print("Using model vit_deit_small_distilled_patch16_224")
    else:
      print("image size wrong")
  elif exact_model == 'DeiT_T16':
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_patch16_224")
    else:
      print("image size wrong")
  elif exact_model == 'DeiT_T16_Distilled':
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_distilled_patch16_224")
    else:
      print("image size wrong")

  elif exact_model == 'CaiT_S24':
    if center_crop == 224:
      model = timm.create_model('cait_s24_224', pretrained=True)
      print("Using model cait_s24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_s24_384', pretrained=True)
      print("Using model cait_s24_384")
    else:
      print("image size wrong")

  elif exact_model == 'CaiT_XXS24':
    if center_crop == 224:
      model = timm.create_model('cait_xxs24_224', pretrained=True)
      print("Using model cait_xxs24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_xxs24_384', pretrained=True)
      print("Using model cait_xxs24_384")
    else:
      print("image size wrong")

  elif exact_model == 'Swin-T_B':
    if center_crop == 224:
      model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
      print("Using model swin_base_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_base_patch4_window7_384', pretrained=True)
      print("Using model swin_base_patch4_window7_384")
    else:
      print("image size wrong")

  elif exact_model == 'Swin-T_L':
    if center_crop == 224:
      model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
      print("Using model swin_large_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_large_patch4_window7_384', pretrained=True)
      print("Using model swin_large_patch4_window7_384")
    else:
      print("image size wrong")

  elif exact_model == 'Swin-T_S':
    if center_crop == 224:
      model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)
      print("Using model swin_small_patch4_window7_224")
    else:
      print("image size wrong")

  elif exact_model == 'Swin-T_T':
    if center_crop == 224:
      model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
      print("Using model swin_tiny_patch4_window7_224")
    else:
      print("image size wrong")

  elif exact_model == 'mixer_b16':
    if center_crop == 224:
      model = timm.create_model('mixer_b16_224', pretrained=True)
      print("Using model mixer_b16_224")
    else:
      print("image size wrong")

  elif exact_model == 'mixer_l16':
    if center_crop == 224:
      model = timm.create_model('mixer_l16_224', pretrained=True)
      print("Using model mixer_l16_224")
    else:
      print("image size wrong")

  elif exact_model == 'AlexNet':
    if center_crop == 224:
      model = models.alexnet(pretrained=True)
      print("Using model AlexNet")
    else:
      print("image size wrong")

  elif exact_model == 'ResNet50':
    if center_crop == 224:
      model = models.resnet50(pretrained=True)
      print("Using model ResNet50")
    else:
      print("image size wrong")

  elif exact_model == 'GoogLeNet':
    if center_crop == 224:
      model = models.googlenet(pretrained=True)
      print("Using model GoogLeNet")
    else:
      print("image size wrong")

  elif exact_model == 'VGG16':
    if center_crop == 224:
      model = models.vgg16(pretrained=True)
      print("Using model VGG16")
    else:
      print("image size wrong")

  model = model.to('cuda:0')
  return model

### Loop through all models

In [None]:
modellist = ["ViT_B16", "ViT_L16", "DeiT_B16", "DeiT_B16_Distilled", "DeiT_S16", "DeiT_S16_Distilled", "DeiT_T16", "DeiT_T16_Distilled", "CaiT_S24", "CaiT_XXS24", "Swin-T_B", "Swin-T_L", "Swin-T_S", "Swin-T_T", "mixer_b16", "mixer_l16", "ResNet50", "AlexNet", "VGG16", "GoogLeNet"]

for modelname in modellist:
  model = model_choice(modelname)

  print("SEVERITY of " + str(severity))
  print("CLASS: " + str(dropdown) + " - " + str(subclass))
  evaluation(model, dataloaders_dict, key)