In [15]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from models.vit import VisionTransformer

In [16]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cpu


In [17]:
model = VisionTransformer(
    img_size=32,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=256,
    num_heads=8,
    hidden_dim=512,
    num_layers=6,
    dropout=0.1
).to(device)



In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

In [None]:
batch_size = 128

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [20]:
def relu_kernel_feature_map(x):
    return torch.nn.functional.relu(x)

In [21]:
class PerformerAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(PerformerAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        batch_size, seq_length, embed_dim = query.size()

        # Linear projections
        Q = self.query_proj(query)  # [batch_size, seq_length, embed_dim]
        K = self.key_proj(key)
        V = self.value_proj(value)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim)

        # Apply kernel feature map
        Q = relu_kernel_feature_map(Q)  # [batch_size, seq_length, num_heads, head_dim]
        K = relu_kernel_feature_map(K)

        # Compute attention (linear complexity)
        KV = torch.einsum('blhd,blhe->bhde', K, V)  # [batch_size, num_heads, head_dim, head_dim]
        QKV = torch.einsum('blhd,bhde->blhe', Q, KV)  # [batch_size, seq_length, num_heads, head_dim]

        # Reshape and project out
        QKV = QKV.contiguous().view(batch_size, seq_length, embed_dim)
        output = self.out_proj(QKV)

        return output

In [25]:
import torch

checkpoint_path = 'best_ViT.pth'
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
print("Checkpoint type:", type(checkpoint))

if isinstance(checkpoint, dict):
    print("Checkpoint keys:", checkpoint.keys())
else:
    print("Checkpoint is not a dictionary.")

Checkpoint type: <class 'collections.OrderedDict'>
Checkpoint keys: odict_keys(['cls_token', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'pos_embed.pe', 'encoder.transformer_encoder.layers.0.self_attn.in_proj_weight', 'encoder.transformer_encoder.layers.0.self_attn.in_proj_bias', 'encoder.transformer_encoder.layers.0.self_attn.out_proj.weight', 'encoder.transformer_encoder.layers.0.self_attn.out_proj.bias', 'encoder.transformer_encoder.layers.0.linear1.weight', 'encoder.transformer_encoder.layers.0.linear1.bias', 'encoder.transformer_encoder.layers.0.linear2.weight', 'encoder.transformer_encoder.layers.0.linear2.bias', 'encoder.transformer_encoder.layers.0.norm1.weight', 'encoder.transformer_encoder.layers.0.norm1.bias', 'encoder.transformer_encoder.layers.0.norm2.weight', 'encoder.transformer_encoder.layers.0.norm2.bias', 'encoder.transformer_encoder.layers.1.self_attn.in_proj_weight', 'encoder.transformer_encoder.layers.1.self_attn.in_proj_bias', 'encoder.transformer_encoder.