In [1]:
!pip install deap



In [2]:
import numpy as np
import pandas as pd
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from deap import base, creator, tools, gp, algorithms
import signal
import time
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from collections import OrderedDict
from IPython.display import clear_output
import random
import operator
import math
import re
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass, field, fields
import os
from math import isclose, sqrt, log
from gym import spaces
from types import SimpleNamespace
from tqdm import tqdm
import pickle 
import ast
seed = 42

In [3]:
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
df_target = pd.read_csv('/kaggle/input/gsoc-symba-task/FeynmanEquations.csv')
df_target.head()

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,2.0,z1,3.0,4.0,z2,1.0,2.0,,,


In [6]:
df_target = df_target.dropna(subset=['Filename'])
df_target

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,2.0,z1,3.0,4.0,z2,1.0,2.0,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,III.15.14,96.0,m,(h/(2*pi))**2/(2*E_n*d**2),3.0,h,1.0,5.0,E_n,1.0,...,,,,,,,,,,
96,III.15.27,97.0,k,2*pi*alpha/(n*d),3.0,alpha,1.0,5.0,n,1.0,...,,,,,,,,,,
97,III.17.37,98.0,f,beta*(1+alpha*cos(theta)),3.0,beta,1.0,5.0,alpha,1.0,...,,,,,,,,,,
98,III.19.51,99.0,E_n,-m*q**4/(2*(4*pi*epsilon)**2*(h/(2*pi))**2)*(1...,4.0,m,1.0,5.0,q,1.0,...,,,,,,,,,,


In [7]:
df_target.loc[21, '# variables'] = 3
df_target.loc[22, '# variables'] = 4
df_target.loc[38, '# variables'] = 4
df_target.loc[82, '# variables'] = 3
df_target.loc[90, '# variables'] = 4
df_target.loc[98, '# variables'] = 5
df_target.loc[18,'Filename'] = 'I.15.10'
df_target.loc[49,'Filename'] = 'I.48.20'
df_target.loc[61,'Filename'] = 'II.11.7'

In [8]:
variables = [
        'x',
        'y',
        'z',
        'a',
        'b',
        'c',
        'd',
        'E',
        'reg_prop',
        'm_s',
        'm_u'
        's_0',
        's_1',
        's_2',
        's_3',
        's_4',
        's_5',
        's_6',
        's_7',
        's_8',
        's_9',
        's_10',
        's_11',
        's_12',
        's_13',
        's_14',
        's_15',
        's_16',
        's_17',
        's_18',
        's_19',
        's_20',
        's_21',
        's_22',
        's_23',
        's_24',
        's_25',
        's_26',
        's_27',
        's_28',
        's_29',
        's_30',
        's_31',
        's_32',
        's_33',
        's_34',
        's_35',
        's_36',
        's_37',
        's_38',
        's_39',
        's_40',
        's_41',
        's_42',
        's_43',
        's_44',
        's_45',
        ]

In [9]:
operators = {
    # Elementary functions
    sp.Add: 'add',
    sp.Mul: 'mul',
    sp.Pow: 'pow',
    sp.exp: 'exp',
    sp.log: 'ln',
    sp.Abs: 'abs',
    sp.sign: 'sign',
#     sp.Sub: 'sub',
#     sp.Div: 'div',
    # Trigonometric Functions
    sp.sin: 'sin',
    sp.cos: 'cos',
    sp.tan: 'tan',
    sp.cot: 'cot',
    sp.sec: 'sec',
    sp.csc: 'csc',
    # Trigonometric Inverses
    sp.asin: 'asin',
    sp.acos: 'acos',
    sp.atan: 'atan',
    sp.acot: 'acot',
    sp.asec: 'asec',
    sp.acsc: 'acsc',
    # Hyperbolic Functions
    sp.sinh: 'sinh',
    sp.cosh: 'cosh',
    sp.tanh: 'tanh',
    sp.coth: 'coth',
    sp.sech: 'sech',
    sp.csch: 'csch',
    # Hyperbolic Inverses
    sp.asinh: 'asinh',
    sp.acosh: 'acosh',
    sp.atanh: 'atanh',
    sp.acoth: 'acoth',
    sp.asech: 'asech',
    sp.acsch: 'acsch',
    sp.Min: 'min',
    # Derivative
    sp.Derivative: 'derivative',
}

operators_inv = {operators[key]: key for key in operators}
operators_inv.update({'sub': lambda x, y: x - y,'div': lambda x, y: x / y})
operators_inv["mul("] = sp.Mul
operators_inv["add("] = sp.Add

operators_nargs = {
    # Elementary functions
    'mul(': -1,
    'add(': -1,
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}

masses_strings = [
        "m_e",
        "m_u",
        "m_d",
        "m_s",
        "m_c",
        "m_b",
        "m_t",
        ]

masses = [sp.Symbol(x) for x in masses_strings]

# these will be converted to the numbers format in `format_number`
integers_types = [
        sp.core.numbers.Integer,
        sp.core.numbers.One,
        sp.core.numbers.NegativeOne,
        sp.core.numbers.Zero,
        ]

numbers_types = integers_types + [sp.core.numbers.Rational,
        sp.core.numbers.Half, sp.core.numbers.Exp1, sp.core.numbers.Pi, "<class 'sympy.core.numbers.Pi'>",
        sp.core.numbers.ImaginaryUnit]

# don't continue evaluating at these, but stop
atoms = [
        str,
        sp.core.symbol.Symbol,
        sp.core.numbers.Exp1,
        sp.core.numbers.Pi,
        "<class 'sympy.core.numbers.Pi'>",
        ] + numbers_types


Inverse_trig = {
    'arcsin': 'asin',
    'arccos': 'acos',
    'arctan': 'atan',
    'arccot': 'acot',
    'arcsec': 'asec',
    'arccsc': 'acsc',
    'arcsinh': 'asinh',
    'arccosh': 'acosh',
    'arctanh': 'atanh',
    'arccoth': 'acoth',
    'arcsech': 'asech',
    'arccsch': 'acsch',         
}

In [10]:
for i in range(len(df_target)):
    formula = df_target['Formula'][i]
    for a in Inverse_trig.keys():
        df_target.loc[i,'Formula'] = re.sub(a,Inverse_trig[a],formula)

In [11]:
df_target

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,2.0,z1,3.0,4.0,z2,1.0,2.0,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,III.15.14,96.0,m,(h/(2*pi))**2/(2*E_n*d**2),3.0,h,1.0,5.0,E_n,1.0,...,,,,,,,,,,
96,III.15.27,97.0,k,2*pi*alpha/(n*d),3.0,alpha,1.0,5.0,n,1.0,...,,,,,,,,,,
97,III.17.37,98.0,f,beta*(1+alpha*cos(theta)),3.0,beta,1.0,5.0,alpha,1.0,...,,,,,,,,,,
98,III.19.51,99.0,E_n,-m*q**4/(2*(4*pi*epsilon)**2*(h/(2*pi))**2)*(1...,5.0,m,1.0,5.0,q,1.0,...,,,,,,,,,,


In [12]:
df_target.iloc[0]

Filename                            I.6.2a
Number                                 1.0
Output                                   f
Formula        exp(-theta**2/2)/sqrt(2*pi)
# variables                            1.0
v1_name                              theta
v1_low                                 1.0
v1_high                                3.0
v2_name                                NaN
v2_low                                 NaN
v2_high                                NaN
v3_name                                NaN
v3_low                                 NaN
v3_high                                NaN
v4_name                                NaN
v4_low                                 NaN
v4_high                                NaN
v5_name                                NaN
v5_low                                 NaN
v5_high                                NaN
v6_name                                NaN
v6_low                                 NaN
v6_high                                NaN
v7_name    

In [13]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

In [14]:
class Tokenizer:
    def __init__(self, vocab_path):
        self.vocab_path = vocab_path
        self.word2id = {}
        self.id2word = {}

        with open(vocab_path) as file:
            words = map(lambda x: x.rstrip('\n'), file.readlines())

        for (n, word) in enumerate(words):
            self.word2id[word] = n
            self.id2word[n] = word 

    def encode(self, lst):
        return np.array([[self.word2id[j] for j in i] for i in lst], dtype=np.ushort)

    def decode(self, lst):
        return [[self.id2word[j] for j in i] for i in lst]

In [15]:
class Encoder_tokeniser(Tokenizer):
    def __init__(self,float_precision,mantissa_len,max_exponent,vocab_path,max_len = 10):
        super().__init__(vocab_path)
        
        self.max_len = max_len
        self.float_precision = float_precision
        self.mantissa_len = mantissa_len
        self.max_exponent = max_exponent
        self.base = (self.float_precision + 1) // self.mantissa_len
        self.max_token = 10 ** self.base
        
    def pre_tokenize(self, data):
        arr = np.array([i.split() for i in data], dtype=np.float32)
        permutation = [-1] + [i for i in range(arr.shape[1]-1)]
        arr = np.pad(arr[:, permutation], ((0,0), (0, self.max_len - arr.shape[1])), mode="constant", constant_values=[-np.inf])
        return arr
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.encode_float(out)
        out = self.encode(out)
        return out
        
    def encode_float(self,values):
        if len(values.shape) == 1:
            seq = []
            value = values
            for val in value:
                if val in [-np.inf, np.inf]:
                    seq.extend(['<pad>']*3)
                    continue
                
                sign = "+" if val >= 0 else "-"
                m, e = (f"%.{self.float_precision}e" % val).split("e")
                i, f = m.lstrip("-").split(".")
                i = i + f
                tokens = chunks(i, self.base)
                expon = int(e) - self.float_precision
                if expon < -self.max_exponent:
                    tokens = ["0" * self.base] * self.mantissa_len
                    expon = int(0)
                seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)])
            return seq
        else:
            seqs = [self.encode_float(values[0])]
            N = values.shape[0]
            for n in range(1, N):
                seqs += [self.encode_float(values[n])]
        return seqs
    def decode_float(self,seq):
        if len(seq) == 0:
            return None
        decoded_seq = []
        for val in chunks(seq, 2 + self.mantissa_len):
            for x in val:
                if x[0] not in ["-", "+", "E", "N"]:
                    return np.nan
            try:
