In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the scaled dot-product attention function
def scaled_dot_product_attention(query, key, value):
    matmul_qk = torch.matmul(query, key.transpose(-2, -1))
    d_k = query.size(-1)
    scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    attention_weights = torch.nn.functional.softmax(scaled_attention_logits, dim=-1)
    output = torch.matmul(attention_weights, value)
    return output, attention_weights

# Define the multi-head attention layer
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        self.depth = d_model // self.num_heads
        
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        
        self.fc = nn.Linear(d_model, d_model)
        
    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)
        
    def forward(self, query, key, value):
        batch_size = query.size(0)
        query = self.split_heads(self.wq(query), batch_size)
        key = self.split_heads(self.wk(key), batch_size)
        value = self.split_heads(self.wv(value), batch_size)
        
        output, attention_weights = scaled_dot_product_attention(query, key, value)
        
        output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
        output = self.fc(output)
        
        return output

# Define the Transformer encoder layer
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        attn_output = self.mha(x, x, x)
        out1 = self.layernorm1(x + self.dropout(attn_output))
        
        ffn_output = self.ffn(out1)
        out2 = self.layernorm2(out1 + self.dropout(ffn_output))
        
        return out2

# Define the full Transformer model for language modeling
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, dropout=0.1):
        super(TransformerLanguageModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.encoder = TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
        self.decoder = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        logits = self.decoder(x)
        
        return logits

# Hyperparameters
vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
dropout = 0.1

# Initialize model, loss, and optimizer
model = TransformerLanguageModel(vocab_size, d_model, num_heads, d_ff, dropout)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy input and target
input_seq = torch.randint(0, vocab_size, (32, 50))
target_seq = torch.randint(0, vocab_size, (32, 50))

# Forward pass
logits = model(input_seq)

# Compute loss and perform backpropagation
loss = criterion(logits.view(-1, vocab_size), target_seq.view(-1))
loss.backward()
optimizer.step()

print("Training loss:", loss.item())
