In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTModel, ViTConfig, ViTForImageClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


cache_dir = "/home/egeberk/ceng502/"
num_classes = 10
model_name = 'google/vit-base-patch32-224-in21k'

config = ViTConfig.from_pretrained(model_name, num_labels=num_classes, cache_dir=cache_dir)
model = ViTForImageClassification.from_pretrained(model_name, config=config, cache_dir=cache_dir)


2023-05-06 13:37:12.044406: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-06 13:37:12.692283: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2023-05-06 13:37:12.692933: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
Some weights of the model checkpoint at google/vit-base-patch32-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.

In [2]:
from tqdm.notebook import tqdm

In [3]:
model = model.to(device)

In [4]:
# Freeze all layers except the head
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

In [5]:
# Load and preprocess the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [6]:
learning_rate = 3e-4
batch_size = 124
num_epochs = 10 

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Set the optimizer and loss function
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified


In [7]:
def train(modelVit):
    # Train the model
    
    train_loss = 0.0
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):

        data = data.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = modelVit(data)
        try:
            loss = criterion(outputs.logits, targets)
        except:
            loss = criterion(outputs, targets)
        loss_val = loss.detach().cpu().item()
        train_loss += loss_val
        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent step
        optimizer.step()

        
    return train_loss/len(train_loader)

In [10]:
train(model)

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [1/10], Loss: 1.0738051921719372


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [2/10], Loss: 0.43248644522806207


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [3/10], Loss: 0.33499665012454044


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [4/10], Loss: 0.2948971911158302


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [5/10], Loss: 0.2710435390287992


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [6/10], Loss: 0.2558640625586014


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [7/10], Loss: 0.24442257984808766


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [8/10], Loss: 0.23512031082617174


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [9/10], Loss: 0.22762656370454495


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [10/10], Loss: 0.22125080657551194
Training complete!


In [8]:
valid_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [9]:
def valid(modelVit):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(tqdm(valid_loader)):
            data = data.to(device)
            targets = targets.to(device)

            # calculate outputs by running images through the network
            outputs = modelVit(data)
            # the class with the highest energy is what we choose as prediction
            try:
                _, predicted = torch.max(outputs.logits, 1)
            except:
                _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    return 100 * correct // total

    

In [11]:
valid(model)

  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 92 %


# Class + head fine tune


In [3]:
model = model.to(device)

# Freeze all layers except the head
for name, param in model.named_parameters():
    if ("classifier" not in name) and ("cls_token" not in name) :
        param.requires_grad = False
    else:
        print(name)
        
        
# Load and preprocess the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


learning_rate = 3e-4
batch_size = 124
num_epochs = 10 

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Set the optimizer and loss function
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
train(model)

vit.embeddings.cls_token
classifier.weight
classifier.bias
Files already downloaded and verified


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [1/10], Loss: 1.0607944261467104


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [2/10], Loss: 0.3625623905053823


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [3/10], Loss: 0.2548982919664076


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [4/10], Loss: 0.2127182101185369


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [5/10], Loss: 0.1896207969598841


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [6/10], Loss: 0.17512742105391946


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [7/10], Loss: 0.1644193258800424


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [8/10], Loss: 0.1563596963181649


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [9/10], Loss: 0.14986024178232593


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [10/10], Loss: 0.14483140294652175
Training complete!


In [4]:
valid_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [5]:
valid(model)

  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 94 %


# Memory Token

In [10]:
from typing import Dict, List, Optional, Set, Tuple, Union
import math

