# Simple Neural Network to learn XOR in PyTorch

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable

In [26]:
class NN(nn.Module):
  def __init__(self):
    super(NN, self).__init__()
    self.l1_linear = nn.Linear(2, 8, bias=False)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    self.l2_linear = nn.Linear(8, 1, bias=False)
    nn.init.kaiming_normal_(self.l2_linear.weight)
  
  def forward(self, x):
    l1 = self.l1_linear(x)
    out = F.sigmoid(self.l1_linear(x))
    out = F.sigmoid(self.l2_linear(out))
    return out

xor_nn = NN()
optimizer = torch.optim.Adam(xor_nn.parameters(), lr=1e-3, weight_decay=1e-5)

# prepare the training data
x_in = Variable(torch.FloatTensor([[1,1],[0,0],[1,0],[0,1]]))
y_out = Variable(torch.FloatTensor([[0],[0],[1],[1]]))

In [27]:
for i in range(10000):
  predict = xor_nn(x_in)
  loss = F.smooth_l1_loss(predict,y_out)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if (i+1)%2000 == 0:
    print('i: %d, loss: %.4f'%(i+1,loss.data[0]))



i: 2000, loss: 0.0322
i: 4000, loss: 0.0065
i: 6000, loss: 0.0014
i: 8000, loss: 0.0005
i: 10000, loss: 0.0004


In [28]:
print('input:\n',x_in)
print('output:\n',xor_nn(x_in))

input:
 tensor([[ 1.,  1.],
        [ 0.,  0.],
        [ 1.,  0.],
        [ 0.,  1.]])
output:
 tensor([[ 0.0259],
        [ 0.0280],
        [ 0.9722],
        [ 0.9731]])


In [29]:
# weights in the model
xor_nn.state_dict()

OrderedDict([('l1_linear.weight', tensor([[-2.5086,  5.3487],
                      [-3.6336, -3.6250],
                      [ 4.3361, -1.9668],
                      [-3.2398, -2.9609],
                      [ 3.6643,  3.6338],
                      [ 2.4062, -5.0674],
                      [ 4.4046, -2.0273],
                      [-4.4977,  2.0641]])),
             ('l2_linear.weight',
              tensor([[-4.7847, -5.6596, -3.0932, -4.9510,  6.4870,  4.9780, -3.6748,
                        3.6012]]))])

fast and slow

In [30]:
def copy_grad(source, target):
  grads = []
  for param in source.parameters():
    grads.append(param.grad.clone())
  grads.reverse()
  for param in target.parameters():
    param.grad = grads.pop()

def update_slow(slow, fast, tau):
  s_d = slow.state_dict()
  f_d = fast.state_dict()
  for name in s_d:
    s_d[name] = s_d[name] * ( 1. - tau) + f_d[name] * tau
  slow.load_state_dict(s_d)

In [42]:
xor_fast = NN()
xor_slow = NN()
xor_fast.load_state_dict(xor_slow.state_dict())
optimizer = torch.optim.Adam(xor_fast.parameters(), lr=1e-3, weight_decay=1e-5)

In [43]:
for i in range(10000):
  predict = xor_slow(x_in)
  loss = F.smooth_l1_loss(predict,y_out)
  optimizer.zero_grad()
  loss.backward()
  copy_grad(xor_slow, xor_fast)
  optimizer.step()
  update_slow(xor_slow, xor_fast, 0.1)
#   print(xor_slow.state_dict())
#   print(xor_fast.state_dict())
  if (i+1)%2000 == 0:
    print('i: %d, loss: %.4f'%(i+1,loss.data[0]))



i: 2000, loss: 0.0960
i: 4000, loss: 0.0390
i: 6000, loss: 0.0014
i: 8000, loss: 0.0004
i: 10000, loss: 0.0002


In [44]:
print('input:\n',x_in)
print('output:\n',xor_slow(x_in))

input:
 tensor([[ 1.,  1.],
        [ 0.,  0.],
        [ 1.,  0.],
        [ 0.,  1.]])
output:
 tensor([[ 0.0020],
        [ 0.0003],
        [ 1.0000],
        [ 0.9604]])