#                 print(val)
                sign = 1 if val[0] == "+" else -1
                mant = ""
                for x in val[1:-1]:
                    mant += x[1:]
                mant = int(mant)
#                 print(mant)
                exp = int(val[-1][1:])
#                 print(exp)
                value = sign * mant * (10 ** exp)
                value = float(value)
            except Exception:
                value = np.nan
            decoded_seq.append(value)
        return decoded_seq

In [16]:
def flatten(l, ltypes=(list, tuple)):
    """
    flatten a python list
    from http://rightfootin.blogspot.com/2006/09/more-on-python-flatten.html
    """
    ltype = type(l)
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        i += 1
    return ltype(l)

In [17]:
def sympy_to_prefix(expression):
    """
    Recursively go from a sympy expression to a prefix notation.
    Returns a flat list of tokens.
    """
    return flatten(sympy_to_prefix_rec(expression, []))

def sympy_to_prefix_rec(expression, ret):
    """
    Recursively go from a sympy expression to a prefix notation.
    The operators all get converted to their names in the array `operators`.
    Returns a nested list, where the nesting basically stands for parentheses.
    Since in prefix notation with a fixed number of arguments for each function (given in `operators_nargs`),
    parentheses are not needed, we can flatten the list later.
    """
    if expression in [sp.core.numbers.Pi, sp.core.numbers.ImaginaryUnit]:
        f = expression
    else:
        f = expression.func
    if f in atoms:
        if type(expression) in numbers_types:
            return ret + format_number(expression)
        return ret+[str(expression)]
    f_str = operators[f]
    f_nargs = operators_nargs[f_str]
    args = expression.args
    if len(args) == 1 & f_nargs == 1:
        ret = ret + [f_str]
        return sympy_to_prefix_rec(args[0], ret)
    if len(args) == 2:
        ret = ret + [f_str, sympy_to_prefix_rec(args[0], []), sympy_to_prefix_rec(args[1], [])]
    if len(args) > 2:
        args = list(map(lambda x: sympy_to_prefix_rec(x, []), args))
        ret = ret + repeat_operator_until_correct_binary(f_str, args)
    return ret
