In [1]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

from src.models import Net
from src.misc import vectorize_weights, get_cardinality, format_to_shapes, teach
from IPython import display
from torchviz import make_dot

In [2]:
# set up the xor data
X = torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.float)
target = torch.tensor([0, 1, 1, 0], dtype=torch.float).unsqueeze(1)

In [3]:
arch = [(2, 40), (40,), (40, 1), (1,)]

net = Net(arch)

cardinality = get_cardinality(arch)
print("Cardinality: {}".format(cardinality))

teacher = nn.Sequential(nn.Linear(cardinality, 100),
                                     nn.ReLU(),
                                     nn.Linear(100, 100),
                                     nn.ReLU(),
                                     nn.Linear(100, 100),
                                     nn.ReLU(),
                                     nn.Linear(100, 100),
                                     nn.ReLU(), 
                                     nn.Linear(100, cardinality))

Cardinality: 161


In [4]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.01)
    
weights = torch.randn(cardinality)
for e in range(10000):
    new_weights = teacher(weights)
    net.set_weights(format_to_shapes(new_weights, arch))
    
    pred = net(X)    
    loss = criterion(pred, target)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    weights = new_weights.detach()
        
    if e % 100 == 0:
        display.clear_output(wait=True)
        print("Epoch: {} ~ BCE: {}".format(e, loss))
        
    if loss < 0.0001:
        break

Epoch: 2100 ~ BCE: 0.00010965594265144318


In [5]:
with torch.no_grad():
    print(torch.sigmoid(net(X)))
    print(target)

tensor([[3.2868e-04],
        [9.9996e-01],
        [9.9998e-01],
        [8.2559e-06]])
tensor([[0.],
        [1.],
        [1.],
        [0.]])
