# seq2seq: Data Preparation

Quarks To Cosmos with AI Virtual Conference: July 12-16, 2021, Carnegie Mellon University

## Contributors

Abdulhakim Alnuqaydan, Ali Kadhim, Sergei Gleyzer, Harrison Prosper

## Hackathon Contributors

Andrew Roberts, Jessica Howard, Samuel Hori, Arvind Balasubramanian, Xiaosheng Zhao, Michael Andrews

July 2021

## Description

Use an encoder/decoder model built using LSTMs to map symbolic mathematical expressions $f(x)$ to their Taylor series expansions to ${\cal O}(x^5)$.

We've heavily borrowed from Charon Guo's excellent tutorial at:

https://charon.me/posts/pytorch/pytorch_seq2seq_1/

### Data Preparation

This notebook performs the following tasks:
  1. Read the sequence pairs from __data/seq2seq_data.txt__.
  1. Exclude
     1. sequences with complex numbers and with Taylor series expansions longer than 1000 characters.
     1. trivial source expressions.
  1. Write the filtered sequences to __data/seq2seq_data_count.txt__, where count is either 10,000 or 60,000 sequences.
  1. Write out __seq2sequtil.py__.
  1. Read filtered data and delimit source (i.e, input) and target (i.e., output) sequences with a tab and newline at the start and end of each sequence, respectively.

In [1]:
import re
import sympy as sp
import numpy as np
import torch

# symbolic mathematics
from sympy import exp, \
    cos, sin, tan, \
    cosh, sinh, tanh, ln, log, E
x = sp.Symbol('x')

from IPython.display import display
    
# enable pretty printing of equations
sp.init_printing(use_latex='mathjax')

In [2]:
# Reproducing a subtle bug!
data = [['a(4*x)','b(x)'], ['1(exp)','2'], ['3(x)','4']]
non_trivial = re.compile(r'(a|3)'\
                         '[(].*\bx\b')
# eliminate expressions that do not involve x, exp, cos etc.
data = filter(lambda d: len(non_trivial.findall(d[0])) > 0, data)
data = list(data)
data

[]

### include non-trival expressions involving 'x' only:  
1. $(f(x)\pm g(x))^{h(x)}$
2. $\frac{h(x)}{(f(x)\pm g(x))}$

In [3]:
data = [['a(4*x)','b(x)'], ['1(exp(x))','2'], ['3(x+5)','4'],
        ['sin(5*x)','5'],['x**5','x*5'],['(1+x)**6','4'],
        ['(1+x)/(2-x)','4'],['(1+x)/(2+1)','4']]

non_trivial = re.compile(r'(exp|cos|sin|tan|ln|log|cosh|sinh|tanh)[(].*\bx\b|'\
                         r'[(].*[+-].*\bx\b.*[)][*][*]|'\
                         r'.*/[(].*[+-]\bx\b.*')

data = filter(lambda d: len(non_trivial.findall(d[0])) > 0, data)
data = list(data)
data

[['1(exp(x))', '2'],
 ['sin(5*x)', '5'],
 ['(1+x)**6', '4'],
 ['(1+x)/(2-x)', '4']]

### Filter sequences

In [4]:
%%time
of_order    = re.compile(' [+] O[(]x[*][*]5.*[)]')
# Please note that breaking a raw string does not propagate its 
# "rawness" across the break :(
non_trivial = re.compile(r'(exp|cos|sin|tan|ln|log|cosh|sinh|tanh)'\
                         r'[(].*\bx\b')
'''
non_trivial = re.compile(r'(exp|cos|sin|tan|ln|log|cosh|sinh|tanh)[(].*\bx\b|'\
                         r'[(].*[+-].*\bx\b.*[)][*][*]|'\
                         r'.*/[(].*[+-]\bx\b.*')
'''
add_count   = re.compile('_data')

