<a href="https://colab.research.google.com/github/bckang-ben/exercise/blob/main/XOR_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

In [2]:
class XORsolver(nn.Module):
  def __init__(self):
    super(XORsolver, self).__init__()
    self.h1 = nn.Linear(2,4)
    self.n1 = nn.LayerNorm(4)
    self.h2 = nn.Linear(4,1)
  
  def forward(self, x):
    out = self.h1(x)
    out = self.n1(out)
    out = self.h2(out)
    
    return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = XORsolver().to(device)

In [3]:
x_tensor = torch.tensor([[0,0],[1,0],[0,1],[1,1]], dtype=torch.float)
y_tensor = torch.tensor([0,1,1,0], dtype=torch.float)

tensor_dataset = TensorDataset(x_tensor, y_tensor)
data_loader = DataLoader(
    dataset = tensor_dataset,
    batch_size = 1,
    shuffle = True,
    drop_last = False
)

In [4]:
n_epoch = 1000
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [5]:
model.train()

for n in range(n_epoch):
  for (x, y) in data_loader:
    optimizer.zero_grad()
    x = x.to(device)
    y = y.to(device)

    pred = model(x).view(-1)
    loss = loss_func(pred, y)
    loss.backward()
    optimizer.step()
  
  if n % 100 == 0:
    print(f'epoch: {n}, x:[0,0], y:[{model(torch.tensor([0,0], dtype=torch.float)).to(device).item():.4f}]')
    print(f'epoch: {n}, x:[0,1], y:[{model(torch.tensor([0,1], dtype=torch.float)).to(device).item():.4f}]')
    print(f'epoch: {n}, x:[1,0], y:[{model(torch.tensor([1,0], dtype=torch.float)).to(device).item():.4f}]')
    print(f'epoch: {n}, x:[1,1], y:[{model(torch.tensor([0,0], dtype=torch.float)).to(device).item():.4f}]\n')

epoch: 0, x:[0,0], y:[0.3643]
epoch: 0, x:[0,1], y:[0.0514]
epoch: 0, x:[1,0], y:[0.3348]
epoch: 0, x:[1,1], y:[0.3643]

epoch: 100, x:[0,0], y:[0.0055]
epoch: 100, x:[0,1], y:[1.0037]
epoch: 100, x:[1,0], y:[1.0027]
epoch: 100, x:[1,1], y:[0.0055]

epoch: 200, x:[0,0], y:[0.0000]
epoch: 200, x:[0,1], y:[1.0000]
epoch: 200, x:[1,0], y:[1.0000]
epoch: 200, x:[1,1], y:[0.0000]

epoch: 300, x:[0,0], y:[0.0000]
epoch: 300, x:[0,1], y:[1.0000]
epoch: 300, x:[1,0], y:[1.0000]
epoch: 300, x:[1,1], y:[0.0000]

epoch: 400, x:[0,0], y:[0.0000]
epoch: 400, x:[0,1], y:[1.0000]
epoch: 400, x:[1,0], y:[1.0000]
epoch: 400, x:[1,1], y:[0.0000]

epoch: 500, x:[0,0], y:[0.0000]
epoch: 500, x:[0,1], y:[1.0000]
epoch: 500, x:[1,0], y:[1.0000]
epoch: 500, x:[1,1], y:[0.0000]

epoch: 600, x:[0,0], y:[0.0000]
epoch: 600, x:[0,1], y:[1.0000]
epoch: 600, x:[1,0], y:[1.0000]
epoch: 600, x:[1,1], y:[0.0000]

epoch: 700, x:[0,0], y:[0.0000]
epoch: 700, x:[0,1], y:[1.0000]
epoch: 700, x:[1,0], y:[1.0000]
epoch: 70