In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

import matplotlib.pyplot as plt

In [None]:
class XORSolver(nn.Module):
  def __init__(self):
    super(XORSolver, self).__init__()

    self.h1 = nn.Linear(2, 4)
    self.n1 = nn.LayerNorm(4)
    self.h2 = nn.Linear(4, 1)
  
  def forward(self, x):
    out = self.h1(x)
    out = self.n1(out)
    out = self.h2(out)

    return out

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = XORSolver()
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_func = nn.MSELoss()

n_epoch = 1000

In [None]:
x_train = torch.tensor([[0,0],[1,0],[0,1],[1,1]], dtype=torch.float)
y_train = torch.tensor([0,1,1,0], dtype=torch.float)

train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=1,
    shuffle=True,
    drop_last=False
)

In [None]:
model.to(device)
model.train()

for i in range(n_epoch):
  for (x, y_) in train_loader:
    x = x.to(device)
    y_ = y_.to(device)

    optimizer.zero_grad()

    predict = model(x).view(-1)
    loss = loss_func(predict, y_)

    loss.backward()
    optimizer.step()
  
  if i % 100 == 0:
    print(f'epoch: {i}')
    print(f'x:[0,0], y:[{model(torch.tensor([0,0],dtype=torch.float)).to(device).item()}]')
    print(f'x:[1,0], y:[{model(torch.tensor([1,0],dtype=torch.float)).to(device).item()}]')
    print(f'x:[0,1], y:[{model(torch.tensor([0,1],dtype=torch.float)).to(device).item()}]')
    print(f'x:[1,1], y:[{model(torch.tensor([1,1],dtype=torch.float)).to(device).item()}]')
    print()

epoch: 0
x:[0,0], y:[1.4901161193847656e-08
x:[1,0], y:[1.0
x:[0,1], y:[0.9999998807907104
x:[1,1], y:[-4.470348358154297e-08

epoch: 100
x:[0,0], y:[1.4901161193847656e-08
x:[1,0], y:[1.0
x:[0,1], y:[0.9999998807907104
x:[1,1], y:[-4.470348358154297e-08

epoch: 200
x:[0,0], y:[1.4901161193847656e-08
x:[1,0], y:[1.0
x:[0,1], y:[1.0
x:[1,1], y:[1.043081283569336e-07

epoch: 300
x:[0,0], y:[-0.0024657994508743286
x:[1,0], y:[0.9951086640357971
x:[0,1], y:[0.9957215189933777
x:[1,1], y:[0.000850290060043335

epoch: 400
x:[0,0], y:[-2.2351741790771484e-07
x:[1,0], y:[0.9999996423721313
x:[0,1], y:[0.9999998807907104
x:[1,1], y:[-2.5331974029541016e-07

epoch: 500
x:[0,0], y:[-2.9802322387695312e-08
x:[1,0], y:[1.0
x:[0,1], y:[1.0
x:[1,1], y:[-1.4901161193847656e-08

epoch: 600
x:[0,0], y:[2.9802322387695312e-08
x:[1,0], y:[1.0
x:[0,1], y:[1.0
x:[1,1], y:[-1.4901161193847656e-08

epoch: 700
x:[0,0], y:[-2.9802322387695312e-08
x:[1,0], y:[1.0
x:[0,1], y:[1.0
x:[1,1], y:[-1.4901161193847656e-

In [None]:
model.eval()
output = []
for x in x_train:
  x = x.to(device)
  predict = model(x).view(-1)

  output.append(predict.item())

print(output)

[9.149312973022461e-06, 1.0000128746032715, 0.9999709129333496, -0.00018653273582458496]