def filterData(inpfile='data/seq2seq_data.txt',
               num_seq=60000,  # number of filtered sequence pairs
               min_len=5,      # minimum length of a sequence
               max_len=1000):  # maximum length of a sequence
    
    data = open(inpfile).readlines()
    
    # eliminate expressions involving complex numbers
    data = filter(lambda d: d.find('I') < 0, data)
    data = list(data)

    # strip away O(...) (of order..)
    data = [of_order.sub('', d) for d in data]
 
    # split pairs at tab
    data = [ d.split('\t') for d in data ]
    #print(data[:5])
    
    # keep source expressions that involve exp, cos, etc.
    # that is, eliminate trivial expressions.
    data = filter(lambda d: len(non_trivial.findall(d[0])) > 0, data)
    data = list(data)
 
    # keep expressions that are >= min_len characters long
    data = filter(lambda d: 
                  (len(d[0]) >= min_len) and (len(d[1]) >= min_len),
                  data)
    data = list(data)
                  
    # keep expressions that are <= max_len characters long
    data = filter(lambda d: 
                  (len(d[0]) <= max_len) and (len(d[1]) <= max_len), 
                  data)
    data = list(data)
    
    N = min(num_seq, len(data))
    #print(len(data))
    
    # simplify the expressions #### Xiaosheng
    # Cancel common factors in the numerator and denominator.
    #data=[[''.join(str(sp.cancel(d[0])).split()),str(sp.cancel(d[1]))+'\n'] for d in data[:N]] #2.6s
    # Simplify only the trigonometric parts of the expression.
    #data=[[''.join(str(sp.trigsimp(d[0])).split()),str(sp.trigsimp(d[1]))+'\n'] for d in data[:N]] #23.6s
    # Simplify an expression.
    data=[[''.join(str(sp.simplify(d[0])).split()),str(sp.simplify(d[1]))+'\n'] for d in data[:N]] #36.2s
    
    outfile = add_count.sub('_data_%d' % N, inpfile)
    print('output file:', outfile)
    
    data = ['\t'.join(d) for d in data]
    open(outfile, 'w').writelines(data[:N])
    
filterData(num_seq=100)
#filterData(num_seq=10000)
#filterData(num_seq=60000)

output file: data/seq2seq_data_100.txt
CPU times: user 59.5 s, sys: 54.8 ms, total: 59.5 s
Wall time: 59.5 s


### Map sequences to lists of indices

  1. Split data into a train, validation, and test set.
  1. Create a token (i.e., a character) to index map from training data.
  1. Map sequences to arrays of indices.
  1. Implement custom DataLoader.

In [5]:
%%writefile seq2sequtil.py
import numpy as np
import re
import torch
from IPython.display import display

# symbolic symbols
from sympy import Symbol, exp, \
    cos, sin, tan, \
    cosh, sinh, tanh, ln, log, E
x = Symbol('x')