def repeat_operator_until_correct_binary(op, args, ret=[]):
    """
    sympy is not strict enough with the number of arguments.
    E.g. multiply takes a variable number of arguments, but for
    prefix notation it needs to ALWAYS have exactly 2 arguments

    This function is only for binary operators.

    Here I choose the convention as follows:
        1 + 2 + 3 --> + 1 + 2 3

    This is the same convention as in https://arxiv.org/pdf/1912.01412.pdf
    on page 15.

    input:
        op: in string form as in the list `operators`
        args: [arg1, arg2, ...] arguments of the operator, e.c. [1, 2, x**2,
                ...]. They can have other things to be evaluated in them
        ret: the list you already have. Usually []. Watch out, I think one has to explicitely give [],
            otherwise somehow the default value gets mutated, which I find a strange python behavior.
    """

    is_binary = operators_nargs[op] == 2
    assert is_binary, "repeat_operator_until_correct_binary only takes binary operators"

    if len(args) == 0:
        return ret
    elif len(ret) == 0:
        ret = [op] + args[-2:]
        args = args[:-2]
    else:
        ret = [op] + args[-1:] + ret
        args = args[:-1]

    return repeat_operator_until_correct_binary(op, args, ret)

def format_number(number):
    if type(number) in integers_types:
        return format_integer(number)
    elif type(number) == sp.core.numbers.Rational:
        return format_rational(number)
    elif type(number) == sp.core.numbers.Half:
        return format_half()
    elif type(number) == sp.core.numbers.Exp1:
        return format_exp1()
    elif type(number) == sp.core.numbers.Pi:
        return format_pi()
    elif type(number) == sp.core.numbers.ImaginaryUnit:
        return format_imaginary_unit()
    else:
        raise NotImplementedError

def format_exp1():
    return ['E']

def format_pi():
    return ['pi']

def format_imaginary_unit():
    return ['I']

def format_half():
    """
    for some reason in sympy 1/2 is its own object and not a rational.
    This function formats it correctly like `format_rational`
    """
    return ['mul'] + ['s+', '1'] + ['pow'] + ['s+', '2'] + ["s-", "1"]

def format_rational(number):
    # for some reason number.p is a string
    p = sp.sympify(number.p)
    q = sp.sympify(number.q)
    return ['mul'] + format_integer(p) + ['pow'] + format_integer(q) + ['s-', '1']

def format_integer(integer):
    """take a sympy integer and format it as in
    https://arxiv.org/pdf/1912.01412.pdf

    input:
        integer: a `sympy.Integer` object, e.g. `sympy.Integer(-1)`

    output:
        [sign_token, digit0, digit1, ...]
        where sign_token is 's+' or 's-'

    Example:
        format_integer(sympy.Integer(-123))
        >> ['s-', '1', '2', '3']

    Implementation notes:
    Somehow Integer inherits from Rational in Sympy and a rational is p/q,
    so integer.p is used to extract the number.
    """
    # plus_sign = "s+"
    plus_sign = "s+"
    minus_sign = "s-"
    abs_num = abs(integer.p)
    is_neg = integer.could_extract_minus_sign()
    digits = list(str(abs_num))
    # digits = [str(abs_num)]

    if is_neg:
        ret = [minus_sign] + digits
    else:
        ret = [plus_sign] + digits

    return ret

In [18]:
def parse_if_str(x):
    if isinstance(x, str):
        return sp.parsing.parse_expr(x)
    return x

In [19]:
def rightmost_string_pos(expr_arr, pos=-1):
    if isinstance(expr_arr[pos], str):
        return len(expr_arr)+pos
    else:
        return rightmost_string_pos(expr_arr, pos-1)


