# Data Generator
Created May 2023 HBP<br>

In [9]:
import sys

try:
    
    from google.colab import drive
    drive.mount('/content/gdrive')
    
    BASE = '/content/gdrive/My Drive/transformer'
    sys.path.append(BASE)
    
    def pathname(filename):
        return f'{BASE:s}/{filename:s}'
    
    print('\nIn Google Colab\n')
    
except:
    
    BASE = ''
    def pathname(filename):
        return filename
    
    print('\nRunning locally\n')

import re
import numpy as np
import random as rn
import torch
from tqdm import tqdm

from IPython.display import display

from sympy import symbols, series, simplify, Rational, UnevaluatedExpr, \
    exp, cos, sin, tan, cosh, sinh, tanh, E, sympify
from sympy.abc import x
a,b,c,d,f,g,h,O = symbols('a,b,c,d,f,g,h,O', real=True)


Running locally



In [2]:
#2**19937-1
0.30103*19937

6001.63511

## Utilities

In [15]:
# pretty print symbolic expression
def pprint(expr):
    display(sympify(expr))

In [4]:
filename='seq2seq_data_2terms.txt'
N       = 50000
    
rnd= rn.randint
OP = ['+', '-', '/', '*'] 
EX = ['**2', '**3']
FN = ['exp', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh']
SC = ['a', 'b', 'c', 'd', 'f', 'g', 'h']

datafile= pathname(filename)

records = []

for n in tqdm(range(N)):
    
    expr = ''
    K = rnd(1, 2)
    
    for k in range(K):

        # choose function
        fn = rn.choice(FN)

        # choose coefficient of x
        nm = rn.choice(SC)

        arg= f'{nm:s}*x'

        expr += f'{fn:s}({arg:s})'

        # choose exponent of function
        u = rn.uniform(0, 1)
        if u < 0.2:
            expr += rn.choice(EX)

        # choose operator
        expr += rn.choice(OP)

    expr   = expr[:-1]
    s_expr = str(simplify(expr))

    # expand in Taylor series
    try:
        cmd    = f'y_expr=series({s_expr:s},x)'
        exec(cmd)
    except:
        continue
        
    y_expr = str(y_expr)

    if y_expr.find('x') < 0: continue

    s_expr = s_expr.replace(' ', '')
    y_expr = y_expr.replace(' ', '')
        
    expr   = f'{s_expr:s}|{y_expr:s}'
    line   = '%s\n' % expr
    records.append(line)

    open(datafile, 'a').write(line)

print()

100%|█████████████████████████████████████| 50000/50000 [40:18<00:00, 20.67it/s]







Ensure terms are ordered in increasing powers of $x$.

In [16]:
filename ='seq2seq_data_2terms.txt'
LOAD = True

if LOAD:
    records = open(filename).readlines()

# remove duplicate lines
print(len(records))
recs = list(set(records))
print(len(recs))

data = [x.strip().split('|') for x in recs]
getx = re.compile('x[*][*][1-9]|x')

records = []
j = 0
for i, (src, trg) in enumerate(data):
    if src.find('n*') > 0: continue
        
    trg = [str(x) for x in sympify(trg).as_ordered_terms()]
    
    Ord = trg[-1]    # O(x**6)
    trg = trg[:-1]   # all terms except O(x**6)
    if len(trg) == 0: continue
    
    xp  = [getx.findall(x) for x in trg]
    tt  = list(zip(xp, trg))
    tt.sort()
    
    xp, trg = zip(*tt)
    trg = '+'.join(trg)+'+'+Ord
    trg = trg.replace('+-', '-').replace(' ', '')
    record = f'{src:s}|{trg:s}'
            
    j += 1
    if j % 1000 == 0:
        print(f'\r{j:10d}', end='')
        
    if j <= 5:
        print(record)
        pprint(src)
        print('-'*20)
        pprint(trg)
        print('-'*60)
        
    records.append(f'{record:s}\n')
    
print()
print(f'saving {len(records):d} lines...')
open('seq2seq_series_2terms.txt', 'w').writelines(records)
print('done!')

51180
14635
cosh(a*x)**3+tanh(b*x)|1+b*x+3*a**2*x**2/2-b**3*x**3/3+7*a**4*x**4/8+2*b**5*x**5/15+O(x**6)


cosh(a*x)**3 + tanh(b*x)

--------------------


1 + b*x - b**3*x**3/3 + 2*b**5*x**5/15 + 3*a**2*x**2/2 + 7*a**4*x**4/8 + O(x**6)

------------------------------------------------------------
exp(b*x)*tanh(f*x)**3|f**3*x**3+b*f**3*x**4+x**5*(b**2*f**3/2-f**5)+O(x**6)


exp(b*x)*tanh(f*x)**3

--------------------


x**5*(b**2*f**3/2 - f**5) + f**3*x**3 + b*f**3*x**4 + O(x**6)

------------------------------------------------------------
sin(b*x)**3/tanh(c*x)**3|b**3/c**3+x**2*(-b**5/(2*c**3)+b**3/c)+x**4*(13*b**7/(120*c**3)-b**5/(2*c)+4*b**3*c/15)+O(x**6)


sin(b*x)**3/tanh(c*x)**3

--------------------


x**2*(-b**5/(2*c**3) + b**3/c) + x**4*(13*b**7/(120*c**3) - b**5/(2*c) + 4*b**3*c/15) + b**3/c**3 + O(x**6)

------------------------------------------------------------
exp(3*h*x)+tanh(h*x)|1+4*h*x+9*h**2*x**2/2+25*h**3*x**3/6+27*h**4*x**4/8+259*h**5*x**5/120+O(x**6)


exp(3*h*x) + tanh(h*x)

--------------------


1 + 4*h*x + 9*h**2*x**2/2 + 25*h**3*x**3/6 + 27*h**4*x**4/8 + 259*h**5*x**5/120 + O(x**6)

------------------------------------------------------------
cos(d*x)*tanh(g*x)|g*x+x**3*(-d**2*g/2-g**3/3)+x**5*(d**4*g/24+d**2*g**3/6+2*g**5/15)+O(x**6)


cos(d*x)*tanh(g*x)

--------------------


x**3*(-d**2*g/2 - g**3/3) + x**5*(d**4*g/24 + d**2*g**3/6 + 2*g**5/15) + g*x + O(x**6)

------------------------------------------------------------
     14000
saving 14367 lines...
done!