class Seq2SeqDataPreparer:
    '''
    This class maps the source (i.e., input) and target (i.e, output) 
    sequences of characters into sequences of indices. The source data 
    are split into x_train, x_valid, and x_test sets and similarly for 
    the target data.
    
    Create a data preparer using
    
    dd = Seq2SeqDataPreparer(X, Y, fractions)
    
    where,

      fractions:    a 2-tuple containing the three-way split of data.
                    e.g.: (50/60, 55/60) means split the data as follows
                    (50000, 5000, 5000)
    '''
    def __init__(self, X, Y,
                 fractions=[50/60, 55/60]): 
        
        self.fractions = fractions
        
        # Get maximum sequence length for input expressions
        #self.x_max_seq_len =  max([len(z) for z in X])
        
        self.x_max_seq_len =  max([self.split_expr(z)[1] for z in X])
        
        # Get maximum sequence length for target expressions
        #self.y_max_seq_len =  max([len(z) for z in Y])
        
        self.y_max_seq_len =  max([self.split_expr(z)[1] for z in Y])
        
        # get length of splits into train, valid, test
        N = int(len(X)*fractions[0])
        M = int(len(X)*fractions[1])
        
        # Create token to index map for source sequences
        t = self.token_tofrom_index(X[:N])
        self.x_token2index, self.x_index2token = t
        
        # Create token to index map for target sequences
        t = self.token_tofrom_index(Y[:N])
        self.y_token2index,self.y_index2token = t
        
        # Structure data into a list of blocks, where each block
        # comprises a tuple (x_data, y_data) whose elements have
        #   x_data.shape: (x_seq_len, batch_size)
        #   y_data.shape: (y_seq_len, batch_size)
        #
        # The sequence and batch sizes can vary from block to block.
        
        self.train_data, self.n_train = self.code_data(X[:N], Y[:N])         
        self.valid_data, self.n_valid = self.code_data(X[N:M],Y[N:M])
        self.test_data,  self.n_test  = self.code_data(X[M:], Y[M:])

    def __del__(self):
        pass
    
    def __len__(self):
        n  = 0
        n += self.n_train
        n += self.n_valid
        n += self.n_test
        return n
    
    def __str__(self):
        s  = ''
        s += 'number of seq-pairs (train): %8d\n' % self.n_train
        s += 'number of seq-pairs (valid): %8d\n' % self.n_valid
        s += 'number of seq-pairs (test):  %8d\n' % self.n_test
        s += '\n'
        s += 'number of source tokens:     %8d\n' % \
        len(self.x_token2index)
        s += 'max source sequence length:  %8d\n' % \
        self.x_max_seq_len
        
        try:
            s += '\n'
            s += 'number of target tokens:     %8d\n' % \
            len(self.y_token2index)
            s += 'max target sequence length:  %8d' % \
            self.y_max_seq_len
        except:
            pass

        return s
         
    def num_tokens(self, which='source'):
        if which[0] in ['s', 'i']:
            return len(self.x_token2index)
        else:
            return len(self.y_token2index)
    
    def max_seq_len(self, which='source'):
        if which[0] in ['s', 'i']:
            return self.x_max_seq_len
        else:
            return self.y_max_seq_len
        
    def decode(self, indices):
        # map list of indices to a list of tokens
        return ''.join([self.y_index2token[i] for i in indices])

    #def token_tofrom_index(self, expressions):
    #    chars = set()
    #    chars.add(' ')  # for padding
    #    chars.add('?')  # for unknown characters
    #    for expression in expressions:
    #        for char in expression:
    #            chars.add(char)
    #    chars = sorted(list(chars))
    #    
    #    char2index = dict([(char, i) for i, char in enumerate(chars)])
    #    index2char = dict([(i, char) for i, char in enumerate(chars)])
    #    return (char2index, index2char)
    
    def token_tofrom_index(self,expressions,special_expressions_str=['exp','cos','sin','tan','ln','log','cosh','sinh','tanh','**']):
        chars = set()
        chars.add(' ')  # for padding
        chars.add('?')  # for unknown characters
        for expression in expressions:
            # Map special function characters into single index eg: sin --> 23
            regex_str = '|'.join(special_expressions_str)
            regex_str = regex_str.replace('**','\*\*')
            regex = re.compile(r'\b('+regex_str+r')\b')
            for spec_exp in special_expressions_str:
                chars.add(spec_exp)
            new_expression = regex.sub('',expression)
            for n_exp in new_expression:
                chars.add(n_exp)
        chars = sorted(list(chars))
        
        char2index = dict([(char, i) for i, char in enumerate(chars)])
        index2char = dict([(i, char) for i, char in enumerate(chars)])
        return (char2index, index2char)
       
    def get_block_indices(self, X, Y):
        # X, and Y are just arrays of strings.
        #
        # 1. Following Michael Andrews' suggestion double sort 
        #    expressions, first with targets then sources. But, also
        #    note the ordinal values "i" of the expressions in X, Y.
        
        #sizes = [(len(a), len(b), i) 
        #         for i, (a, b) in enumerate(zip(Y, X))]
        
        sizes = [(self.split_expr(a)[1], self.split_expr(b)[1], i) ##### Arvind
                 for i, (a, b) in enumerate(zip(Y, X))]
        sizes.sort()
  
        # 2. Find ordinal values (indices) of all expression pairs 
        #    for which the sources are the same length and the
        #    targets are the same length. In general, the sources and
        #    targets differ in length.
     
        block_indices = []
        n, m, i  = sizes[0] # n, m, i = len(target), len(source), index
        previous = (n, m)
        indices  = [i] # cache index of first expression
        
        for n, m, i in sizes[1:]: # skip first expression
            
            size = (n, m)
            
            if size == previous:
                indices.append(i) # cache index of expression
            else:
                # found a new boundary, so save previous 
                # set of indices...
                block_indices.append(indices)
                
                # ...and start a new list of indices
                indices = [i]

            previous = size
            
        # cache expression indices of last block
        block_indices.append(indices)
        
        return block_indices
    
    
    def make_block(self, expressions, indices, token2index, unknown):
        
        # batch size of current block
        batch_size = len(indices)
        
        # By construction, all expressions of a block have 
        # the same length, so can use the length of first expression
        #seq_len = len(expressions[indices[0]])
        
        seq_len = self.split_expr(expressions[indices[0]])[1] ### Arvind
        
        # Create an empty block of correct shape and size
        data    = np.zeros((seq_len, batch_size), dtype='long')
        #print('seq_len, batch_size: (%d, %d)' % (seq_len, batch_size))
        
        # loop over expressions for current block
        # m: ordinal value of expression in current block
        # k: ordinal value of expression in original list of expressions
        # n: ordinal value of character in a given expression
        
        for m, k in enumerate(indices):
            
            expr = expressions[k]
            
            #print('%5d expr[%d] | %s |' % (m, k, expr[1:-1]))
            
            # copy coded characters to 2D arrays
            
            reduced_expr = self.split_expr(expr)[0] #### Arvind
            
            #for n, char in enumerate(expr):
            for n, char in enumerate(reduced_expr): ###Arvind
                #print('\t\t(n, m): (%d, %d)' % (n, m))
                try:
                    data[n, m] = token2index[char]
                except:
                    data[n, m] = unknown
                    
        return data
    
    def code_data(self, X, Y):
        # Implement Arvind's idea
        
        # X, Y consist of delimited strings: 
        #   \tab<characters\newline
        
        # loop over sequence pairs and convert them to sequences
        # of integers using the two token2index maps
      
        x_space   = self.x_token2index[' ']
        x_unknown = self.x_token2index['?']
        
        y_space   = self.y_token2index[' ']
        y_unknown = self.y_token2index['?']
 
        # 1. Get blocks containing sequences of the same length.
        
        block_indices = self.get_block_indices(X, Y)
        
        # 2. Loop over the indices associated with each block of coded
        #    sequences. The indices are the ordinal values of the
        #    sequence pairs X and Y.
        
        blocks = []
        n_data = 0
       
        for indices in block_indices:

            x_data = self.make_block(X, indices, 
                                     self.x_token2index, x_unknown)
 
            y_data = self.make_block(Y, indices, 
                                     self.y_token2index, y_unknown)

            blocks.append((x_data, y_data))
            n = len(indices)
            n_data += n
        
        assert n_data == len(X)
        
        return blocks, n_data
    
    def split_expr(self, expression, individual_tokens=['exp','cos','sin','tan','ln','log','cosh','sinh','tanh','**'\
                           ,'1','2','3','4','5','6','7','8','9','0','x','\n','\t']):
    
    #Returns a split expression (into individual character groups that are to be passed on
    #to the token_tofrom_index function to map to indices). Also returns length of final
    #expression (counting len(sin) as 1) to be used for caluclating len of expressions 
    #during block generation.

    
        regex_str = '|'.join(individual_tokens)
        regex_str = regex_str.replace('**','\*\*')
        regex = re.compile(r'('+regex_str+r'|[()+-/*])')
        reduced_expression_chars = re.split(regex,expression)
        while '' in reduced_expression_chars:
            reduced_expression_chars.remove('')
        return reduced_expression_chars, len(reduced_expression_chars)
    
