In [None]:
import sympy as sp
from sympy import *
import pandas as pd
import re
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
import math
import os
import random
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass, field, fields
from typing import Optional
from transformers import LEDForConditionalGeneration,LEDConfig

In [None]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

In [None]:
class Tokenizer:
    def __init__(self, vocab_path):
        self.vocab_path = vocab_path
        self.word2id = {}
        self.id2word = {}

        with open(vocab_path) as file:
            words = map(lambda x: x.rstrip('\n'), file.readlines())

        for (n, word) in enumerate(words):
            self.word2id[word] = n
            self.id2word[n] = word 

    def encode(self, lst):
        return np.array([[self.word2id[j] for j in i] for i in lst], dtype=np.ushort)

    def decode(self, lst):
        return [[self.id2word[j] for j in i] for i in lst]

In [None]:
class Encoder_tokeniser(Tokenizer):
    def __init__(self,float_precision,mantissa_len,max_exponent,vocab_path,max_len = 10):
        super().__init__(vocab_path)
        
        self.max_len = max_len
        self.float_precision = float_precision
        self.mantissa_len = mantissa_len
        self.max_exponent = max_exponent
        self.base = (self.float_precision + 1) // self.mantissa_len
        self.max_token = 10 ** self.base
        
    def pre_tokenize(self, data):
        arr = np.array([i.split() for i in data], dtype=np.float32)
        permutation = [-1] + [i for i in range(arr.shape[1]-1)]
        arr = np.pad(arr[:, permutation], ((0,0), (0, self.max_len - arr.shape[1])), mode="constant", constant_values=[-np.inf])
        return arr
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.encode_float(out)
        out = self.encode(out)
        return out
        
    def encode_float(self,values):
        if len(values.shape) == 1:
            seq = []
            value = values
            for val in value:
                if val in [-np.inf, np.inf]:
                    seq.extend(['<pad>']*3)
                    continue
                
                sign = "+" if val >= 0 else "-"
                m, e = (f"%.{self.float_precision}e" % val).split("e")
                i, f = m.lstrip("-").split(".")
                i = i + f
                tokens = chunks(i, self.base)
                expon = int(e) - self.float_precision
                if expon < -self.max_exponent:
                    tokens = ["0" * self.base] * self.mantissa_len
                    expon = int(0)
                seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)])
            return seq
        else:
            seqs = [self.encode_float(values[0])]
            N = values.shape[0]
            for n in range(1, N):
                seqs += [self.encode_float(values[n])]
        return seqs
    def decode_float(self,seq):
        decoded_seq = []
        for val in chunks(decoded_seq, 2 + self.mantissa_len):
            for x in val:
                if x[0] not in ["-", "+", "E", "N"]:
                    return np.nan
            try:
                sign = 1 if val[0] == "+" else -1
                mant = ""
                for x in val[1:-1]:
                    mant += x[1:]
                mant = int(mant)
                exp = int(val[-1][1:])
                value = sign * mant * (10 ** exp)
                value = float(value)
            except Exception:
                value = np.nan
            decoded_seq.append(value)
        return decoded_seq

In [None]:
df_target = pd.read_csv('/kaggle/input/gsoc-symba-task/FeynmanEquations.csv')

In [None]:
df_target

In [None]:
df_target = df_target.dropna(subset=['Filename'])
df_target

In [None]:
df_target.loc[21, '# variables'] = 3
df_target.loc[22, '# variables'] = 4
df_target.loc[38, '# variables'] = 4
df_target.loc[82, '# variables'] = 3
df_target.loc[90, '# variables'] = 4
df_target.loc[98, '# variables'] = 5
df_target.loc[18,'Filename'] = 'I.15.10'
df_target.loc[49,'Filename'] = 'I.48.20'
df_target.loc[61,'Filename'] = 'II.11.7'

In [None]:
variables = [
        'x',
        'y',
        'z',
        'a',
        'b',
        'c',
        'd',
        'E',
        'reg_prop',
        'm_s',
        'm_u'
        's_0',
        's_1',
        's_2',
        's_3',
        's_4',
        's_5',
        's_6',
        's_7',
        's_8',
        's_9',
        's_10',
        's_11',
        's_12',
        's_13',
        's_14',
        's_15',
        's_16',
        's_17',
        's_18',
        's_19',
        's_20',
        's_21',
        's_22',
        's_23',
        's_24',
        's_25',
        's_26',
        's_27',
        's_28',
        's_29',
        's_30',
        's_31',
        's_32',
        's_33',
        's_34',
        's_35',
        's_36',
        's_37',
        's_38',
        's_39',
        's_40',
        's_41',
        's_42',
        's_43',
        's_44',
        's_45',
        ]

