In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math

from torch import optim
from tqdm import tqdm

In [2]:
from cptn import *

In [3]:
Xs = torch.tensor([[0.,0.],[1.,0.],[0.,1.],[1.,1.]])

In [4]:
class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.builder = CTNHelper()
        self.builder.add("in_node1", (2,2),   ["in1", "u1"])
        self.builder.add("in_node2", (2,2), ["in2", "u2"])
        
        self.builder.add("mid_node1", (2,2, 2,2), ["u1", "u2", "o1", "o2"])
        
        self.builder.add("out_node1", (2,2), ["o1", "o2"])
        
        self.params = nn.ParameterDict(map(toparam, self.builder.to_dict().items()))
            
    def forward(self, x):
        x1, x2 = x.permute(1, 0)
        
        x1, x2 = x1.view(-1, 1), x2.view(-1, 1)
        
        in_node1 = batch_eval_poly(legendrend_to_pbasis(self.params["in_node1"]), x1, "in1")
        in_node2 = batch_eval_poly(legendrend_to_pbasis(self.params["in_node2"]), x2, "in2")
        
        inner_int = batch_poly_mul(in_node1.align_to(*self.builder.all_axes_with_batch), legendrend_to_pbasis(self.params["mid_node1"]))
        merged_mid_in1 = batch_defn_integral(inner_int, (-1, 1), "u1").align_to(*self.builder.all_axes_with_batch)
                
        merged_mid_in1_in2 = batch_defn_integral(batch_poly_mul(in_node2.align_to(*self.builder.all_axes_with_batch), merged_mid_in1), (-1, 1), "u2").align_to(*self.builder.all_axes_with_batch)
        merged_mid_in1_in2_out = batch_defn_integral(batch_defn_integral(batch_poly_mul(merged_mid_in1_in2, legendrend_to_pbasis(self.params["out_node1"])), (-1, 1), "o1"), (-1, 1), "o2")
        
        return keep_axes(merged_mid_in1_in2_out, ["batch"])

In [5]:
model = Model()

all_axes=['u1', 'o2', 'in1', 'in2', 'o1', 'u2']


  return super(Tensor, self).refine_names(names)


In [6]:
sum(map(lambda x: x.numel(), model.parameters()))

108

In [7]:
Xs = torch.tensor([[0.,0.],[1.,0.],[0.,1.],[1.,1.]])

# and
# ys = torch.tensor([[0],[0],[0],[1]]).float()
# nand
# ys = torch.tensor([[1],[1],[1],[0]]).float()
# xor
# ys = torch.tensor([[0],[1],[1],[0]]).float()
# or
# ys = torch.tensor([[0],[1],[1],[1]]).float()
# nor
# ys = torch.tensor([[1],[0],[0],[0]]).float()
# ~xor
ys = torch.tensor([[1],[0],[0],[1]]).float()

In [8]:
model(Xs)

tensor([ 3.6433, -1.9709, -2.0122,  1.4906], grad_fn=<SqueezeBackward1>,
       names=('batch',))

In [9]:
EPOCHS = 300
opt = optim.Adam(model.parameters(), lr=1e-3)
lossf = nn.MSELoss()

# for _ in tqdm(range(EPOCHS)):
for idx in range(EPOCHS):
    opt.zero_grad()
    loss = lossf(model(Xs).rename(None).unsqueeze(-1), ys)
    print(f"[{idx}]: {loss.item()=}")

    loss.backward()
    opt.step()

[0]: loss.item()=3.7902841567993164
[1]: loss.item()=3.6702611446380615
[2]: loss.item()=3.5532331466674805


grad.sizes() = [1, 3, 1, 1, 3, 1], strides() = [0, 1, 0, 0, 3, 0]
param.sizes() = [1, 3, 1, 1, 3, 1], strides() = [0, 1, 0, 0, 3, 0] (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/accumulate_grad.h:202.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[3]: loss.item()=3.4392051696777344
[4]: loss.item()=3.3281733989715576
[5]: loss.item()=3.22013258934021
[6]: loss.item()=3.115056276321411
[7]: loss.item()=3.0129220485687256
[8]: loss.item()=2.9137015342712402
[9]: loss.item()=2.817370653152466
[10]: loss.item()=2.7238967418670654
[11]: loss.item()=2.633244037628174
[12]: loss.item()=2.545368194580078
[13]: loss.item()=2.4602296352386475
[14]: loss.item()=2.3777787685394287
[15]: loss.item()=2.2979605197906494
[16]: loss.item()=2.22072172164917
[17]: loss.item()=2.1460039615631104
[18]: loss.item()=2.0737457275390625
[19]: loss.item()=2.0038809776306152
[20]: loss.item()=1.9363491535186768
[21]: loss.item()=1.8710784912109375
[22]: loss.item()=1.8080034255981445
[23]: loss.item()=1.7470598220825195
[24]: loss.item()=1.6881740093231201
[25]: loss.item()=1.6312793493270874
[26]: loss.item()=1.576311707496643
[27]: loss.item()=1.5232032537460327
[28]: loss.item()=1.471889615058899
[29]: loss.item()=1.422305703163147
[30]: loss.item()=1

In [10]:
model(Xs)

tensor([ 1.0195,  0.0080, -0.0049,  0.9991], grad_fn=<SqueezeBackward1>,
       names=('batch',))