def rightmost_operand_pos(expr, pos):
    operators = list(operators_inv.keys()) + ["s+", "s-"] + variables
    if expr[pos] in operators:
        return pos
    else:
        return rightmost_operand_pos(expr, pos-1)

def unformat_integer(arr):
    """
    inverse of the function format_integer.

    input:
        arr: array of strings just as the output of format_integer. E.g. ["s+", "4", "2"]

    output:
        the correspinding sympy integer, e.g. sympy.Integer(42) in the above example.

    The sign tokens are "s+" for positive integers and "s-" for negative. 0 comes with "s+", but does not matter.

    """
    sign_token = arr[0]
    ret = "-" if sign_token == "s-" else ""
    for s in arr[1:]:
        ret += str(s)

    return sp.parsing.parse_expr(ret)

In [20]:
def prefix_to_sympy(expr_arr):
    if len(expr_arr) == 1:
        return parse_if_str(expr_arr[0])
    op_pos = rightmost_operand_pos(expr_arr,len(expr_arr) - 1)
    if (op_pos == -1) | (op_pos == len(expr_arr) - 1 ):
        print("something went wrong, operator should not be at end of array")
    op = expr_arr[op_pos]
    if op in operators_inv.keys():
        num_args = operators_nargs[op]
        op = operators_inv[op]
        args = expr_arr[op_pos+1:op_pos+num_args+1]
        args = [parse_if_str(a) for a in args]
#         print(op,*args)
        func = op(*args)
        expr = expr_arr[0:op_pos] + [func] + expr_arr[op_pos+num_args+1:]
        return prefix_to_sympy(expr)

    elif (op == 's+') | (op == "s-"):
        # int_end_pos = rightmost_int_pos(expr_arr)
        string_end_pos = rightmost_string_pos(expr_arr)
        integer = unformat_integer(expr_arr[op_pos:string_end_pos+1])
        expr_arr_new = expr_arr[0:op_pos] + [integer] + expr_arr[string_end_pos+1:]
        return prefix_to_sympy(expr_arr_new)
    elif op in variables:
        op = sp.sympify(op)
        expr_arr_new = expr_arr[0:op_pos] + [op] + expr_arr[op_pos+1:]
        return prefix_to_sympy(expr_arr_new)

    return op

In [22]:
functional_to_sympy = {
    'add': sp.Add,
    'sub': lambda x, y: sp.Add(x, -y),
    'mul': sp.Mul,
    'div': lambda x, y: sp.Mul(x, sp.Pow(y, -1)),
    'pow': sp.Pow,
    'sin': sp.sin,
    'cos': sp.cos,
    'tan': sp.tan,
    'abs': sp.Abs,
    'max': sp.Max,
    'min': sp.Min,
    'tanh': sp.tanh,
    'protected_div': lambda x, y: sp.Mul(x, sp.Pow(y, -1)),
    'protected_pow': sp.Pow,
    'protected_exp': sp.exp,
    'protected_log': sp.log,
    'protected_sqrt': sp.sqrt,
    'pi': sp.pi
}

In [23]:
def convert_to_sympy_expression(func_form):
    """
    Convert a functional form string back into a SymPy expression.
    
    Parameters:
    - func_form: String in functional form (e.g., 'mul(s_1, s_2, s_4, pow(s_3, -1))')
    
    Returns:
    - SymPy expression.
    """
    def _parse_expr(expression):
        for func, sympy_func in functional_to_sympy.items():
            if expression.startswith(f"{func}(") and expression.endswith(")"):
                # Extract the arguments
                args_str = expression[len(func) + 1:-1]
                args = _split_args(args_str)
                # Recursively parse the arguments
                parsed_args = [_parse_expr(arg) for arg in args]
                return sympy_func(*parsed_args)
        
        # If it's a variable (e.g., s_1), return it as a SymPy symbol
        if expression.startswith('s_'):
            return sp.Symbol(expression)
        
        # If it's a number, parse it as a SymPy number
        try:
            return sp.sympify(expression)
        except sp.SympifyError:
            raise ValueError(f"Unsupported expression format: {expression}")

    def _split_args(args_str):
        """
        Split the argument string into individual arguments, considering nested functions.
        """
        args = []
        current_arg = []
        depth = 0

        for char in args_str:
            if char == ',' and depth == 0:
                args.append(''.join(current_arg).strip())
                current_arg = []
            else:
                current_arg.append(char)
                if char == '(':
                    depth += 1
                elif char == ')':
                    depth -= 1

        args.append(''.join(current_arg).strip())
        return args

    return _parse_expr(func_form)

In [24]:
class DecoderTokenizer(Tokenizer):
    def __init__(self, vocab_path):
        super().__init__(vocab_path)

    def equation_encoder(self, data):
        return [sympy_to_prefix(expr) for expr in data]
    
    def equation_decoder(self, data):
        return [prefix_to_sympy(lst) for lst in data]

    def pre_tokenize(self, data):
        return data
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.equation_encoder(out)
        out = [['<bos>'] + i + ['<eos>'] for i in out]
        out = self.encode(out)
        return out
    
    def reverse_tokenize(self, data):
        out = self.decode(data)
        out = self.equation_decoder(out)
        return out

In [25]:
INPUT_DIR = '/kaggle/input/gsoc-symba-task/Feynman_with_units/Feynman_with_units/'

