In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch import optim
import os
import io
import math
import random
import time
import numpy

from ode import ODEEnvironment
from train import get_parser

In [None]:
class EmbeddingCust(nn.Module):
    def __init__(self, num_embed, embed_dim, pad_idx=None):
        super().__init__()
        self.embedding = nn.Embedding(num_embed, embed_dim, padding_idx=pad_idx)
        nn.init.normal_(self.embedding.weight, mean=0, std=embed_dim**-0.5)
        if pad_idx is not None:
            nn.init.constant_(self.embedding.weight[pad_idx], 0)

    def forward(self, x):
        return self.embedding(x)

def get_masks(seq_len, lengths, causal_att):
    """
    Create masks for hidden states and attention.
    """
    assert lengths.max().item() <= seq_len
    batch_size = lengths.size(0)
    seq_range = torch.arange(seq_len, dtype=torch.long, device=lengths.device)
    mask = seq_range < lengths[:, None]

    if causal_att:
        attn_mask = seq_range[None, None, :].repeat(batch_size, seq_len, 1) <= seq_range[None, :, None]
    else:
        attn_mask = mask

    return mask, attn_mask

class MultiHeadAttention(nn.Module):
    _id_counter = 0

    def __init__(self, n_heads, embed_dim, dropout):
        super().__init__()
        self.layer_id = MultiHeadAttention._id_counter
        MultiHeadAttention._id_counter += 1
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.dropout = dropout
        assert self.embed_dim % self.n_heads == 0

        self.q_lin = nn.Linear(embed_dim, embed_dim)
        self.k_lin = nn.Linear(embed_dim, embed_dim)
        self.v_lin = nn.Linear(embed_dim, embed_dim)
        self.out_lin = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, mask, kv=None, use_cache=False):
        """
        Compute self-attention or cross-attention.
        """
        batch_size, q_len, _ = query.size()
        k_len = q_len if kv is None else kv.size(1)

        head_dim = self.embed_dim // self.n_heads
        reshaped_mask = (batch_size, 1, q_len, k_len) if mask.dim() == 3 else (batch_size, 1, 1, k_len)

        q = self.q_lin(query).view(batch_size, q_len, self.n_heads, head_dim).transpose(1, 2)
        if kv is None:
            k = self.k_lin(query).view(batch_size, q_len, self.n_heads, head_dim).transpose(1, 2)
            v = self.v_lin(query).view(batch_size, q_len, self.n_heads, head_dim).transpose(1, 2)
        else:
            k = self.k_lin(kv).view(batch_size, k_len, self.n_heads, head_dim).transpose(1, 2)
            v = self.v_lin(kv).view(batch_size, k_len, self.n_heads, head_dim).transpose(1, 2)

        q = q / math.sqrt(head_dim)
        att_scores = torch.matmul(q, k.transpose(2, 3))
        mask = (mask == 0).view(reshaped_mask).expand_as(att_scores)
        att_scores.masked_fill_(mask, -float("inf"))

        att_weights = F.softmax(att_scores.float(), dim=-1).type_as(att_scores)
        att_weights = F.dropout(att_weights, p=self.dropout, training=self.training)
        context = torch.matmul(att_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim)

        return self.out_lin(context)

