In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.polynomial.polynomial import polyval, polyvander
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from binarypredictor.dataset import FunctionDataset
from binarypredictor.net import BinaryPredictor

In [2]:
def epoch(net, train_loader, loss_fn, optimizer, debug=False):
    # Create x range
    x = np.arange(0., 1.0, step=0.01)
    
    for pc, qc, targets in train_loader:
        # Get the polynomials
        pc, qc = pc.numpy(), qc.numpy()
        p = torch.tensor(polyval(x, pc.T))
        q = torch.tensor(polyval(x, qc.T))
        
        inp = torch.hstack((p, q))
        
        out = net(inp.float()).reshape(-1, 10, 2)[:, :1]
                
        targets = targets.reshape(-1, 10, 2)[:, :1]
        
        if debug:
            print(out[0])
            print(targets[0])
        
        loss = loss_fn(out, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(loss)
    
def train(net, train_loader, loss_fn, optimizer, nr_epochs):
    for i in range(nr_epochs):
        epoch(net, train_loader, loss_fn, optimizer)

In [7]:
fd = FunctionDataset()
net = BinaryPredictor(train=True)

train_loader = DataLoader(fd, batch_size=256, shuffle=True)

loss_fn = nn.MSELoss()

optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

train(net, train_loader, loss_fn, optimizer, 1000)

tensor(0.6956, grad_fn=<MseLossBackward0>)
tensor(0.7130, grad_fn=<MseLossBackward0>)
tensor(0.6369, grad_fn=<MseLossBackward0>)
tensor(0.5812, grad_fn=<MseLossBackward0>)
tensor(0.5959, grad_fn=<MseLossBackward0>)
tensor(0.6246, grad_fn=<MseLossBackward0>)
tensor(0.5864, grad_fn=<MseLossBackward0>)
tensor(0.4974, grad_fn=<MseLossBackward0>)
tensor(0.4834, grad_fn=<MseLossBackward0>)
tensor(0.4746, grad_fn=<MseLossBackward0>)
tensor(0.5497, grad_fn=<MseLossBackward0>)
tensor(0.4716, grad_fn=<MseLossBackward0>)
tensor(0.5333, grad_fn=<MseLossBackward0>)
tensor(0.4550, grad_fn=<MseLossBackward0>)
tensor(0.4984, grad_fn=<MseLossBackward0>)
tensor(0.4651, grad_fn=<MseLossBackward0>)
tensor(0.4843, grad_fn=<MseLossBackward0>)
tensor(0.4065, grad_fn=<MseLossBackward0>)
tensor(0.4471, grad_fn=<MseLossBackward0>)
tensor(0.4329, grad_fn=<MseLossBackward0>)
tensor(0.4532, grad_fn=<MseLossBackward0>)
tensor(0.4779, grad_fn=<MseLossBackward0>)
tensor(0.4307, grad_fn=<MseLossBackward0>)
tensor(0.41

tensor(0.4076, grad_fn=<MseLossBackward0>)
tensor(0.3395, grad_fn=<MseLossBackward0>)
tensor(0.3612, grad_fn=<MseLossBackward0>)
tensor(0.3253, grad_fn=<MseLossBackward0>)
tensor(0.3274, grad_fn=<MseLossBackward0>)
tensor(0.3983, grad_fn=<MseLossBackward0>)
tensor(0.3611, grad_fn=<MseLossBackward0>)
tensor(0.3843, grad_fn=<MseLossBackward0>)
tensor(0.3934, grad_fn=<MseLossBackward0>)
tensor(0.3366, grad_fn=<MseLossBackward0>)
tensor(0.3813, grad_fn=<MseLossBackward0>)
tensor(0.4114, grad_fn=<MseLossBackward0>)
tensor(0.3582, grad_fn=<MseLossBackward0>)
tensor(0.3823, grad_fn=<MseLossBackward0>)
tensor(0.3416, grad_fn=<MseLossBackward0>)
tensor(0.3366, grad_fn=<MseLossBackward0>)
tensor(0.3710, grad_fn=<MseLossBackward0>)
tensor(0.3451, grad_fn=<MseLossBackward0>)
tensor(0.3848, grad_fn=<MseLossBackward0>)
tensor(0.3810, grad_fn=<MseLossBackward0>)
tensor(0.3260, grad_fn=<MseLossBackward0>)
tensor(0.3521, grad_fn=<MseLossBackward0>)
tensor(0.3854, grad_fn=<MseLossBackward0>)
tensor(0.36

tensor(0.3611, grad_fn=<MseLossBackward0>)
tensor(0.4079, grad_fn=<MseLossBackward0>)
tensor(0.3329, grad_fn=<MseLossBackward0>)
tensor(0.3571, grad_fn=<MseLossBackward0>)
tensor(0.3907, grad_fn=<MseLossBackward0>)
tensor(0.3197, grad_fn=<MseLossBackward0>)
tensor(0.3452, grad_fn=<MseLossBackward0>)
tensor(0.3329, grad_fn=<MseLossBackward0>)
tensor(0.3547, grad_fn=<MseLossBackward0>)
tensor(0.3873, grad_fn=<MseLossBackward0>)
tensor(0.3282, grad_fn=<MseLossBackward0>)
tensor(0.3502, grad_fn=<MseLossBackward0>)
tensor(0.3921, grad_fn=<MseLossBackward0>)
tensor(0.4015, grad_fn=<MseLossBackward0>)
tensor(0.3675, grad_fn=<MseLossBackward0>)
tensor(0.3509, grad_fn=<MseLossBackward0>)
tensor(0.3970, grad_fn=<MseLossBackward0>)
tensor(0.3883, grad_fn=<MseLossBackward0>)
tensor(0.3699, grad_fn=<MseLossBackward0>)
tensor(0.3596, grad_fn=<MseLossBackward0>)
tensor(0.3515, grad_fn=<MseLossBackward0>)
tensor(0.3210, grad_fn=<MseLossBackward0>)
tensor(0.3412, grad_fn=<MseLossBackward0>)
tensor(0.32

tensor(0.3694, grad_fn=<MseLossBackward0>)
tensor(0.3581, grad_fn=<MseLossBackward0>)
tensor(0.3579, grad_fn=<MseLossBackward0>)
tensor(0.3722, grad_fn=<MseLossBackward0>)
tensor(0.3746, grad_fn=<MseLossBackward0>)
tensor(0.3930, grad_fn=<MseLossBackward0>)
tensor(0.3691, grad_fn=<MseLossBackward0>)
tensor(0.3586, grad_fn=<MseLossBackward0>)
tensor(0.3266, grad_fn=<MseLossBackward0>)
tensor(0.4105, grad_fn=<MseLossBackward0>)
tensor(0.3802, grad_fn=<MseLossBackward0>)
tensor(0.3416, grad_fn=<MseLossBackward0>)
tensor(0.3535, grad_fn=<MseLossBackward0>)
tensor(0.3337, grad_fn=<MseLossBackward0>)
tensor(0.3338, grad_fn=<MseLossBackward0>)
tensor(0.4150, grad_fn=<MseLossBackward0>)
tensor(0.3149, grad_fn=<MseLossBackward0>)
tensor(0.3311, grad_fn=<MseLossBackward0>)
tensor(0.3790, grad_fn=<MseLossBackward0>)
tensor(0.3792, grad_fn=<MseLossBackward0>)
tensor(0.3327, grad_fn=<MseLossBackward0>)
tensor(0.4077, grad_fn=<MseLossBackward0>)
tensor(0.3253, grad_fn=<MseLossBackward0>)
tensor(0.34

tensor(0.3270, grad_fn=<MseLossBackward0>)
tensor(0.3410, grad_fn=<MseLossBackward0>)
tensor(0.3637, grad_fn=<MseLossBackward0>)
tensor(0.3938, grad_fn=<MseLossBackward0>)
tensor(0.3004, grad_fn=<MseLossBackward0>)
tensor(0.4135, grad_fn=<MseLossBackward0>)
tensor(0.3559, grad_fn=<MseLossBackward0>)
tensor(0.3769, grad_fn=<MseLossBackward0>)
tensor(0.3657, grad_fn=<MseLossBackward0>)
tensor(0.3596, grad_fn=<MseLossBackward0>)
tensor(0.3108, grad_fn=<MseLossBackward0>)
tensor(0.3721, grad_fn=<MseLossBackward0>)
tensor(0.3337, grad_fn=<MseLossBackward0>)
tensor(0.3879, grad_fn=<MseLossBackward0>)
tensor(0.3359, grad_fn=<MseLossBackward0>)
tensor(0.3597, grad_fn=<MseLossBackward0>)
tensor(0.3891, grad_fn=<MseLossBackward0>)
tensor(0.3178, grad_fn=<MseLossBackward0>)
tensor(0.3602, grad_fn=<MseLossBackward0>)
tensor(0.3396, grad_fn=<MseLossBackward0>)
tensor(0.3575, grad_fn=<MseLossBackward0>)
tensor(0.3270, grad_fn=<MseLossBackward0>)
tensor(0.3371, grad_fn=<MseLossBackward0>)
tensor(0.36

tensor(0.3479, grad_fn=<MseLossBackward0>)
tensor(0.3577, grad_fn=<MseLossBackward0>)
tensor(0.3861, grad_fn=<MseLossBackward0>)
tensor(0.3313, grad_fn=<MseLossBackward0>)
tensor(0.3207, grad_fn=<MseLossBackward0>)
tensor(0.3777, grad_fn=<MseLossBackward0>)
tensor(0.3460, grad_fn=<MseLossBackward0>)
tensor(0.3271, grad_fn=<MseLossBackward0>)
tensor(0.2958, grad_fn=<MseLossBackward0>)
tensor(0.3518, grad_fn=<MseLossBackward0>)
tensor(0.4114, grad_fn=<MseLossBackward0>)
tensor(0.3705, grad_fn=<MseLossBackward0>)
tensor(0.3790, grad_fn=<MseLossBackward0>)
tensor(0.3204, grad_fn=<MseLossBackward0>)
tensor(0.3572, grad_fn=<MseLossBackward0>)
tensor(0.3152, grad_fn=<MseLossBackward0>)
tensor(0.3692, grad_fn=<MseLossBackward0>)
tensor(0.3239, grad_fn=<MseLossBackward0>)
tensor(0.2786, grad_fn=<MseLossBackward0>)
tensor(0.3586, grad_fn=<MseLossBackward0>)
tensor(0.3728, grad_fn=<MseLossBackward0>)
tensor(0.3535, grad_fn=<MseLossBackward0>)
tensor(0.3142, grad_fn=<MseLossBackward0>)
tensor(0.34

In [8]:
epoch(net, train_loader, loss_fn, optimizer, debug=True)

tensor([[0.1116, 0.2707]], grad_fn=<SelectBackward0>)
tensor([[0.5595, 0.3667]])
tensor([[0.0831, 0.0991]], grad_fn=<SelectBackward0>)
tensor([[-1., -1.]])
tensor([[0.0335, 0.0271]], grad_fn=<SelectBackward0>)
tensor([[0.7914, 0.6269]])
tensor([[0.3250, 0.3005]], grad_fn=<SelectBackward0>)
tensor([[0.3666, 0.6392]])
tensor(0.3442, grad_fn=<MseLossBackward0>)


In [5]:
ttt = torch.tensor([1., 0.5, .6, .5, .3, .5, -1., -1.])
print(ttt[ttt > -1.])

tensor([1.0000, 0.5000, 0.6000, 0.5000, 0.3000, 0.5000])
