In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

from typing import Union

# from torchvision import 

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class AttentionHead(nn.Module):
    #def __init__(self, dim_in: int, dim_q: int, dim_k: int): # mistake, from the paper it says dim_q = dim_k but not dim_k != dim_v
    def __init__(self, dim_in: int, dim_qk: int, dim_v: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_qk)
        self.k = nn.Linear(dim_in, dim_qk)
        self.v = nn.Linear(dim_in, dim_v)
    
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        q = self.q(query)
        k = self.k(key)
        v = self.v(value)

        return self.scaled_dot_product_attention(q, k, v, mask)

    @staticmethod
    def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Computed scaled dot product of attention.

        Args:
            q (torch.Tensor): Query tensor of shape (batch_size, seq_len, dim_qk).
            k (torch.Tensor): Key tensor of shape (batch_size, seq_len, dim_qk).
            v (torch.Tensor): Value tensor of shape (batch_size, seq_len, dim_v).
            mask (torch.Tensor, optional): Mask tensor of shape (batch_size, seq_len, seq_len).

        Returns:
            torch.Tensor: Output tensor after attention is applied.
        """
        # Scaled dot-product
        scores = q.bmm(k.transpose(1, 2)) / torch.sqrt(torch.tensor(q.size(-1), dtype=torch.float32) + 1e-8)

        # Apply mask if needed
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax across the last dim
        attention_weights = F.softmax(scores, dim=-1)

        # Compute the output
        output = attention_weights.bmm(v)

        return output


In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_qk: int, dim_v: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_qk, dim_v) for _ in range(num_heads)]
        )

        self.linear = nn.Linear(num_heads * dim_qk, dim_in)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
        return self.linear(
            torch.cat([h(query, key, value, mask) for h in self.heads], dim=-1)
        )


In [6]:
#utils.py

def position_encoding(seq_len: int, dim_model: int, device: torch.device = torch.device('cpu')) -> torch.Tensor:
    """
    Generates sinusoidal positional encodings.

    Args:
        seq_len (int): The length of the sequence.
        dim_model (int): The dimensionality of the model (must be even for sin-cos pairing).
        device (torch.device): The device on which to create the encoding tensor.
    
    Returns:
        Tensor: The positional encoding tensor of shape (1, seq_len, dim_model)
    """
    # Position indices for each position in the sequence (shape: [1, seq_len, 1])
    pos = torch.arange(seq_len, dtype=torch.float32, device=device).reshape(1, -1, 1)

    # Dimension indices for each encoding dimension (shape: [1, 1, dim_model])
    dim = torch.arange(dim_model, dtype=torch.float32, device=device).reshape(1, 1, -1)

    # Compute the phase values based on position and dimension indices
    # This uses `dim_model` to scale each dimension in the sinusoidal pattern
    phase = pos / 1e4 ** (dim / dim_model)

    # Apply sin to even dimensions and cos to odd dimensions
    encoding = torch.zeros((1, seq_len, dim_model), device=device)
    encoding[..., 0::2] = torch.sin(phase[..., 0::2]) # Apply sin to even dimensions
    encoding[..., 1::2] = torch.cos(phase[..., 1::2]) # Apply cos to odd dimensions

    return encoding
    

def feed_forward(dim_model: int = 512, dim_feedforward: int = 2048) -> nn.Module:
    """
    Fully connected feed-forward network, which is applied to each position separately and identically.

    Args:
        dim_model (int): The dimensionality of the model.
        dim_feedforward (int): inner-layer dimensionality.
    
    Returns:
        nn.Module: Sequential layer of feed-forward neural network.
    """
    return nn.Sequential(
        nn.Linear(dim_model, dim_feedforward),
        nn.ReLU(),
        nn.Linear(dim_feedforward, dim_model),
    ) # Possible improvements from vanila ViT is adding dropout, layer norm, and etc.


In [7]:
class ResidualConnections(nn.Module):
    def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
        """
        Residual connection with layer normalization and optional dropout.
        
        Args:
            sublayer (nn.Module): The sublayer to be applied, e.g., MultiHeadAttention or FeedForward.
            dimension (int): The dimension for the LayerNorm.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, *tensors: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the residual layer.
        
        Assumes that the first tensor in `tensors` is the primary input for the residual connection.
        
        Args:
            *tensors: Input tensors where the first is assumed to be the main residual input.
        
        Returns:
            Tensor: The output tensor after applying residual connection, dropout, and normalization.
        """
        # Apply sublayer, then add the residual connection, followed by normalization
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))
        

In [8]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """
        Transformer Encoder Layer with residual connections around multi-head
        attention and feed-forward sub-layers.

        Args:
            dim_model (int): Dimension of the model embeddings.
            num_heads (int): Number of attention heads.
            dim_feedforward (int): Dimension of the feed-forward layer.
            dropout (float): Dropout probability.
        """
        super().__init__()
        # Calculate dimensions for query/key and value based on number of heads
        dim_qk = dim_v = max(dim_model // num_heads, 1)
        
        # Multi-head attention layer with residual connection 
        self.attention = ResidualConnections(
            MultiHeadAttention(num_heads, dim_model, dim_qk, dim_v),
            dimension=dim_model,
            dropout=dropout,
        )

        # Feed-forward network with residual connection
        self.feed_forward = ResidualConnections(
            feed_forward(dim_model, dim_feedforward),
            dimension=dim_model,
            dropout=dropout
        )

    def forward(self, src):
        """
        Forward pass through the Transformer Encoder Layer.
        
        Args:
            src (Tensor): Input tensor of shape (batch_size, seq_len, dim_model).
        
        Returns:
            Tensor: Output tensor of the same shape as input.
        """
        src = self.attention(src, src, src)
        return self.feed_forward(src)


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        num_layers: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        apply_positional_encoding: bool = True
    ):
        """
        Transformer Encoder comprising multiple encoder layers.

        Args:
            num_layers (int): Number of encoder layers.
            dim_model (int): Dimension of the model embeddings.
            num_heads (int): Number of attention heads.
            dim_feedforward (int): Dimension of the feed-forward layer.
            dropout (float): Dropout probability.
            apply_positional_encoding (bool): Whether to add positional encoding.
        """
        super().__init__()
        self.apply_positional_encoding = apply_positional_encoding

        # Stack of encoder layers
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

    def add_positional_encoding(self, src: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encoding to the input tensor.

        Args:
            src (Tensor): Input tensor of shape (batch_size, seq_len, dim_model).
        
        Returns:
            Tensor: Input tensor with positional encoding added.
        """
        seq_len, dimension = src.size(1), src.size(2)
        return src + position_encoding(seq_len, dimension, device=src.device)

    def forward(self, src: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Transformer Encoder.

        Args:
            src (Tensor): Input tensor of shape (batch_size, seq_len, dim_model).

        Returns:
            Tensor: Encoded tensor of the same shape as input.
        """
        # Optionally add positional encoding
        if self.apply_positional_encoding:
            src = self.add_positional_encoding(src)

        # Pass through each encoder layer
        for layer in self.layers:
            src = layer(src)
        
        return src


The decoder module is quite similar to the encoder, with just a few small differences:
- The decoder accepts two inputs (the target sequence and the encoder memory), rather than one input.
- There are two multi-head attention modules per layer (the target sequence self-attention module and the decoder-encoder attention module) rather than just one.
- The second multi-head attention module, rather than strict self attention, expects the encoder memory as $K$ and $V$.
- Since accessing future elements of the target sequence would be "cheating," we need to mask out future elements of the input target sequence.

In [9]:
class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """
        Transformer Decoder Layer with two residual-attention connections and a 
        feed-forward network, each wrapped in residual connections.

        Args:
            dim_model (int): Dimension of the model embeddings.
            num_heads (int): Number of attention heads.
            dim_feedforward (int): Dimension of the feed-forward layer.
            dropout (float): Dropout probability.
        """
        super().__init__()
        # Dimensions for query/key and value
        dim_qk = dim_v = max(dim_model // num_heads, 1)

        # Self-attention with residual connection
        self.attention1 = ResidualConnections(
            MultiHeadAttention(num_heads, dim_model, dim_qk, dim_v),
            dimension=dim_model,
            dropout=dropout
        )

        # Cross-attention with residual connection
        self.attention2 = ResidualConnections(
            MultiHeadAttention(num_heads, dim_model, dim_qk, dim_v),
            dimension=dim_model,
            dropout=dropout
        )

        # Feed-forward with residual connection
        self.feed_forward = ResidualConnections(
            feed_forward(dim_model, dim_feedforward),
            dimension=dim_model,
            dropout=dropout
        )
    
    def forward(self, trg: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Transformer Decoder Layer.
        
        Args:
            trg (Tensor): Target tensor of shape (batch_size, seq_len, dim_model).
            memory (Tensor): Memory tensor from the encoder of shape (batch_size, seq_len, dim_model).
        
        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, dim_model).
        """
        trg = self.attention1(trg, trg, trg) # Self-attention
        trg = self.attention2(trg, memory, memory) # Cross-attention with encoder output
        return self.feed_forward(trg)


class TransformerDecoder(nn.Module):
    def __init__(
        self,
        num_layers: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        apply_positional_encoding: bool = True
    ):
        """
        Transformer Decoder consisting of multiple decoder layers.

        Args:
            num_layers (int): Number of decoder layers.
            dim_model (int): Dimension of the model embeddings.
            num_heads (int): Number of attention heads.
            dim_feedforward (int): Dimension of the feed-forward layer.
            dropout (float): Dropout probability.
            apply_positional_encoding (bool): Whether to add positional encoding.
        """
        super().__init__()
        self.apply_positional_encoding = apply_positional_encoding

        # Stack of decoder layers
        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

        # Final linear layer to project to output vocabulary size or target dimension
        self.linear = nn.Linear(dim_model, dim_model)
    
    def add_positional_encoding(self, trg: torch.Tensor) -> torch.Tensor:
        """
        Adds positional encoding to the target tensor.
        
        Args:
            trg (Tensor): Target tensor of shape (batch_size, seq_len, dim_model).
        
        Returns:
            Tensor: Target tensor with positional encoding added.
        """
        seq_len, dimension = trg.size(1), trg.size(2)
        return trg + position_encoding(seq_len, dimension, device=trg.device)

    def forward(self, trg: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Transformer Decoder.
        
        Args:
            trg (Tensor): Target tensor of shape (batch_size, seq_len, dim_model).
            memory (Tensor): Memory tensor from the encoder of shape (batch_size, seq_len, dim_model).
        
        Returns:
            Tensor: Decoded tensor of shape (batch_size, seq_len, dim_model).
        """
        # Optionally add positional encoding
        if self.apply_positional_encoding:
            trg = self.add_positional_encoding(trg)

        # Pass through each decoder layer
        for layer in self.layers:
            trg = layer(trg, memory)
        
        # Linear projection, softmax should be applied externally if needed
        # return torch.softmax(self.linear(trg), dim=-1) #The softmax should typically be applied outside the decoder in a final output layer if necessary.
        return self.linear(trg)


In [10]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_model: int = 512,
        num_heads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """
        Transformer model consisting of an encoder and decoder.

        Args:
            num_encoder_layers (int): Number of layers in the encoder.
            num_decoder_layers (int): Number of layers in the decoder.
            dim_model (int): Dimension of the model.
            num_heads (int): Number of attention heads.
            dim_feedforward (int): Dimension of the feed-forward network.
            dropout (float): Dropout probability.
        """
        super().__init__()

        self.encoder = TransformerEncoder(
            num_layers=num_encoder_layers, 
            dim_model=dim_model, 
            num_heads=num_heads, 
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
        )

        self.decoder = TransformerDecoder(
            num_layers=num_decoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

    
    def forward(self, src: torch.Tensor, trg: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the Transformer model.

        Args:
            src (Tensor): Source tensor of shape (batch_size, src_seq_len, dim_model).
            trg (Tensor): Target tensor of shape (batch_size, trg_seq_len, dim_model).

        Returns:
            Tensor: Output tensor of shape (batch_size, trg_seq_len, dim_model).
        """
        # Pass through encoder
        memory = self.encoder(src)

        # Pass through decoder with the encoder's output as memory
        output = self.decoder(trg, memory)

        return output


In [11]:
src = torch.rand(64, 32, 512)
tgt = torch.rand(64, 16, 512)
out = Transformer()(src, tgt)
print(out.shape)
# torch.Size([64, 16, 512])

torch.Size([64, 16, 512])


# Vision Transformer

The steps of ViT are as follows:

1. Split input image into patches
2. Flatten the patches
3. Produce linear embeddings from the flattened patches
4. Add position embeddings
5. Feed the sequence preceeded by a `[class]` token as input to a standard transformer encoder
6. Pretrain the model to ouptut image labels for the `[class]` token (fully supervised on a huge dataset such as ImageNet-22K)
7. Fine-tune on the downstream dataset for the specific image classification task

In [41]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention (MSA) layer for Transformer-based models.

    This layer splits the input into multiple attention heads, computes self-attention for each head, 
    and concatenates the results. It is commonly used in Transformer architectures for both NLP and 
    Vision Transformers (ViT).

    Args:
        dim_model (int): The dimensionality of the input and output representations.
        num_heads (int): The number of attention heads. Each head will have a dimensionality of `dim_model / num_heads`.

    Attributes:
        q_linear (nn.Linear): Linear layer to project input to query vectors.
        k_linear (nn.Linear): Linear layer to project input to key vectors.
        v_linear (nn.Linear): Linear layer to project input to value vectors.
        out_linear (nn.Linear): Final linear layer to project concatenated head outputs back to `dim_model`.
    """

    def __init__(
        self,
        dim_model: int,
        num_heads: int = 2,
    ):  
        
        super(MultiHeadSelfAttention, self).__init__()

        # Ensure the model dimension is divisible by the number of heads
        assert dim_model % num_heads == 0, f"dim_model {dim_model} must be divisible by num_heads {num_heads}"

        self.num_heads = num_heads
        self.dim_head = dim_model // num_heads

        # Linear layers to project input into query, key, and value spaces
        self.q = nn.Linear(dim_model, dim_model)
        self.k = nn.Linear(dim_model, dim_model)
        self.v = nn.Linear(dim_model, dim_model)

        # Output projection layer
        self.out_linear = nn.Linear(dim_model, dim_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for multi-head self-attention.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, dim_model), where:
                - batch_size is the number of input samples,
                - seq_len is the sequence length, and
                - dim_model is the dimensionality of each token in the sequence.

        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, dim_model) after multi-head self-attention is applied.
        """
        batch_size, seq_len, dim_model = x.size()

        # Project inputs to multi-head query, key, and value spaces
        q = self.q(x).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
        k = self.k(x).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)
        v = self.v(x).view(batch_size, seq_len, self.num_heads, self.dim_head).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.dim_head ** 0.5)  # Shape: (batch_size, num_heads, seq_len, seq_len)
        attention_weights = torch.softmax(scores, dim=-1)  # Apply softmax to get attention weights
        attended_values = torch.matmul(attention_weights, v)  # Shape: (batch_size, num_heads, seq_len, dim_head)
        
        # Concatenate heads and project to output dimension
        attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, dim_model)
        output = self.out_linear(attended_values)

        return output