In [None]:
operators = {
    # Elementary functions
    sp.Add: 'add',
    sp.Mul: 'mul',
    sp.Pow: 'pow',
    sp.exp: 'exp',
    sp.log: 'ln',
    sp.Abs: 'abs',
    sp.sign: 'sign',
    # Trigonometric Functions
    sp.sin: 'sin',
    sp.cos: 'cos',
    sp.tan: 'tan',
    sp.cot: 'cot',
    sp.sec: 'sec',
    sp.csc: 'csc',
    # Trigonometric Inverses
    sp.asin: 'asin',
    sp.acos: 'acos',
    sp.atan: 'atan',
    sp.acot: 'acot',
    sp.asec: 'asec',
    sp.acsc: 'acsc',
    # Hyperbolic Functions
    sp.sinh: 'sinh',
    sp.cosh: 'cosh',
    sp.tanh: 'tanh',
    sp.coth: 'coth',
    sp.sech: 'sech',
    sp.csch: 'csch',
    # Hyperbolic Inverses
    sp.asinh: 'asinh',
    sp.acosh: 'acosh',
    sp.atanh: 'atanh',
    sp.acoth: 'acoth',
    sp.asech: 'asech',
    sp.acsch: 'acsch',
    # Derivative
    sp.Derivative: 'derivative',
}

operators_inv = {operators[key]: key for key in operators}
operators_inv["mul("] = sp.Mul
operators_inv["add("] = sp.Add

operators_nargs = {
    # Elementary functions
    'mul(': -1,
    'add(': -1,
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}

masses_strings = [
        "m_e",
        "m_u",
        "m_d",
        "m_s",
        "m_c",
        "m_b",
        "m_t",
        ]

masses = [sp.Symbol(x) for x in masses_strings]

# these will be converted to the numbers format in `format_number`
integers_types = [
        sp.core.numbers.Integer,
        sp.core.numbers.One,
        sp.core.numbers.NegativeOne,
        sp.core.numbers.Zero,
        ]

numbers_types = integers_types + [sp.core.numbers.Rational,
        sp.core.numbers.Half, sp.core.numbers.Exp1, sp.core.numbers.Pi, "<class 'sympy.core.numbers.Pi'>",
        sp.core.numbers.ImaginaryUnit]

# don't continue evaluating at these, but stop
atoms = [
        str,
        sp.core.symbol.Symbol,
        sp.core.numbers.Exp1,
        sp.core.numbers.Pi,
        "<class 'sympy.core.numbers.Pi'>",
        ] + numbers_types


Inverse_trig = {
    'arcsin': 'asin',
    'arccos': 'acos',
    'arctan': 'atan',
    'arccot': 'acot',
    'arcsec': 'asec',
    'arccsc': 'acsc',
    'arcsinh': 'asinh',
    'arccosh': 'acosh',
    'arctanh': 'atanh',
    'arccoth': 'acoth',
    'arcsech': 'asech',
    'arccsch': 'acsch',         
}

In [None]:
def sympy_expression(formula):
    # create a map of variables
    variables_map = {key : sp.Symbol(key) for key in variables}

    for a in Inverse_trig.keys():
        formula = re.sub(a,Inverse_trig[a],formula)

    # Convert to sympy expression
    return sp.sympify(formula, locals=variables_map)

In [None]:
def flatten(l, ltypes=(list, tuple)):
    """
    flatten a python list
    from http://rightfootin.blogspot.com/2006/09/more-on-python-flatten.html
    """
    ltype = type(l)
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        i += 1
    return ltype(l)

In [None]:
def sympy_to_prefix(expression):
    """
    Recursively go from a sympy expression to a prefix notation.
    Returns a flat list of tokens.
    """
    return flatten(sympy_to_prefix_rec(expression, []))

def sympy_to_prefix_rec(expression, ret):
    """
    Recursively go from a sympy expression to a prefix notation.
    The operators all get converted to their names in the array `operators`.
    Returns a nested list, where the nesting basically stands for parentheses.
    Since in prefix notation with a fixed number of arguments for each function (given in `operators_nargs`),
    parentheses are not needed, we can flatten the list later.
    """
    if expression in [sp.core.numbers.Pi, sp.core.numbers.ImaginaryUnit]:
        f = expression
    else:
        f = expression.func
    if f in atoms:
        if type(expression) in numbers_types:
            return ret + format_number(expression)
        return ret+[str(expression)]
    f_str = operators[f]
    f_nargs = operators_nargs[f_str]
    args = expression.args
    if len(args) == 1 & f_nargs == 1:
        ret = ret + [f_str]
        return sympy_to_prefix_rec(args[0], ret)
    if len(args) == 2:
        ret = ret + [f_str, sympy_to_prefix_rec(args[0], []), sympy_to_prefix_rec(args[1], [])]
    if len(args) > 2:
        args = list(map(lambda x: sympy_to_prefix_rec(x, []), args))
        ret = ret + repeat_operator_until_correct_binary(f_str, args)
    return ret
