In [None]:
import torch
import torch.nn.functional as F

In [None]:
# Dataset
x = torch.tensor([
    [a, b]
    for a in (0, 1)
    for b in (0, 1)
], dtype=torch.float).t()

y = torch.tensor([
    a ^ b
    for a in (0, 1)
    for b in (0, 1)
], dtype=torch.float)


def do_nothing(*args):
    pass


class Model:
    def __init__(self, activation=F.relu):
        # 2 binary features -> HiddenLayer(2 neurons) -> relu -> OutputLayer(1 neuron)
        # Hidden layer parameters
        self.W = torch.randn((2, 2)) * 2
        self.b = torch.randn((2, 1))

        # Output layer parameters
        self.w = torch.randn(2) / 32 * 2
        self.c = torch.randn(1) / 32

        self.parameters = [self.W, self.w, self.b, self.c]
        
        for p in self.parameters:
            p.requires_grad = True
        
        self.activation = activation

    def forward(self, input):
        return self.w @ self.activation(self.W @ input + self.b) + self.c
    
    def train(self, n, α=1e-3, report=do_nothing):
        for iteration in range(n):
            loss = mse_loss(self.forward(x), y)
            report(iteration, loss)
            loss.backward()

            with torch.no_grad():
                for parameter in self.parameters:
                    parameter -= α * parameter.grad
                    parameter.grad.zero_()
    
    def __str__(self):
        return f"W: {self.W}\nb: {self.b}\nw: {self.w}\nc: {self.c}"

def mse_loss(y0, y1):
    return (y0 - y1).pow(2).sum()

In [None]:
good_count = 0
bad_count = 0

for attempt in range(50):
    xor_model = Model(activation=F.relu)
    xor_model.train(5000, α=3e-3)
    result = torch.round(xor_model.forward(x) * 10) / 10
    if torch.equal(result, torch.tensor([0., 1., 1., 0.])):
        good_count += 1
        print(f'Attempt {attempt}: good')
    else:
        bad_count += 1
        print(f'Attempt {attempt}: bad ({result})')

In [None]:
# Model finds global optimum only 30% of the time =(
# All other times it gets stuck in a platou
good_count / (good_count + bad_count)