In [1]:
import torch.nn as nn
import torch.optim as optim

import config
import loader

In [2]:
# Define the model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(9, 32)  # 9 input features, 32 hidden units
        self.fc2 = nn.Linear(32, 16)  # 32 hidden units, 16 hidden units in second layer
        self.fc3 = nn.Linear(16, 1)  # 16 hidden units, 1 output unit
        self.relu = nn.ReLU()  # Activation function
        self.sigmoid = nn.Sigmoid()  # Sigmoid for binary classification

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Create the model
model = SimpleNN().to(config.device)
total_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {total_params}')

# Define loss function and optimizer
criterion = nn.BCELoss()  # Binary Cross Entropy loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

Total number of parameters: 865


In [3]:
dataloader = loader.dataloader()

# Training loop
num_epochs = 20

for epoch in range(num_epochs):
    losses = []
    for inputs, labels in dataloader:
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels.unsqueeze(1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track the losses
        losses.append(loss.item())

    # Print loss every epoch
    average_loss = sum(losses) / len(losses)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}')

Features shape: (391294, 9)
Labels shape: (391294,)
5289 positive examples, 386005 negative examples
data loaded with 14.9 MiB
Epoch [1/20], Loss: 0.0152
Epoch [2/20], Loss: 0.0050
Epoch [3/20], Loss: 0.0056
Epoch [4/20], Loss: 0.0218
Epoch [5/20], Loss: 0.0061
Epoch [6/20], Loss: 0.0070
Epoch [7/20], Loss: 0.0082
Epoch [8/20], Loss: 0.0040
Epoch [9/20], Loss: 0.0232
Epoch [10/20], Loss: 0.0588
Epoch [11/20], Loss: 0.0219
Epoch [12/20], Loss: 0.0044
Epoch [13/20], Loss: 0.0062
Epoch [14/20], Loss: 0.3253
Epoch [15/20], Loss: 0.0459
Epoch [16/20], Loss: 0.0468
Epoch [17/20], Loss: 0.1707
Epoch [18/20], Loss: 0.0078
Epoch [19/20], Loss: 0.0049
Epoch [20/20], Loss: 0.0053