def repeat_operator_until_correct_binary(op, args, ret=[]):
    """
    sympy is not strict enough with the number of arguments.
    E.g. multiply takes a variable number of arguments, but for
    prefix notation it needs to ALWAYS have exactly 2 arguments

    This function is only for binary operators.

    Here I choose the convention as follows:
        1 + 2 + 3 --> + 1 + 2 3

    This is the same convention as in https://arxiv.org/pdf/1912.01412.pdf
    on page 15.

    input:
        op: in string form as in the list `operators`
        args: [arg1, arg2, ...] arguments of the operator, e.c. [1, 2, x**2,
                ...]. They can have other things to be evaluated in them
        ret: the list you already have. Usually []. Watch out, I think one has to explicitely give [],
            otherwise somehow the default value gets mutated, which I find a strange python behavior.
    """

    is_binary = operators_nargs[op] == 2
    assert is_binary, "repeat_operator_until_correct_binary only takes binary operators"

    if len(args) == 0:
        return ret
    elif len(ret) == 0:
        ret = [op] + args[-2:]
        args = args[:-2]
    else:
        ret = [op] + args[-1:] + ret
        args = args[:-1]

    return repeat_operator_until_correct_binary(op, args, ret)

def format_number(number):
    if type(number) in integers_types:
        return format_integer(number)
    elif type(number) == sp.core.numbers.Rational:
        return format_rational(number)
    elif type(number) == sp.core.numbers.Half:
        return format_half()
    elif type(number) == sp.core.numbers.Exp1:
        return format_exp1()
    elif type(number) == sp.core.numbers.Pi:
        return format_pi()
    elif type(number) == sp.core.numbers.ImaginaryUnit:
        return format_imaginary_unit()
    else:
        raise NotImplementedError

def format_exp1():
    return ['E']

def format_pi():
    return ['pi']

def format_imaginary_unit():
    return ['I']

def format_half():
    """
    for some reason in sympy 1/2 is its own object and not a rational.
    This function formats it correctly like `format_rational`
    """
    return ['mul'] + ['s+', '1'] + ['pow'] + ['s+', '2'] + ["s-", "1"]

def format_rational(number):
    # for some reason number.p is a string
    p = sp.sympify(number.p)
    q = sp.sympify(number.q)
    return ['mul'] + format_integer(p) + ['pow'] + format_integer(q) + ['s-', '1']

def format_integer(integer):
    """take a sympy integer and format it as in
    https://arxiv.org/pdf/1912.01412.pdf

    input:
        integer: a `sympy.Integer` object, e.g. `sympy.Integer(-1)`

    output:
        [sign_token, digit0, digit1, ...]
        where sign_token is 's+' or 's-'

    Example:
        format_integer(sympy.Integer(-123))
        >> ['s-', '1', '2', '3']

    Implementation notes:
    Somehow Integer inherits from Rational in Sympy and a rational is p/q,
    so integer.p is used to extract the number.
    """
    # plus_sign = "s+"
    plus_sign = "s+"
    minus_sign = "s-"
    abs_num = abs(integer.p)
    is_neg = integer.could_extract_minus_sign()
    digits = list(str(abs_num))
    # digits = [str(abs_num)]

    if is_neg:
        ret = [minus_sign] + digits
    else:
        ret = [plus_sign] + digits

    return ret

In [None]:
def parse_if_str(x):
    if isinstance(x, str):
        return sp.parsing.parse_expr(x)
    return x

In [None]:
def rightmost_string_pos(expr_arr, pos=-1):
    if isinstance(expr_arr[pos], str):
        return len(expr_arr)+pos
    else:
        return rightmost_string_pos(expr_arr, pos-1)


def rightmost_operand_pos(expr, pos=-1):
    operators = list(operators_inv.keys()) + ["s+", "s-"] + variables
    if expr[pos] in operators:
        return len(expr) + pos
    else:
        return rightmost_operand_pos(expr, pos-1)

def unformat_integer(arr):
    """
    inverse of the function format_integer.

    input:
        arr: array of strings just as the output of format_integer. E.g. ["s+", "4", "2"]

    output:
        the correspinding sympy integer, e.g. sympy.Integer(42) in the above example.

    The sign tokens are "s+" for positive integers and "s-" for negative. 0 comes with "s+", but does not matter.

    """
    sign_token = arr[0]
    ret = "-" if sign_token == "s-" else ""
    for s in arr[1:]:
        ret += str(s)

    return sp.parsing.parse_expr(ret)

