In [19]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, RandomRotation, RandomCrop, ColorJitter, Normalize
from torchvision.datasets.mnist import MNIST
from torch.optim.lr_scheduler import ReduceLROnPlateau


np.random.seed(0)
torch.manual_seed(0)


def create_patches(images, n_patches):
    n, c, h, w = images.shape

    assert h == w

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches



class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MultiHeadSelfAttention, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.Linear(d, d)
        self.k_mappings = nn.Linear(d, d)
        self.v_mappings = nn.Linear(d, d)
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        batch_size, seq_length, _ = sequences.size()
        head_dim = self.d // self.n_heads

        q = self.q_mappings(sequences)
        k = self.k_mappings(sequences)
        v = self.v_mappings(sequences)

        q = q.view(batch_size, seq_length, self.n_heads, head_dim).permute(0, 2, 1, 3)
        k = k.view(batch_size, seq_length, self.n_heads, head_dim).permute(0, 2, 1, 3)
        v = v.view(batch_size, seq_length, self.n_heads, head_dim).permute(0, 2, 1, 3)

        attention = self.softmax(torch.matmul(q, k.permute(0, 1, 3, 2)) / (head_dim ** 0.5))
        output = torch.matmul(attention, v).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, self.d)

        return output



class ViTBlock(nn.Module):
    def __init__(self, hidden_dim, n_heads, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.mhsa = MultiHeadSelfAttention(hidden_dim, n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, mlp_ratio * hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_dim, hidden_dim)
        )

    def forward(self, x):
        out = self.mhsa(self.norm1(x))
        out = x + out
        out = out + self.mlp(out)
        return out



class VisionTransformer(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=4, hidden_dim=16, n_heads=2, output_dim=10):
        super(VisionTransformer, self).__init__()

        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim

        assert chw[1] % n_patches == 0
        assert chw[2] % n_patches == 0
        self.patch_size = (chw[1] // n_patches, chw[2] // n_patches)

        self.input_dim = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_dim, self.hidden_dim)
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_dim))
        self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_dim), persistent=False)
        self.blocks = nn.ModuleList([ViTBlock(hidden_dim, n_heads) for _ in range(n_blocks)])
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        n, c, h, w = images.shape
        patches = create_patches(images, self.n_patches).to(self.positional_embeddings.device)
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)
        for block in self.blocks:
            out = block(out)
        out = out[:, 0]
        return self.mlp(out)




def get_positional_embeddings(sequence_length, d):
    positions = torch.arange(sequence_length, dtype=torch.float32).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2, dtype=torch.float32) * -(torch.log(torch.tensor(10000.0)) / d))
    embeddings = torch.zeros(sequence_length, d)
    embeddings[:, 0::2] = torch.sin(positions / div_term)
    embeddings[:, 1::2] = torch.cos(positions / div_term)
    return embeddings



def data_augmentation(images):
    transform = nn.Sequential(
        RandomRotation(10),
        RandomCrop(28, padding=4),
        ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    )
    augmented_images = []
    for image in images:
        augmented_images.append(transform(image))
    return torch.stack(augmented_images)


def train(model, train_loader, test_loader, optimizer, criterion, scheduler, device):
    n_epochs = 30

    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0
        for batch in train_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            x = data_augmentation(x)  # Apply data augmentation
            y_hat = model(x)
            loss = criterion(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)
        train_loss /= len(train_loader.dataset)

        model.eval()
        test_loss = 0.0
        accuracy = 0.0
        with torch.no_grad():
            for batch in test_loader:
                x, y = batch
                x, y = x.to(device), y.to(device)
                y_hat = model(x)
                loss = criterion(y_hat, y)
                test_loss += loss.item() * x.size(0)
                _, predicted_labels = torch.max(y_hat, 1)
                accuracy += (predicted_labels == y).sum().item()

            test_loss /= len(test_loader.dataset)
            accuracy /= len(test_loader.dataset)

            print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            scheduler.step(test_loss)


# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset
train_dataset = MNIST(root="./data", train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root="./data", train=False, transform=ToTensor(), download=True)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Create model
model = VisionTransformer(chw=(1, 28, 28)).to(device)

# Define loss function and optimizer
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True)

# Train the model
train(model, train_loader, test_loader, optimizer, criterion, scheduler, device)


Epoch 1/30, Train Loss: 2.1635, Test Loss: 2.0296, Accuracy: 0.4228
Epoch 2/30, Train Loss: 2.0531, Test Loss: 1.9152, Accuracy: 0.5475
Epoch 3/30, Train Loss: 1.9659, Test Loss: 1.8262, Accuracy: 0.6411
Epoch 4/30, Train Loss: 1.9064, Test Loss: 1.7825, Accuracy: 0.6789
Epoch 5/30, Train Loss: 1.8549, Test Loss: 1.7299, Accuracy: 0.7329
Epoch 6/30, Train Loss: 1.8167, Test Loss: 1.6904, Accuracy: 0.7715
Epoch 7/30, Train Loss: 1.7851, Test Loss: 1.6685, Accuracy: 0.7935
Epoch 8/30, Train Loss: 1.7588, Test Loss: 1.6451, Accuracy: 0.8198
Epoch 9/30, Train Loss: 1.7369, Test Loss: 1.6358, Accuracy: 0.8263
Epoch 10/30, Train Loss: 1.7203, Test Loss: 1.6117, Accuracy: 0.8501
Epoch 11/30, Train Loss: 1.7102, Test Loss: 1.6126, Accuracy: 0.8495
Epoch 12/30, Train Loss: 1.6997, Test Loss: 1.6187, Accuracy: 0.8433
Epoch 00012: reducing learning rate of group 0 to 1.0000e-04.
Epoch 13/30, Train Loss: 1.6619, Test Loss: 1.5758, Accuracy: 0.8862
Epoch 14/30, Train Loss: 1.6502, Test Loss: 1.5734

In [20]:
    # Save the model
    torch.save(model.state_dict(), 'model_weights.pth')
    print("Model saved.")

Model saved.