In [26]:
class FeynmanDataset(Dataset):
    def __init__(self, df, dataset_dir):
        super().__init__()
        self.df = df
        self.dataset_dir = dataset_dir
        self.prefix_equations = np.load(os.path.join(dataset_dir, "prefix_equations.npy"))
        # prefix_equations = []

        prefix_equations = []
        for prefix in self.prefix_equations:
            prefix_equations.append(np.trim_zeros(prefix))
        self.prefix_equations = prefix_equations

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(os.path.join(self.dataset_dir, row['Filename']), f"{row['data_num']}.npy")
        x = np.load(path).astype(np.int32)

        path = os.path.join(self.dataset_dir, f"{row['Filename']}.npy")
        y = self.prefix_equations[int(row['number']) - 1]

        return (torch.Tensor(x).long(), torch.Tensor(y).long())

In [38]:
def get_datasets(df, input_df, dataset_dir):
    train_df, test_df = train_test_split(df, test_size=0.1,random_state = 42)
    train_equations = train_df['Filename'].tolist()
    test_equations = test_df['Filename'].tolist()

    input_test_df = input_df[input_df['Filename'].isin(test_equations)]
    input_train_df = input_df[input_df['Filename'].isin(train_equations)]

    input_train_df, input_val_df = train_test_split(input_train_df, test_size = 0.1, shuffle=True)

    train_dataset = FeynmanDataset(input_train_df, dataset_dir)
    val_dataset = FeynmanDataset(input_val_df, dataset_dir)
    test_dataset = FeynmanDataset(input_test_df, dataset_dir)

    datasets = {
        "train":train_dataset,
        "test":test_dataset,
        "valid":val_dataset
        }

    return datasets

In [39]:
class config:
    def __init__(self):
        self.input_max_len = 1000
        self.max_len = 11
        self.df_path = '/kaggle/input/gsoc-symba-task/FeynmanEquationsModified.csv'
        self.encoder_vocab = '/kaggle/input/gsoc-symba-task/encoder_vocab (1).txt'
        self.decoder_vocab = '/kaggle/input/gsoc-symba-task/decoder_vocab (2).txt'
        self.output_dir = '/kaggle/working/dataset_arrays'

In [40]:
train_df = pd.read_csv('/kaggle/input/gsoc-dataset-arrays/train_df.csv')
train_df.rename(columns = {'filename':'Filename'}, inplace = True)
train_df

Unnamed: 0,Filename,data_num,number
0,I.6.2a,0,1
1,I.6.2a,1,1
2,I.6.2a,2,1
3,I.6.2a,3,1
4,I.6.2a,4,1
...,...,...,...
99995,III.21.20,995,100
99996,III.21.20,996,100
99997,III.21.20,997,100
99998,III.21.20,998,100


In [41]:
datasets = get_datasets(df_target,train_df,'/kaggle/input/gsoc-dataset-arrays/dataset_arrays/')

In [42]:
class TokenEmbedding(nn.Module):
    ''' helper Module to convert tensor of input indices into corresponding tensor of token embeddings'''
    
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class PositionalEncoding(nn.Module):
    ''' helper Module that adds positional encoding to the token embedding to introduce a notion of word order.'''
    
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        self.pos_embedding = torch.zeros((maxlen, emb_size))
        self.pos_embedding[:, 0::2] = torch.sin(pos * den)
        self.pos_embedding[:, 1::2] = torch.cos(pos * den)
        self.pos_embedding = self.pos_embedding.unsqueeze(0)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding_1', self.pos_embedding)

    def forward(self, token_embedding: Tensor):
#         print(token_embedding.shape)
        token_embedding = token_embedding.to('cuda:0')
        self.pos_embedding = self.pos_embedding.to('cuda:0')
#         token_embedding = token_embedding
#         self.pos_embedding = self.pos_embedding
        return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])

    
class LinearPointEmbedder(nn.Module):
    def __init__(self, vocab_size: int, input_emb_size, emb_size, max_input_points,dropout =0.2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, input_emb_size)
        self.emb_size = emb_size
        self.input_size = max_input_points*input_emb_size
        self.fc1 = nn.Linear(self.input_size, emb_size)
        self.fc2 = nn.Linear(emb_size, emb_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, tokens):
        out = self.embedding(tokens.long()) * math.sqrt(self.emb_size)
        bs, n = out.shape[0], out.shape[1]
        out = out.view(bs, n, -1)
        out = self.activation(self.fc1(out))
        out = self.dropout(out)
        out = self.fc2(out)
        return out
    