In [None]:
def prefix_to_sympy(expr_arr):
    if len(expr_arr) == 1:
        return parse_if_str(expr_arr[0])
    op_pos = rightmost_operand_pos(expr_arr)
    if (op_pos == -1) | (op_pos == len(expr_arr)):
        print("something went wrong, operator should not be at end of array")
    op = expr_arr[op_pos]
    if op in operators_inv.keys():
        num_args = operators_nargs[op]
        op = operators_inv[op]
        args = expr_arr[op_pos+1:op_pos+num_args+1]
        args = [parse_if_str(a) for a in args]
        func = op(*args)
        expr = expr_arr[0:op_pos] + [func] + expr_arr[op_pos+num_args+1:]
        return prefix_to_sympy(expr)

    elif (op == 's+') | (op == "s-"):
        # int_end_pos = rightmost_int_pos(expr_arr)
        string_end_pos = rightmost_string_pos(expr_arr)
        integer = unformat_integer(expr_arr[op_pos:string_end_pos+1])
        expr_arr_new = expr_arr[0:op_pos] + [integer] + expr_arr[string_end_pos+1:]
        return prefix_to_sympy(expr_arr_new)
    elif op in variables:
        op = sp.sympify(op)
        expr_arr_new = expr_arr[0:op_pos] + [op] + expr_arr[op_pos+1:]
        return prefix_to_sympy(expr_arr_new)

    return op

In [None]:
class DecoderTokenizer(Tokenizer):
    def __init__(self, vocab_path):
        super().__init__(vocab_path)

    def equation_encoder(self, data):
        return [sympy_to_prefix(expr) for expr in data]
    
    def equation_decoder(self, data):
        return [prefix_to_sympy(lst) for lst in data]

    def pre_tokenize(self, data):
        return data
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.equation_encoder(out)
        out = [['<bos>'] + i + ['<eos>'] for i in out]
        out = self.encode(out)
        return out
    
    def reverse_tokenize(self, data):
        out = self.decode(data)
        out = self.equation_decoder(out)
        return out

In [None]:
INPUT_DIR = '/kaggle/input/gsoc-symba-task/Feynman_with_units/Feynman_with_units/'

In [None]:
PAD_IDX = 0

def prepare_dataset(config):

    input_max_len = config.input_max_len
    df = pd.read_csv(config.df_path)

    encoder_tokenizer = Encoder_tokeniser(2,1,100,config.encoder_vocab)
    decoder_tokenizer = DecoderTokenizer(config.decoder_vocab)

    train_df = {
        "filename":[],
        "data_num":[], 
        "number":[]
        }
    
    for (index, row) in tqdm(df.iterrows()):
        with open(INPUT_DIR + row['Filename']) as file:
            data = file.readlines()
        X = encoder_tokenizer.tokenize(data)

        n_splits = X.shape[0] // input_max_len
        X = X[:n_splits*input_max_len]
        x_chunks = np.split(X, n_splits)

        sub_dir = os.path.join(config.output_dir, row["Filename"])
        os.makedirs(sub_dir, exist_ok=True)
        
        for (index, x) in enumerate(x_chunks):
            np.save(os.path.join(sub_dir, f"{index}.npy"), x)

        train_df["filename"].extend([row["Filename"]]*n_splits)
        train_df["data_num"].extend([i for i in range(n_splits)])
        train_df["number"].extend([row["Number"] for i in range(n_splits)])

    train_df = pd.DataFrame(train_df)

    equations_df = {
        "filename":[],
        "Prefix_lists":[],
        "encoded":[]
        }
    
    prefix_equations = np.zeros((100, 256)).astype(np.int32)
    for (index, row) in df.iterrows():
        equations_df["filename"].append(row["Filename"])
        prefix = eval(row["Prefix_lists"])
        prefix = ["<bos>"] + prefix + ["<eos>"]
        equations_df["Prefix_lists"].append(prefix)
        y = decoder_tokenizer.encode([prefix])[0]
        y = np.pad(y, (0, 256 - len(y)))
        prefix_equations[int(row["Number"])-1, :] = y
        equations_df["encoded"].append(y)

    path = os.path.join(config.output_dir, "prefix_equations.npy")
    np.save(path, prefix_equations)
    equations_df = pd.DataFrame(equations_df)

    return train_df, equations_df

class FeynmanDataset(Dataset):
    def __init__(self, df, dataset_dir):
        super().__init__()
        self.df = df
        self.dataset_dir = dataset_dir
        self.prefix_equations = np.load(os.path.join(dataset_dir, "prefix_equations.npy"))
        # prefix_equations = []

        prefix_equations = []
        for prefix in self.prefix_equations:
            prefix_equations.append(np.trim_zeros(prefix))

        self.prefix_equations = prefix_equations

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(os.path.join(self.dataset_dir, row['Filename']), f"{row['data_num']}.npy")
        x = np.load(path).astype(np.int32)

        path = os.path.join(self.dataset_dir, f"{row['Filename']}.npy")
        y = self.prefix_equations[int(row['number']) - 1]

        return (torch.Tensor(x).long(), torch.Tensor(y).long())

