In [2]:
import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query, key, value):
        matmul_qk = torch.matmul(query, key.transpose(-2, -1))

        # scale matmul_qk
        depth = query.shape[-1]
        logits = matmul_qk / math.sqrt(depth)

        # softmax is normalized on the last axis (seq_len_k)
        attention_weights = nn.functional.softmax(logits, dim=-1)

        output = torch.matmul(attention_weights, value)

        return output, attention_weights

# Toy example
query = torch.rand((50, 10))  # 50 queries, each of size 10
key = torch.rand((50, 10))  # 50 keys, each of size 10
value = torch.rand((50, 10))  # 50 values, each of size 10

attention_layer = ScaledDotProductAttention()
output, attention_weights = attention_layer(query, key, value)

print("Output:", output)
print("Attention weights:", attention_weights)

print(attention_weights.shape)

Output: tensor([[0.4738, 0.5273, 0.4417, 0.4843, 0.6053, 0.4484, 0.4463, 0.4749, 0.4427,
         0.4518],
        [0.4814, 0.5329, 0.4511, 0.4874, 0.6104, 0.4421, 0.4477, 0.4827, 0.4516,
         0.4567],
        [0.4718, 0.5312, 0.4451, 0.4824, 0.6118, 0.4397, 0.4421, 0.4820, 0.4373,
         0.4457],
        [0.4659, 0.5303, 0.4475, 0.4845, 0.6092, 0.4441, 0.4444, 0.4795, 0.4340,
         0.4422],
        [0.4755, 0.5280, 0.4419, 0.4888, 0.6031, 0.4495, 0.4477, 0.4800, 0.4436,
         0.4532],
        [0.4697, 0.5352, 0.4445, 0.4887, 0.6061, 0.4433, 0.4444, 0.4790, 0.4455,
         0.4529],
        [0.4725, 0.5334, 0.4452, 0.4867, 0.6018, 0.4448, 0.4433, 0.4810, 0.4390,
         0.4486],
        [0.4793, 0.5274, 0.4392, 0.4855, 0.6035, 0.4444, 0.4450, 0.4777, 0.4391,
         0.4557],
        [0.4756, 0.5231, 0.4422, 0.4858, 0.6119, 0.4480, 0.4511, 0.4819, 0.4403,
         0.4512],
        [0.4718, 0.5376, 0.4423, 0.4904, 0.6098, 0.4387, 0.4443, 0.4916, 0.4422,
         0.4483],
  

In [3]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [12]:
import torch.optim as optim

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_epochs = 15
learning_rate = 0.001
batch_size = 64
input_size =6
num_classes = 3
sequence_length = 3
num_layers = 2
dropout = 0.3
heads = 2
forward_expansion = 4

# Create the model
model = TransformerBlock(embed_size=input_size, heads=heads, dropout=dropout, forward_expansion=forward_expansion).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Toy dataset
inputs = torch.randn(num_epochs * batch_size, sequence_length, input_size).to(device)
labels = torch.randint(num_classes, (num_epochs * batch_size,)).to(device)

# Training loop
for epoch in range(num_epochs):
    for i in range(0, inputs.size(0), batch_size):
        # Get mini-batch inputs and labels
        inputs_mini = inputs[i:i+batch_size]
        labels_mini = labels[i:i+batch_size]

        # Forward pass
        outputs = model(inputs_mini, inputs_mini, inputs_mini, mask=None)
        outputs = outputs.mean(dim=1)
        loss = criterion(outputs, labels_mini)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for i in range(0, inputs.size(0), batch_size):
        inputs_mini = inputs[i:i+batch_size]
        labels_mini = labels[i:i+batch_size]
        outputs = model(inputs_mini, inputs_mini, inputs_mini, mask=None)
        outputs = outputs.mean(dim=1)
        _, predicted = torch.max(outputs.data, 1)
        total += labels_mini.size(0)
        correct += (predicted == labels_mini).sum().item()

    print('Accuracy of the model on the 10000 test inputs: {} %'.format(100 * correct / total))

Epoch [1/15], Loss: 1.9434
Epoch [2/15], Loss: 1.8854
Epoch [3/15], Loss: 1.8049
Epoch [4/15], Loss: 1.5884
Epoch [5/15], Loss: 1.5126
Epoch [6/15], Loss: 1.4796
Epoch [7/15], Loss: 1.4721
Epoch [8/15], Loss: 1.4669
Epoch [9/15], Loss: 1.3825
Epoch [10/15], Loss: 1.4160
Epoch [11/15], Loss: 1.3881
Epoch [12/15], Loss: 1.3914
Epoch [13/15], Loss: 1.3567
Epoch [14/15], Loss: 1.2998
Epoch [15/15], Loss: 1.2995
Accuracy of the model on the 10000 test inputs: 36.041666666666664 %
