In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from random import randint
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
from tqdm.notebook import tqdm

In [2]:
class XORDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x = randint(0, 1)
        y = randint(0, 1)
        return torch.tensor([x, y]), torch.tensor([x ^ y])

In [3]:
x = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

In [4]:
xor_train = XORDataset(1000)
xor_test = XORDataset(100)
train_loader = DataLoader(xor_train, batch_size=4, shuffle=True)
test_loader = DataLoader(xor_test, batch_size=4, shuffle=True)

In [5]:
w1 = torch.randn(2, 4, requires_grad=True)
b1 = torch.randn(4, requires_grad=True)
w2 = torch.randn(4, 1, requires_grad=True)
b2 = torch.randn(1, requires_grad=True)

In [6]:
z1 = x.float() @ w1 + b1
a1 = torch.sigmoid(z1)
z2 = a1 @ w2 + b2
y_pred = torch.sigmoid(z2)

In [7]:
loss = -torch.mean(y * torch.log(y_pred) + (1 - y) * torch.log(1 - y_pred))
loss

tensor(1.0115, grad_fn=<NegBackward0>)

In [8]:
torch.nn.functional.binary_cross_entropy(y_pred, y)

tensor(1.0115, grad_fn=<BinaryCrossEntropyBackward0>)

In [10]:
for epoch in range(100):
    total_train_loss = 0
    total_validation_loss = 0
    test_accuracy = 0
    for i, (x, y) in enumerate(train_loader):    
        z1 = x.float() @ w1 + b1
        a1 = torch.sigmoid(z1)
        z2 = a1 @ w2 + b2
        y_pred = torch.sigmoid(z2)

        w1.grad = None
        b1.grad = None
        w2.grad = None
        b2.grad = None

        loss = loss = -torch.mean(y * torch.log(y_pred) + (1 - y) * torch.log(1 - y_pred))
        loss.backward()
        total_train_loss += loss.item()
        with torch.no_grad():
            w1 -= 0.01 * w1.grad
            b1 -= 0.01 * b1.grad
            w2 -= 0.01 * w2.grad
            b2 -= 0.01 * b2.grad
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            z1 = x.float() @ w1 + b1
            a1 = torch.sigmoid(z1)
            z2 = a1 @ w2 + b2
            y_pred = torch.sigmoid(z2)
            loss = -torch.mean(y * torch.log(y_pred) + (1 - y) * torch.log(1 - y_pred))
            total_validation_loss += loss.item()
            accuracy = ((y_pred > 0.5) == y).sum().item() / y.shape[0]
            test_accuracy += accuracy
    print(f"Epoch {epoch}: train loss {total_train_loss / len(train_loader)} validation loss {total_validation_loss / len(test_loader)} test accuracy {test_accuracy / len(test_loader)}")

Epoch 0: train loss 0.679673070192337 validation loss 0.6768500113487244 test accuracy 0.75
Epoch 1: train loss 0.6801569554805755 validation loss 0.6725856328010559 test accuracy 0.58
Epoch 2: train loss 0.6780541615486145 validation loss 0.674308066368103 test accuracy 0.79
Epoch 3: train loss 0.6794984564781189 validation loss 0.6772174549102783 test accuracy 0.76
Epoch 4: train loss 0.6760360081195831 validation loss 0.6757626867294312 test accuracy 0.49
Epoch 5: train loss 0.6775535476207734 validation loss 0.6625968217849731 test accuracy 0.6
Epoch 6: train loss 0.6748241431713105 validation loss 0.6745300483703613 test accuracy 0.75
Epoch 7: train loss 0.6730584037303925 validation loss 0.6706727051734924 test accuracy 0.52
Epoch 8: train loss 0.6709829375743867 validation loss 0.6770709657669067 test accuracy 0.73
Epoch 9: train loss 0.6707784655094147 validation loss 0.6698770165443421 test accuracy 0.72
Epoch 10: train loss 0.6687397592067719 validation loss 0.674689197540283

In [11]:
x = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

z1 = x.float() @ w1 + b1
a1 = torch.sigmoid(z1)
z2 = a1 @ w2 + b2
y_pred = torch.sigmoid(z2)

In [12]:
(y_pred > 0.5 ).int()

tensor([[0],
        [1],
        [1],
        [0]], dtype=torch.int32)