In [None]:
# This is the old code for when we dealt with class imbalance by removing samples

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Standardize the features
scaler = StandardScaler()
data[features] = scaler.fit_transform(data[features])

# Convert data to PyTorch tensors
X = torch.tensor(data[features].values, dtype=torch.float32)
y = torch.tensor(data['target'].values, dtype=torch.float32)  # For BCE loss, use float labels
weights_nominal = torch.tensor(data['weight_nominal_scaled'].values, dtype=torch.float32) # Use this one for saving the weights when balancing by removing

# Split data into training+validation and test sets (80/20)
X_train_val, X_test, y_train_val, y_test, wn_train_val, wn_test = train_test_split(X, y, weights_nominal, test_size=0.2, random_state=42)

# Further split the training data into train and validation sets (80/20 of 80%)
X_train, X_val, y_train, y_val, wn_train, wn_val  = train_test_split(X_train_val, y_train_val, wn_train_val, test_size=0.2, random_state=42)

# Create DataLoaders for each set
train_loader = DataLoader(TensorDataset(X_train, y_train, wn_train), batch_size=64, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val, wn_val), batch_size=64, shuffle=False)
test_loader = DataLoader(TensorDataset(X_test, y_test, wn_test), batch_size=64, shuffle=False)

# Define the model
n_features = len(features)

class SimpleNN(nn.Module):
    def __init__(self, n_layers=1, n_neurons=64):
        super(SimpleNN, self).__init__()

        layers = []
        layers.append(nn.Linear(n_features, n_neurons))
        layers.append(nn.BatchNorm1d(n_neurons))  # Add BatchNorm after the input layer

        for _ in range(n_layers - 1):
            layers.append(nn.Linear(n_neurons, n_neurons))
            layers.append(nn.BatchNorm1d(n_neurons))  # Add BatchNorm after each hidden layer
            
        layers.append(nn.Linear(n_neurons, 1)) # Final layer with 1 neuron for binary classification

        self.layers = nn.ModuleList(layers)
        self.sigmoid = nn.Sigmoid() 

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
        x = self.sigmoid(self.layers[-1](x)) #Temp moved
        # x = self.layers[-1](x) # USE THIS LINE FOR BCEWithLogitsLoss
        return x

model = SimpleNN(n_layers=2, n_neurons=64) # This is where we define the model

# Loss and optimizer
criterion = nn.BCELoss(reduction="mean")  # Binary cross entropy 
# criterion = nn.BCEWithLogitsLoss(reduction="mean")  # Binary cross entropy with logits
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model with validation tracking
n_epochs = 20
train_losses = []
val_losses = []

for epoch in range(n_epochs):
    # Training phase
    model.train()
    running_train_loss = 0.0
    for batch_x, batch_y, batch_wn in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        batch_y = batch_y.view(-1, 1)  # Reshape to match output shape
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item() * batch_x.size(0)
        
    epoch_train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    # Validation phase
    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for batch_x, batch_y, batch_wn in val_loader:
            outputs = model(batch_x)
            batch_y = batch_y.view(-1, 1)
            val_loss = criterion(outputs, batch_y)
            running_val_loss += val_loss.item() * batch_x.size(0)

    epoch_val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(epoch_val_loss)

    print(f'Epoch [{epoch+1}/{n_epochs}], Training Loss: {epoch_train_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}')

# Plot the training and validation losses
plt.plot(range(n_epochs), train_losses, label='Training Loss')
plt.plot(range(n_epochs), val_losses, label='Validation Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.xticks(range(n_epochs))
plt.legend()
plt.show()