class Seq2SeqDataLoader:
    '''
    dataloader = Seq2seqDataLoader(dataset, device, sample=True)    
    '''
    def __init__(self, dataset, device, sample=True):
        self.dataset = dataset
        self.device  = device
        self.sample  = sample  
        self.init()

    def __iter__(self):
        return self
    
    def __next__(self):
        
        # increment iteration counter
        self.count += 1
        
        if self.count <= self.max_count:
            
            # 1. randomly pick a block or return blocks in order.
            if self.sample:
                k = np.random.randint(len(self.dataset))
            else:
                k = self.count-1 # must subtract one!
            
            # 2. create tensors directly on the device of interest
            X = torch.tensor(self.dataset[k][0], 
                             device=self.device)
            
            Y = torch.tensor(self.dataset[k][1], 
                             device=self.device)
        
            # shape of X and Y: (seq_len, batch_size)
            return X, Y
        else:
            self.count = 0
            raise StopIteration
            
    def init(self, max_count=0, sample=True):
        n_data = len(self.dataset)
        self.max_count = n_data if max_count < 1 else min(max_count, 
                                                          n_data)
        self.sample= sample
        self.count = 0
        
# Delimit each sequence in filtered sequences
# The start of sequence (SOS) and end of sequence (EOS) 
# tokens are "\t" and "\n", respectively.