class Model_seq2seq(nn.Module):
    '''Seq2Seq Network'''
    
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 input_emb_size: int,
                 max_input_points: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1,):
        super(Model_seq2seq, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = LinearPointEmbedder(src_vocab_size, input_emb_size, emb_size, max_input_points)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.src_tok_emb(src)
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))

        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.src_tok_emb(src), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)
    
    def beam_search(self, src: Tensor, src_mask: Tensor, src_padding_mask: Tensor, beam_size: int = 3, max_len: int = 65,start_state = None):
        # Encode the source sequence
        memory = self.encode(src, src_mask)
        
        # Initialize the decoder input with the <sos> token (assuming 0 is the <sos> token)
        batch_size = src.size(0)
        if start_state == None:
            start_symbol = 1  # Modify according to your tokenization scheme

            # Beam search variables
            beam = [(torch.tensor([[start_symbol]], device=device), 0)]  # (sequence, score)
        else:
            beam = [(torch.tensor(start_state,device = device), 0)]
        completed_sequences = []

        for _ in range(max_len):
            candidates = []
            for seq, score in beam:
                if seq[0, -1].item() == 58:  # Assuming 1 is the <eos> token
                    completed_sequences.append((seq, score))
                    continue
                
                # Decode the current sequence
                tgt_mask = (torch.triu(torch.ones((seq.size(1), seq.size(1)), device=device))).transpose(0, 1)
                tgt_mask = tgt_mask.float().masked_fill(tgt_mask == 0, float('-inf')).masked_fill(tgt_mask == 1, float(0.0))
                tgt_emb = self.positional_encoding(self.tgt_tok_emb(seq))
                out = self.transformer.decoder(tgt_emb, memory, tgt_mask)
                logits = self.generator(out[:, -1, :])
                log_probs = torch.log_softmax(logits, dim=-1)

                # Get the top beam_size candidates
                top_log_probs, top_indices = torch.topk(log_probs, beam_size)
                for i in range(beam_size):
                    new_seq = torch.cat([seq, top_indices[:, i].unsqueeze(1)], dim=1)
                    new_score = score + top_log_probs[:, i].item()
                    candidates.append((new_seq, new_score))

            # Sort candidates by score and select the top beam_size sequences
            candidates = sorted(candidates, key=lambda x: x[1], reverse=True)
            beam = candidates[:beam_size]

        # If no sequence ended with <eos>, return the best candidate
        if not completed_sequences:
            completed_sequences = beam

        # Sort completed sequences by score and return the best one
        completed_sequences = sorted(completed_sequences, key=lambda x: x[1], reverse=True)
#         best_sequence, best_score = completed_sequences[0]

        return completed_sequences[:3]


In [43]:
config = config()

In [44]:
encoder_tokenizer = Encoder_tokeniser(2,1,100,config.encoder_vocab)
decoder_tokenizer = DecoderTokenizer(config.decoder_vocab)

In [45]:
def convert_to_functional_form(expr):
    # Ensure the input is a SymPy expression
    expr = sp.sympify(expr)
    for key,value in operators.items():
        if isinstance(expr, key):
            args = expr.args
            return f"{value}({', '.join(convert_to_functional_form(arg) for arg in args)})"
    for item in numbers_types:
        if type(expr) == item:
            return str(expr)
    if isinstance(expr, sp.Symbol):
        return str(expr)
    else:
        print(expr)
        raise ValueError(f"Unsupported expression type: {type(expr)}")

In [46]:
INPUT_PATH = '/kaggle/input/gsoc-symba-task/Feynman_with_units/Feynman_with_units/'
test_file_paths = [
    'I.6.2a',
    'I.12.5', 
    'I.18.14',
    'I.39.1',
    'I.43.16',
    'I.43.31',
    'II.4.23',
    'II.21.32',
    'II.35.21',
    'II.38.3'
]

In [47]:
model = Model_seq2seq(num_encoder_layers=2,
              num_decoder_layers=6,
              emb_size=64,
              nhead=8,
              src_vocab_size=1104,
              tgt_vocab_size=59,
              input_emb_size=64,
              max_input_points=33,
              )

reference_model = Model_seq2seq(num_encoder_layers=2,
              num_decoder_layers=6,
              emb_size=64,
              nhead=8,
              src_vocab_size=1104,
              tgt_vocab_size=59,
              input_emb_size=64,
              max_input_points=33,
              )
path = '/kaggle/input/gsoc-symba-seq2seq/default/best_checkpoint.pth'
model.load_state_dict(torch.load(path)["state_dict"])
reference_model.load_state_dict(torch.load(path)["state_dict"])

<All keys matched successfully>

In [48]:
def generate_seed_expressions(model, dataset, indices, device, decoder_tokenizer):
    seed_expr = []
    for i in indices:
        src = dataset[file_index*1000 + i][0].unsqueeze(0).to(device)
        src_seq_len = src.shape[1]
        model = model.to(device)
        src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)
        src_padding_mask = (torch.zeros((src.shape[0], src_seq_len), device=device)).type(torch.bool)
        eq = model.beam_search(src, src_mask, src_padding_mask)

        for j in range(0, 3):
            try:
                b = decoder_tokenizer.equation_decoder([decoder_tokenizer.decode((eq[0][0].to('cpu')).numpy())[0][1:-1]])[0]
                b = convert_to_functional_form(b)
                seed_expr.append(b)
            except:
                continue

    protected = {
        'exp': 'protected_exp',
        'div': 'protected_div',
        'sqrt': 'protected_sqrt',
        'pow': 'protected_pow'
    }

    for i in range(len(seed_expr)):
        for a in protected:
            seed_expr[i] = re.sub(a, protected[a], seed_expr[i])
    
    return seed_expr

In [68]:
def logabs(x1):
    if x1 == 0:
        return 1
    return math.log(abs(x1))

def n3(x1):
    return x1 ** 3

def n4(x1):
    return x1 ** 4

def protected_div(x1, x2):
    try:
        return x1 / x2 if abs(x2) > 0.001 else 1.
    except ZeroDivisionError:
        return 1.

def protected_exp(x1):
    try:
        return math.exp(x1) if x1 < 100 else 0.0
    except OverflowError:
        return 0.0

def protected_log(x1):
    try:
        return math.log(abs(x1)) if abs(x1) > 0.001 else 0.
    except ValueError:
        return 0.

def protected_sqrt(x1):
    return math.sqrt(abs(x1))

