In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, emb_dim, max_len=64):
        super().__init__()
        pe = torch.zeros(max_len, emb_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, emb_dim)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (batch_size, seq_len, emb_dim)
        return x + self.pe[:, :x.size(1)]


In [None]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = torch.softmax(scores, dim=-1)
    return torch.matmul(attn, v), attn


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        assert emb_dim % num_heads == 0
        self.d_k = emb_dim // num_heads
        self.num_heads = num_heads

        self.q_linear = nn.Linear(emb_dim, emb_dim)
        self.k_linear = nn.Linear(emb_dim, emb_dim)
        self.v_linear = nn.Linear(emb_dim, emb_dim)
        self.out = nn.Linear(emb_dim, emb_dim)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        def transform(x, linear):
            x = linear(x)
            x = x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
            return x  # (batch_size, heads, seq_len, d_k)

        q, k, v = transform(q, self.q_linear), transform(k, self.k_linear), transform(v, self.v_linear)
        scores, attn = scaled_dot_product(q, k, v, mask)
        scores = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.out(scores)


In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, emb_dim, ff_dim):
        super().__init__()
        self.linear1 = nn.Linear(emb_dim, ff_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ff_dim, emb_dim)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, emb_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(emb_dim, num_heads)
        self.ff = PositionwiseFeedForward(emb_dim, ff_dim)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, emb_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(emb_dim, num_heads)
        self.cross_attn = MultiHeadAttention(emb_dim, num_heads)
        self.ff = PositionwiseFeedForward(emb_dim, ff_dim)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.norm3 = nn.LayerNorm(emb_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        x = self.norm1(x + self.dropout(self.self_attn(x, x, x, tgt_mask)))
        x = self.norm2(x + self.dropout(self.cross_attn(x, enc_output, enc_output, src_mask)))
        x = self.norm3(x + self.dropout(self.ff(x)))
        return x


In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_layers, num_heads, ff_dim, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.pos_encoding = PositionalEncoding(emb_dim, max_len)
        self.layers = nn.ModuleList([
            EncoderLayer(emb_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])

    def forward(self, src, mask=None):
        x = self.embedding(src)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x


In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_layers, num_heads, ff_dim, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.pos_encoding = PositionalEncoding(emb_dim, max_len)
        self.layers = nn.ModuleList([
            DecoderLayer(emb_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(emb_dim, vocab_size)

    def forward(self, tgt, enc_output, src_mask=None, tgt_mask=None):
        x = self.embedding(tgt)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return self.fc_out(x)


In [None]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, emb_dim, num_layers, num_heads, ff_dim, max_len):
        super().__init__()
        self.encoder = TransformerEncoder(src_vocab_size, emb_dim, num_layers, num_heads, ff_dim, max_len)
        self.decoder = TransformerDecoder(tgt_vocab_size, emb_dim, num_layers, num_heads, ff_dim, max_len)

    def make_subsequent_mask(self, size):
        mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(1)
        return mask  # (1, 1, tgt_len, tgt_len)

    def forward(self, src, tgt, src_mask=None):
        enc_output = self.encoder(src, src_mask)
        tgt_mask = self.make_subsequent_mask(tgt.size(1)).to(tgt.device)
        output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        return output


In [None]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")
pad_token_id = tokenizer.pad_token_id
sos_token_id = tokenizer.convert_tokens_to_ids("<pad>")  # We can define <pad> as start
eos_token_id = tokenizer.eos_token_id


In [None]:
MAX_LEN = 64

def preprocess(text, target):
    # Example prompt for Cypher generation
    input_text = f"translate question to cypher: {text}"
    input_ids = tokenizer.encode(input_text, max_length=MAX_LEN, padding="max_length", truncation=True)
    target_ids = tokenizer.encode(target, max_length=MAX_LEN, padding="max_length", truncation=True)
    return torch.tensor(input_ids), torch.tensor(target_ids)


In [None]:
from torch.utils.data import Dataset

class CypherQADataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs  # list of (question, cypher)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_tensor, tgt_tensor = preprocess(src, tgt)
        return src_tensor, tgt_tensor


In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        optimizer.zero_grad()
        output = model(src, tgt_input)  # output: (batch, tgt_len-1, vocab_size)

        output = output.reshape(-1, output.shape[-1])
        tgt_output = tgt_output.reshape(-1)

        loss = criterion(output, tgt_output)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)


In [None]:
import torch.nn as nn
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    src_vocab_size=tokenizer.vocab_size,
    tgt_vocab_size=tokenizer.vocab_size,
    emb_dim=128,
    num_layers=2,
    num_heads=4,
    ff_dim=256,
    max_len=64
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [None]:
EPOCHS = 5

for epoch in range(EPOCHS):
    loss = train(model, loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1} - Loss: {loss:.4f}")