In [None]:

def get_datasets(df, input_df, dataset_dir):
    train_df, test_df = train_test_split(df, test_size=0.1,random_state = 42)
    train_equations = train_df['Filename'].tolist()
    test_equations = test_df['Filename'].tolist()

    input_test_df = input_df[input_df['Filename'].isin(test_equations)]
    input_train_df = input_df[input_df['Filename'].isin(train_equations)]

    input_train_df, input_val_df = train_test_split(input_train_df, test_size = 0.1, shuffle=True)

    train_dataset = FeynmanDataset(input_train_df, dataset_dir)
    val_dataset = FeynmanDataset(input_val_df, dataset_dir)
    test_dataset = FeynmanDataset(input_test_df, dataset_dir)

    datasets = {
        "train":train_dataset,
        "test":test_dataset,
        "valid":val_dataset
        }

    return datasets

def get_dataloaders(datasets, train_bs, test_bs):
    train_dataloader = DataLoader(datasets['train'], batch_size=train_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(datasets['valid'], batch_size=test_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    test_dataloader = DataLoader(datasets['test'], batch_size=test_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    
    dataloaders = {
        "train":train_dataloader,
        "test":test_dataloader,
        "valid":val_dataloader
        }
    
    return dataloaders

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for (src_sample, tgt_sample) in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)
        
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [None]:
class config:
    def __init__(self):
        self.input_max_len = 1000
        self.max_len = 11
        self.df_path = '/kaggle/input/gsoc-symba-task/FeynmanEquationsModified.csv'
        self.encoder_vocab = '/kaggle/input/gsoc-symba-task/encoder_vocab (1).txt'
        self.decoder_vocab = '/kaggle/input/gsoc-symba-task/decoder_vocab (2).txt'
        self.output_dir = '/kaggle/working/dataset_arrays'

In [None]:
train_df = pd.read_csv('/kaggle/input/gsoc-dataset-arrays/train_df.csv')

In [None]:
train_df.rename(columns = {'filename':'Filename'}, inplace = True)
train_df

In [None]:
datasets = get_datasets(df_target,train_df,'/kaggle/input/gsoc-dataset-arrays/dataset_arrays/')

In [None]:
dataloaders = get_dataloaders(datasets,64,64)

In [None]:
class TokenEmbedding(nn.Module):
    ''' helper Module to convert tensor of input indices into corresponding tensor of token embeddings'''
    
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class PositionalEncoding(nn.Module):
    ''' helper Module that adds positional encoding to the token embedding to introduce a notion of word order.'''
    
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        self.pos_embedding = torch.zeros((maxlen, emb_size))
        self.pos_embedding[:, 0::2] = torch.sin(pos * den)
        self.pos_embedding[:, 1::2] = torch.cos(pos * den)
        self.pos_embedding = self.pos_embedding.unsqueeze(0)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding_1', self.pos_embedding)

    def forward(self, token_embedding: Tensor):
#         print(token_embedding.shape)
        token_embedding = token_embedding.to('cuda:0')
        self.pos_embedding = self.pos_embedding.to('cuda:0')
        return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])

    
class LinearPointEmbedder(nn.Module):
    def __init__(self, vocab_size: int, input_emb_size, emb_size, max_input_points,dropout =0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, input_emb_size)
        self.emb_size = emb_size
        self.input_size = max_input_points*input_emb_size
        self.fc1 = nn.Linear(self.input_size, emb_size)
        self.fc2 = nn.Linear(emb_size, emb_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, tokens):
        out = self.embedding(tokens.long()) * math.sqrt(self.emb_size)
        bs, n = out.shape[0], out.shape[1]
        out = out.view(bs, n, -1)
        out = self.activation(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return out
    

class Model_seq2seq(nn.Module):
    '''Seq2Seq Network'''
    
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 input_emb_size: int,
                 max_input_points: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,):
        super(Model_seq2seq, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = LinearPointEmbedder(src_vocab_size, input_emb_size, emb_size, max_input_points)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.src_tok_emb(src)
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))

        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.src_tok_emb(src), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

In [None]:
class AverageMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt, device):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    src_padding_mask = (torch.zeros((src.shape[0], src_seq_len), device=device)).type(torch.bool)
    tgt_padding_mask = (tgt == PAD_IDX)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

def sequence_accuracy(y_pred, y_true):

    count = 0
    total = len(y_pred)
    for (predicted_tokens, original_tokens) in zip(y_pred, y_true):
        original_tokens = original_tokens.tolist()
        predicted_tokens = predicted_tokens.tolist()
        if original_tokens == predicted_tokens:
            count = count+1

    return count/total