def protected_inv(x1):
    try:
        return 1. / x1 if abs(x1) > 0.001 else 0.
    except ZeroDivisionError:
        return 0.

def protected_expneg(x1):
    try:
        return math.exp(-x1) if x1 > -100 else 0.0
    except OverflowError:
        return 0.0

def protected_n2(x1):
    return x1 ** 2 if abs(x1) < 1e6 else 0.0

def protected_n3(x1):
    return x1 ** 3 if abs(x1) < 1e6 else 0.0

def protected_n4(x1):
    return x1 ** 4 if abs(x1) < 1e6 else 0.0

def protected_pow(x1,x2):
    try:
        a = math.pow(x1,x2)
        return a
    except:
        return 1e7

def protected_sigmoid(x1):
    return 1 / (1 + protected_expneg(x1))

In [49]:
def make_pset(num_var):
    pset = gp.PrimitiveSet("MAIN", num_var)  # Assuming a single input variable, adjust as needed

    # Add unprotected primitive operations
    pset.addPrimitive(operator.add, 2)
    pset.addPrimitive(operator.sub, 2)
    pset.addPrimitive(operator.mul, 2)
    # Do not add the unprotected division, use protected one instead

    pset.addPrimitive(math.sin, 1)
    pset.addPrimitive(math.cos, 1)
    pset.addPrimitive(math.tan, 1)
    # Do not add the unprotected exp and log, use protected ones instead
    # Do not add the unprotected sqrt, use protected one instead
    pset.addPrimitive(operator.neg, 1)
    pset.addPrimitive(abs, 1)
    pset.addPrimitive(max, 2)
    pset.addPrimitive(min, 2)
    pset.addPrimitive(math.tanh, 1)
    
    for i in range(1,11):
        pset.addTerminal(i)

    # Add protected primitive operations
    pset.addPrimitive(protected_div, 2)
    pset.addPrimitive(protected_pow, 2)
    pset.addPrimitive(protected_exp, 1)
    pset.addPrimitive(protected_log, 1)
    pset.addPrimitive(protected_sqrt, 1)
    pset.addTerminal(math.pi, name="pi")
    
    rename_kwargs = {"ARG{}".format(i): f"s_{i+1}" for i in range(0,num_var)}
    
    pset.renameArguments(**rename_kwargs)
    
    return pset 

In [50]:
def evalSymbReg(individual, points,toolbox):
    
    func = toolbox.compile(expr=individual)
    sqerrors = ((((func(*x) - y)**2)/len(points)) for x, y in points)
    
    return math.fsum(sqerrors),

def e_lexicase_selection(individuals, k, points):
    selected = []
    for _ in range(k):
        remaining = individuals[:]
        random.shuffle(points)  # Shuffle the test cases
        for point in points:
            errors = [abs(evalSymbReg(ind, [point])[0]) for ind in remaining]
            min_error = min(errors)
            remaining = [ind for ind, error in zip(remaining, errors) if error == min_error]
            if len(remaining) == 1:
                break
        selected.append(random.choice(remaining))
    return selected

    # Seed population with predefined solutions
def seed_population(pop_size,seed_exprs,pset,toolbox):
    population = []
    count = 0
    for expr in seed_exprs:
#         print(expr)
        try :
            ind = creator.Individual.from_string(expr, pset)
            count += 1
            population.append(ind)
        except :
            continue
    print("The count of valid seed_expressions:- ")
    print(len(seed_exprs),count)       
    for _ in range(pop_size - count):
        ind = toolbox.individual()
        population.append(ind)
    return population

In [51]:
def run_genetic_programming(population, toolbox, ngen, cxpb, mutpb):
    for ind in population:
        ind.fitness.values = toolbox.evaluate(ind)

    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    population, log = algorithms.eaSimple(population, toolbox, cxpb, mutpb, ngen, stats=stats, halloffame=hof, verbose=True)
    return population, stats, hof

In [52]:
def generate_preference_pairs(population, points):
    pairs = []
    for i in range(len(population)):
        for j in range(i + 1, len(population)):
            ind1 = population[i]
            ind2 = population[j]
            if evalSymbReg(ind1, points) < evalSymbReg(ind2, points):
                pairs.append((ind1, ind2))
            elif evalSymbReg(ind1, points) > evalSymbReg(ind2, points):
                pairs.append((ind2, ind1))
    return pairs

In [53]:
class PreferenceDataset(Dataset):
    def __init__(self, preference_pairs, src):
        self.pairs = preference_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        better, worse = self.pairs[idx]
        better_expr = sympy_to_prefix(convert_to_sympy_expression(better))
        better_expr = decoder_tokenizer.encode([['<bos>'] + better_expr + ['<eos>']])[0]
        worse_expr = sympy_to_prefix(convert_to_sympy_expression(worse))
        worse_expr = decoder_tokenizer.encode([['<bos>'] + worse_expr + ['<eos>']])[0]
        return better_expr, worse_expr

In [54]:
def dpo_loss(pi_logps, ref_logps, beta):
    
    """
    Calculate the DPO loss using the log probabilities of the better and worse expressions.
    
    Parameters:
    - pi_logps: Log probabilities from the policy model.
    - ref_logps: Log probabilities from the reference model.
    - beta: Temperature parameter for the KL divergence penalty.
    
    Returns:
    - losses: Calculated losses.
    - rewards: Reward values.
    """
    
    pi_logratios = pi_logps[:, 0] - pi_logps[:, 1]
    ref_logratios = ref_logps[:, 0] - ref_logps[:, 1]
    losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
    rewards = beta * (pi_logps - ref_logps).detach()
    return losses, rewards