def loadData(inpfile):
    # format of data:
    # input expression<tab>target expression<newline>
    data = [a.split('\t') for a in open(inpfile).readlines()]
    
    X, Y = [], []
    for i, (x, y) in enumerate(data):
        X.append('\t%s\n' % x)
        # get rid of spaces in target sequence
        y = ''.join(y.split())
        Y.append('\t%s\n' % y)
        
    print('Example source:')
    print(X[-1])
    pprint(X[-1])
    print('Example target:')
    print(Y[-1])
    pprint(Y[-1])

    return (X, Y)

def loadData_simplify(inpfile):
    # format of data:  ##### Xiaosheng
    # input expression<tab>target expression<newline>
    data = [a.split('\t') for a in open(inpfile).readlines()]
    
    X, Y = [], []
    for i, (x, y) in enumerate(data):
        X.append('\t%s\n' % x)
        # get rid of spaces in target sequence
        y = ''.join(y.split())
        Y.append('\t%s\n' % y)
        
    for ii in range(-10,-1):
        #simplify the expressions
        X_simp=sp.simplify(X[ii])
        targets_simp=sp.simplify(Y[ii])

        #print the sources
        print('Example source:')
        print (X[ii])
        pprint (X[ii])
        print('\033[1;34m'+'Simplified example source:'+'\033[0m')
        print (X_simp)
        pprint (str(X_simp))

        #print the targets
        print('Example target:')
        print (Y[ii])
        pprint (Y[ii])
        print('\033[1;34m'+ 'Simplified example target:'+'\033[0m')
        print (targets_simp)
        pprint (str(targets_simp))
        
        print ('\033[1;31m'+'*'*100+'\033[0m')
        
    return (X, Y)

def pprint(expr):
    display(eval(expr))

Overwriting seq2sequtil.py


#### Display a few sequence pairs

In [6]:
import seq2sequtil as sq
import importlib
importlib.reload(sq)
#inputs, targets = sq.loadData('data/seq2seq_data_10000.txt')
inputs, targets = sq.loadData('data/seq2seq_data_100.txt')