In [None]:
class Trainer:
    """
    Trainer class for training and evaluating a PyTorch model.
    """
    def __init__(self, config, dataloaders):
        """
        Initialize Trainer object.

        Args:
        - config: Configuration object containing training parameters
        - dataloaders: Dictionary containing data loaders for train, validation, and test sets
        """
        self.config = config
        self.device = torch.device(self.config.device)
        self.dataloaders = dataloaders

        seed_everything(self.config.seed)

        self.scaler = torch.cuda.amp.GradScaler()
        if self.config.use_half_precision:
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        # Initialize model, optimizer, scheduler, and criterion
        self.model = self.get_model()
        self.model.to(self.device)
        self.optimizer = self.get_optimizer()
        self.scheduler = self.get_scheduler()
        self.criterion = self.get_criterion()

        # Initialize training-related variables
        self.current_epoch = 0
        self.best_accuracy = -1
        self.best_val_loss = 1e6
        self.train_loss_list = []
        self.valid_loss_list = []
        self.valid_accuracy_tok_list = []

        # Create directory for saving logs
        self.logs_dir = os.path.join(self.config.root_dir, self.config.experiment_name)
        os.makedirs(self.logs_dir, exist_ok=True)

    def get_model(self):
        """
        Initialize and return the model based on the configuration.
        """
        model = Model_seq2seq(num_encoder_layers=self.config.num_encoder_layers,
                          num_decoder_layers=self.config.num_decoder_layers,
                          emb_size=self.config.embedding_size,
                          nhead=self.config.nhead,
                          src_vocab_size=self.config.src_vocab_size,
                          tgt_vocab_size=self.config.tgt_vocab_size,
                          input_emb_size=self.config.input_emb_size,
                          max_input_points=self.config.max_input_points,
                          )
        
        return model

    def get_optimizer(self):
        """
        Initialize and return the optimizer based on the configuration.
        """
        optimizer_parameters = self.model.parameters()

        if self.config.optimizer_type == "sgd":
            optimizer = torch.optim.SGD(optimizer_parameters, lr=self.config.optimizer_lr, momentum=self.config.optimizer_momentum,)
        elif self.config.optimizer_type == "adam":
            optimizer = torch.optim.Adam(optimizer_parameters, lr=self.config.optimizer_lr, eps=1e-8, weight_decay=self.config.optimizer_weight_decay)
        elif self.config.optimizer_type == "adamw":
            optimizer = torch.optim.AdamW(optimizer_parameters, lr=self.config.optimizer_lr, eps=1e-8, weight_decay=self.config.optimizer_weight_decay)
        else:
            raise NotImplementedError
        
        return optimizer
    
    def get_scheduler(self):
        """
        Initialize and return the learning rate scheduler based on the configuration.
        """
        if self.config.scheduler_type == "multi_step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.scheduler_milestones, gamma=self.config.scheduler_gamma)
        elif self.config.scheduler_type == "reduce_lr_on_plateau":
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', patience=2)
        elif self.config.scheduler_type == "cosine_annealing_warm_restart":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, self.config.T_0, self.config.T_mult)
        elif self.config.scheduler_type == "none":
            scheduler = None
        else:
            raise NotImplementedError
        
        return scheduler

    
    def get_criterion(self):
        """
        Initialize and return the loss function based on the configuration.
        """
        if self.config.criterion == "cross_entropy":
            criterion = torch.nn.CrossEntropyLoss()
        else:
            raise NotImplementedError
        
        return criterion

    def train_one_epoch(self):
        """
        Train the model for one epoch.
        """
        self.model.train()
        pbar = tqdm(self.dataloaders['train'], total=len(self.dataloaders['train']))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] Train")
        running_loss = AverageMeter()
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)

            bs = src.size(0)

            with torch.autocast(device_type='cuda', dtype=self.dtype):
                src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:, :-1], self.device)
                logits = self.model(src, tgt[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))
                
            running_loss.update(loss.item(), bs)
            pbar.set_postfix(loss=running_loss.avg)
            
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()

            if self.config.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()

        return running_loss.avg

    def evaluate(self, phase):
        """
        Evaluate the model on validation or test data.

        Args:
        - phase: Phase of evaluation, either "valid" or "test".

        Returns:
        - Tuple containing average token accuracy and average loss.
        """
        self.model.eval()
        
        pbar = tqdm(self.dataloaders[phase], total=len(self.dataloaders[phase]))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] {phase.capitalize()}")
        running_loss = AverageMeter()
        running_acc_tok = AverageMeter()
        
        
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.to(self.device)
            bs = src.size(0)
            
            with torch.autocast(device_type='cuda', dtype=self.dtype):
                if self.config.model_name == "seq2seq_transformer":
                    with torch.no_grad():
                        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt[:, :-1], self.device)
                        logits = self.model(src, tgt[:, :-1], src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
                        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))
                else:
                    with torch.no_grad():
                        logits = self.model(src, tgt[:, :-1])
                        loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt[:, 1:].reshape(-1))

            y_pred = torch.argmax(logits.reshape(-1, logits.shape[-1]), 1)
            correct = (y_pred == tgt[:, 1:].reshape(-1)).cpu().numpy().mean()
            
            running_loss.update(loss.item(), bs)
            running_acc_tok.update(correct, bs)
            
        return running_acc_tok.avg, running_loss.avg

    def train(self):
        """
        Main training loop.
        """
        start_epoch = self.current_epoch
        for self.current_epoch in range(start_epoch, self.config.epochs):
            training_loss = self.train_one_epoch() 
            valid_accuracy_tok, valid_loss = self.evaluate("valid")
            
            self.train_loss_list.append(round(training_loss, 7))
            self.valid_loss_list.append(round(valid_loss, 7))
            self.valid_accuracy_tok_list.append(round(valid_accuracy_tok, 7))
            
            if self.scheduler == "multi_step":
                self.scheduler.step()
            elif self.scheduler == "reduce_lr_on_plateau":
                self.scheduler.step(valid_loss)
                
            if valid_loss<self.best_val_loss:
                self.best_val_loss = valid_loss

            self.save_model("last_checkpoint.pth")

            if valid_accuracy_tok > self.best_accuracy:
                print(f"==> Best Accuracy improved to {round(valid_accuracy_tok, 7)} from {self.best_accuracy}")
                self.best_accuracy = round(valid_accuracy_tok, 7)
                self.save_model("best_checkpoint.pth")
            
            self.log_results()

        
    def save_model(self, file_name):
        """
        Save model checkpoints.
        """
        state_dict = self.model.state_dict()
        torch.save({
                "epoch": self.current_epoch + 1,
                "state_dict": state_dict,
                'optimizer': self.optimizer.state_dict(),
                "train_loss_list": self.train_loss_list,
                "valid_loss_list": self.valid_loss_list,
                "valid_accuracy_tok_list": self.valid_accuracy_tok_list,
            }, os.path.join(self.logs_dir, file_name))

    def log_results(self):
        """
        Log training results to a CSV file.
        """
        data_list = [self.train_loss_list, self.valid_loss_list, self.valid_accuracy_tok_list]
        column_list = ['train_losses', 'valid_losses', 'token_valid_accuracy']
        
        df_data = np.array(data_list).T
        df = pd.DataFrame(df_data, columns=column_list)
        df.to_csv(os.path.join(self.logs_dir, "logs.csv"))
        
    def test_seq_acc(self):
        """
        Evaluate model's sequence accuracy on test data.
        """
        file = os.path.join(self.logs_dir, "best_checkpoint.pth")
        state_dict = torch.load(file, map_location=self.device)['state_dict']
        self.model.load_state_dict(state_dict)
        
        test_accuracy_tok, _ = self.evaluate("test")
        
        predictor = Predictor(self.config)
        
        print("Calculating Sequence Accuracy for predictions (1 example per batch)")
        pbar = tqdm(self.dataloaders["test"], total=len(self.dataloaders["test"]))
        pbar.set_description(f"Test")
        
        y_preds = []
        y_true = []
        for src, tgt in pbar:
            src = src.to(self.device)
            tgt = tgt.numpy()
            bs = src.size(0)
            y_pred = predictor.predict(src[0].unsqueeze(0)) #only one example from each batch
            y_preds.append(y_pred.cpu().numpy())
            y_true.append(np.trim_zeros(tgt[0]))

        test_accuracy_seq = sequence_accuracy(y_true, y_preds)
        f= open(os.path.join(self.logs_dir, "score.txt"),"w+")
        f.write(f"Token Accuracy = {(round(test_accuracy_tok, 7))}\n")
        f.write(f"Sequence Accuracy = {(round(test_accuracy_seq, 7))}\n")
        f.close()
        print(f"Test Accuracy: {round(test_accuracy_tok, 7)} | Valid Accuracy: {self.best_accuracy}") 
        print(f"Test Sequence Accuracy: {test_accuracy_seq}")

