In [1]:
# model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", attn_implementation="eager", torch_dtype=torch.float32)

# Modified model

In [2]:
from transformers.models.vit.modeling_vit import ViTModel, ViTAttention, ViTLayer, ViTEncoder, ViTConfig
from torch import nn
import torch
from typing import Dict, List, Optional, Set, Tuple, Union
from transformers import ViTForImageClassification
import torch.nn.functional as F
import torch.optim as optim

class DHSLayer(ViTLayer):
    def __init__(self, config):
        super().__init__(config)
        self.hidden_size = config.hidden_size
        self.mlp_layer = nn.Sequential(
            nn.Linear(self.hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        self.loss = 0

    def forward(self, hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False):
        keep_cls = True
        threshold = 0.5
        if keep_cls:
            mlp_output = self.mlp_layer(hidden_states[:, 1:])
            boolean_mask = (mlp_output > threshold).squeeze(-1)
            cls_col = torch.ones((hidden_states.shape[0], 1), dtype=torch.bool).to(boolean_mask.device)
            boolean_mask = torch.cat((cls_col, boolean_mask), dim=1)
        else:
            mlp_output = self.mlp_layer(hidden_states)
            boolean_mask = (mlp_output > threshold).squeeze(-1)


        filtered_states = hidden_states[boolean_mask].unsqueeze(0)
        attn_output = super().forward(filtered_states)[0].squeeze(0)

        output = hidden_states.clone()
        output[boolean_mask] = attn_output
        if self.training:
            if keep_cls:
                real_output = super().forward(hidden_states[:, 1:])[0]
                cos_similarity = torch.abs(F.cosine_similarity(real_output, hidden_states[:, 1:], dim=-1))
            else:
                real_output = super().forward(hidden_states)[0]
                cos_similarity = (F.cosine_similarity(real_output, hidden_states, dim=-1) + 1) / 2

            self.loss = nn.MSELoss()(cos_similarity, 1 - mlp_output.squeeze(-1))
        else:
            self.loss = 0
        # return super().forward(hidden_states)
        return (output, )

class ModifiedViTEncoder(ViTEncoder):
    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.layer = nn.ModuleList([DHSLayer(config) for _ in range(config.num_hidden_layers)])

class ModifiedViTModel(ViTModel):
    def __init__(self, config: ViTConfig):
        super().__init__(config)
        self.encoder = ModifiedViTEncoder(config)
        self.classifier = nn.Linear(config.hidden_size, 100)

    def forward(self, pixel_values, output_attentions=False, output_hidden_states=False, return_dict=True):
        outputs = super().forward(pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
        logits = self.classifier(outputs['last_hidden_state'][:, 0])
        output = lambda: None
        setattr(output, 'logits', logits)

        return output

config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
# pretrained_vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
# pretrained_classification_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

modified_vit_model = ModifiedViTModel(config)

# modified_vit_model.load_state_dict(pretrained_vit_model.state_dict(), strict=False)
# modified_vit_model.classifier.weight.data = pretrained_classification_model.classifier.weight.data.clone()
# modified_vit_model.classifier.bias.data = pretrained_classification_model.classifier.bias.data.clone()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [4]:
# from PIL import Image
# import os
# from torchvision.transforms import Compose, Resize, ToTensor, Normalize
# from transformers import ViTFeatureExtractor
# import torch
# import requests


# transform = Compose([
#     Resize((224, 224)),
#     ToTensor(),
#     Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
# ])

# images_list = []
# images_list2 = []
# for image_name in os.listdir('images'):
#     print(image_name)
#     img = Image.open(os.path.join('images', image_name)).convert('RGB')
#     images_list2.append(img)
#     img_tensor = transform(img)
#     images_list.append(img_tensor)

# batch_tensor = torch.stack(images_list)

# # feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# # inputs = feature_extractor(images=images_list2, return_tensors="pt", padding=True)


# with torch.no_grad():
#     # out1 = pretrained_vit_model(batch_tensor)['pooler_output']
#     logits = modified_vit_model(batch_tensor)


# predicted_class_indices = logits.argmax(dim=-1).tolist()

# url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
# response = requests.get(url)
# class_names = response.json()

# for idx, class_index in enumerate(predicted_class_indices):
#     print(f"Predicted class index for image {idx + 1}: {class_index}")
#     print(class_names[class_index])

In [None]:
# plt.imshow(inputs['pixel_values'].squeeze(0).permute(1, 2, 0))
# print(inputs['pixel_values'].max(), inputs['pixel_values'].min())
# plt.show()
# test_dataset[1][0].show()
# plt.imshow(test_dataset[1][0])
# plt.show()

In [5]:
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import matplotlib.pyplot as plt
import torchvision

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cifar_processor = AutoImageProcessor.from_pretrained("Ahmed9275/Vit-Cifar100")
pretrained_cifar_model = AutoModelForImageClassification.from_pretrained("Ahmed9275/Vit-Cifar100")

pretrained_cifar_model.to(device)


class CIFAR100Dataset(Dataset):
    def __init__(self, root, train=True, processor=None):
        self.dataset = torchvision.datasets.CIFAR100(root=root, train=train, download=True)
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        inputs = self.processor(images=image, return_tensors="pt")
        return inputs['pixel_values'].squeeze(0), label

data_path = '/content/data'
train_dataset = CIFAR100Dataset(root=data_path, train=True, processor=cifar_processor)
test_dataset = CIFAR100Dataset(root=data_path, train=False, processor=cifar_processor)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

cuda


preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.68k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/344M [00:00<?, ?B/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /content/data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:12<00:00, 13201391.87it/s]


Extracting /content/data/cifar-100-python.tar.gz to /content/data
Files already downloaded and verified


In [6]:
new_state_dict = {}
for key in pretrained_cifar_model.state_dict().keys():
    new_key = key.replace('vit.', '')
    new_state_dict[new_key] = pretrained_cifar_model.state_dict()[key]


modified_vit_model.load_state_dict(new_state_dict, strict=False)
modified_vit_model.to(device)

def test(model, dataloader):
    model.eval()
    total_correct = 0
    with torch.no_grad():
        for (inputs, labels) in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            logits = outputs.logits

            predicted_class_indices = logits.argmax(dim=-1)

            total_correct += torch.sum(predicted_class_indices.to('cpu') == labels).item()

    return total_correct / len(dataloader.dataset)

In [None]:
def train(model, train_loader, test_loader, num_epochs=10, lr=1e-4):
    model.train()
    criterion = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            logits = model(inputs).logits

            classification_loss = criterion(logits, labels)

            cosine_loss = 0.0
            for i in range(len(model.encoder.layer)):
                vit_layer = model.encoder.layer[i]
                cosine_loss += vit_layer.loss

            total_loss = classification_loss + 0.1 * cosine_loss

            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()
        print(running_loss / len(train_loader))
        print(test(model, test_loader))


for param in modified_vit_model.parameters():
    param.requires_grad = False

for layer in modified_vit_model.encoder.layer:
    for param in layer.mlp_layer.parameters():
        param.requires_grad = True

modified_vit_model.to(device)

num_epochs = 10
lr = 1e-4

print(test(modified_vit_model, test_loader))
train(modified_vit_model, train_loader, test_loader, num_epochs, lr)




0.0104


# pretrained model accuracy 89.85