In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import AutoModelForImageClassification, AutoImageProcessor, AutoConfig
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEncoder, ViTConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling

import sys
from loguru import logger
from typing import Optional,Union, Tuple
import os
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

current_file_name = os.path.splitext(os.path.basename(os.getcwd()))[0]
logger.remove() #remove default
logger.add(f"{current_file_name}.html", format="<b>{time}</b> {time:YYYY-MM-DD HH:mm} | {level} | {message}", mode="a")

2025-03-06 12:09:38.729412: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-06 12:09:38.745898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741243178.765565 2454497 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741243178.771471 2454497 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-06 12:09:38.792081: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

1

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load CIFAR100 dataset
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

seed_val = 42
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

# Create data loaders
batch_size  = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

sim_threshold = 0.95
mlp_threshold = 0.95

: 

In [None]:

class NeuralNet(nn.Module):
    def __init__(self, arr,use_batch_norm=True):
        super().__init__()
        layers = []
        for i in range(len(arr) - 2):
            layers.append(nn.Linear(arr[i], arr[i+1]))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(arr[i+1]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(arr[-2],arr[-1]));
        layers.append(nn.Sigmoid())
        self.size = len(arr);
        self.model = nn.Sequential(*layers)
        self.criterion =  nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
    
    def forward(self, x):
        return self.model(x)
    
    def train(self, inputs, targets):
        self.optimizer.zero_grad()
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)
        loss.backward()
        self.optimizer.step()
        return loss.item()



: 

In [None]:

class ModifyLayer(ViTLayer):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embedding_size = config.hidden_size
        self.mlp = NeuralNet([2 * self.embedding_size, 64, 1])

    def forward(
        self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        original: Optional[torch.Tensor] = None,  # custom parameter
        train_mlp : bool = False
        ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:

        outputs = ()

        cls_tokens = hidden_states[:, 0, :].unsqueeze(1) #[2, 1, 768]          
        non_cls_tokens = hidden_states[:, 1:, :]        # [2, 196, 768]
        cls_expanded = cls_tokens.expand(-1, non_cls_tokens.size(1), -1)  # Now shape: [2, 196, 768]
        mlp_inputs = torch.cat((cls_expanded, non_cls_tokens), dim=-1) # Result shape: [2, 196, 1536]
        mlp_inputs = mlp_inputs.reshape(-1,2*self.embedding_size) # result  shape: [2 * 196, 1536]
        attention_mask = None

        if train_mlp:
            # 1) get the original outputs for netx layer
            with torch.no_grad():
                original_output = super().forward(original)
                original = original_output[0]

            # 2) getting similarity score
            num_tokens = len(original[0])
            similarity = F.cosine_similarity(original.reshape(batch_size*num_tokens,-1),hidden_states.reshape(batch_size*num_tokens,-1), dim=1)
            similarity = similarity.reshape(batch_size,num_tokens)
            attention_mask = similarity < sim_threshold #1 need to be used in transformer
            attention_mask[:,0] = True
        
            # 3) mlp
            mlp_loss = self.mlp.train(mlp_inputs.clone().detach(),attention_mask[:,1:].reshape(-1).unsqueeze(-1).float().clone().detach()) #att_mask shape: [2 * 196, 1]
        else:
            attention_mask = self.mlp(mlp_inputs) < mlp_threshold #1 need to be used in transformer
            attention_mask[:,0] = True
            
 
        for i in range(attention_mask.size(0)):
            trimmed_input = hidden_states[i][attention_mask[i] == 1]
            layer_output = super().forward(trimmed_input.unsqueeze(0)) # expect 3d tensor
            hidden_states[i][attention_mask[i] == 1] = layer_output[0]
        

        outputs = (hidden_states,) + outputs
        return outputs, original


: 

In [None]:

class ModifyEncoder(ViTEncoder):
    def __init__(self,config):
        super().__init__(config)
        self.layer = nn.ModuleList()
        for i in range(config.num_hidden_layers):
            self.layer.add_module(f'layer{i}',ModifyLayer(config))

    def forward(self,
        hidden_states: torch.Tensor,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
        output_mask: bool = False,
        train_mlp : bool = False,
        ) -> Union[tuple, BaseModelOutput]:

        # optional to necessary
        all_hidden_states = () if output_hidden_states else None
        all_boolean_mask = () if output_mask else None
        all_self_attentions = () if output_attentions else None

        original = hidden_states.clone().detach().to(device)  # make deep copy
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    layer_module.__call__,
                    hidden_states,
                    layer_head_mask,
                    output_attentions,
                )
            else:
                layer_outputs,original= layer_module(hidden_states, layer_head_mask, output_attentions,original,train_mlp)
                print("loop done")

            hidden_states = layer_outputs[0]
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
            if output_mask:
                all_boolean_mask = all_boolean_mask + (layer_outputs[1],)
        
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)
        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        ), all_boolean_mask