In [55]:
def generate_square_subsequent_mask(sz, device):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [56]:
PAD_IDX = 0
def train_transformer_dpo(model, reference_model, preference_pairs, src, epochs=10, batch_size=32, lr=1e-4, beta=1.0, device='cuda'):
    dataset = PreferenceDataset(preference_pairs, src)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = model.to(device)
    reference_model = reference_model.to(device)
    reference_model.eval()
    src_mask = torch.zeros((src.shape[1], src.shape[1]), device=device).type(torch.bool)
    src_padding_mask = (torch.zeros((src.shape[0], src.shape[1]), device=device)).type(torch.bool)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for better, worse in dataloader:
            optimizer.zero_grad()
            
            # Stack better and worse probabilities
            better_mask = generate_square_subsequent_mask(better.shape[1], device)
            worse_mask = generate_square_subsequent_mask(worse.shape[1], device)
            src_mask = torch.zeros((src.shape[1], src.shape[1]), device=device).type(torch.bool)
            
            src_padding_mask = (torch.zeros((src.shape[0], src.shape[1]), device=device)).type(torch.bool)
            better_padding_mask = (better == PAD_IDX)
            worse_padding_mask = (worse == PAD_IDX)
            
            pi_logps_better = model(src,
                                    better,
                                    src_mask,
                                    better_mask,
                                    src_padding_mask,
                                    better_padding_mask,
                                    src_padding_mask).log_softmax(dim=-1)
            pi_logps_worse = model(src,
                                    worse,
                                    src_mask,
                                    worse_mask,
                                    src_padding_mask,
                                    worse_padding_mask,
                                    src_padding_mask).log_softmax(dim=-1)
            ref_logps_better = reference_model(src,
                                    better,
                                    src_mask,
                                    better_mask,
                                    src_padding_mask,
                                    better_padding_mask,
                                    src_padding_mask).log_softmax(dim=-1)
            ref_logps_worse = reference_model(src,
                                    worse,
                                    src_mask,
                                    better_mask,
                                    src_padding_mask,
                                    worse_padding_mask,
                                    src_padding_mask).log_softmax(dim=-1)

            pi_logps = torch.stack([pi_logps_better, pi_logps_worse], dim=1)
            ref_logps = torch.stack([ref_logps_better, ref_logps_worse], dim=1)

            # Compute DPO loss
            losses, _ = dpo_loss(pi_logps, ref_logps, beta)
            loss = losses.mean()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(dataloader)}")

In [75]:
def main_training_loop(model, reference_model, dataset, src, points, decoder_tokenizer, num_vars, num_cycles=5, pop_size=100, ngen=20, cxpb=0.5, mutpb=0.2, beta=1.0, device='cuda'):
    random_numbers = [random.randint(0, 999) for _ in range(25)]

    for cycle in range(num_cycles):
        print(f"Cycle {cycle+1}/{num_cycles}")

        # Generate seed expressions
        seed_expr = generate_seed_expressions(model, dataset, random_numbers, device, decoder_tokenizer)
        
        # Create the primitive set
        pset = make_pset(num_vars)

        # Set up the genetic programming toolbox
        toolbox = base.Toolbox()
        toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=2)
        toolbox.register("individual", tools.initIterate, creator.Individual, toolbox.expr)
        toolbox.register("population", seed_population)
        toolbox.register("compile", gp.compile, pset=pset)
        toolbox.register("evaluate", evalSymbReg, points=points)
        toolbox.register("select", lambda individuals, k: e_lexicase_selection(individuals, k, points))
        toolbox.register("mate", gp.cxOnePoint)
        toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr, pset=pset)
        toolbox.register("map", map)
        
        # Seed the population
        population = seed_population(pop_size, seed_expr, pset,toolbox)
        
        # Run genetic programming
        population, stats, hof = run_genetic_programming(population, toolbox, ngen, cxpb, mutpb)
        
        # Generate preference pairs
        preference_pairs = generate_preference_pairs(population, points)
        
        # Train the Transformer model using DPO
        train_transformer_dpo(model, reference_model, preference_pairs, src, device=device, beta=beta)

        print("Best individual:", hof[0])
        print("Fitness:", hof[0].fitness.values)

    return model, population, stats, hof

In [76]:
file_index = 0
test_path = INPUT_DIR + '/' + test_file_paths[file_index]

In [77]:
with open(test_path) as file:
    data = file.readlines()
arr = np.array([i.split() for i in data], dtype=np.float32)

points = []
for i in arr:
    count = 0
    temp = []
    for j in i[0:-1]:
        count += 1
        temp.append(j)
    num_vars = count
    points.append((temp,i[-1]))

In [None]:
src = datasets['test'][file_index*1000 + i][0].unsqueeze(0).to(device)

In [None]:
final_model, final_population, final_stats, final_hof = main_training_loop(model, reference_model, datasets['test'], src, points, decoder_tokenizer, num_vars, num_cycles=5, pop_size=100, ngen=20, cxpb=0.5, mutpb=0.2, beta=1.0, device='cuda')

In [None]:
print("Final Best individual:", final_hof[0])
print("Final Fitness:", final_hof[0].fitness.values)