<a href="https://colab.research.google.com/github/katelyn98/CorruptionRobustness/blob/main/CalculateShapeBias.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install & import required libraries / repos

In [None]:
!pip3 install pytorch_pretrained_vit
!pip3 install timm
!pip3 install transformers

from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
from pathlib import Path
import os
from skimage import io
from PIL import Image
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import PIL
from pytorch_pretrained_vit import ViT
import timm
from torch.utils import model_zoo

In [None]:
!git clone https://github.com/rgeirhos/texture-vs-shape.git

In [None]:
!mv texture-vs-shape/code/probabilities_to_decision.py ./ && mv texture-vs-shape/code/* ./

In [None]:
import probabilities_to_decision

In [None]:
STIMULI = "texture-vs-shape/stimuli/style-transfer-preprocessed-512/"

## Transformation Constant

In [None]:
center_crop = 224
preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(center_crop),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  ])

## Function to calculate shape bias

In [None]:
#get the images collectively
def calculate_shape_bias(dir, preprocess, model):
  root_dir = dir
  images = []
  labels = []

  shape = 0
  texture = 0
  total = 0

  model.eval()

  for label in os.listdir(root_dir): #for every folder in the directory
    data_dir = Path(root_dir) / label #go into that folder
    data_files = data_dir.glob('*.png') #gather the images by .png name
    
    for image in data_files: #for every image in the folder
        images.append(image) #add the image path to the list
        labels.append(label) #add the folder name to the list of labels

        shape_type = label 

        #get texture type from file name
        types = str(image).split('/')
        types = types[4]
        typenum = types.split('-')
        typenum = typenum[1]
        texture_type = typenum.split('.')
        texture_type = texture_type[0]
        texture_type = texture_type[:-1]

        input = Image.open(image)
        input_tensor = preprocess(input)
        input_batch = input_tensor.unsqueeze(0)
        input_batch = input_batch.to('cuda')

        with torch.no_grad():

          output_probs = model(input_batch)

          ##############################################
          ## Code from Robert Geirhos: https://github.com/rgeirhos/texture-vs-shape#code ##

          softmax_output = torch.nn.functional.softmax(output_probs[0], dim=0)

          # convert to numpy
          softmax_output_numpy = softmax_output.cpu().numpy() # replace with conversion

          # create mapping
          mapping = probabilities_to_decision.ImageNetProbabilitiesTo16ClassesMapping()
          
          # obtain decision 
          decision_from_16_classes = mapping.probabilities_to_decision(softmax_output_numpy)
          
          ##############################################

          if decision_from_16_classes == shape_type:
            shape += 1
          
          if decision_from_16_classes == texture_type:
            texture += 1

          total += 1

  print("SHAPE CORRECT TOTAL")
  print(shape)

  print("TEXTURE CORRECT TOTAL")
  print(texture)

  print("TOTAL IMAGES")
  print(total)

  print("SHAPE BIAS")
  print(shape / (shape + texture))

## Function to calculate the number of parameters of model

In [None]:
def num_params(model):
  return sum(p.numel() for p in model.parameters())

## Choose model to calculate shape bias

In [None]:
#@title Choose your model to evaluate

model_choice = "ViT" #@param ["ViT", "CaiT", "DeiT", "Swin-T", "MLP-Mixer", "ResNet50", "AlexNet", "VGG", "GoogLeNet"]

exact_model = "ViT_B16" #@param ["ViT_B16", "ViT_L16", "DeiT_B16", "DeiT_B16_Distilled", "DeiT_S16", "DeiT_S16_Distilled", "DeiT_T16", "DeiT_T16_Distilled", 'CaiT_S24', 'CaiT_XXS24', 'Swin-T_B', ''Swin-T_L', 'Swin-T_S', 'Swin-T_T', 'mixer_b16', 'mixer_l16']

if exact_model == 'ViT_B16':
  if model_choice != "ViT":
    print("please set model_choice to 'ViT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_base_patch16_224', pretrained=True)
      print("Using model vit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_base_patch16_384', pretrained=True)
      print("Using model vit_base_patch16_384")
    else:
      print("image size wrong")

elif exact_model == 'ViT_L16':
  if model_choice != "ViT":
    print("please set model_choice to 'ViT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_large_patch16_224', pretrained=True)
      print("Using model vit_large_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_large_patch16_384', pretrained=True)
      print("Using model vit_large_patch16_384")
    else:
      print("image size wrong")

elif exact_model == 'DeiT_B16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_patch16_224', pretrained=True)
      print("Using model vit_deit_base_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_patch16_384', pretrained=True)
      print("Using model vit_deit_base_patch16_384")
    else:
      print("image size wrong")
elif exact_model == "DeiT_B16_Distilled":
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_224")
    elif center_crop == 384:
      model = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=True)
      print("Using model vit_deit_base_distilled_patch16_384")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_S16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_patch16_224', pretrained = True)
      print("Using model vit_deit_small_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_S16_Distilled':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained = True)
      print("Using model vit_deit_small_distilled_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_T16':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_patch16_224")
    else:
      print("image size wrong")
elif exact_model == 'DeiT_T16_Distilled':
  if model_choice != "DeiT":
    print("please set model_choice to 'DeiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=True)
      print("Using model vit_deit_tiny_distilled_patch16_224")
    else:
      print("image size wrong")

elif exact_model == 'CaiT_S24':
  if model_choice != "CaiT":
    print("please set model_choice to 'CaiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('cait_s24_224', pretrained=True)
      print("Using model cait_s24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_s24_384', pretrained=True)
      print("Using model cait_s24_384")
    else:
      print("image size wrong")

elif exact_model == 'CaiT_XXS24':
  if model_choice != "CaiT":
    print("please set model_choice to 'CaiT'")
  else:
    if center_crop == 224:
      model = timm.create_model('cait_xxs24_224', pretrained=True)
      print("Using model cait_xxs24_224")
    elif center_crop == 384:
      model = timm.create_model('cait_xxs24_384', pretrained=True)
      print("Using model cait_xxs24_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_B':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
      print("Using model swin_base_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_base_patch4_window7_384', pretrained=True)
      print("Using model swin_base_patch4_window7_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_L':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
      print("Using model swin_large_patch4_window7_224")
    elif center_crop == 384:
      model = timm.create_model('swin_large_patch4_window7_384', pretrained=True)
      print("Using model swin_large_patch4_window7_384")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_S':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)
      print("Using model swin_small_patch4_window7_224")
    else:
      print("image size wrong")

elif exact_model == 'Swin-T_T':
  if model_choice != "Swin-T":
    print("please set model_choice to 'Swin-T'")
  else:
    if center_crop == 224:
      model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
      print("Using model swin_tiny_patch4_window7_224")
    else:
      print("image size wrong")

elif exact_model == 'mixer_b16':
  if model_choice != "MLP-Mixer":
    print("please set model_choice to 'MLP-Mixer'")
  else:
    if center_crop == 224:
      model = timm.create_model('mixer_b16_224', pretrained=True)
      print("Using model mixer_b16_224")
    else:
      print("image size wrong")

elif exact_model == 'mixer_l16':
  if model_choice != "MLP-Mixer":
    print("please set model_choice to 'MLP-Mixer'")
  else:
    if center_crop == 224:
      model = timm.create_model('mixer_l16_224', pretrained=True)
      print("Using model mixer_l16_224")
    else:
      print("image size wrong")

model = model.to('cuda:0')


In [None]:
#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

# Vision Transformers

## Data-efficient Image Transformer ([DeiT](https://github.com/facebookresearch/deit?fbclid=IwAR2qzERDHwVdKSlah1v7MCsqp15EigeAjZbYp1F_YHm3ZR2-Bxkcejmq5r0))

In [None]:
# # pre-trained model from https://github.com/rwightman/pytorch-image-models
# model = timm.create_model('vit_deit_base_patch16_224', pretrained=True)
# model = model.to('cuda')

# #calculate the number of parameters
# print("NUM PARAMS")
# print(num_params(model)

# calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_distilled_patch16_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_small_patch16_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_small_distilled_patch16_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#deit model
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_224', pretrained=True) # base_distilled
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## Class-Attention in Image Transformers ([CaiT](https://github.com/facebookresearch/deit/blob/main/README_cait.md))

In [None]:
#cait model
model = timm.create_model('cait_s24_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#cait model
model = timm.create_model('cait_xxs24_224', pretrained=True)
model = model.to('cuda:0')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## Vision Transformer ([ViT](https://github.com/google-research/vision_transformer))

In [None]:
# pre-trained model from https://github.com/rwightman/pytorch-image-models
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
# pre-trained model from https://github.com/rwightman/pytorch-image-models
model = timm.create_model('vit_large_patch16_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
# pre-trained model from https://github.com/rwightman/pytorch-image-models
model = timm.create_model('vit_small_patch16_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## Swin Transformer ([Swin-T](https://arxiv.org/abs/2103.14030)

In [None]:
#pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
# #pre-trained model from # https://github.com/rwightman/pytorch-image-models

# model = timm.create_model('swin_base_patch4_window12_384', pretrained=True)
# model = model.to('cuda')

# #calculate the number of parameters
# print("NUM PARAMS")
# print(num_params(model))

# preprocess = transforms.Compose([
#       transforms.Resize(384),
#       transforms.CenterCrop(384),
#       transforms.ToTensor(),
#       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#   ])

# calculate_shape_bias(STIMULI, preprocess, model)


In [None]:
#pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

# MLP-mixer ([Dosovitskiy, et. al., 2021](https://arxiv.org/abs/2105.01601))

In [None]:
!pip3 install timm
import timm
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)

In [None]:
#MLP-mixer pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('mixer_b16_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
#MLP-mixer pre-trained model from # https://github.com/rwightman/pytorch-image-models

model = timm.create_model('mixer_l16_224', pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

# Convolutional Neural Networks

## ResNet50

In [None]:
#ResNet50
model = models.resnet50(pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
# #pre-trained model from # https://github.com/rwightman/pytorch-image-models

# model = timm.create_model('resnet50', pretrained=True)
# model = model.to('cuda')

# #calculate the number of parameters
# print("NUM PARAMS")
# print(num_params(model))

# calculate_shape_bias(STIMULI, preprocess, model)

## AlexNet

In [None]:
#calculate shape bias of AlexNet
model = models.alexnet(pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## ResNet50 trained on Stylized-ImageNet ([Geirhos, et. al., 2019](https://github.com/rgeirhos/texture-vs-shape))

In [None]:
#SIN
model = torchvision.models.resnet50(pretrained=False)
model = torch.nn.DataParallel(model).cuda()
checkpoint = model_zoo.load_url('https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/6f41d2e86fc60566f78de64ecff35cc61eb6436f/resnet50_train_60_epochs-c8e5653e.pth.tar')
model.load_state_dict(checkpoint["state_dict"])

model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## ResNet50 trained on Stylized-ImageNet and ImageNet ([Geirhos, et. al., 2019](https://github.com/rgeirhos/texture-vs-shape))

In [None]:
#SIN + IN
model = torchvision.models.resnet50(pretrained=False)
model = torch.nn.DataParallel(model).cuda()
checkpoint = model_zoo.load_url('https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_train_45_epochs_combined_IN_SF-2a0d100e.pth.tar')
model.load_state_dict(checkpoint["state_dict"])

model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## AlexNet trained on Stylized ImageNet ([Geirhos, et. al., 2019](https://github.com/rgeirhos/texture-vs-shape))

In [None]:
# AlexNet on SIN
model = torchvision.models.alexnet(pretrained=False)
model.features = torch.nn.DataParallel(model.features)
model.cuda()
checkpoint = model_zoo.load_url('https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/0008049cd10f74a944c6d5e90d4639927f8620ae/alexnet_train_60_epochs_lr0.001-b4aa5238.pth.tar')
model.load_state_dict(checkpoint["state_dict"])

model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## VGG

In [None]:
#VGG
model = models.vgg16(pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

## GoogLeNet

In [None]:
#GoogLeNet
model = models.googlenet(pretrained=True)
model = model.to('cuda')

#calculate the number of parameters
print("NUM PARAMS")
print(num_params(model))

calculate_shape_bias(STIMULI, preprocess, model)

In [None]:
!ls -l /content/texture-vs-shape/stimuli/style-transfer-preprocessed-512/car | wc -l

81