: 

In [None]:

# https://huggingface.co/docs/transformers/en/model_doc/vit#transformers.ViTModel (for more info on parameters)

class Modifymodel(ViTModel): # return Union[tuple, BaseModelOutputwithpooling]
    def __init__(self,config,processor):
        super().__init__(config)
        self.config = config
        self.encoder = ModifyEncoder(config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.processor = processor

    def forward( self,
        pixel_values: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None, # bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_mask: Optional[bool] = None,
        train_mlp : bool = False,
        ) -> Union[Tuple, BaseModelOutputWithPooling]:

        # converting optinal type to normal type
        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # Prepare head mask if needed 1.0 in head_mask indicate we keep the head and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        print("head_mask",head_mask)
        # start
        # 1) add embeddings after processing
        embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding)
        print('embedding_output',embedding_output.shape)
        encoder_outputs, boolean_masks = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            output_mask=output_mask,
            train_mlp = train_mlp
        )
        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(sequence_output)
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
        
        if not return_dict: # Whether or not to return a ModelOutput instead of a plain tuple.
            head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
            return head_outputs + encoder_outputs[1:]
        else:
            outputs = BaseModelOutputWithPooling(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            )
            logits = self.classifier(outputs['last_hidden_state'][:, 0])
            setattr(outputs, 'logits', logits)
            setattr(outputs, 'boolean_masks', boolean_masks)
            return outputs
        

    def load_weights(self, pretrained_model_name_or_path: str):
        pretrained_model = ViTModel.from_pretrained(pretrained_model_name_or_path)
        self.load_state_dict(pretrained_model.state_dict(), strict=False)


    def train(self, train_loader=train_loader, epochs=1):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        loss_history = []

        for epoch in range(epochs):
            epoch_loss = 0
            for inputs, targets in train_loader:
                processed_inputs = self.processor(inputs, return_tensors="pt")
                pixel_values = processed_inputs["pixel_values"].to(device)
                targets = targets.to(device)

                outputs = self.forward(pixel_values=pixel_values, train_mlp=True).logits
                print(outputs)
                loss = criterion(outputs, targets)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()

            avg_loss = epoch_loss / len(train_loader)
            loss_history.append(avg_loss)
            logger.info(f"<p>Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}</p>")


        plt.plot(loss_history, label='Training Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training Loss Curve')
        plt.legend()
        plot_path = "outputs/training_loss.png"
        plt.savefig(plot_path)
        plt.close()

        logger.info(f'<p>Training loss plot saved at: {plot_path}</p>')
        logger.info(f'<img src="{plot_path}" alt="Training Loss Curve">')

    def test(self, test_loader=test_loader):
        criterion = torch.nn.CrossEntropyLoss()
        total_loss = 0

        with torch.no_grad():
            for inputs, targets in test_loader:
                processed_inputs = self.processor(inputs, return_tensors="pt")
                pixel_values = processed_inputs["pixel_values"].to(device)
                targets = targets.to(device)

                outputs = self.forward(pixel_values=pixel_values, train_mlp=False).logits
                loss = criterion(outputs, targets)
                total_loss += loss.item()

        avg_loss = total_loss / len(test_loader)
        logger.info(f"<p>Test Loss: {avg_loss:.4f}</p>")



: 

In [2]:
import inspect

def get_source_code(obj):
    try:
        return inspect.getsource(obj.__class__)
    except Exception as e:
        return f"Unable to retrieve source code: {str(e)}"

In [3]:
model_name = "tzhao3/DeiT-CIFAR100"
model = AutoModelForImageClassification.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
model

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [5]:
print(get_source_code(model))

@add_start_docstrings(
    """
    ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    """,
    VIT_START_DOCSTRING,
)
class ViTForImageClassification(ViTPreTrainedModel):
    def __init__(self, config: ViTConfig) -> None:
        super().__init__(config)

        self.num_labels = config.num_labels
        self.vit = ViTModel(config, add_pooling_layer=False)

        # Classifier head
        self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

        # Initialize weights and apply final processing
        self

In [None]:

modifymodel = Modifymodel(config,processor)
modifymodel.load_weights(model_name)
modifymodel.to(device)


Some weights of ViTModel were not initialized from the model checkpoint at tzhao3/DeiT-CIFAR100 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Modifymodel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ModifyEncoder(
    (layer): ModuleList(
      (0-11): 12 x ModifyLayer(
        (attention): ViTSdpaAttention(
          (attention): ViTSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn)

In [None]:
!nvidia-smi

Sun Mar  2 15:27:52 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A5000               Off | 00000000:3B:00.0 Off |                  Off |
| 30%   23C    P8              18W / 230W |   7589MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000               Off | 00000000:AF:00.0 Off |  

In [6]:
# Clear all variables with confirmation
%reset