Skip to content

cloneofsimo/poly2SOP

Repository files navigation

poly2SOP

Transformer takes a polynomial, expresses it as sum of powers. Implemented with pytorch.

Introduction

Tasks such as https://en.wikipedia.org/wiki/Sums_of_three_cubes, https://en.wikipedia.org/wiki/Taxicab_number and conjectures such as finding non - trivial integer solution of requires deeper understanding of polynomial's characteristics.

For example, if one finds non - trivial polynomial with no odd degree that is sum of 2 quintics (), then by clever substitution :

This approach was actually first used by Ramanujan, in his set of parameterized solutions of .

Inspired from https://arxiv.org/pdf/1912.01412.pdf, I wondered if deep NNs are capable of such natural polynomial manipulations.

Here, we implement seq2seq model, using default pytorch's TransformerEncoder, TransformerDecoder module. Datasets were created with SymPy.

from models import SOP

chars = list("0987654321-+*()^xy")
n_vocab = len(chars) + 2 #One for paddings, one for init token.
device = torch.device("cuda:0")

model = SOP(
    d_model = 512,
    n_head = 8,
    num_layers = 6,
    n_vocab = n_vocab, 
    max_len = max_len, 
    chars = chars,
    device = device
)

Simple use case with dataset I've created can be found in the repository too!

...
chars = list("0987654321-+*()^xy")
n_vocab = len(chars) + 2

model = SOP(
    d_model = 512,
    n_head = 8,
    num_layers = 6,
    n_vocab = n_vocab, 
    max_len = max_len, 
    chars = chars,
    device = device
)

opt = optim.AdamW(model.parameters(), lr = lr, weight_decay = 1e-10)
dataset = eq_dataset(max_len = max_len, chars = chars)
dl = DataLoader(dataset, shuffle= True, batch_size= batch_size,  drop_last= True, num_workers = 3)
criterion = nn.CrossEntropyLoss()
model.to(device)

for epoch in range(1, epochs + 1):
    pbar = tqdm(dl)
    tot_loss = 0
    cnt = 0
    for (x, yin, yout) in pbar:

        x = x.to(device)
        yin = torch.cat([torch.ones(batch_size, 1) * (n_vocab - 1), yin], dim = 1).long()
        yin = yin.to(device)
        yout = yout.to(device)
        y_pred = model(x, yin)

        loss = criterion(y_pred.view(-1, n_vocab - 1), yout.view(-1))
        model.zero_grad()
        loss.backward()
        opt.step()
        tot_loss += loss.item()
        cnt += 1
        pbar.set_description(f"current loss : {tot_loss/cnt:.5f}")

    eq = "2*y^4-2*y^3-y^2+1"
    ans = "(1-y^2)^2+(-y^2+y)^2"

    ral = model.toSOP(eq, gen_len = max_len - 1)
    print(f'Epoch {epoch} : Loss : {tot_loss/cnt :.5f}, Example : {ral[0]}')

You can easily create your own dataset with sympy.

x, y = symbols('x y')
pol = [x, y, 1, x*y, x*x, y*y]
n = 2
def random_function(cr = 2):
    f = 0
    for mo in pol:
        f = f + mo*ri(-cr, cr)
    return expand(f)

# Later on...
f1, f2 = random_function(), random_function()
f3 = f1**n + f2**n
f4 = expand(f3)
FILE_x.write(str(f4).replace(' ', '').replace('**', '^') + '\n')
FILE_y.write(str(f3).replace(' ', '').replace('**', '^') + '\n')

About

Transformer takes a polynomial, expresses it as sum of powers.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages