In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch import optim
from collections import OrderedDict
import os
import io
import math
import random
import time
import itertools
import numpy as np

#!pip install SumOfSquares

from parser import get_parser

In [2]:
N_MAX_POSITIONS = 4096  # maximum input sequence length


def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
    if padding_idx is not None:
        nn.init.constant_(m.weight[padding_idx], 0)
    return m


def create_sinusoidal_embeddings(n_pos, dim, out):
    position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
    out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
    out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
    out.detach_()
    out.requires_grad = False


def get_masks(slen, lengths, causal):
    """
    Generate hidden states mask, and optionally an attention mask.
    """
    assert lengths.max().item() <= slen
    bs = lengths.size(0)
    alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
    mask = alen < lengths[:, None]

    # attention mask is the same as mask, or triangular inferior attention (causal)
    if causal:
        attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
    else:
        attn_mask = mask

    # sanity check
    assert mask.size() == (bs, slen)
    assert causal is False or attn_mask.size() == (bs, slen, slen)

    return mask, attn_mask


class MultiHeadAttention(nn.Module):

    NEW_ID = itertools.count()

    def __init__(self, n_heads, dim, dropout):
        super().__init__()
        self.layer_id = next(MultiHeadAttention.NEW_ID)
        self.dim = dim
        self.n_heads = n_heads
        self.dropout = dropout
        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)

    def forward(self, input, mask, kv=None, use_cache=False):
        """
        Self-attention (if kv is None) or attention over source sentence (provided by kv).
        Input is (bs, qlen, dim)
        Mask is (bs, klen) (non-causal) or (bs, klen, klen)
        """
        assert not (use_cache and self.cache is None)
        bs, qlen, dim = input.size()
        if kv is None:
            klen = qlen if not use_cache else self.cache["slen"] + qlen
        else:
            klen = kv.size(1)
        assert dim == self.dim, "Dimensions do not match: %s input vs %s configured" % (dim, self.dim)
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)

        def shape(x):
            """projection"""
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """compute context"""
            return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)

        q = shape(self.q_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        if kv is None:
            k = shape(self.k_lin(input))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(input))  # (bs, n_heads, qlen, dim_per_head)
        elif not use_cache or self.layer_id not in self.cache:
            k = v = kv
            k = shape(self.k_lin(k))  # (bs, n_heads, qlen, dim_per_head)
            v = shape(self.v_lin(v))  # (bs, n_heads, qlen, dim_per_head)

        if use_cache:
            if self.layer_id in self.cache:
                if kv is None:
                    k_, v_ = self.cache[self.layer_id]
                    k = torch.cat([k_, k], dim=2)  # (bs, n_heads, klen, dim_per_head)
                    v = torch.cat([v_, v], dim=2)  # (bs, n_heads, klen, dim_per_head)
                else:
                    k, v = self.cache[self.layer_id]
            self.cache[self.layer_id] = (k, v)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, qlen, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, qlen, klen)
        mask = (mask == 0).view(mask_reshape).expand_as(scores)  # (bs, n_heads, qlen, klen)
        scores.masked_fill_(mask, -float("inf"))  # (bs, n_heads, qlen, klen)

        weights = F.softmax(scores.float(), dim=-1).type_as(scores)  # (bs, n_heads, qlen, klen)
        weights = F.dropout(weights, p=self.dropout, training=self.training)  # (bs, n_heads, qlen, klen)
        context = torch.matmul(weights, v)  # (bs, n_heads, qlen, dim_per_head)
        context = unshape(context)  # (bs, qlen, dim)

        if TransformerModel.STORE_OUTPUTS and not self.training:
            self.outputs = weights.detach().cpu()

        return self.out_lin(context)