In [None]:
BOS_IDX = 1
EOS_IDX = 58  #69

class Predictor:
    """
    Predictor class for generating predictions using a trained model.
    """
    def __init__(self, config):
        """
        Initialize Predictor object.

        Args:
        - config: Configuration object containing model parameters
        """
        self.config = config
        self.device = torch.device(self.config.device)

        # Get the model
        self.model = self.get_model()
        self.model.to(self.device)

        # Load the best checkpoint
        self.logs_dir = os.path.join(self.config.root_dir, self.config.experiment_name)
        path = os.path.join(self.logs_dir, "best_checkpoint.pth")
        self.model.load_state_dict(torch.load(path)["state_dict"])
        
        # Set the model to evaluation mode
        self.model.eval()
        
    def get_model(self):
        model = Model_seq2seq(num_encoder_layers=self.config.num_encoder_layers,
                      num_decoder_layers=self.config.num_decoder_layers,
                      emb_size=self.config.embedding_size,
                      nhead=self.config.nhead,
                      src_vocab_size=self.config.src_vocab_size,
                      tgt_vocab_size=self.config.tgt_vocab_size,
                      input_emb_size=self.config.input_emb_size,
                      max_input_points=self.config.max_input_points,
                      )
        
        return model
    
    def generate_square_subsequent_mask(self, sz, device):
        mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def greedy_decode(self, src, src_mask, max_len, start_symbol, src_padding_mask=None):
        src = src.to(self.device)
        src_mask = src_mask.to(self.device)
        src_padding_mask = src_padding_mask.to(self.device)
        dim = 1

        memory = self.model.encode(src, src_mask)
        memory = memory.to(self.device)
        dim = 1
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(self.device)
        for i in range(max_len-1):

            tgt_mask = (self.generate_square_subsequent_mask(ys.size(1), self.device).type(torch.bool)).to(self.device)

            out = self.model.decode(ys, memory, tgt_mask)
            prob = self.model.generator(out[:, -1])

            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.item()

            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=dim)
            if next_word == EOS_IDX:
                break

        return ys


    def predict(self, x):
        self.model.eval()
        src = x
        num_tokens = src.shape[1]

        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        src_padding_mask = torch.zeros(1, num_tokens).type(torch.bool)
        tgt_tokens = self.greedy_decode(src, src_mask, max_len=256, start_symbol=BOS_IDX, src_padding_mask=src_padding_mask).flatten()

        return tgt_tokens

