In [1]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

from transformers import ViTForImageClassification, ViTFeatureExtractor
import torchvision.transforms as transforms
import numpy as np
from copy import deepcopy

2024-01-24 21:44:33.727399: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-24 21:44:33.727455: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-24 21:44:33.729287: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# cifarx10 vitb

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit_model = 'nateraw/vit-base-patch16-224-cifar10'
model = ViTForImageClassification.from_pretrained(vit_model, output_hidden_states=True).to(device)
model.eval()
feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model)
to_tensor = transforms.ToTensor()



In [3]:
# Load the CIFAR-10 dataset
dataset = load_dataset('cifar10', split='test')

class CIFAR10HFDataset(Dataset):
    def __init__(self, hf_dataset, feature_extractor):
        self.hf_dataset = hf_dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        # Extract image and label from the dataset
        image = self.hf_dataset[idx]['img']
        label = self.hf_dataset[idx]['label']

        # Preprocess the image
        inputs = self.feature_extractor(images=image, return_tensors='pt')
        image = to_tensor(image)

        # Remove the batch dimension
        pixel_values = inputs['pixel_values'].squeeze()

        return image, pixel_values, label

cifar_data = CIFAR10HFDataset(dataset, feature_extractor)
dataloader = DataLoader(cifar_data, batch_size=32, shuffle=False)

In [4]:

# Create empty lists to store the predicted labels and true labels
predicted_labels = []
true_labels = []

# Iterate over the batches in the DataLoader
for i, (images, inputs, labels) in enumerate(dataloader):
    inputs = inputs.to(device)
    labels = labels.to(device)
    # Forward pass through the model
    with torch.no_grad():
        outputs = model(inputs, labels=labels)

    # Get the predicted labels from the model outputs
    _, predicted = torch.max(outputs.logits, dim=1)

    # Append the predicted labels and true labels to the lists
    predicted_labels.extend(predicted.cpu().numpy())
    true_labels.extend(labels.cpu().numpy())

# Convert the lists to numpy arrays
predicted_labels = np.array(predicted_labels)
true_labels = np.array(true_labels)

# Print the accuracy of the model
accuracy = np.mean(predicted_labels == true_labels)
print(f"Accuracy: {accuracy}")


Accuracy: 0.9852


In [5]:
# front layers: 0, 1, 2, 3, 4
# mid layers: 5, 6, 7, 8, 9
# back layers: 10, 11, 12

## Layer ablation

In [6]:
# exclude layer 
def layer_ablation(model, num_layers: List or Int): 
    new_model = deepcopy(model)
    new_ModuleList = nn.ModuleList()
    
    if isinstance(num_layers, int):
        num_layers = [num_layers]
    
    for i in range(0,12):
        if i not in num_layers:
            new_ModuleList.append(new_model.vit.encoder.layer[i])

    new_model.vit.encoder.layer = new_ModuleList
    
    new_model.eval()

    # Create empty lists to store the predicted labels and true labels
    predicted_labels = []
    true_labels = []

    # Iterate over the batches in the DataLoader
    for i, (images, inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Forward pass through the model
        with torch.no_grad():
            outputs = new_model(inputs, labels=labels)

        # Get the predicted labels from the model outputs
        _, predicted = torch.max(outputs.logits, dim=1)

        # Append the predicted labels and true labels to the lists
        predicted_labels.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

    # Convert the lists to numpy arrays
    predicted_labels = np.array(predicted_labels)
    true_labels = np.array(true_labels)

    # Print the accuracy of the model
    accuracy = np.mean(predicted_labels == true_labels)
    print(f"Accuracy: {accuracy}")
    return accuracy

In [7]:
for i in range(0,12):
    print(f'Ablation layer {i}')
    layer_ablation(model, i)

Ablation layer 0
Accuracy: 0.1093
Ablation layer 1
Accuracy: 0.8456
Ablation layer 2
Accuracy: 0.9522
Ablation layer 3
Accuracy: 0.9584
Ablation layer 4
Accuracy: 0.9672
Ablation layer 5
Accuracy: 0.9689
Ablation layer 6
Accuracy: 0.9747
Ablation layer 7
Accuracy: 0.973
Ablation layer 8
Accuracy: 0.968
Ablation layer 9
Accuracy: 0.9581
Ablation layer 10
Accuracy: 0.9799
Ablation layer 11
Accuracy: 0.9629


In [8]:
print('exclude layer 2,3')
layer_ablation(model, [2,3])

print('exclude layer 3,4')
layer_ablation(model, [3,4])

exclude layer 2,3


Accuracy: 0.7193
exclude layer 3,4
Accuracy: 0.8375


0.8375

In [9]:
print('exclude layer 7,8')
layer_ablation(model, [7,8])

print('exclude layer 8,9')
layer_ablation(model, [8,9])

print('exclude layer 9,10')
layer_ablation(model, [9,10])

exclude layer 7,8
Accuracy: 0.9011
exclude layer 8,9
Accuracy: 0.7105
exclude layer 9,10
Accuracy: 0.8915


0.8915