Train a neural network that takes a 4-bit input (e.g., [1, 0, 1, 1]) and outputs a one-hot encoded vector representing the number of 1s (from 0 to 4). There are 5 possible output classes.

In [3]:
%pip install torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

# Generate all 4-bit combinations (from 0000 to 1111)
trainx = torch.tensor([[int(b) for b in format(i, '04b')] for i in range(16)], dtype=torch.float32)

# Count the number of 1's in each combination
trainy = torch.tensor([int(x.sum().item()) for x in trainx], dtype=torch.long)  # long for class indices

# One-hot encoding is not needed as CrossEntropyLoss expects class indices

# Dataset and DataLoader
dataset = TensorDataset(trainx, trainy)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Define the model
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 5)  # 5 output classes

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)  # No softmax needed — CrossEntropyLoss does that

model = Net()

# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# Train the model
for epoch in range(1000):
    for xb, yb in loader:
        preds = model(xb)
        loss = loss_fn(preds, yb)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate accuracy
with torch.no_grad():
    preds = model(trainx)
    predicted_classes = preds.argmax(dim=1)
    accuracy = (predicted_classes == trainy).float().mean()
    print(f"Accuracy: {accuracy.item() * 100:.2f}%")

# Show predictions
for i in range(len(trainx)):
    print(f"Input: {trainx[i].numpy()} → Predicted count: {predicted_classes[i].item()}, Actual: {trainy[i].item()}")


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Accuracy: 100.00%
Input: [0. 0. 0. 0.] → Predicted count: 0, Actual: 0
Input: [0. 0. 0. 1.] → Predicted count: 1, Actual: 1
Input: [0. 0. 1. 0.] → Predicted count: 1, Actual: 1
Input: [0. 0. 1. 1.] → Predicted count: 2, Actual: 2
Input: [0. 1. 0. 0.] → Predicted count: 1, Actual: 1
Input: [0. 1. 0. 1.] → Predicted count: 2, Actual: 2
Input: [0. 1. 1. 0.] → Predicted count: 2, Actual: 2
Input: [0. 1. 1. 1.] → Predicted count: 3, Actual: 3
Input: [1. 0. 0. 0.] → Predicted count: 1, Actual: 1
Input: [1. 0. 0. 1.] → Predicted count: 2, Actual: 2
Input: [1. 0. 1. 0.] → Predicted count: 2, Actual: 2
Input: [1. 0. 1. 1.] → Predicted count: 3, Actual: 3
Input: [1. 1. 0. 0.] → Predicted count: 2, Actual: 2
Input: [1. 1. 0. 1.] → Predicted count: 3, Actual: 3
Input: [1. 1. 1. 0.] → Predicted count: 3, Actual: 3
Input: [1. 1. 1. 1.] → Predicted count: 4, 