Example source:
	cosh(9*x-7)



cosh(9⋅x - 7)

Example target:
	2187*x**4*cosh(7)/8-243*x**3*sinh(7)/2+81*x**2*cosh(7)/2-9*x*sinh(7)+cosh(7)



      4                3               2                                
2187⋅x ⋅cosh(7)   243⋅x ⋅sinh(7)   81⋅x ⋅cosh(7)                        
─────────────── - ────────────── + ───────────── - 9⋅x⋅sinh(7) + cosh(7)
       8                2                2                              

### Check data preparer

In [7]:
fractions=[8/10, 9/10]
db = sq.Seq2SeqDataPreparer(inputs, targets, fractions)
print(db)

number of seq-pairs (train):       80
number of seq-pairs (valid):       10
number of seq-pairs (test):        10

number of source tokens:           31
max source sequence length:        74

number of target tokens:           32
max target sequence length:      3881


### Test tokenization

In [8]:
print(db.x_index2token)
print(db.y_index2token)
test_expression = '\texp(-1)-5*x*exp(-1)+25*x**2*exp(-1)/2-125*x**3*exp(-1)/6+625*x**4*exp(-1)/24\n'

# Code the data
data = db.make_block([test_expression],[0],db.y_token2index,db.y_token2index['?'])
print(data[:,0])
# Test by decoding the coded data
db.decode(data[:,0])

{0: '\t', 1: '\n', 2: ' ', 3: '(', 4: ')', 5: '*', 6: '**', 7: '+', 8: '-', 9: '/', 10: '0', 11: '1', 12: '2', 13: '3', 14: '4', 15: '5', 16: '6', 17: '7', 18: '8', 19: '9', 20: '?', 21: 'cos', 22: 'cosh', 23: 'exp', 24: 'ln', 25: 'log', 26: 'sin', 27: 'sinh', 28: 'tan', 29: 'tanh', 30: 'x'}
{0: '\t', 1: '\n', 2: ' ', 3: '(', 4: ')', 5: '*', 6: '**', 7: '+', 8: '-', 9: '/', 10: '0', 11: '1', 12: '2', 13: '3', 14: '4', 15: '5', 16: '6', 17: '7', 18: '8', 19: '9', 20: '?', 21: 'E', 22: 'cos', 23: 'cosh', 24: 'exp', 25: 'ln', 26: 'log', 27: 'sin', 28: 'sinh', 29: 'tan', 30: 'tanh', 31: 'x'}
[ 0 24  3  8 11  4  8 15  5 31  5 24  3  8 11  4  7 12 15  5 31  6 12  5
 24  3  8 11  4  9 12  8 11 12 15  5 31  6 13  5 24  3  8 11  4  9 16  7
 16 12 15  5 31  6 14  5 24  3  8 11  4  9 12 14  1]


'\texp(-1)-5*x*exp(-1)+25*x**2*exp(-1)/2-125*x**3*exp(-1)/6+625*x**4*exp(-1)/24\n'

### Check data loader 

In [8]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader = sq.Seq2SeqDataLoader(db.train_data, device)

n = 0
print('%5s\t%-20s\t%-20s' % ('block', 'X.shape', 'Y.shape'))
for i, (X, Y) in enumerate(train_loader):
    if i % 1000 == 0:
        print('%5d\t%-20s\t%-20s' % (i, X.shape, Y.shape))
    n += 1
print("\nnumber of blocks: %d" % n)

block	X.shape             	Y.shape             
    0	torch.Size([21, 1]) 	torch.Size([71, 1]) 
 1000	torch.Size([27, 1]) 	torch.Size([27, 1]) 
 2000	torch.Size([11, 1]) 	torch.Size([50, 1]) 
 3000	torch.Size([36, 1]) 	torch.Size([183, 1])
 4000	torch.Size([24, 1]) 	torch.Size([199, 1])

number of blocks: 4058