class TransformerFFN(nn.Module):

    def __init__(self, in_dim, dim_hidden, out_dim, dropout):
        super().__init__()
        self.dropout = dropout
        self.lin1 = nn.Linear(in_dim, dim_hidden)
        self.lin2 = nn.Linear(dim_hidden, out_dim)

    def forward(self, input):
        x = self.lin1(input)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class TransformerModel(nn.Module):

    STORE_OUTPUTS = False

    def __init__(self, params, id2word, is_encoder, with_output):
        """
        Transformer model (encoder or decoder).
        """
        super().__init__()

        # encoder / decoder, output layer
        self.dtype = torch.half if params.fp16 else torch.float
        self.is_encoder = is_encoder
        self.is_decoder = not is_encoder
        self.with_output = with_output
        self.params = params

        # dictionary
        self.n_words = len(id2word)
        self.eos_index = 0
        self.pad_index = 1
        self.id2word = id2word

        # model parameters
        self.dim = params.emb_dim  # 512 by default
        self.hidden_dim = self.dim * 4  # 2048 by default
        self.n_heads = params.n_heads  # 8 by default
        self.n_layers = params.n_enc_layers if is_encoder else params.n_dec_layers
        self.dropout = params.dropout
        self.attention_dropout = params.attention_dropout
        self.max_src_len = params.max_src_len
        assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"

        # embeddings
        self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim)
        if params.sinusoidal_embeddings:
            create_sinusoidal_embeddings(N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight)
        self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
        if self.is_decoder:
            self.layer_norm15 = nn.ModuleList()
            self.encoder_attn = nn.ModuleList()

        for layer_id in range(self.n_layers):
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
            if self.is_decoder:
                self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
                self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout))
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))

        self.cache = None

        # output layer
        if self.with_output:
            self.proj = nn.Linear(self.dim, self.n_words, bias=True)
            if params.share_inout_emb:
                self.proj.weight = self.embeddings.weight

    def forward(self, mode, **kwargs):
        """
        Forward function with different forward modes.
        ### Small hack to handle PyTorch distributed.
        """
        if mode == "fwd":
            return self.fwd(**kwargs)
        elif mode == "predict":
            return self.predict(**kwargs)
        else:
            raise Exception("Unknown mode: %s" % mode)

    def fwd(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, use_cache=False):
        """
        Inputs:
            `x` LongTensor(slen, bs), containing word indices
            `lengths` LongTensor(bs), containing the length of each sentence
            `causal` Boolean, if True, the attention is only done over previous hidden states
            `positions` LongTensor(slen, bs), containing word positions
        """
        # lengths = (x != self.pad_index).float().sum(dim=1)
        # mask = x != self.pad_index

        # check inputs
        slen, bs = x.size()
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
        x = x.transpose(0, 1)  # batch size as dimension 0
        assert (src_enc is None) == (src_len is None)
        if src_enc is not None:
            assert self.is_decoder
            assert src_enc.size(0) == bs
        assert not (use_cache and self.cache is None)

        # generate masks
        mask, attn_mask = get_masks(slen, lengths, causal)
        if self.is_decoder and src_enc is not None:
            if self.max_src_len > 0:
                src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < torch.clamp(src_len[:, None], max=self.max_src_len)
            else:
                src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]

        # positions
        if positions is None:
            positions = x.new(slen).long()
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
            assert positions.size() == (slen, bs)
            positions = positions.transpose(0, 1)

        # do not recompute cached elements
        if use_cache:
            _slen = slen - self.cache["slen"]
            x = x[:, -_slen:]
            positions = positions[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # all layer outputs
        if TransformerModel.STORE_OUTPUTS and not self.training:
            self.outputs = []

        # embeddings
        tensor = self.embeddings(x)
        tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)
        if TransformerModel.STORE_OUTPUTS and not self.training:
            self.outputs.append(tensor.detach().cpu())

        # transformer layers
        for i in range(self.n_layers):

            # self attention
            self.attentions[i].cache = self.cache
            attn = self.attentions[i](tensor, attn_mask, use_cache=use_cache)
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
            if self.is_decoder and src_enc is not None:
                self.encoder_attn[i].cache = self.cache
                attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, use_cache=use_cache)
                attn = F.dropout(attn, p=self.dropout, training=self.training)
                tensor = tensor + attn
                tensor = self.layer_norm15[i](tensor)

            # FFN
            tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)

            tensor *= mask.unsqueeze(-1).to(tensor.dtype)
            if TransformerModel.STORE_OUTPUTS and not self.training:
                self.outputs.append(tensor.detach().cpu())

        # update cache length
        if use_cache:
            self.cache["slen"] += tensor.size(1)

        # move back sequence length to dimension 0
        tensor = tensor.transpose(0, 1)

        return tensor

    def predict(self, tensor, pred_mask, y, get_scores, neg_samples=None):
        """
        Given the last hidden state, compute word scores and/or the loss.
            `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
                we need to predict a word
            `y` is a LongTensor of shape (pred_mask.sum(),)
            `get_scores` is a boolean specifying whether we need to return scores
            `neg_samples` is an optional LongTensor of shape (pred_mask.sum(), n_negatives)
                containing negative samples for unlikelihood or contrastive training
        """
        x = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim)
        assert (y == self.pad_index).sum().item() == 0
        scores = self.proj(x).view(-1, self.n_words)

        # Standard cross-entropy loss
        ce_loss = F.cross_entropy(scores.float(), y, reduction="mean")

        # unlikelihood training parameters
        ul_alpha = params.ul_alpha  # weight for unlikelihood loss
        ul_topp = params.ul_topp    # top-p sampling for negative candidates
        ul_topk = params.ul_topk    # top-k sampling for negative candidates
        ul_temp = params.ul_temp    # temperature for negative sampling

        # contrastive training parameters
        contrastive_alpha = params.contrastive_alpha  # weight for contrastive loss
        contrastive_temp = params.contrastive_temp    # temperature for contrastive loss

        def generate_neg_samples_ul(scores, topp, topk, temp):
          # Generate negative samples automatically
          with torch.no_grad():
              probs = F.softmax(scores.float() / temp, dim=-1)

              if topp > 0:
                  # Top-p sampling
                  sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                  sorted_indices_to_remove = cumulative_probs > topp
                  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                  sorted_indices_to_remove[..., 0] = 0

                  # Zero out low-probability tokens
                  probs = probs.scatter(1, sorted_indices, sorted_probs * (~sorted_indices_to_remove).float())
                  probs = probs / probs.sum(dim=-1, keepdim=True)

              # Sample negative candidates
              if topk > 0:
                  # Top-k sampling
                  topk_probs, topk_indices = torch.topk(probs, topk, dim=-1)
                  sampled_indices = torch.multinomial(topk_probs, 1).squeeze(-1)
                  ul_targets = topk_indices.gather(-1, sampled_indices.unsqueeze(-1)).squeeze(-1)
              else:
                  # Sample from full distribution
                  ul_targets = torch.multinomial(probs, 1).squeeze(-1)

          return ul_targets

        def _sample_from_scores(scores, topp, topk, temp):
          """Safe sampling with shape [batch_size, vocab_size]"""
          probs = F.softmax(scores.float() / temp, dim=-1)

          if topp > 0:
              sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
              cum_probs = torch.cumsum(sorted_probs, dim=-1)
              mask = cum_probs > topp
              mask[..., 1:] = mask[..., :-1].clone()
              mask[..., 0] = False
              sorted_probs[mask] = 0
              sorted_probs.div_(sorted_probs.sum(dim=-1, keepdim=True))
              probs = torch.gather(sorted_probs, -1, torch.argsort(sorted_idx, dim=-1))

          if topk > 0:
              topk_probs, topk_idx = torch.topk(probs, topk, dim=-1)
              samples = torch.multinomial(topk_probs, 1)
              return topk_idx.gather(-1, samples).squeeze(-1)
          else:
              return torch.multinomial(probs, 1).squeeze(-1)

        def generate_neg_samples_contr(scores, topp=0.9, topk=0, temp=1.0, num_negatives=5):
          """
          Generate negative sequences with strict shape control
          Args:
              scores: Model logits [batch_size, seq_len, vocab_size]
          Returns:
              neg_samples: [batch_size, num_negatives, seq_len]
          """
          batch_size = scores.size(0)
          seq_len = scores.size(1)
          device = scores.device

          neg_samples = torch.full((batch_size, num_negatives, seq_len),
                                self.pad_index,
                                dtype=torch.long,
                                device=device)
          neg_samples[:, :, 0] = self.eos_index

          for n in range(num_negatives):
              for i in range(1, seq_len):
                  # Current sequence [batch_size, i]
                  current_input = neg_samples[:, n, :i]

                  with torch.no_grad():
                      # Forward pass expects [seq_len, batch_size]
                      outputs = self.fwd(
                          x=current_input.t(),  # [i, batch_size]
                          lengths=torch.full((batch_size,), i, device=device),
                          causal=True
                      )  # [i, batch_size, dim]

                      # Get last token features [batch_size, dim]
                      last_hidden = outputs[-1]  # [batch_size, dim]

                      # Project to vocab size
                      last_scores = self.proj(last_hidden)  # [batch_size, vocab_size]

                      # Sample next tokens
                      next_tokens = _sample_from_scores(
                          last_scores,
                          topp=topp,
                          topk=topk,
                          temp=temp
                      )  # [batch_size]

                      # Update sequences
                      neg_samples[:, n, i] = next_tokens

          return neg_samples

        # Unlikelihood loss
        ul_loss = None
        if self.training and (ul_alpha > 0):
            if neg_samples is not None:
                # Use provided negative samples
                ul_targets = neg_samples
            else:
                ul_targets = generate_neg_samples_ul(scores, ul_topp, ul_topk, ul_temp)

            # Compute unlikelihood loss
            log_probs = F.log_softmax(scores.float(), dim=-1)
            ul_loss = -torch.log(1 - log_probs.exp().gather(1, ul_targets.unsqueeze(-1)) + 1e-5).mean()

            # Combine losses
            loss = ce_loss + ul_alpha * ul_loss
        else:
            loss = ce_loss

        if contrastive_alpha > 0 :
            if neg_samples is None:
                neg_samples = generate_neg_samples_contr(scores, ul_topp, ul_topk, ul_temp)
                print(neg_samples.shape)
            # Get hidden states for negative samples
            with torch.no_grad():
                neg_tensors = []
                for i in range(neg_samples.size(1)):
                    neg_seq = neg_samples[:, i]
                    neg_tensor = self.fwd(
                        x=neg_seq,
                        lengths=(neg_seq != self.pad_index).sum(dim=1),
                        causal=self.is_decoder,
                        positions=None
                    )
                    neg_tensors.append(neg_tensor)
                neg_tensor = torch.stack(neg_tensors, dim=1)  # (seq_len, bs, num_neg, dim)

            # Get positive and negative representations
            pos_rep = x  # (bs*seq_len, dim)
            neg_rep = neg_tensor[pred_mask.unsqueeze(-1).unsqueeze(-1).expand_as(neg_tensor)]
            neg_rep = neg_rep.view(-1, neg_samples.size(1), self.dim)  # (bs*seq_len, num_neg, dim)

            # Compute contrastive loss
            pos_sim = F.cosine_similarity(pos_rep, pos_rep, dim=-1)
            neg_sim = F.cosine_similarity(pos_rep.unsqueeze(1), neg_rep, dim=-1)

            # Temperature-scaled similarities
            pos_sim = pos_sim / contrastive_temp
            neg_sim = neg_sim / contrastive_temp

            # Contrastive loss (infoNCE)
            logits = torch.cat([pos_sim, neg_sim], dim=1)
            labels = torch.arange(logits.size(0), device=logits.device)
            contrastive_loss = F.cross_entropy(logits, labels)

            # Combine losses
            loss = ce_loss + contrastive_alpha * contrastive_loss
        else:
            loss = ce_loss

        return scores, loss

    def generate(self, src_enc, src_len, max_len=200, sample_temperature=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        """

        # input batch
        bs = len(src_len)
        assert src_enc.size(0) == bs

        # generated sentences
        generated = src_len.new(max_len, bs)  # upcoming output
        generated.fill_(self.pad_index)  # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)  # we use <EOS> for <BOS> everywhere

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)

        # current position / max lengths / length of generated sentences / unfinished sentences
        cur_len = 1
        gen_len = src_len.clone().fill_(1)
        unfinished_sents = src_len.clone().fill_(1)

        # cache compute states
        self.cache = {"slen": 0}

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                "fwd",
                x=generated[:cur_len],
                lengths=gen_len,
                positions=positions[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                use_cache=True,
            )
            assert tensor.size() == (1, bs, self.dim)
            tensor = tensor.data[-1, :, :]  # .to(self.dtype)  # (bs, dim)
            scores = self.proj(tensor)  # (bs, n_words)

            # select next words: sample or greedy
            if sample_temperature is None:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            else:
                next_words = torch.multinomial(F.softmax(scores.float() / sample_temperature, dim=1), 1).squeeze(1)
            assert next_words.size() == (bs,)

            # update generations / lengths / finished sentences / current length
            generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

        # add <EOS> to unfinished sentences
        if cur_len == max_len:
            generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)

        # sanity check
        assert (generated == self.eos_index).sum() == 2 * bs

        return generated[:cur_len], gen_len

    def generate_beam(self, src_enc, src_len, beam_size, length_penalty, early_stopping, max_len=200):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        """

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_len)
        n_words = self.n_words

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(max_len, bs * beam_size)  # upcoming output
        generated.fill_(self.pad_index)  # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)  # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).float().fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        self.cache = {"slen": 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                "fwd",
                x=generated[:cur_len],
                lengths=src_len.new(bs * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                use_cache=True,
            )
            assert tensor.size() == (1, bs * beam_size, self.dim)
            tensor = tensor.data[-1, :, :]  # (bs * beam_size, dim)
            scores = self.proj(tensor)  # (bs * beam_size, n_words)
            scores = F.log_softmax(scores.float(), dim=-1)  # (bs * beam_size, n_words)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size * n_words)  # (bs, beam_size * n_words)

            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
            assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone().cpu(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in self.cache.keys():
                if k != "slen":
                    self.cache[k] = (self.cache[k][0][beam_idx], self.cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
        for i, hypo in enumerate(best):
            decoded[: tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == 2 * bs

        return decoded, tgt_len, generated_hyps


class BeamHypotheses(object):

    def __init__(self, n_hyp, max_len, length_penalty, early_stopping):
        """
        Initialize n-best list of hypotheses.
        """
        self.max_len = max_len - 1  # ignoring <BOS>
        self.length_penalty = length_penalty
        self.early_stopping = early_stopping
        self.n_hyp = n_hyp
        self.hyp = []
        self.worst_score = 1e9

    def __len__(self):
        """
        Number of hypotheses in the list.
        """
        return len(self.hyp)

    def add(self, hyp, sum_logprobs):
        """
        Add a new hypothesis to the list.
        """
        score = sum_logprobs / len(hyp) ** self.length_penalty
        if len(self) < self.n_hyp or score > self.worst_score:
            self.hyp.append((score, hyp))
            if len(self) > self.n_hyp:
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
                del self.hyp[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)

    def is_done(self, best_sum_logprobs):
        """
        If there are enough hypotheses and that none of the hypotheses being generated
        can become better than the worst one in the heap, then we are done with this sentence.
        """
        if len(self) < self.n_hyp:
            return False
        elif self.early_stopping:
            return True
        else:
            return self.worst_score >= best_sum_logprobs / self.max_len**self.length_penalty

In [3]:
max_degree = 6          # max degree of the equations in the system
int_base = 1000         # int base for the tokenizer

operators = ['+', '-', '*', '/', '^', 'sqrt', 'exp', 'ln', 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'Abs']
variables = [f"x{i}" for i in range(2 * max_degree)]
constants = ["pi", "E"]
symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
elements = [str(i) for i in range(max(10, int_base))]
SPECIAL_WORDS = ["<s>", "</s>", "<pad>", "(", ")"] + [f"<SPECIAL_{i}>" for i in range(10)]
func_separator = "<SPECIAL_3>"

words = SPECIAL_WORDS + constants + variables + operators + symbols + elements
id2word = {i: s for i, s in enumerate(words)} # tokenizer dictionary
word2id = {s: i for i, s in id2word.items()}
n_words = len(words)
eos_index = 0
pad_index = 1

In [4]:
class EnvDataset(Dataset):
    """
    Helper class for creating datasets.
    """
    def __init__(self, num_workers, path, word2id, train, max_len=1024, max_output_len=512, reload_size=-1, size=None):
        super().__init__()
        self.num_workers = num_workers
        self.path = path
        self.train = train
        self.count = 0
        assert size is None

        self.max_len = max_len
        self.max_output_len = max_output_len
        self.word2id = word2id

        self.func_separator = "<SPECIAL_3>"
        self.eos_index = 0
        self.pad_index = 1

        # reloading from file
        if path is not None:
            assert os.path.isfile(path)
            print(f"Loading data from {path} ...")
            with io.open(path, mode="r", encoding="utf-8") as f:
                # either reload the entire file, or the first N lines (for the training set)
                if not train:
                    lines = [line.rstrip().split("|") for line in f]
                else:
                    lines = []
                    for i, line in enumerate(f):
                        if i == reload_size:
                            break
                        lines.append(line.rstrip().split("|"))
            self.data = [xy.split("\t") for _, xy in lines]
            self.data = [xy for xy in self.data if len(xy) == 2]
            print(f"Loaded {len(self.data)} equations from the disk.")

        # dataset size: infinite iterator for train, finite for validation (default of 5000 if no file provided)
        if self.train:
            self.size = 1 << 60
        elif size is None:
            self.size = 5000 if path is None else len(self.data)
        else:
            assert size > 0
            self.size = size

    def batch_sequences(self, sequences):
        """
        Take as input a list of n sequences (torch.LongTensor vectors) and return
        a tensor of size (slen, n) where slen is the length of the longest
        sentence, and a vector lengths containing the length of each sentence.
        """
        lengths = torch.LongTensor([len(s) + 2 for s in sequences])
        sent = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(self.pad_index)
        assert lengths.min().item() > 2

        sent[0] = self.eos_index
        for i, s in enumerate(sequences):
            sent[1 : lengths[i] - 1, i].copy_(s)
            sent[lengths[i] - 1, i] = self.eos_index

        return sent, lengths

    def collate_fn(self, elements):
        """
        Collate samples into a batch.
        """
        x, y = zip(*elements)
        nb_eqs = [seq.count(self.func_separator) for seq in x]
        x = [torch.LongTensor([self.word2id[w] for w in seq]) for seq in x]
        y = [torch.LongTensor([self.word2id[w] for w in seq]) for seq in y]
        x, x_len = self.batch_sequences(x)
        y, y_len = self.batch_sequences(y)
        return (x, x_len), (y, y_len), torch.LongTensor(nb_eqs)

    def get_worker_id(self):
        """
        Get worker ID.
        """
        if not self.train:
            return 0
        worker_info = torch.utils.data.get_worker_info()
        assert (worker_info is None) == (self.num_workers == 0)
        return 0 if worker_info is None else worker_info.id

    def __len__(self):
        """
        Return dataset size.
        """
        return self.size

    def __getitem__(self, index):
        """
        Return a training  by reading it from a file.
        """
        return self.read_sample(index)

    def read_sample(self, index):
        """
        Read a sample.
        """
        while True:
            if self.train:
                index = random.randint(0, len(self.data)-1)
            x, y = self.data[index]
            x = x.split()
            y = y.split()
            if (self.max_len > 0 and len(x) >= self.max_len) or (self.max_output_len > 0 and len(y) >= self.max_output_len):
                index += 1
                continue
            return x, y

In [5]:
def create_train_iterator(data_path, num_workers, batch_size, word2id):
    """
    Create a training dataset.
    """
    print(f"Creating train iterator...")

    if num_workers is None:
        num_workers = min(4, os.cpu_count() or 1)

    dataset = EnvDataset(path=data_path, num_workers=num_workers, word2id=word2id, train=True)
    return DataLoader(
        dataset,
        timeout=(0 if num_workers == 0 else 86400),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )

In [6]:
def get_batch(dataloader):
        """
        Return a training batch for a specific task.
        """
        try:
            batch = next(dataloader)
        except Exception as e:
            print(
                "An unknown exception occurred when fetching batch. "
            )
            raise
        return batch

In [7]:
def to_cuda(*args):
    """
    Move tensors to CUDA.
    """
    if not torch.cuda.is_available():
        return args
    return [None if x is None else x.cuda() for x in args]

def enc_dec_step(encoder, decoder, dataloader, optimizer, clip_grad_norm=5):
    """
    Encoding / decoding step.
    """
    encoder.train()
    decoder.train()

    # batch
    (x1, len1), (x2, len2), nb_ops = get_batch(dataloader)

    # cuda
    x1, len1, x2, len2 = to_cuda(x1, len1, x2, len2)

    # target words to predict
    alen = torch.arange(len2.max(), dtype=torch.long, device=len2.device)
    pred_mask = alen[:, None] < len2[None] - 1  # do not predict anything given the last target word

    y = x2[1:].masked_select(pred_mask[:-1])

    assert len(y) == (len2 - 1).sum().item()

    # forward / loss
    encoded = encoder("fwd", x=x1, lengths=len1, causal=False)
    decoded = decoder("fwd", x=x2, lengths=len2, causal=True, src_enc=encoded.transpose(0, 1), src_len=len1)
    _, loss = decoder("predict", tensor=decoded, pred_mask=pred_mask, y=y, get_scores=False)

    # check NaN
    if (loss != loss).data.any():
        print("NaN detected")

    parameters = {
        k: p for v in [encoder, decoder] for k, p in v.named_parameters() if p.requires_grad
    }
    model_params = list(parameters.values())

    # regular optimization
    optimizer.zero_grad()
    loss.backward()
    if clip_grad_norm > 0:
        clip_grad_norm_(model_params, clip_grad_norm)
    optimizer.step()

    return loss

In [None]:
# training
max_epoch = 8       # how many epochs will the model train
epoch = 0           # current epoch
epoch_size = 600    # number of equations in an epoch
n_total_iter = 0    # total number of passed iterations

parser = get_parser()

parser.add_argument("--ul_alpha", type=float, default=0.0)
parser.add_argument("--ul_topp", type=float, default=0.9)
parser.add_argument("--ul_topk", type=int, default=20)
parser.add_argument("--ul_temp", type=float, default=0.8)

parser.add_argument("--contrastive_alpha", type=float, default=0.0)
parser.add_argument("--contrastive_temp", type=float, default=0.1)

params = parser.parse_args([])    # setup of all transformer parameters

batch_size = 16
data_path = "Data.txt"
dump_path = "" # path for saving the model
model_name = "Lyapunov"
dataloader = iter(create_train_iterator(data_path=data_path, num_workers=None, batch_size=batch_size, word2id=word2id))

encoder = TransformerModel(params, id2word=id2word, is_encoder=True, with_output=False)
decoder = TransformerModel(params, id2word=id2word, is_encoder=False, with_output=True)

optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

last_time = time.time()

for _ in range(max_epoch):

    print("============ Starting epoch %i ... ============" % epoch)

    n_equations = 0
    n_iter = 0

    while n_equations < epoch_size:

        # training steps
        loss = enc_dec_step(encoder, decoder, dataloader, optimizer)
        n_equations += batch_size
        n_iter += 1
        n_total_iter += 1

        if n_total_iter % 10 == 0:
          new_time = time.time()
          diff = new_time - last_time
          print('Training speed is ', 20/diff, ' iter/s. The loss in ', n_iter, '. iteration: ', loss.item())
          last_time = new_time

    print("============ End of epoch %i ============" % epoch)

    # end of epoch

    """
    Save the model / checkpoints.
    """
    path = os.path.join(dump_path, "%s.pth" % model_name)
    print("Saving %s to %s ..." % (model_name, path))

    data = {
        "epoch": epoch,
        "n_total_iter": n_total_iter,
        "loss": loss.item(),
    }

    print(f"Saving encoder parameters ...")
    data["encoder"] = encoder.state_dict()
    print(f"Saving decoder parameters ...")
    data["decoder"] = decoder.state_dict()

    print("Saving optimizer ...")
    data["optimizer"] = optimizer.state_dict()

    torch.save(data, path)

    epoch += 1

Creating train iterator...
Loading data from Data.txt ...
Loaded 600 equations from the disk.


In [None]:
# model evaluation

!git clone https://github.com/facebookresearch/Lyapunov/ # optim.py will signal an error, just comment problematic lines
!python Lyapunov/train.py --dump_path "" --export_data false --cpu true --reload_data "ode_lyapunov,Lyapunov/benchmarks/BPoly,Lyapunov/benchmarks/FBarr,Lyapunov/benchmarks/FLyap,Lyapunov/benchmarks/FSOSTOOL" --env_base_seed -1  --num_workers 1 --eval_only true --reload_model "Lyapunov.pth"

In [23]:
import re
import sympy as sp
from collections import OrderedDict

class MathTokenizer:    # tokenizer implementation from the paper
    def __init__(self, max_degree, int_base):
        self.max_degree = max_degree
        self.int_base = int_base

        self.operators = {
            "+": 2, "-": 2, "*": 2, "/": 2, "^": 2,
            "sqrt": 1, "exp": 1, "ln": 1,
            "sin": 1, "cos": 1, "tan": 1,
            "asin": 1, "acos": 1, "atan": 1,
            "Abs": 1
        }

        self.variables = OrderedDict({f"x{i}": sp.Symbol(f"x{i}") for i in range(2 * self.max_degree)})
        self.constants = ["pi", "E"]
        self.symbols = ["I", "INT+", "INT-", "FLOAT+", "FLOAT-", ".", "10^"]
        self.elements = [str(i) for i in range(max(10, self.int_base))]

        SPECIAL_WORDS = ["<s>", "</s>", "<pad>", "(", ")"]
        SPECIAL_WORDS += [f"<SPECIAL_{i}>" for i in range(10)]

        self.words = SPECIAL_WORDS + self.constants + list(self.variables.keys()) + list(self.operators.keys()) + self.symbols + self.elements
        self.id2word = {i: s for i, s in enumerate(self.words)}
        self.word2id = {s: i for i, s in self.id2word.items()}
        self.n_words = len(self.words)
        self.eos_index = eos_index = 0
        self.pad_index = pad_index = 1

    def parse_expression(self, expr):
        tokens = re.findall(r'sqrt|exp|ln|sin|cos|tan|asin|acos|atan|Abs|[+\-*/^()]|x\d+|pi|E|\d+\.?\d*', expr)
        pos = 0

        def peek():
            return tokens[pos] if pos < len(tokens) else None

        def consume(expected=None):
            nonlocal pos
            token = tokens[pos]
            if expected and token != expected:
                raise ValueError(f"Expected {expected} but got {token}")
            pos += 1
            return token

        class TreeNode:
            def __init__(self, value):
                self.value = value
                self.children = []

            def add_child(self, child):
                self.children.append(child)

            def __repr__(self):
                return f"TreeNode({self.value}, {self.children})"

        def parse_primary():
            token = peek()
            if token is None:
                raise ValueError("Unexpected end of expression")
            if token == '(':
                consume('(')
                node = parse_expr()
                consume(')')
                return node
            elif token in self.operators and self.operators[token] == 1:
                func = consume()
                node = TreeNode(func)
                consume('(')
                node.add_child(parse_expr())
                consume(')')
                return node
            else:
                return TreeNode(consume())

        def parse_factor():
            return parse_primary()

        def parse_term():
            node = parse_factor()
            while peek() in ('*', '/'):
                op = consume()
                new_node = TreeNode(op)
                new_node.add_child(node)
                new_node.add_child(parse_factor())
                node = new_node
            return node

        def parse_expr():
            node = parse_term()
            while peek() in ('+', '-'):
                op = consume()
                new_node = TreeNode(op)
                new_node.add_child(node)
                new_node.add_child(parse_term())
                node = new_node
            return node

        return parse_expr()

    def enumerate_tree(self, tree):
        if not tree:
            return []
        result = [tree.value]
        for child in tree.children:
            result.extend(self.enumerate_tree(child))
        return result

    def tokenize_integer(self, n):
        if n == 0:
            return ['0']
        sign = '-' if n < 0 else None
        n = abs(n)
        tokens = []
        while n > 0:
            remainder = n % 1000
            tokens.append(str(remainder))
            n //= 1000
        tokens = tokens[::-1]
        if sign:
            tokens = [sign] + tokens
        return tokens

    def tokenize_real(self, x):
        x = abs(x)
        s = str(x)
        if '.' in s:
            integer_part, fractional_part = s.split('.')
            fractional_part = fractional_part.rstrip('0')
            if fractional_part == '':
                mantissa = integer_part
                exponent = 0
            else:
                mantissa = integer_part + fractional_part
                exponent = -len(fractional_part)
        else:
            mantissa = s
            exponent = 0
        return self.tokenize_integer(int(mantissa)) + ['10^'] + self.tokenize_integer(exponent)

    def tokenize_expression(self, expression):
        tokens = []
        for token in expression:
            if token in self.words:
                tokens.append(token)
            elif re.match(r'^-?\d+$', token):
                tokens.extend(self.tokenize_integer(int(token)))
            elif re.match(r'^-?\d+\.\d*$', token):
                tokens.extend(self.tokenize_real(float(token)))
            else:
                raise ValueError(f"Unknown token: {token}")
        return tokens

    def tokenize(self, expressions):
        tokenized_expressions = []
        for expression in expressions:
            tree = self.parse_expression(expression)
            enumerated = self.enumerate_tree(tree)
            tokens = self.tokenize_expression(enumerated)
            tokenized_expressions.extend(tokens + ["<SPECIAL_3>"])
        return tokenized_expressions[:-1]

    def encode(self, tokenized_expressions):
      return [self.word2id[id] for id in tokenized_expressions]

tokenizer = MathTokenizer(int_base=1024, max_degree=3)
expressions = ["cos(2.1 * x0) * (x1 + 2)", "sin(3 * x1 + 2)"]
print(expressions)
tokenized_expr = tokenizer.tokenize(expressions)
tokenized = tokenizer.encode(tokenized_expr)
print(tokenized_expr)
print(tokenized)

['cos(2.1 * x0) * (x1 + 2)', 'sin(3 * x1 + 2)']
['*', 'cos', '*', '21', '10^', '-', '1', 'x0', '+', 'x1', '2', '<SPECIAL_3>', 'sin', '+', '*', '3', 'x1', '2']
[25, 32, 25, 66, 44, 24, 46, 17, 23, 18, 47, 8, 31, 23, 25, 48, 18, 47]
