In [64]:
import torch
from typing import Tuple

In [66]:
class XORNet(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = torch.nn.Linear(2, 128)
        self.fc2 = torch.nn.Linear(128, 1)
        self.activation = torch.nn.Sigmoid()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        return x

In [67]:
def train(model: XORNet, x: torch.Tensor, y: torch.Tensor, learn_rate:float = 0.1, epochs: int = 10_000) -> None:
    criterion: torch.nn.Module = torch.nn.MSELoss()
    optimizer: torch.optim.Optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)

    for epoch in range(epochs):
        outputs = model.forward(x)
        loss = criterion.forward(outputs, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Loss {loss.item():.4f}")

In [68]:
def evaluate(model: XORNet, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    with torch.no_grad():
        outputs = model.forward(X)
        predictions = (outputs > 0.5).float()
        return outputs, predictions

In [69]:
def main() -> None:
    X = torch.tensor([
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ], dtype=torch.float32)

    y = torch.tensor([
        [0],
        [1],
        [1],
        [0],
    ], dtype=torch.float32)

    model = XORNet()
    train(model, X, y)
    outputs, predictions = evaluate(model, X)

    print("\nFinal Raw Outputs:")
    print(outputs)
    print("Rounded Predictions:")
    print(predictions)

In [70]:
if __name__ == "__main__":
    main()

Epoch 0 | Loss 0.2509
Epoch 1000 | Loss 0.2464
Epoch 2000 | Loss 0.2397
Epoch 3000 | Loss 0.2224
Epoch 4000 | Loss 0.1760
Epoch 5000 | Loss 0.0981
Epoch 6000 | Loss 0.0432
Epoch 7000 | Loss 0.0213
Epoch 8000 | Loss 0.0126
Epoch 9000 | Loss 0.0084

Final Raw Outputs:
tensor([[0.0682],
        [0.9227],
        [0.9205],
        [0.0873]])
Rounded Predictions:
tensor([[0.],
        [1.],
        [1.],
        [0.]])