In [11]:
class CustomViTLayerForMemory(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTConfig, attention, intermediate, output, layernorm_b, layernorm_a, memory_token) -> None:
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = attention #ViTAttention(config)
        self.intermediate = intermediate #ViTIntermediate(config)
        self.output = output #ViTOutput(config)
        self.layernorm_before = layernorm_b #nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = layernorm_a #nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        self.memory_token = memory_token # torch.nn.Parameter(torch.randn(memory_token_length, embed_dim)) * 0.02

    def forward(
        self,
        hidden_states: torch.Tensor,
        #memory_token: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        
        original_states = hidden_states
        
        hidden_states = torch.cat((hidden_states, self.memory_token.expand(hidden_states.shape[0], -1, -1)), dim=1)
        
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        
        attention_output = self_attention_outputs[0][:,:-self.memory_token.shape[0]] #only input attends
        """
        else:
            attention_output = self_attention_outputs[0]
        """
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection
        hidden_states = attention_output + original_states

        # in ViT, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # second residual connection is done here
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs


In [12]:
class PaperModel(nn.Module):
    def __init__(self, config, vit, embed_dim = 768, memory_token_length = 10):
        super(PaperModel, self).__init__()
        self.memory_token_length = memory_token_length
        self.embed_dim = embed_dim
        self.model = vit
        self.memory_token = nn.ParameterList([torch.nn.Parameter(torch.randn(self.memory_token_length, self.embed_dim) * 0.02)
                                           for _ in range(len(self.model.vit.encoder.layer))])
        for index, layer_ in enumerate(self.model.vit.encoder.layer):
            self.model.vit.encoder.layer[index] = CustomViTLayerForMemory(config, 
                                          layer_.attention,
                                          layer_.intermediate,
                                          layer_.output,
                                          layer_.layernorm_before,
                                          layer_.layernorm_after,
                                         self.memory_token[index])
            
            
        
        #self.memory_token = torch.nn.Parameter(
        #    torch.randn(self.memory_token_length, self.embed_dim)
        #) #memory token
        
        
    def forward(self, x):
        hidden_states = self.model.vit.embeddings(x)
        
        for layer in self.model.vit.encoder.layer:
            hidden_states = layer(hidden_states)[0]#layer(hidden_states, self.memory_token)[0]            
        
        normalized = self.model.vit.layernorm(hidden_states)
        logits = self.model.classifier(normalized[:, 0, :])
        return logits

In [13]:
pm = PaperModel(config, model).to(device)

In [14]:
#memory + cls + head
for name, param in pm.named_parameters():
    if ("classifier" not in name) and ("cls_token" not in name) and ("memory_token" not in name):
        param.requires_grad = False
    else:
        param.requires_grad = True
        print(name)

model.vit.embeddings.cls_token
model.vit.encoder.layer.0.memory_token
model.vit.encoder.layer.1.memory_token
model.vit.encoder.layer.2.memory_token
model.vit.encoder.layer.3.memory_token
model.vit.encoder.layer.4.memory_token
model.vit.encoder.layer.5.memory_token
model.vit.encoder.layer.6.memory_token
model.vit.encoder.layer.7.memory_token
model.vit.encoder.layer.8.memory_token
model.vit.encoder.layer.9.memory_token
model.vit.encoder.layer.10.memory_token
model.vit.encoder.layer.11.memory_token
model.classifier.weight
model.classifier.bias


In [16]:
# Set the optimizer and loss function
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pm.parameters()), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
num_epochs = 300
train_losses = []
acc_vals = []
for epoch in range(num_epochs):
    train_loss = train(pm)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss}")
    train_losses.append(train_loss)
    
    acc_val = valid(pm)
    
    print(f'Accuracy of the network on the 10000 test images: {acc_val} %')
    acc_vals.append(acc_val)

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [1/300], Loss: 0.4820975792216192


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 96 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [2/300], Loss: 0.09054598888834145


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [3/300], Loss: 0.06677886168472469


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [4/300], Loss: 0.05176729497057155


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [5/300], Loss: 0.045392660817735374


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [6/300], Loss: 0.03684841303009595


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [7/300], Loss: 0.030220185592655314


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [8/300], Loss: 0.02457816515574324


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [9/300], Loss: 0.021791133649836265


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [10/300], Loss: 0.017969800122306995


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [11/300], Loss: 0.014483976960216768


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [12/300], Loss: 0.01172283375756769


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [13/300], Loss: 0.010527419917761478


  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 97 %


  0%|          | 0/404 [00:00<?, ?it/s]

In [19]:
train_losses