class FeedForward(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout):
        super().__init__()
        self.lin1 = nn.Linear(input_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = dropout

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        return F.dropout(x, p=self.dropout, training=self.training)

class TransformerModel(nn.Module):
    def __init__(self, half_prec, is_encoder, num_words, eos_idx, pad_idx, id2word, n_max_positions=4096,
                 embed_dim=512, n_heads=8, num_enc_layers=6, num_dec_layers=6,
                 dropout=0.1, att_dropout=0.1, max_src_len=512, with_output=True):
        """
        Transformer model for encoding or decoding.
        """
        super().__init__()

        self.dtype = torch.half if half_prec else torch.float
        self.is_encoder = is_encoder
        self.num_words = num_words
        self.eos_idx = eos_idx
        self.pad_idx = pad_idx
        self.id2word = id2word

        self.embed_dim = embed_dim
        self.hidden_dim = embed_dim * 4
        self.n_heads = n_heads
        self.n_layers = num_enc_layers if is_encoder else num_dec_layers
        self.dropout = dropout
        self.att_dropout = att_dropout
        self.max_src_len = max_src_len

        self.position_embeddings = EmbeddingCust(n_max_positions, embed_dim)
        self.embeddings = EmbeddingCust(num_words, embed_dim, pad_idx=pad_idx)
        self.emb_layer_norm = nn.LayerNorm(embed_dim, eps=1e-12)

        self.mhas = nn.ModuleList()
        self.layer_norms1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norms2 = nn.ModuleList()
        self.layer_norms_dec = nn.ModuleList() if not is_encoder else None
        self.encoder_atts = nn.ModuleList() if not is_encoder else None

        for _ in range(self.n_layers):
            self.mhas.append(MultiHeadAttention(n_heads, embed_dim, dropout=att_dropout))
            self.layer_norms1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.ffns.append(FeedForward(embed_dim, self.hidden_dim, embed_dim, dropout=dropout))
            self.layer_norms2.append(nn.LayerNorm(embed_dim, eps=1e-12))

            if not is_encoder:
                self.layer_norms_dec.append(nn.LayerNorm(embed_dim, eps=1e-12))
                self.encoder_atts.append(MultiHeadAttention(n_heads, embed_dim, dropout=att_dropout))

        self.cache = None
        self.final_lin = nn.Linear(embed_dim, num_words) if with_output else None

    def forward(self, mode, **kwargs):
        """
        Handles multiple forward modes.
        """
        if mode == "fwd":
            return self.fwd(**kwargs)
        elif mode == "predict":
            return self.predict(**kwargs)
        else:
            raise Exception("Unknown mode: %s" % mode)

    def fwd(self, x, lengths, causal_att, src_enc=None, src_len=None, positions=None, use_cache=False):
        """
        Inputs:
            `x` LongTensor(slen, batch_size), token indices
            `lengths` LongTensor(batch_size), sentence lengths
            `causal_att` Boolean, enables causal attention
            `positions` LongTensor(slen, batch_size), token positions
        """
        slen, batch_size = x.size()
        assert lengths.size(0) == batch_size
        assert lengths.max().item() <= slen
        x = x.transpose(0, 1)  # Set batch size as first dimension
        assert (src_enc is None) == (src_len is None)
        if src_enc is not None:
            assert self.is_encoder
            assert src_enc.size(0) == batch_size
        assert not (use_cache and self.cache is None)

        # Generate masks
        mask, attn_mask = get_masks(slen, lengths, causal_att)
        if self.is_encoder and src_enc is not None:
            if self.max_src_len > 0:
                src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < torch.clamp(src_len[:, None], max=self.max_src_len)
            else:
                src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]

        # Compute positions
        if positions is None:
            positions = x.new(slen).long()
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
            assert positions.size() == (slen, batch_size)
            positions = positions.transpose(0, 1)

        # Handle cached elements
        if use_cache:
            _slen = slen - self.cache["slen"]
            x = x[:, -_slen:]
            positions = positions[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # Compute embeddings
        tensor = self.embeddings(x) + self.position_embeddings(positions).expand_as(x)
        tensor = self.emb_layer_norm(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # Transformer layers
        for i in range(self.n_layers):
            # Multi-head self-attention
            self.mhas[i].cache = self.cache
            attn = self.mhas[i](tensor, attn_mask, use_cache=use_cache)
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norms1[i](tensor)

            # Cross-attention (only for decoder)
            if self.is_encoder and src_enc is not None:
                self.encoder_atts[i].cache = self.cache
                attn = self.encoder_atts[i](tensor, src_mask, kv=src_enc, use_cache=use_cache)
                attn = F.dropout(attn, p=self.dropout, training=self.training)
                tensor = tensor + attn
                tensor = self.layer_norms_dec[i](tensor)

            # Feedforward network
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norms2[i](tensor)

            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # Update cache
        if use_cache:
            self.cache["slen"] += tensor.size(1)

        return tensor.transpose(0, 1)

    def predict(self, tensor, pred_mask, y, get_scores):
        """
        Compute scores and/or loss from hidden states.
        """
        x = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.embed_dim)
        assert (y == self.pad_idx).sum().item() == 0
        scores = self.final_lin(x).view(-1, self.num_words)
        loss = F.cross_entropy(scores.float(), y, reduction="mean")
        return scores, loss

    def generate(self, src_enc, src_len, max_len=200, sample_temp=None):
        """
        Generate a sentence given the initial input.
        `x`:
            - LongTensor(batch_size, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3 W4 <EOS>
        `lengths`:
            - LongTensor(batch_size) [5, 6]
        `positions`:
            - False for regular position encoding (LM)
            - True to reset positions during generation (MT)
        """

        # Get the batch size and ensure the source encoding matches
        batch_size = len(src_len)
        assert src_enc.size(0) == batch_size

        # Prepare the tensor for generated sentences
        generated = src_len.new(max_len, batch_size)  # upcoming output
        generated.fill_(self.pad_idx)  # initialize with <PAD>
        generated[0].fill_(self.eos_idx)  # start with <EOS> as <BOS>

        # Initialize positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, batch_size)

        # Set up generation length tracking
        cur_len = 1
        gen_len = src_len.clone().fill_(1)
        unfinished_sents = src_len.clone().fill_(1)

        # Cache computation states
        self.cache = {"slen": 0}

        # Start the generation loop
        while cur_len < max_len:

            # Compute word scores
            tensor = self.forward(
                "fwd",
                x=generated[:cur_len],
                lengths=gen_len,
                positions=positions[:cur_len],
                causal_att=True,
                src_enc=src_enc,
                src_len=src_len,
                use_cache=True,
            )
            assert tensor.size() == (1, batch_size, self.embed_dim)
            tensor = tensor.data[-1, :, :]  # (batch_size, embed_dim)
            scores = self.final_lin(tensor)  # (batch_size, num_words)

            # Select the next words: sample or greedy
            if sample_temp is None:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            else:
                next_words = torch.multinomial(F.softmax(scores.float() / sample_temp, dim=1), 1).squeeze(1)
            assert next_words.size() == (batch_size,)

            # Update generations, lengths, unfinished sentences, and current length
            generated[cur_len] = next_words * unfinished_sents + self.pad_idx * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_idx).long())
            cur_len = cur_len + 1

            # Stop if each sentence has an </s> or max length is reached
            if unfinished_sents.max() == 0:
                break

        # Add <EOS> to unfinished sentences if max length is reached
        if cur_len == max_len:
            generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_idx)

        # Sanity check
        assert (generated == self.eos_idx).sum() == 2 * batch_size

        return generated[:cur_len], gen_len

    def generate_beam(self, src_enc, src_len, beam_size, len_penalty, early_stopping, max_len=200):
        """
        Generate a sequence using beam search decoding.

        `src_enc`:
            - LongTensor(batch_size, sequence_length)
            Contains source sentences, with <EOS> tokens at the beginning and end.

        `src_len`:
            - LongTensor(batch_size)
            Represents the lengths of the source sentences.

        `beam_size`:
            - Integer
            Specifies the number of hypotheses to maintain at each decoding step.

        `len_penalty`:
            - Float
            A penalty factor for the length of the decoded sequences.

        `early_stopping`:
            - Boolean
            If True, decoding stops once the best hypothesis has been found.

        `max_len`:
            - Integer, default=200
            The maximum length of the generated sequence.
        """

        # Ensure input consistency
        assert src_enc.size(0) == src_len.size(0), "Mismatch between batch size and source lengths"
        assert beam_size >= 1, "Beam size must be greater than or equal to 1"

        # Extract batch size and number of words in the vocabulary
        batch_size = len(src_len)
        num_words = self.num_words

        # Expand source encoder and lengths to match beam size
        src_enc = src_enc.unsqueeze(1).expand((batch_size, beam_size) + src_enc.shape[1:]).contiguous().view((batch_size * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(batch_size, beam_size).contiguous().view(-1)

        # Initialize output tensor and fill it with padding tokens
        generated = src_len.new(max_len, batch_size * beam_size)
        generated.fill_(self.pad_idx)
        generated[0].fill_(self.eos_idx)  # Start sequence with <EOS>

        # List to store generated hypotheses for each sentence
        generated_hyps = [BeamHypotheses(beam_size, max_len, len_penalty, early_stopping) for _ in range(batch_size)]

        # Prepare positions tensor for the sequence
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)

        # Initialize beam scores and set the first beam score to 0
        beam_scores = src_enc.new(batch_size, beam_size).float().fill_(0)
        beam_scores[:, 1:] = -1e9  # Prevent the second beam from being selected initially
        beam_scores = beam_scores.view(-1)

        # Initialize current length to 1 (since we start with <EOS>)
        cur_len = 1

        # Cache to store computed states
        self.cache = {"slen": 0}

        # Initialize 'done' flag for each sentence
        done = [False for _ in range(batch_size)]

        while cur_len < max_len:
            # Compute word scores using the model's forward pass
            tensor = self.forward(
                "fwd",
                x=generated[:cur_len],
                lengths=src_len.new(batch_size * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                causal_att=True,
                src_enc=src_enc,
                src_len=src_len,
                use_cache=True,
            )
            assert tensor.size() == (1, batch_size * beam_size, self.embed_dim)
            tensor = tensor.data[-1, :, :]  # Extract the relevant tensor from the last time step
            scores = self.final_lin(tensor)  # Apply final linear transformation to obtain scores
            scores = F.log_softmax(scores.float(), dim=-1)  # Apply log softmax to get probabilities

            # Calculate total scores by adding beam scores to the word scores
            _scores = scores + beam_scores[:, None].expand_as(scores)
            _scores = _scores.view(batch_size, beam_size * num_words)

            # Select the top 2 * beam_size scores
            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)

            # Prepare the next batch of beam content
            next_batch_beam = []

            # Iterate over each sentence in the batch
            for sent_id in range(batch_size):
                # Check if the sentence is already finished
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    # If done, pad the batch for this sentence
                    next_batch_beam.extend([(0, self.pad_idx, 0)] * beam_size)
                    continue

                # Prepare the next sentence's beam content
                next_sent_beam = []

                # Process each potential next word for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
                    beam_id = idx // num_words
                    word_id = idx % num_words

                    # If end of sentence or reached max length, finalize hypothesis
                    if word_id == self.eos_idx or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone().cpu(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))

                    # Ensure the beam is full
                    if len(next_sent_beam) == beam_size:
                        break

                # Pad the batch if needed
                assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_idx, 0)] * beam_size
                next_batch_beam.extend(next_sent_beam)

            # Sanity check on the next batch
            assert len(next_batch_beam) == batch_size * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # Re-order batch and update internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in self.cache.keys():
                if k != "slen":
                    self.cache[k] = (self.cache[k][0][beam_idx], self.cache[k][1][beam_idx])

            # Update current length
            cur_len += 1

            # Break if all sentences are done
            if all(done):
                break

        # Select the best hypothesis for each sentence
        tgt_len = src_len.new(batch_size)
        best = []
        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # Prepare the final output
        decoded = src_len.new(tgt_len.max().item(), batch_size).fill_(self.pad_idx)
        for i, hypo in enumerate(best):
            decoded[: tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_idx

        # Final check to ensure all sequences end with <EOS>
        assert (decoded == self.eos_idx).sum() == 2 * batch_size

        return decoded, tgt_len, generated_hyps

class BeamHypotheses(object):

    def __init__(self, n_hyp, max_len, len_penalty, early_stopping):
        """
        Initializes the n-best list of hypotheses for beam search.
        """
        self.max_len = max_len - 1  # Exclude <BOS> from max length
        self.len_penalty = len_penalty
        self.early_stopping = early_stopping
        self.n_hyp = n_hyp
        self.hyp = []  # List to store hypotheses
        self.worst_score = 1e9  # Initially set worst score to a very high value

    def __len__(self):
        """
        Returns the number of hypotheses stored in the list.
        """
        return len(self.hyp)

    def add(self, hyp, sum_logprobs):
        """
        Adds a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.len_penalty
        if len(self) < self.n_hyp or score > self.worst_score:
            self.hyp.append((score, hyp))
            if len(self) > self.n_hyp:
                # Remove the worst hypothesis if the list exceeds n_hyp
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
                del self.hyp[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs):
        """
        Checks if we are done generating hypotheses for this sentence.
        """
        if len(self) < self.n_hyp:
            return False
        elif self.early_stopping:
            return True
        else:
            return self.worst_score >= best_sum_logprobs / self.max_len**self.len_penalty

In [None]:
max_degree = 6
int_base = 1000
max_len = 1024
max_output_len = 512

operators = ['+', '-', '*', '/', '^', 'sqrt', 'exp', 'ln', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'Abs']
variables = [f"x{i}" for i in range(2 * max_degree)]
constants = ["pi", "E"]
symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
elements = [str(i) for i in range(max(10, int_base))]
SPECIAL_WORDS = ["<s>", "</s>", "<pad>", "(", ")"] + [f"<SPECIAL_{i}>" for i in range(10)]
func_separator = "<SPECIAL_3>"

words = SPECIAL_WORDS + constants + variables + operators + symbols + elements
id2word = {i: s for i, s in enumerate(words)}
word2id = {s: i for i, s in id2word.items()}
n_words = len(words)
eos_index = 0
pad_index = 1

In [None]:
class EnvDataset(Dataset):
    """
    Helper class for creating datasets.
    """
    def __init__(self, num_workers, path, word2id, train, max_len=1024, max_output_len=512, reload_size=-1, size=None):
        super().__init__()
        self.num_workers = num_workers
        self.path = path
        self.train = train
        self.count = 0
        assert size is None

        self.max_len = max_len
        self.max_output_len = max_output_len
        self.word2id = word2id

        self.func_separator = "<SPECIAL_3>"
        self.eos_index = 0
        self.pad_index = 1

        # reloading from file
        if path is not None:
            assert os.path.isfile(path)
            print(f"Loading data from {path} ...")
            with io.open(path, mode="r", encoding="utf-8") as f:
                # either reload the entire file, or the first N lines (for the training set)
                if not train:
                    lines = [line.rstrip().split("|") for line in f]
                else:
                    lines = []
                    for i, line in enumerate(f):
                        if i == reload_size:
                            break
                        lines.append(line.rstrip().split("|"))
            self.data = [xy.split("\t") for _, xy in lines]
            self.data = [xy for xy in self.data if len(xy) == 2]
            print(f"Loaded {len(self.data)} equations from the disk.")

        # dataset size: infinite iterator for train, finite for validation (default of 5000 if no file provided)
        if self.train:
            self.size = 1 << 60
        elif size is None:
            self.size = 5000 if path is None else len(self.data)
        else:
            assert size > 0
            self.size = size

    def batch_sequences(self, sequences):
        """
        Take as input a list of n sequences (torch.LongTensor vectors) and return
        a tensor of size (slen, n) where slen is the length of the longest
        sentence, and a vector lengths containing the length of each sentence.
        """
        lengths = torch.LongTensor([len(s) + 2 for s in sequences])
        sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index)
        assert lengths.min().item() > 2

        sent[0] = self.eos_index
        for i, s in enumerate(sequences):
            sent[1 : lengths[i] - 1, i].copy_(s)
            sent[lengths[i] - 1, i] = self.eos_index

        return sent, lengths

    def collate_fn(self, elements):
        """
        Collate samples into a batch.
        """
        x, y = zip(*elements)
        nb_eqs = [seq.count(self.func_separator) for seq in x]
        x = [torch.LongTensor([self.word2id[w] for w in seq]) for seq in x]
        y = [torch.LongTensor([self.word2id[w] for w in seq]) for seq in y]
        x, x_len = self.env.batch_sequences(x)
        y, y_len = self.env.batch_sequences(y)
        return (x, x_len), (y, y_len), torch.LongTensor(nb_eqs)

    def get_worker_id(self):
        """
        Get worker ID.
        """
        if not self.train:
            return 0
        worker_info = torch.utils.data.get_worker_info()
        assert (worker_info is None) == (self.num_workers == 0)
        return 0 if worker_info is None else worker_info.id

    def __len__(self):
        """
        Return dataset size.
        """
        return self.size

    def __getitem__(self, index):
        """
        Return a training  by reading it from a file.
        """
        return self.read_sample(index)

    def read_sample(self, index):
        """
        Read a sample.
        """
        while True:
            if self.train:
                index = self.env.rng.randint(len(self.data))
            x, y = self.data[index]
            x = x.split()
            y = y.split()
            if (self.max_len > 0 and len(x) >= self.max_len) or (self.max_output_len > 0 and len(y) >= self.max_output_len):
                index += 1
                continue
            return x, y

In [None]:
def create_train_iterator(data_path, num_workers, batch_size, word2id):
    """
    Create a training dataset.
    """
    print(f"Creating train iterator...")

    if num_workers is None:
        num_workers = min(4, os.cpu_count() or 1)

    dataset = EnvDataset(path=data_path, num_workers=num_workers, word2id=word2id, train=True)
    return DataLoader(
        dataset,
        timeout=(0 if num_workers == 0 else 86400),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )

In [None]:
def get_batch(dataloader):
        """
        Return a training batch for a specific task.
        """
        try:
            batch = next(dataloader)
        except Exception as e:
            print(
                "An unknown exception occurred when fetching batch. "
            )
            raise
        return batch

In [None]:
def to_cuda(*args):
    """
    Move tensors to CUDA.
    """
    if not torch.cuda.is_available():
        return args
    return [None if x is None else x.cuda() for x in args]

def enc_dec_step(encoder, decoder, dataloader, optimizer, clip_grad_norm=5):
    """
    Encoding / decoding step.
    """
    encoder.train()
    decoder.train()

    # batch
    (x1, len1), (x2, len2), nb_ops = get_batch(dataloader)

    # cuda
    x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2)

    # target words to predict
    alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
    pred_mask = alen[:, None] < len2[None] - 1  # do not predict anything given the last target word

    y = x2[1:].masked_select(pred_mask[:-1])

    assert len(y) == (len2 - 1).sum().item()

    # forward / loss
    encoded = encoder("fwd", x=x1, lengths=len1, causal_att=False)
    decoded = decoder("fwd", x=x2, lengths=len2, causal_att=True, src_enc=encoded.transpose(0, 1), src_len=len1)
    _, loss = decoder("predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False)

    # check NaN
    if (loss != loss).data.any():
        print("NaN detected")

    parameters = {
        k: p for v in [encoder, decoder] for k, p in v.named_parameters() if p.requires_grad
    }
    model_params = list(parameters.values())

    # regular optimization
    optimizer.zero_grad()
    loss.backward()
    if clip_grad_norm > 0:
        clip_grad_norm_(model_params, clip_grad_norm)
    optimizer.step()

    return loss

def save_checkpoint(name, epoch, n_total_iter, best_metrics, encoder, decoder, optimizer, dump_path=""):
        """
        Save the model / checkpoints.
        """
        path = os.path.join(dump_path, "%s.pth" % name)
        print("Saving %s to %s ..." % (name, path))

        data = {
            "epoch": epoch,
            "n_total_iter": n_total_iter,
            "best_metrics": best_metrics,
        }

        print(f"Saving encoder parameters ...")
        data["encoder"] = encoder.state_dict()
        print(f"Saving decoder parameters ...")
        data["decoder"] = decoder.state_dict()

        print("Saving optimizer ...")
        data["optimizer"] = optimizer.state_dict()

        torch.save(data, path)

def save_best_model(scores, metrics, best_metrics, epoch, n_total_iter, encoder, decoder, optimizer):
    """
    Save best models according to given validation metrics.
    """
    for metric, biggest in metrics:
        if metric not in scores:
            print('Metric "%s" not found in scores!' % metric)
            continue
        factor = 1 if biggest else -1
        if factor * scores[metric] > factor * best_metrics[metric]:
            best_metrics[metric] = scores[metric]
            print("New best score for %s: %.6f" % (metric, scores[metric]))
            save_checkpoint("best-%s" % metric, epoch, n_total_iter, best_metrics, encoder, decoder, optimizer, dump_path="")

def save_periodic(save_periodic, epoch, n_total_iter, best_metrics, encoder, decoder, optimizer):
    """
    Save the models periodically.
    """
    if save_periodic > 0 and epoch % params.save_periodic == 0:
        save_checkpoint("periodic-%i" % epoch, epoch, n_total_iter, best_metrics, encoder, decoder, optimizer, dump_path="")

In [None]:
parser = get_parser()
params = parser.parse_args()
helper_env = ODEEnvironment(params)

def idx_to_infix(env, idx, input=True):
    """
    Convert an indexed prefix expression to SymPy.
    """
    prefix = [env.id2word[wid] for wid in idx]
    infix = env.input_to_infix(prefix) if input else env.output_to_infix(prefix)
    return infix

def check_hypothesis(eq):
    """
    Check a hypothesis for a given equation and its solution.
    """
    global helper_env
    helper_env.rng = np.random.RandomState(0)
    src = [helper_env.id2word[wid] for wid in eq["src"]]
    tgt = [helper_env.id2word[wid] for wid in eq["tgt"]]
    hyp = [helper_env.id2word[wid] for wid in eq["hyp"]]

    try:
        is_valid = helper_env.check_lyap_validity(src, hyp, tgt)
    except MyTimeoutError:
        is_valid = -3
    except Exception as e:
        is_valid = -4

    # update hypothesis
    eq["src"] = helper_env.input_to_infix(src)
    eq["tgt"] = tgt
    eq["hyp"] = hyp
    eq["is_valid"] = is_valid
    return eq

def enc_dec_step_beam(data_type, data_path_idx, scores, encoder, decoder, data_path, max_output_len=512,
                      eval_verbose=True, dump_path="", beam_size=50, batch_size_eval=16, beam_length_penalty=1,
                      beam_early_stopping=True, size=None):

    """
    Encoding / decoding step with beam generation and SymPy check.
    """
    global helper_env

    encoder.eval()
    decoder.eval()
    max_beam_length = max_output_len + 2
    # evaluation details
    if eval_verbose:
        eval_path = os.path.join(dump_path, f"eval.beam.{data_type}.{data_path_idx}.{scores['epoch']}")
        f_export = open(eval_path, "w")
        print(f"Writing evaluation results in {eval_path} ...")

    def display_logs(logs, offset):
        """
        Display detailed results about success / fails.
        """
        if eval_verbose == 0:
            return
        for i, res in sorted(logs.items()):
            n_valid = sum([int(v) for _, _, v in res["hyps"]])
            s = f"Equation {offset + i} ({n_valid}/{len(res['hyps'])})\nsrc={res['src']}\ntgt={res['tgt']}\n"
            for hyp, score, valid in res["hyps"]:
                if score is None:
                    s += f"{int(valid)} {hyp}\n"
                else:
                    s += f"{int(valid)} {score :.3e} {hyp}\n"
            f_export.write(s + "\n")
            f_export.flush()

    # stats
    xe_loss = 0
    n_valid = torch.zeros(1000, beam_size, dtype=torch.long)
    n_total = torch.zeros(1000, dtype=torch.long)
    n_perfect_match = 0
    n_correct = 0
    n_timeout = 0
    n_optim = 0
    n_input_err = 0
    n_other_err = 0

    # iterator
    iterator = helper_env.create_test_iterator(
        data_type,
        "ode_lyapunov",
        data_path=data_path,
        data_path_idx=data_path_idx,
        batch_size=batch_size_eval,
        params=params,
        size=size,
    )
    eval_size = len(iterator.dataset)

    for (x1, len1), (x2, len2), nb_ops in iterator:

        # cuda
        x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2)

        # target words to predict
        alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
        pred_mask = alen[:, None] < len2[None] - 1  # do not predict anything given the last target word
        y = x2[1:].masked_select(pred_mask[:-1])
        assert len(y) == (len2 - 1).sum().item()

        x1_, len1_ = x1, len1

        bs = len(len1)

        # forward
        encoded = encoder("fwd", x=x1_, lengths=len1_, causal_att=False)
        decoded = decoder("fwd", x=x2, lengths=len2, causal_att=True, src_enc=encoded.transpose(0, 1), src_len=len1_)
        word_scores, loss = decoder("predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=True)

        # correct outputs per sequence / valid top-1 predictions
        t = torch.zeros_like(pred_mask, device=y.device)
        t[pred_mask] += word_scores.max(1)[1] == y
        valid = (t.sum(0) == len2 - 1).cpu().long()
        n_perfect_match += valid.sum().item()

        # save evaluation details
        beam_log = {}
        for i in range(len(len1)):
            src = idx_to_infix(helper_env, x1[1 : len1[i] - 1, i].tolist(), True)
            tgt = idx_to_infix(helper_env, x2[1 : len2[i] - 1, i].tolist(), False)
            if valid[i]:
                beam_log[i] = {"src": src, "tgt": tgt, "hyps": [(tgt, None, True)]}

        # stats
        xe_loss += loss.item() * len(y)
        n_valid[:, 0].index_add_(-1, nb_ops, valid)
        n_total.index_add_(-1, nb_ops, torch.ones_like(nb_ops))

        # continue if everything is correct. if eval_verbose, perform
        # a full beam search, even on correct greedy generations
        if valid.sum() == len(valid) and eval_verbose < 2:
            display_logs(beam_log, offset=n_total.sum().item() - bs)
            continue

        # invalid top-1 predictions - check if there is a solution in the beam
        invalid_idx = (1 - valid).nonzero().view(-1)
        print(
            f"({n_total.sum().item()}/{eval_size}) Found {bs - len(invalid_idx)}/{bs} " f"valid top-1 predictions. Generating solutions ..."
        )

        max_beam_length = max_output_len + 2
        # generate
        _, _, generations = decoder.generate_beam(
            encoded.transpose(0, 1),
            len1_,
            beam_size=beam_size,
            length_penalty=beam_length_penalty,
            early_stopping=beam_early_stopping,
            max_len=max_beam_length,
        )
        # prepare inputs / hypotheses to check
        # if eval_verbose < 2, no beam search on equations solved greedily
        inputs = []
        for i in range(len(generations)):
            if valid[i] and eval_verbose < 2:
                continue
            for j, (score, hyp) in enumerate(sorted(generations[i].hyp, key=lambda x: x[0], reverse=True)):
                inputs.append(
                    {
                        "i": i,
                        "j": j,
                        "score": score,
                        "src": x1[1 : len1[i] - 1, i].tolist(),
                        "tgt": x2[1 : len2[i] - 1, i].tolist(),
                        "hyp": hyp[1:].tolist(),
                        "task": task,
                    }
                )

        # check hypotheses with multiprocessing
        outputs = []
        with ProcessPoolExecutor(max_workers=20) as executor:
            for output in executor.map(check_hypothesis, inputs, chunksize=1):
                outputs.append(output)

        # read results
        for i in range(bs):

            # select hypotheses associated to current equation
            gens = sorted([o for o in outputs if o["i"] == i], key=lambda x: x["j"])
            assert (len(gens) == 0) == (valid[i] and eval_verbose < 2) and (i in beam_log) == valid[i]
            if len(gens) == 0:
                continue

            # source / target
            src = gens[0]["src"]
            tgt = gens[0]["tgt"]
            beam_log[i] = {"src": src, "tgt": tgt, "hyps": []}

            # for each hypothesis
            for j, gen in enumerate(gens):

                # sanity check
                assert gen["src"] == src and gen["tgt"] == tgt and gen["i"] == i and gen["j"] == j

                # if the hypothesis is correct, and we did not find a correct one before
                is_valid = gen["is_valid"]
                if is_valid == 1 and not valid[i]:
                    n_valid[nb_ops[i], j] += 1
                    valid[i] = 1

                # update beam log
                beam_log[i]["hyps"].append((gen["hyp"], gen["score"], is_valid))
                if j == 0:
                    n_correct += is_valid != -2
                    n_timeout += is_valid == -3
                    n_optim += is_valid == -1
                    n_input_err += is_valid == -5
                    n_other_err += is_valid == -4

        # valid solutions found with beam search
        print(f"    Found {valid.sum().item()}/{bs} solutions in beam hypotheses.")

        # export evaluation details
        if eval_verbose:
            assert len(beam_log) == bs
            display_logs(beam_log, offset=n_total.sum().item() - bs)

    # evaluation details
    if eval_verbose:
        f_export.close()
        print(f"Evaluation results written in {eval_path}")

    # log
    _n_valid = n_valid.sum().item()
    _n_total = n_total.sum().item()
    print(f"{_n_valid}/{_n_total} ({100. * _n_valid / _n_total}%) equations were evaluated correctly.")

    # compute perplexity and prediction accuracy
    assert _n_total == eval_size

    data_path_idx_scores = ""
    task = "ode_lyapunov"
    scores[f"{data_type}_{task}_{data_path_idx_scores}xe_loss"] = xe_loss / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}beam_acc"] = 100.0 * _n_valid / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}perfect"] = 100.0 * n_perfect_match / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}correct"] = 100.0 * (n_perfect_match + n_correct) / _n_total

    scores[f"{data_type}_{task}_{data_path_idx_scores}optim"] = 100.0 * n_optim / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}timeout"] = 100.0 * n_timeout / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}input_err"] = 100.0 * n_input_err / _n_total
    scores[f"{data_type}_{task}_{data_path_idx_scores}other_err"] = 100.0 * n_other_err / _n_total

    # per class perplexity and prediction accuracy
    for i in range(len(n_total)):
        if n_total[i].item() == 0:
            continue
        print(f"{i}: {n_valid[i].sum().item()} / {n_total[i].item()} " f"({100. * n_valid[i].sum().item() / max(n_total[i].item(), 1)}%)")
        scores[f"{data_type}_{task}_{data_path_idx_scores}beam_acc_{i}"] = 100.0 * n_valid[i].sum().item() / max(n_total[i].item(), 1)

In [None]:
def run_all_evals(epoch, encoder, decoder, eval_size):
        """
        Run all evaluations.
        """
        scores = OrderedDict({"epoch": epoch})
        with torch.no_grad():
            eval_tasks = [["valid", 1]]
            for data_type, data_path_idx in eval_tasks:
                enc_dec_step_beam(data_type, data_path_idx, scores, encoder, decoder, eval_size)
        return scores

In [None]:
# training
max_epoch = 3
epoch = 0
epoch_size = 200
n_total_iter = 0
eval_size =

# validation metrics
validation_metrics = "valid_ode_lyapunov_beam_acc"
metrics = []
metrics = [m for m in validation_metrics.split(",") if m != ""]
for m in metrics:
    m = (m[1:], False) if m[0] == "_" else (m, True)
    metrics.append(m)
best_metrics = {metric: (-1e12 if biggest else 1e12) for (metric, biggest) in metrics}

batch_size = 16
data_path = "FLyap.txt"
dataloader = iter(create_train_iterator(data_path=data_path, num_workers=None, batch_size=batch_size))
optimizer =

encoder = TransformerModel(params, id2word, is_encoder=True, with_output=False)
decoder = TransformerModel(params, id2word, is_encoder=False, with_output=True)

last_time = time.time()

for _ in range(max_epoch):

    print("============ Starting epoch %i ... ============" % epoch)

    n_equations = 0
    n_iter = 0

    while n_equations < epoch_size:

        # training steps
        loss = enc_dec_step(encoder, decoder, dataloader, optimizer)
        n_equations += batch_size
        n_iter += 1
        n_total_iter += 1 #print some stats

        if n_total_iter % 20 == 0:
          new_time = time.time()
          diff = new_time - last_time
          print('Training speed is ', 20/diff, ' iter/s. The loss in ', n_iter, '. iteration: ', loss.item())
          last_time = new_time

    print("============ End of epoch %i ============" % epoch)

    # evaluate perplexity
    scores = run_all_evals(epoch, encoder, decoder, eval_size)

    # end of epoch
    save_best_model(scores, metrics, best_metrics, epoch, n_total_iter, encoder, decoder, optimizer)
    save_periodic(save_periodic, epoch, n_total_iter, best_metrics, encoder, decoder, optimizer)
    save_checkpoint("checkpoint", epoch, n_total_iter, best_metrics, encoder, decoder, optimizer, dump_path="")
    epoch += 1

In [None]:
# evaluate perplexity
if params.is_master:
    scores = evaluator.run_all_evals()
    logger.info(scores)

    # print / JSON log
    for k, v in scores.items():
        logger.info("%s -> %.6f" % (k, v))
    logger.info("__log__:%s" % json.dumps(scores))

    # end of epoch
    trainer.save_best_model(scores)
    trainer.save_periodic()
    trainer.end_epoch(scores)

In [None]:
import re
import sympy as sp
from collections import OrderedDict

class TreeNode:
    def __init__(self, value):
        self.value = value
        self.children = []

    def add_child(self, child):
        self.children.append(child)

    def __repr__(self):
        return f"TreeNode({self.value}, {self.children})"

class MathTokenizer:
    def __init__(self, params, max_degree=5, int_base=1000):
        self.max_degree = max_degree
        self.int_base = int_base

        self.operators = ['+', '-', '*', '/', '^', 'sqrt', 'exp', 'ln', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'Abs']
        self.variables = [f"x{i}" for i in range(2 * self.max_degree)]
        self.constants = ["pi", "E"]
        self.symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
        self.elements = [str(i) for i in range(max(10, self.int_base))]
        self.SPECIAL_WORDS = ["<s>", "</s>", "<pad>", "(", ")"] + [f"<SPECIAL_{i}>" for i in range(10)]
        self.func_separator = ["SEP"]

        self.words = self.SPECIAL_WORDS + self.constants + self.variables + self.operators + self.symbols + self.elements + self.func_separator
        self.id2word = {i: s for i, s in enumerate(self.words)}
        self.word2id = {s: i for i, s in self.id2word.items()}
        self.n_words = len(self.words)
        self.eos_index = 0
        self.pad_index = 1


    def parse_expression(expr):
        """
        Parse an infix expression into a tree structure using recursive descent.
        Supports numbers, variables, operators (+, -, *, /) and functions (cos, sin).
        """
        tokens = re.findall(r'sqrt|exp|ln|sin|cos|tan|asin|acos|atan|Abs|[+\-*/^()]|x\d+|pi|E|\d+\.?\d*', expr)
        pos = 0

        def peek():
            return tokens[pos] if pos < len(tokens) else None

        def consume(expected=None):
            nonlocal pos
            token = tokens[pos]
            if expected and token != expected:
                raise ValueError(f"Expected {expected} but got {token}")
            pos += 1
            return token

        def parse_primary():
            token = peek()
            if token is None:
                raise ValueError("Unexpected end of expression")
            if token == '(':
                consume('(')
                node = parse_expr()
                consume(')')
                return node
            elif token in {'cos', 'sin'}:
                # Function call: function may be followed by a parenthesized argument.
                func = consume()
                node = TreeNode(func)
                if peek() == '(':
                    consume('(')
                    node.add_child(parse_expr())
                    consume(')')
                else:
                    node.add_child(parse_primary())
                return node
            else:
                # variable or constant
                return TreeNode(consume())

        def parse_factor():
            # For now, factor is simply a primary; we assume no unary '-' separate from number tokens.
            return parse_primary()

        def parse_term():
            node = parse_factor()
            while peek() in ('*', '/'):
                op = consume()
                new_node = TreeNode(op)
                new_node.add_child(node)
                new_node.add_child(parse_factor())
                node = new_node
            return node

        def parse_expr():
            node = parse_term()
            while peek() in ('+', '-'):
                op = consume()
                new_node = TreeNode(op)
                new_node.add_child(node)
                new_node.add_child(parse_term())
                node = new_node
            return node

        return parse_expr()

    def enumerate_tree(self, tree):
        """
        Enumerate the tree in Polish (pre-order) notation.
        """
        if not tree:
            return []
        result = [tree.value]
        for child in tree.children:
            result.extend(self.enumerate_tree(child))
        return result

    def tokenize_integer(self, n):
        """
        Tokenize an integer in base 1000.
        """
        if n == 0:
            return ['0']
        tokens = []
        sign = '+' if n >= 0 else '-'
        n = abs(n)
        while n > 0:
            remainder = n % 1000
            tokens.append(str(remainder))
            n = n // 1000
        tokens.append(sign)
        return tokens[::-1]  # Reverse to get the correct order

    def tokenize_real(self, x):
        """
        Tokenize a real number in scientific notation.
        For example, 2.1 is represented as 21 * 10^(-1)
        """
        sign = '+' if x >= 0 else '-'
        x = abs(x)
        s = str(x)
        if '.' in s:
            integer_part, fractional_part = s.split('.')
            # Remove any trailing zeros for correct exponent computation:
            fractional_part = fractional_part.rstrip('0')
            if fractional_part == '':
                mantissa = integer_part
                exponent = 0
            else:
                mantissa = integer_part + fractional_part
                exponent = -len(fractional_part)
        else:
            mantissa = s
            exponent = 0
        return self.tokenize_integer(int(mantissa)) + ['10^'] + self.tokenize_integer(exponent)

    def tokenize_expression(self, expression):
        """
        Convert the enumerated expression (in Polish notation) into tokens.
        Handles operators, variables, and constants (both integer and real).
        """
        tokens = []
        for token in expression:
            if token in OPERATORS or token in VARIABLES:
                tokens.append(token)
            elif re.match(r'^-?\d+$', token):  # Integer constant
                tokens.extend(self.tokenize_integer(int(token)))
            elif re.match(r'^-?\d+\.\d*$', token):  # Real constant
                tokens.extend(self.tokenize_real(float(token)))
            else:
                raise ValueError(f"Unknown token: {token}")
        return tokens

    def tokenize(self, expressions):
        """
        Tokenize a list of expressions.
        """
        tokenized_expressions = []
        for expression in expressions:
            tree = self.parse_expression(expression)
            enumerated = self.enumerate_tree(tree)
            tokens = self.tokenize_expression(enumerated)
            tokenized_expressions.extend(tokens + self.func_separator)
        return tokenized_expressions[:-1]  # Remove the last SEP

# Example usage
expressions = [
    "cos(2.1 * x0) * (x1 + 2)",
    "sin(3 * x1 + 2)"
]

tokenized = tokenize(expressions)
print(tokenized)

NameError: name 'tokenize' is not defined

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1, layer_id=None):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.layer_id = layer_id  # Unique identifier for caching
        self.dropout = dropout

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

        # KV cache dictionary
        self.cache = {}

    def reset_cache(self):
        """Reset the KV cache."""
        self.cache = {}

    def forward(self, query, key=None, value=None, mask=None, is_causal=False, use_cache=False):
        batch_size, qlen, dim = query.shape

        # If key and value are not provided, assume self-attention
        if key is None:
            key = query
        if value is None:
            value = query

        # Linear transformations
        Q = self.query(query)  # (batch_size, qlen, embed_dim)
        K = self.key(key)      # (batch_size, klen, embed_dim)
        V = self.value(value)  # (batch_size, klen, embed_dim)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, qlen, head_dim)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, klen, head_dim)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # (batch_size, num_heads, klen, head_dim)

        # KV caching logic: for self-attention in decoding, accumulate keys/values if qlen==1
        if use_cache:
            if self.layer_id in self.cache:
                cached_K, cached_V = self.cache[self.layer_id]
                if qlen == 1:  # incremental decoding: append new keys/values
                    K = torch.cat([cached_K, K], dim=2)
                    V = torch.cat([cached_V, V], dim=2)
                else:
                    # For full sequence processing, reuse cached values
                    K, V = cached_K, cached_V
            self.cache[self.layer_id] = (K, V)

        klen = K.size(2)

        # Scaled dot-product attention
        Q = Q / math.sqrt(self.head_dim)
        scores = torch.matmul(Q, K.transpose(2, 3))  # (batch_size, num_heads, qlen, klen)

        # Apply causal mask if needed
        if is_causal:
            causal_mask = torch.triu(torch.ones(qlen, klen, device=query.device), diagonal=1).bool()
            scores = scores.masked_fill(causal_mask, -float("inf"))

        # Apply additional mask if provided
        if mask is not None:
            # If mask is 3D assume (batch, qlen, klen), else 2D assumed to be (batch, klen)
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)  # (batch, 1, qlen, klen)
            else:
                mask = mask.unsqueeze(1).unsqueeze(1)  # (batch, 1, 1, klen)
            # Assume mask==0 means masked position.
            scores = scores.masked_fill((mask == 0), -float("inf"))

        # Compute attention weights
        weights = F.softmax(scores.float(), dim=-1).type_as(scores)
        weights = F.dropout(weights, p=self.dropout, training=self.training)

        # Compute context vector and reassemble
        context = torch.matmul(weights, V)  # (batch_size, num_heads, qlen, head_dim)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        output = self.out(context)
        return output

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim, dropout=0.1):
        """
        Args:
            embed_dim (int): Dimensionality of the input and output.
            ff_dim (int): Dimensionality of the hidden layer.
            dropout (float): Dropout probability.
        """
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)  # First linear transformation
        self.linear2 = nn.Linear(ff_dim, embed_dim)  # Second linear transformation
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()  # Non-linear activation function

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, embed_dim).
        """
        # First linear layer + activation
        x = self.activation(self.linear1(x))  # (batch_size, seq_len, ff_dim)
        # Second linear layer + dropout
        x = self.linear2(x)  # (batch_size, seq_len, embed_dim)
        x = self.dropout(x)  # Apply dropout after the second linear layer
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1, is_decoder=False, layer_id=None):
        super(TransformerBlock, self).__init__()
        self.is_decoder = is_decoder

        # For decoder blocks, use distinct layer IDs for self and cross attention
        if is_decoder:
            self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout, layer_id=f"{layer_id}_self")
            self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout, layer_id=f"{layer_id}_cross")
            self.norm2 = nn.LayerNorm(embed_dim)
            self.dropout2 = nn.Dropout(dropout)
        else:
            self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout, layer_id=layer_id)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.feed_forward = FeedForward(embed_dim, ff_dim, dropout)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output=None, src_mask=None, tgt_mask=None, use_cache=False):
        # Self-attention
        attn_output = self.self_attn(x, mask=tgt_mask, is_causal=self.is_decoder, use_cache=use_cache)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)

        # Cross-attention for decoder blocks
        if self.is_decoder:
            if encoder_output is None:
                raise ValueError("encoder_output must be provided for decoder blocks")
            # Use src_mask generated from source lengths (if available) for cross-attention
            cross_attn_output = self.cross_attn(x, key=encoder_output, value=encoder_output, mask=src_mask, use_cache=use_cache)
            x = x + self.dropout2(cross_attn_output)
            x = self.norm2(x)

        # Feedforward network
        ff_output = self.feed_forward(x)
        x = x + self.dropout3(ff_output)
        x = self.norm3(x)

        return x

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim, num_heads, ff_dim, num_layers, dropout=0.1, is_decoder=False, padding_idx=None):
        super(TransformerModel, self).__init__()
        self.embed_dim = embed_dim
        self.is_decoder = is_decoder
        self.padding_idx = padding_idx

        # Embedding layers
        self.src_embed = nn.Embedding(src_vocab_size, embed_dim, padding_idx=padding_idx)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, embed_dim, padding_idx=padding_idx)
        self.pos_embed = nn.Parameter(torch.zeros(1, 1000, embed_dim))  # Learned positional embeddings

        self.embed_norm = nn.LayerNorm(embed_dim)
        self.embed_dropout = nn.Dropout(dropout)

        # Encoder stack
        self.encoder = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout, is_decoder=False, layer_id=i)
            for i in range(num_layers)
        ])
        # Decoder stack (if applicable)
        if self.is_decoder:
            self.decoder = nn.ModuleList([
                TransformerBlock(embed_dim, num_heads, ff_dim, dropout, is_decoder=True, layer_id=i)
                for i in range(num_layers)
            ])

        # Final output layer
        self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)

    def reset_cache(self):
        """Reset the KV cache for all MultiHeadAttention modules."""
        for block in self.encoder:
            block.self_attn.reset_cache()
        if self.is_decoder:
            for block in self.decoder:
                block.self_attn.reset_cache()
                block.cross_attn.reset_cache()

    def generate_padding_mask(self, x):
        """
        Generate a padding mask for the input tensor `x` of shape (bs, slen).
        Returns a mask of shape (bs, slen) where non-padding tokens are 1.
        """
        return (x != self.padding_idx).float()

    def generate_src_mask(self, src_len, max_src_len=None):
        """
        Generate a source mask for the encoder output.
        Args:
            src_len (torch.Tensor): Lengths of the source sequences.
            max_src_len (int, optional): Maximum source length to consider.
        Returns:
            torch.Tensor: A boolean mask of shape (bs, max_src_len) where True indicates valid tokens.
        """
        if max_src_len is not None:
            src_len = torch.clamp(src_len, max=max_src_len)
        max_len = int(src_len.max().item())
        return torch.arange(max_len, device=src_len.device)[None, :] < src_len[:, None]

    def forward(self, src, src_len=None, tgt=None, src_mask=None, tgt_mask=None, use_cache=False, compute_loss=False):
        """
        Args:
            src (torch.Tensor): Source sequence (bs, src_len).
            src_len (torch.Tensor, optional): Lengths of the source sequences.
            tgt (torch.Tensor, optional): Target sequence (bs, tgt_len).
            src_mask (torch.Tensor, optional): External mask for the encoder input.
            tgt_mask (torch.Tensor, optional): External mask for the decoder input.
            use_cache (bool): Whether to use KV caching.
            compute_loss (bool): Whether to compute the cross-entropy loss.
        Returns:
            encoder_output, decoder_output, loss (if computed)
        """
        # Reset cache at the start of a new sequence if caching is enabled
        if use_cache:
            self.reset_cache()

        # Generate padding masks from input tokens
        src_pad_mask = self.generate_padding_mask(src)
        if tgt is not None:
            tgt_pad_mask = self.generate_padding_mask(tgt)
        else:
            tgt_pad_mask = None

        # Generate source mask from lengths for cross-attention (if provided)
        if self.is_decoder and src_len is not None:
            src_enc_mask = self.generate_src_mask(src_len)
        else:
            src_enc_mask = src_pad_mask  # fallback to padding mask

        # Source embeddings with positional encoding
        src_seq_len = src.size(1)
        src_emb = self.src_embed(src) + self.pos_embed[:, :src_seq_len, :]
        src_emb = self.embed_norm(src_emb)
        src_emb = self.embed_dropout(src_emb)

        # Encoder pass
        encoder_output = src_emb
        for layer in self.encoder:
            encoder_output = layer(encoder_output, src_mask=src_pad_mask, use_cache=use_cache)

        # If not a decoder, return encoder output only
        if not self.is_decoder:
            return encoder_output, None

        # For decoder: preserve original target indices for loss computation
        original_tgt = tgt.clone()  # Save token indices before embedding

        # Target embeddings with positional encoding
        tgt_seq_len = tgt.size(1)
        tgt_emb = self.tgt_embed(tgt) + self.pos_embed[:, :tgt_seq_len, :]
        tgt_emb = self.embed_norm(tgt_emb)
        tgt_emb = self.embed_dropout(tgt_emb)

        # Decoder pass
        decoder_output = tgt_emb
        for layer in self.decoder:
            # Pass src_enc_mask to cross-attention instead of src_pad_mask if available
            decoder_output = layer(decoder_output, encoder_output=encoder_output,
                                   src_mask=src_enc_mask, tgt_mask=tgt_pad_mask, use_cache=use_cache)

        # Final projection to vocabulary
        output = self.fc_out(decoder_output)  # (bs, tgt_len, tgt_vocab_size)

        loss = None
        if compute_loss:
            # Flatten predictions and original target token indices
            output_flat = output.view(-1, self.fc_out.out_features)
            tgt_flat = original_tgt.view(-1)
            # Exclude padding positions from loss computation
            non_pad_mask = tgt_flat != self.padding_idx
            output_flat = output_flat[non_pad_mask]
            tgt_flat = tgt_flat[non_pad_mask]
            loss = F.cross_entropy(output_flat, tgt_flat, reduction="mean")

        return encoder_output, output, loss