In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from random import randint
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
from tqdm.notebook import tqdm

In [2]:
class XORDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x = randint(0, 1)
        y = randint(0, 1)
        return torch.tensor([x, y]), torch.tensor([x ^ y])

In [3]:
xor_train = XORDataset(1000)
xor_test = XORDataset(100)
train_loader = DataLoader(xor_train, batch_size=4, shuffle=True)
test_loader = DataLoader(xor_test, batch_size=4, shuffle=True)

In [4]:
class XORNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 4)
        self.fc2 = nn.Linear(4, 1)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = self.fc2(x)
        return torch.sigmoid(x)

In [5]:
xor_net = XORNet()
xor_net.to(device)

XORNet(
  (fc1): Linear(in_features=2, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=1, bias=True)
)

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Number of parameters in the model: {count_parameters(xor_net)}')

Number of parameters in the model: 17


In [7]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(xor_net.parameters(), lr=0.01)

In [8]:
for epoch in range(15):
    total_train_loss = 0
    total_validation_loss = 0
    test_accuracy = 0
    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = xor_net(x.float())
        loss = loss_fn(y_pred, y.float())
        total_train_loss += loss.item()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            x, y = x.to(device), y.to(device)
            y_pred = xor_net(x.float())
            loss = loss_fn(y_pred, y.float())
            total_validation_loss += loss.item()
            accuracy = ((y_pred > 0.5) == y).sum().item() / y.shape[0]
            test_accuracy += accuracy
    print(f"Epoch {epoch}: train loss {total_train_loss / len(train_loader)} validation loss {total_validation_loss / len(test_loader)} test accuracy {test_accuracy / len(test_loader)}")

Epoch 0: train loss 0.6953675804138184 validation loss 0.7047996068000794 test accuracy 0.44
Epoch 1: train loss 0.6505474356412888 validation loss 0.5804151773452759 test accuracy 0.73
Epoch 2: train loss 0.5328610790371895 validation loss 0.52809679210186 test accuracy 0.71
Epoch 3: train loss 0.4638965935409069 validation loss 0.45193952560424805 test accuracy 0.71
Epoch 4: train loss 0.3635666768997908 validation loss 0.2958161759376526 test accuracy 1.0
Epoch 5: train loss 0.23187883818149566 validation loss 0.16704378485679627 test accuracy 1.0
Epoch 6: train loss 0.1390726335197687 validation loss 0.11331006944179535 test accuracy 1.0
Epoch 7: train loss 0.08679552430659533 validation loss 0.07040661692619324 test accuracy 1.0
Epoch 8: train loss 0.058327246747910975 validation loss 0.04631608545780182 test accuracy 1.0
Epoch 9: train loss 0.04042630817368627 validation loss 0.034579760394990444 test accuracy 1.0
Epoch 10: train loss 0.03053604332357645 validation loss 0.0255818

In [9]:
pred = xor_net(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float().to(device))

In [10]:
(pred > 0.5 ).int()

tensor([[0],
        [1],
        [1],
        [0]], device='mps:0', dtype=torch.int32)