In [None]:
class Config:
    experiment_name: Optional[str] = "default"
    root_dir: Optional[str] = "./"
    device: Optional[str] = "cuda:0"
    save_at_epochs: Optional[list] = field(default_factory=list)
    debug: Optional[bool] = False
        
    #training parameters
    epochs: Optional[int] = 10
    seed: Optional[int] = 42
    use_half_precision: Optional[bool] = True

    #data loader parameters
    train_shuffle: Optional[bool] = True
    test_shuffle: Optional[bool] = False
    training_batch_size: Optional[int] = 1024
    test_batch_size: Optional[int] = 2048
    num_workers: Optional[int] = 4
    pin_memory: Optional[bool] = True
        
    # scheduler parameters
    scheduler_type: Optional[str] = "cosine_annealing_warm_restart" # multi_step or none
    T_0: Optional[int] = 10
    T_mult: Optional[int] = 1

    # optimizer parameters
    optimizer_type: Optional[str] = "adam" # sgd or adam
    optimizer_lr: Optional[float] = 0.0001   
    optimizer_momentum: Optional[float] = 0.9
    optimizer_weight_decay: Optional[float] = 0.0001
    optimizer_no_decay: Optional[list] = field(default_factory=list)
    clip_grad_norm: Optional[float] = -1
        
    # Model Parameters
    model_name: Optional[str] = "seq2seq_transformer"
#     model_name: Optional[str] = "LongFormerEncoderDecoder"
    embedding_size: Optional[int] = 64
    hidden_dim: Optional[int] = 64
    nhead: Optional[int] = 8
    num_encoder_layers: Optional[int] = 2
    num_decoder_layers: Optional[int] = 6
    dropout: Optional[int] = 0.2
    pretrain: Optional[bool] = False
    input_emb_size: Optional[int] = 64
    max_input_points: Optional[int] = 33
    src_vocab_size: Optional[int] = 1104
    tgt_vocab_size: Optional[int] = 59

    # Criterion
    criterion: Optional[str] = "cross_entropy"
        
    def print_config(self):
        print("="*50+"\nConfig\n"+"="*50)
        for field in fields(self):
            print(field.name.ljust(30), getattr(self, field.name))
        print("="*50)

    def save(self, root_dir):
        path = root_dir + "/config.txt"
        with open(path, "w") as f:
            f.write("="*50+"\nConfig\n"+"="*50 + "\n")
            for field in fields(self):
                f.write(field.name.ljust(30) + ": " + str(getattr(self, field.name)) + "\n")
            f.write("="*50) 

In [None]:
config = Config

In [None]:
trainer = Trainer(config,dataloaders)
trainer.train()
trainer.test_seq_acc()