In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from random import randint
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
from tqdm.notebook import tqdm

In [2]:
class XORDataset(Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x = randint(0, 1)
        y = randint(0, 1)
        return torch.tensor([x, y]), torch.tensor([x ^ y])

In [3]:
xor_train = XORDataset(1000)
xor_test = XORDataset(100)
train_loader = DataLoader(xor_train, batch_size=4, shuffle=True)
test_loader = DataLoader(xor_test, batch_size=4, shuffle=True)

In [4]:
class XORNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w1 = nn.Parameter(torch.randn(2, 4, requires_grad=True))
        self.b1 = nn.Parameter(torch.randn(4, requires_grad=True))
        self.w2 = nn.Parameter(torch.randn(4, 1, requires_grad=True))
        self.b2 = nn.Parameter(torch.randn(1, requires_grad=True))
 
    def forward(self, x):
        # implement forward pass
        # x = batch_size x 2, 2 is the number of features

        # Verbrose version
        # w1 = 2 x 4, 4 is the number of hidden units
        # b1 = 1 x 4, 4 is the number of hidden units
        # torch.matmul(x, w1) = (batch_size x 2) x (2 x 4) = batch_size x 4  
        # torch.matmul(x, w1) + b1 = batch_size x 4 + 1 x 4 = batch_size x 4
        x = torch.matmul(x, self.w1) + self.b1 
        x = torch.sigmoid(x)
        x = torch.matmul(x, self.w2) + self.b2
        x = torch.sigmoid(x)

        # Concise version
        # x = torch.sigmoid(x @ self.w1 + self.b1)
        # x = torch.sigmoid(x @ self.w2 + self.b2)
        return x

In [5]:
xor_net = XORNet()
xor_net.to(device)

XORNet()

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Number of parameters in the model: {count_parameters(xor_net)}')

Number of parameters in the model: 17


In [7]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(xor_net.parameters(), lr=0.01)

In [8]:
for epoch in range(10):
    total_train_loss = 0
    total_validation_loss = 0
    test_accuracy = 0
    for i, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = xor_net(x.float())
        loss = loss_fn(y_pred, y.float())
        total_train_loss += loss.item()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        for i, (x, y) in enumerate(test_loader):
            x, y = x.to(device), y.to(device)
            y_pred = xor_net(x.float())
            loss = loss_fn(y_pred, y.float())
            total_validation_loss += loss.item()
            accuracy = ((y_pred > 0.5) == y).sum().item() / y.shape[0]
            test_accuracy += accuracy
    print(f"Epoch {epoch}: train loss {total_train_loss / len(train_loader)} validation loss {total_validation_loss / len(test_loader)} test accuracy {test_accuracy / len(test_loader)}")

Epoch 0: train loss 0.6714922738075256 validation loss 0.6674622559547424 test accuracy 0.63
Epoch 1: train loss 0.5793871027231217 validation loss 0.51549560546875 test accuracy 0.71
Epoch 2: train loss 0.45628410899639127 validation loss 0.4161962366104126 test accuracy 1.0
Epoch 3: train loss 0.3334219079017639 validation loss 0.27913627207279207 test accuracy 1.0
Epoch 4: train loss 0.2222734870314598 validation loss 0.1778747010231018 test accuracy 1.0
Epoch 5: train loss 0.13979574462771416 validation loss 0.10922032535076141 test accuracy 1.0
Epoch 6: train loss 0.0896419452726841 validation loss 0.0722236853837967 test accuracy 1.0
Epoch 7: train loss 0.06057548549771309 validation loss 0.0502554976940155 test accuracy 1.0
Epoch 8: train loss 0.043482094347476956 validation loss 0.036636829823255536 test accuracy 1.0
Epoch 9: train loss 0.03233708334714174 validation loss 0.02796132765710354 test accuracy 1.0


In [9]:
pred = xor_net(torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float().to(device))

In [10]:
(pred > 0.5 ).int()

tensor([[0],
        [1],
        [1],
        [0]], device='mps:0', dtype=torch.int32)

In [11]:
a_data = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float()
w = torch.randn(2, 4)
b = torch.randn(4)


In [12]:
a_data.shape

torch.Size([4, 2])

In [13]:
w.shape

torch.Size([2, 4])