[2.0987822932772118,
 1.7861921536450338,
 1.5670567500119161,
 1.4042141891942166,
 1.2790964316613604,
 1.180181681077079,
 1.1007926406541673,
 1.0349034649310727,
 0.9799850758644614,
 0.9323421806687175,
 0.89216248809111,
 0.8562249011627519,
 0.8251895919294641,
 0.797393385431554,
 0.7721824296335182,
 0.7494477324261524,
 0.7291482809451547,
 0.7103912710848421,
 0.6936880566991201,
 0.6776619774870353,
 0.6631273078446341,
 0.6502166125591439,
 0.638310113608247,
 0.6269282804857387,
 0.615705787840456,
 0.6054490137808394,
 0.5963591269337305,
 0.5878394600926059,
 0.5788741146426389,
 0.5711343225718725,
 0.5637091265456511,
 0.5570303041598584,
 0.5499377931551178,
 0.5435361485965181,
 0.5371362156059483,
 0.5313240162069255,
 0.5258289616886932,
 0.5205286705110332,
 0.5159834059010638,
 0.5109650904294287,
 0.5059272140116975,
 0.5013424431038375,
 0.49807981335290585,
 0.49250793659893594,
 0.48881919158272225,
 0.4847982216737058,
 0.48116562087642084,
 0.477603739514

In [20]:
acc_vals

[62,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 76,
 77,
 77,
 78,
 78,
 79,
 79,
 79,
 80,
 80,
 81,
 81,
 81,
 81,
 82,
 82,
 82,
 82,
 82,
 83,
 83,
 83,
 83,
 83,
 84,
 84,
 84,
 84,
 84,
 84,
 84,
 84,
 84,
 84,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 85,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 86,
 87,
 86,
 86,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87,
 87]

In [24]:
for epoch in range(num_epochs):
    train_loss = train(pm)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss}")
    train_losses.append(train_loss)
    
    acc_val = valid(pm)
    
    print(f'Accuracy of the network on the 10000 test images: {acc_val} %')
    acc_vals.append(acc_val)

  0%|          | 0/404 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 74.00 MiB (GPU 0; 7.79 GiB total capacity; 5.03 GiB already allocated; 23.12 MiB free; 5.48 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

No memory token

In [11]:
class CustomViTLayerForMemory(nn.Module):
    """This corresponds to the Block class in the timm implementation."""

    def __init__(self, config: ViTConfig, attention, intermediate, output, layernorm_b, layernorm_a)-> None:#, memory_token) -> None:
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = attention #ViTAttention(config)
        self.intermediate = intermediate #ViTIntermediate(config)
        self.output = output #ViTOutput(config)
        self.layernorm_before = layernorm_b #nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layernorm_after = layernorm_a #nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        #self.memory_token = memory_token # torch.nn.Parameter(torch.randn(memory_token_length, embed_dim)) * 0.02

    def forward(
        self,
        hidden_states: torch.Tensor,
        #memory_token: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
        
        original_states = hidden_states
        
        #hidden_states = torch.cat((hidden_states, self.memory_token.expand(hidden_states.shape[0], -1, -1)), dim=1)
        
        self_attention_outputs = self.attention(
            self.layernorm_before(hidden_states),  # in ViT, layernorm is applied before self-attention
            head_mask,
            output_attentions=output_attentions,
        )
        
        attention_output = self_attention_outputs[0]#[:,:-self.memory_token.shape[0]] #only input attends
        """
        else:
            attention_output = self_attention_outputs[0]
        """
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # first residual connection
        hidden_states = attention_output + original_states

        # in ViT, layernorm is also applied after self-attention
        layer_output = self.layernorm_after(hidden_states)
        layer_output = self.intermediate(layer_output)

        # second residual connection is done here
        layer_output = self.output(layer_output, hidden_states)

        outputs = (layer_output,) + outputs

        return outputs




class PaperModel(nn.Module):
    def __init__(self, config, vit, embed_dim = 768):#, memory_token_length = 10):
        super(PaperModel, self).__init__()
        #self.memory_token_length = memory_token_length
        self.embed_dim = embed_dim
        self.model = vit
        #self.memory_token = nn.ParameterList([torch.nn.Parameter(torch.randn(self.memory_token_length, self.embed_dim) * 0.02)
        #                                   for _ in range(len(self.model.vit.encoder.layer))])
        for index, layer_ in enumerate(self.model.vit.encoder.layer):
            self.model.vit.encoder.layer[index] = CustomViTLayerForMemory(config, 
                                          layer_.attention,
                                          layer_.intermediate,
                                          layer_.output,
                                          layer_.layernorm_before,
                                          layer_.layernorm_after)#,
                                         #self.memory_token[index])
            
            
        
        #self.memory_token = torch.nn.Parameter(
        #    torch.randn(self.memory_token_length, self.embed_dim)
        #) #memory token
        
        
    def forward(self, x):
        hidden_states = self.model.vit.embeddings(x)
        
        for layer in self.model.vit.encoder.layer:
            hidden_states = layer(hidden_states)[0]#layer(hidden_states, self.memory_token)[0]            
        
        normalized = self.model.vit.layernorm(hidden_states)
        logits = self.model.classifier(normalized[:, 0, :])
        return logits

pm = PaperModel(config, model).to(device)

#memory + cls + head
for name, param in pm.named_parameters():
    if ("classifier" not in name) and ("cls_token" not in name) and ("memory_token" not in name):
        param.requires_grad = False
    else:
        print(name)

model.vit.embeddings.cls_token
model.classifier.weight
model.classifier.bias


In [12]:
train(pm)

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [1/10], Loss: 1.0749211574692537


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [2/10], Loss: 0.43293110443518895


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [3/10], Loss: 0.3343392747758639


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [4/10], Loss: 0.29389287282408466


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [5/10], Loss: 0.27153268386379326


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [6/10], Loss: 0.25559217267033485


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [8/10], Loss: 0.23523561197268492


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [9/10], Loss: 0.22781454978307875


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [10/10], Loss: 0.221015185385131
Training complete!


In [14]:
valid(pm)

  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 92 %


In [25]:
train(pm)

  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [1/10], Loss: 0.27229671062219263


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [2/10], Loss: 0.2285295785543057


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [3/10], Loss: 0.21214430391936018


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [4/10], Loss: 0.20146242372396558


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [5/10], Loss: 0.19165904060153677


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [6/10], Loss: 0.18456119456614303


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [7/10], Loss: 0.1778454251107898


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [8/10], Loss: 0.1725820623670179


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [9/10], Loss: 0.16802955144960988


  0%|          | 0/404 [00:00<?, ?it/s]

Epoch [10/10], Loss: 0.16291090112529916
Training complete!


In [26]:
valid(pm)

  0%|          | 0/81 [00:00<?, ?it/s]

Accuracy of the network on the 10000 test images: 93 %


In [82]:
for batch_idx, (data, targets) in enumerate(tqdm(valid_loader)):
    data = data.to(device)
    break

  0%|          | 0/81 [00:00<?, ?it/s]

In [84]:
pm(data)

tensor([[-0.0278, -0.8870,  1.5921,  ..., -1.4284, -0.4906, -0.6489],
        [ 0.3318, -0.6451,  1.4786,  ..., -1.2412,  0.2162, -0.3718],
        [ 0.2759, -0.6854,  1.3918,  ..., -1.2371,  0.3058, -0.3987],
        ...,
        [-0.0230, -0.9826,  1.5481,  ..., -1.1421, -0.3711, -0.2848],
        [-0.1069, -0.4906,  1.4808,  ..., -1.5035, -0.3257, -0.0312],
        [ 0.1012, -1.0433,  2.1145,  ..., -1.3779, -0.4260, -0.6846]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

In [71]:
class customAttentionWithMemory(nn.Module):
    def __init__(self, config: ViTConfig, query, key, value, dropout) -> None:
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
                f"heads {config.num_attention_heads}."
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = query
        self.key = key
        self.value = value

        self.dropout = dropout

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self, q, k, v):
        mixed_query_layer = self.query(q)

        key_layer = self.transpose_for_scores(self.key(k))
        value_layer = self.transpose_for_scores(self.value(v))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer,)

        return outputs

In [72]:
class PaperModel(nn.Module):
    def __init__(self, vit, embed_dim = 768, memory_token_length = 10):
        super(PaperModel, self).__init__()
        self.memory_token_length = memory_token_length
        self.embed_dim = embed_dim
        self.model = vit
        self.memory_token = torch.nn.Parameter(
            torch.randn(self.memory_token_length, self.embed_dim)
        ) #memory token
        for layer in self.model.vit.encoder:
            att_model = layer.attention
            layer.attention = customAttentionWithMemory(config, att_model.attention.query,\
                                                        att_model.attention.key, att_model.attention.value,att_model.attention.dropout)
        
        
    def forward(self, x):
        input = self.model.vit.embeddings(x)
        for layer in self.model.vit.encoder.layer:
            
        
        att_out = self.layer(x, self.number_of_head)
        att_residual_out = att_out + x
        norm1_out = self.dropout1(self.norm1(att_residual_out))
        ff_out = self.MLP_sequence(norm1_out)
        ff_res_out = ff_out + norm1_out
        norm2_out = self.dropout2(self.norm2(ff_res_out))
        return norm2_out

IndentationError: expected an indented block after 'for' statement on line 17 (1751082145.py, line 20)

In [73]:
config.num_attention_heads

12

In [23]:
model.vit.encoder.layer[0](k)

(tensor([[[ 1.1750,  1.2434,  0.0246,  ...,  1.0854,  0.5624,  1.8326],
          [ 0.5839,  0.4748, -0.9245,  ...,  0.4982,  0.0449,  1.3741],
          [-0.5252,  1.6358, -0.8340,  ...,  0.5063,  0.6809,  1.2358],
          ...,
          [ 1.5790,  1.5951, -0.5681,  ..., -0.0584,  0.5637,  0.7831],
          [ 0.4778,  0.8916, -0.3458,  ...,  0.7740,  0.7255,  1.5669],
          [-0.1720,  2.2242,  0.1825,  ...,  0.7067,  0.1703,  1.1183]],
 
         [[ 0.0317,  1.9712, -1.1072,  ...,  1.0647,  0.2739,  1.5160],
          [-0.5021,  2.0634, -1.3776,  ...,  0.7768,  0.0027,  1.1941],
          [ 0.5766,  0.7833, -0.5032,  ...,  0.6730, -0.1111,  0.7818],
          ...,
          [ 0.6456,  1.2389, -1.0500,  ...,  0.8230, -0.2533,  1.4723],
          [-0.2427,  0.3021, -1.5380,  ...,  0.9418, -0.4302,  0.5285],
          [ 0.6844,  1.5814, -0.3445,  ...,  0.5765, -0.6269,  1.1135]],
 
         [[ 0.0652,  0.3140, -0.6031,  ...,  0.4982,  1.1300,  2.2200],
          [ 0.7496,  1.8503,

In [24]:
q = torch.rand([28, 50, 768])
expanded_memory = torch.randn((10, 768))


In [31]:
expanded_memory.expand(28, -1, -1).shape

torch.Size([28, 10, 768])

In [36]:
s = torch.cat((q, expanded_memory.expand(28, -1, -1)), dim=1)

In [37]:
s.shape

torch.Size([28, 60, 768])

In [39]:
s[:,:-expanded_memory.shape[0]].shape

torch.Size([28, 50, 768])

In [16]:
q = torch.rand(s.shape).to(device)
k = torch.rand([28, 60, 768]).to(device)
v = torch.rand([28, 60, 768]).to(device)

In [25]:
q_a = model.vit.encoder.layer[0].attention.attention.query(q)
k_a = model.vit.encoder.layer[0].attention.attention.key(k)


In [26]:
attention = q.matmul(k.transpose(-1, -2)) 

In [51]:
layer_ = model.vit.encoder.layer[0]
layer_

ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)

In [52]:
custom_layer_ = CustomViTLayerForMemory(config, layer_.attention, layer_.intermediate, layer_.output, layer_.layernorm_before, layer_.layernorm_after)

In [55]:
custom_layer_ = custom_layer_.to(device)

In [56]:
q = q.to(device)
expanded_memory = expanded_memory.to(device)

In [59]:
custom_layer_(q, expanded_memory)[0].shape

torch.Size([28, 50, 768])

In [66]:
model.

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_

In [28]:
attention.shape

torch.Size([28, 50, 60])

In [74]:
att_model = model.vit.encoder.layer[0].attention

ca = customAttentionWithMemory(config, att_model.attention.query, att_model.attention.key, att_model.attention.value, att_model.attention.dropout)

In [80]:
v.shape

torch.Size([28, 60, 768])

In [77]:
ca(q,k,v)[0].shape

torch.Size([28, 50, 768])

In [55]:
model.vit.config.attention_probs_dropout_prob

0.0

In [58]:
model.vit.encoder.layer[0]

ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)

In [24]:
for i,p in model.named_parameters():
    print(i)

vit.embeddings.cls_token
vit.embeddings.position_embeddings
vit.embeddings.patch_embeddings.projection.weight
vit.embeddings.patch_embeddings.projection.bias
vit.encoder.layer.0.attention.attention.query.weight
vit.encoder.layer.0.attention.attention.query.bias
vit.encoder.layer.0.attention.attention.key.weight
vit.encoder.layer.0.attention.attention.key.bias
vit.encoder.layer.0.attention.attention.value.weight
vit.encoder.layer.0.attention.attention.value.bias
vit.encoder.layer.0.attention.output.dense.weight
vit.encoder.layer.0.attention.output.dense.bias
vit.encoder.layer.0.intermediate.dense.weight
vit.encoder.layer.0.intermediate.dense.bias
vit.encoder.layer.0.output.dense.weight
vit.encoder.layer.0.output.dense.bias
vit.encoder.layer.0.layernorm_before.weight
vit.encoder.layer.0.layernorm_before.bias
vit.encoder.layer.0.layernorm_after.weight
vit.encoder.layer.0.layernorm_after.bias
vit.encoder.layer.1.attention.attention.query.weight
vit.encoder.layer.1.attention.attention.query

In [23]:
for i,p in model.vit.named_parameters():
    print(i)

embeddings.cls_token
embeddings.position_embeddings
embeddings.patch_embeddings.projection.weight
embeddings.patch_embeddings.projection.bias
encoder.layer.0.attention.attention.query.weight
encoder.layer.0.attention.attention.query.bias
encoder.layer.0.attention.attention.key.weight
encoder.layer.0.attention.attention.key.bias
encoder.layer.0.attention.attention.value.weight
encoder.layer.0.attention.attention.value.bias
encoder.layer.0.attention.output.dense.weight
encoder.layer.0.attention.output.dense.bias
encoder.layer.0.intermediate.dense.weight
encoder.layer.0.intermediate.dense.bias
encoder.layer.0.output.dense.weight
encoder.layer.0.output.dense.bias
encoder.layer.0.layernorm_before.weight
encoder.layer.0.layernorm_before.bias
encoder.layer.0.layernorm_after.weight
encoder.layer.0.layernorm_after.bias
encoder.layer.1.attention.attention.query.weight
encoder.layer.1.attention.attention.query.bias
encoder.layer.1.attention.attention.key.weight
encoder.layer.1.attention.attention

In [12]:
s = model.vit.embeddings(data)

In [26]:
model.vit.embeddings.cls_token.requires_grad 

False

In [18]:
for i, layer in enumerate(model.vit.encoder.layer):
    print(i, layer)

0 ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (output): ViTSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
  )
  (intermediate): ViTIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): ViTOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
1 ViTLayer(
  (attention): ViTAttention(
    (attention): ViTSelfAttentio

In [13]:
s.shape

torch.Size([784, 50, 768])

In [36]:
a = model.vit.encoder(s)

In [42]:
model.vit.layernorm(a.last_hidden_state)[:,0].shape

torch.Size([32, 768])

In [44]:
model.classifier(model.vit.layernorm(a.last_hidden_state)[:,0]).shape

torch.Size([32, 10])