In [None]:
import torch
import math
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter
import torch.nn.functional as F
# Load the dataset and make data loader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
def scaled_dot_product(q, k, v):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)

    return values, attention

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim = 784, nhead = 2, dropout = 0.1, batch_size = 64, seq_length = 784):
        super().__init__()
        self.embed_dim = embed_dim
        self.nhead = nhead
        self.dropout = dropout
        self.head_dim = embed_dim // nhead
        self.batch_size = batch_size
        self.seq_length = seq_length

        assert self.head_dim * nhead == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))

        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, attn_mask = None):
         batch_size, seq_length = query.shape
         qkv = torch._C._nn.linear(query, self.in_proj_weight)
         qkv = qkv.unflatten(-1, (3, self.embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
         q, k, v = qkv[0], qkv[1], qkv[2]


         q = q.view(batch_size, self.nhead, self.head_dim)
         k = k.view(batch_size, self.nhead, self.head_dim)
         v = v.view(batch_size, self.nhead, self.head_dim)

         values, _ = scaled_dot_product(q, k, v)
         values = values.reshape(batch_size, seq_length)

         o = self.o_proj(values)
         return o


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model = 784, nhead = 2, dim_feedforward = 784, dropout = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

        self.attention = MultiheadAttention(d_model, nhead, dropout = dropout)

    def forward(self, x, src_mask = None):
        # sa
        attn = self.attention(x, x, x, attn_mask = src_mask)
        x = x + self.dropout1(attn)
        x = self.norm1(x)

        # ff
        x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))
        x = self.norm2(x)

        return x

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, sequence_length):
        super(PositionalEncoding, self).__init__()
        self.sequence_length = sequence_length
        pe = torch.zeros(sequence_length)
        position = torch.arange(0, sequence_length, dtype=torch.float)
        pe = position / torch.pow(10000, (2 * (position // 2)) / torch.tensor(sequence_length).float())
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)


In [None]:
class Encoder(nn.Module):
    def __init__(self, num_layers = 2, norm=None, ** block_args):
        super().__init__()
        self.num_layers = num_layers
        self.norm = norm
        self.layers = nn.ModuleList([EncoderLayer(**block_args) for _ in range(num_layers)])

    def forward(self, x, src_mask = None):
        for mod in self.layers:
            output = mod(x, src_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

In [None]:
class enhance_classifier(nn.Module):

    def __init__(self, d_model=784, nhead=2, num_layers=2, dim_feedforward=784, num_classes=10):
        super(enhance_classifier, self).__init__()

        self.d_model = d_model

        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.pos_encoder = PositionalEncoding(d_model)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        # x = self.pos_encoder(x)
        x = self.transformer_encoder(x)

        x = self.fc(x)
        
        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = enhance_classifier()
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:
# Train the model
model.train()
for epoch in range(10):
    loss_sum = 0
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("epoch", epoch, 'acc:', correct / total)


In [None]:

# Eval the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))