In [45]:
class ViT(nn.Module):
    """
    Vision Transformer (ViT) implementation.

    This model divides an input image into patches, embeds them, adds positional encoding, 
    and applies Transformer layers to produce a classification output.

    Args:
        input_shape (tuple): Shape of the input image (channels, height, width).
        n_patches (int): Number of patches along each dimension.
        hidden_d (int): Dimensionality of the hidden layer.
        n_heads (int): Number of attention heads in multi-head self-attention.
        out_d (int): Dimensionality of the output (number of classes for classification).
    """

    def __init__(
        self,
        input_shape,
        n_patches=7,
        hidden_d=8,
        n_heads=2,
        out_d=10,
    ):
        super(ViT, self).__init__()

        # Input and patch size checks
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.n_heads = n_heads

        assert input_shape[1] % n_patches == 0, "Input height not divisible by number of patches"
        assert input_shape[2] % n_patches == 0, "Input width not divisible by number of patches"

        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden_d = hidden_d

        # 1) Linear layers for patch embeddings
        self.input_d = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, hidden_d)

        # 2) Classification Token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))

        # 3) Positional embedding - forward layer

        # 4a) Layer Normalization 1
        self.ln1 = nn.LayerNorm((self.n_patches * self.n_patches + 1, self.hidden_d))

        # 4b) MSA and classification token
        self.msa = MultiHeadSelfAttention(self.hidden_d, n_heads)

        # 5a) Layer Normalization 2
        self.ln2 = nn.LayerNorm((self.n_patches * self.n_patches + 1, self.hidden_d))

        # 5b) Encoder MLP
        self.enc_mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.hidden_d),
            nn.ReLU()
        )

        # 6) Classification MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Vision Transformer.

        Args:
            images (Tensor): Batch of images of shape (batch_size, channels, height, width).
        
        Returns:
            Tensor: Output logits of shape (batch_size, out_d).
        """
        device = images.device

        # Divide images into patches and flatten each patch
        n, c, w, h = images.shape
        patches = images.reshape(n, self.n_patches ** 2, self.input_d)

        # Running linear layer for tokenization
        tokens = self.linear_mapper(patches).to(device)

        # Adding classification token to the tokens
        class_token = self.class_token.to(device)
        tokens = torch.stack([torch.vstack((class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += position_encoding(self.n_patches ** 2 + 1, self.hidden_d, device).repeat(n, 1, 1)

        # TRANSFORMER ENCODER BEGINS ###################################
        # NOTICE: MULTIPLE ENCODER BLOCKS CAN BE STACKED TOGETHER ######
        # Running Layer Normalization, MSA and residual connection
        tokens = tokens.to(device)
        ln1_tokens = self.ln1(tokens)
        msa_out = self.msa(ln1_tokens)
        out = tokens + msa_out

        # Running Layer Normalization, MLP and residual connection
        out = out + self.enc_mlp(self.ln2(out))
        # TRANSFORMER ENCODER ENDS   ###################################

        # Getting the classification token only
        out = out[:, 0]

        return self.mlp(out)



In [21]:
def train_vit(model, device, train_loader, optimizer, criterion, epochs):
    model.train()

    for epoch in range(epochs):
        train_loss = 0.0
        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            output = model(data)
            loss = criterion(output, target) / len(data)

            train_loss += loss

            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")


def test_vit(model, device, test_loader, criterion):
    model.eval()

    correct, total = 0, 0
    test_loss = 0.0
    for i, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)

        output = model(data)

        loss = criterion(output, target)
        test_loss += loss

        correct += torch.sum(torch.argmax(output, dim=1) == target).item()
        total += len(data)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

In [15]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_set = datasets.MNIST(root='./datasets', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./datasets', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
test_loader = DataLoader(test_set, batch_size=16, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./datasets\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:03<00:00, 2779455.42it/s]


Extracting ./datasets\MNIST\raw\train-images-idx3-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 102937.83it/s]


Extracting ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 899341.41it/s] 


Extracting ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1135785.42it/s]

Extracting ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw






In [46]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ViT((1, 28, 28), n_patches=7, hidden_d=20, n_heads=2, out_d=10)
model = model.to(device)

N_EPOCHS = 1
LR = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

In [47]:
train_vit(model, device, train_loader, optimizer, criterion, N_EPOCHS)

Epoch 1/1 loss: 526.00


In [49]:
test_vit(model, device, test_loader, criterion)

Test loss: 1435.14
Test accuracy: 16.49%
