In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

In [2]:
# --- Step 1: Generate the Fake Dataset ---

def create_fake_survival_data(num_samples=10000):
    """
    Generates a synthetic dataset for mouse survival based on a set of rules.
    """
    # Generate random base features
    x_position = np.random.rand(num_samples) * 100  # Env size 100x100
    y_position = np.random.rand(num_samples) * 100
    hunger_level = np.random.rand(num_samples)      # 0 to 1
    stress_level = np.random.rand(num_samples)      # 0 to 1
    density = np.random.rand(num_samples) * 50      # 0 to 50 mice nearby

    # Combine features into a single matrix (num_samples x 5 features)
    # The columns are the features, and the rows are different mice (samples).
    features = np.stack([
        x_position, y_position, hunger_level, stress_level, density
    ], axis=1)

    # print(features)

    # Generate labels based on our "rule of nature"
    # Label 1 = dead, Label 0 = alive
    # Start by assuming all are alive
    labels = np.zeros(num_samples)

    # Rule 1: High stress AND high hunger is very bad.
    condition1 = (stress_level > 0.8) & (hunger_level > 0.85)

    # Rule 2: Extreme population density is very bad.
    condition2 = density > 45

    # Apply these rules to set the labels for "dead" mice
    labels[condition1 | condition2] = 1

        
    # Add a little bit of random noise to make it more realistic
    # (Sometimes a healthy mouse dies, or a stressed one survives)
    noise = np.random.rand(num_samples) < 0.05 # 5% chance of flipping the label
    labels = np.abs(labels - noise)

    # print(labels)


    print(f"Generated {num_samples} samples.")
    print(f"Number of 'alive' mice (0): {np.count_nonzero(labels == 0)}")
    print(f"Number of 'dead' mice (1): {np.count_nonzero(labels == 1)}\n")
    
    return features, labels

In [3]:
# --- Step 2: Prepare Data for PyTorch ---

# Create the data
features_np, labels_np = create_fake_survival_data(num_samples=100000)

# Convert NumPy arrays to PyTorch Tensors
# The features are our 'X' and the labels are our 'y'
features_tensor = torch.tensor(features_np, dtype=torch.float32)
labels_tensor = torch.tensor(labels_np, dtype=torch.float32).unsqueeze(1) # Add a dimension for the loss function
# print(labels_tensor)

# Create a Dataset and DataLoader to handle batching
batch_size = 32
dataset = TensorDataset(features_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Generated 100000 samples.
Number of 'alive' mice (0): 83668
Number of 'dead' mice (1): 16332



In [4]:
class SurvivalClassifier(nn.Module):
    def __init__(self):
        super(SurvivalClassifier, self).__init__()
        # Input layer: 5 features (x, y, hunger, stress, density)
        # Hidden layer 1: 16 neurons
        # Hidden layer 2: 8 neurons
        # Output layer: 1 neuron (a single value representing the prediction)
        self.network = nn.Sequential(
            nn.Linear(5, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1) # Output is a single value, called a 'logit'
        )

    def forward(self, x):
        return self.network(x)

In [5]:
# --- Step 4: Train the Network ---

# Hyperparameters
epochs = 200
learning_rate = 0.0005 

# Instantiate the model, loss function, and optimizer
model = SurvivalClassifier()
# BCEWithLogitsLoss is perfect for binary classification. It's numerically stable.
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

print("Starting training...")
# The training loop
for epoch in range(epochs):
    # *** FIX 3: TRACK AVERAGE EPOCH LOSS ***
    total_loss = 0.0
    num_batches = 0
    
    for batch_features, batch_labels in data_loader:
        outputs = model(batch_features)
        loss = criterion(outputs, batch_labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}')

print("Training finished.\n")


Starting training...
Epoch [1/200], Average Loss: 0.4332
Epoch [2/200], Average Loss: 0.3784
Epoch [3/200], Average Loss: 0.3528
Epoch [4/200], Average Loss: 0.3343
Epoch [5/200], Average Loss: 0.3188
Epoch [6/200], Average Loss: 0.3050
Epoch [7/200], Average Loss: 0.2943
Epoch [8/200], Average Loss: 0.2859
Epoch [9/200], Average Loss: 0.2802
Epoch [10/200], Average Loss: 0.2753
Epoch [11/200], Average Loss: 0.2704
Epoch [12/200], Average Loss: 0.2683
Epoch [13/200], Average Loss: 0.2644
Epoch [14/200], Average Loss: 0.2607
Epoch [15/200], Average Loss: 0.2577
Epoch [16/200], Average Loss: 0.2565
Epoch [17/200], Average Loss: 0.2537
Epoch [18/200], Average Loss: 0.2522
Epoch [19/200], Average Loss: 0.2507
Epoch [20/200], Average Loss: 0.2497
Epoch [21/200], Average Loss: 0.2471
Epoch [22/200], Average Loss: 0.2469
Epoch [23/200], Average Loss: 0.2460
Epoch [24/200], Average Loss: 0.2456
Epoch [25/200], Average Loss: 0.2453
Epoch [26/200], Average Loss: 0.2458
Epoch [27/200], Average Lo

In [6]:
# --- Step 5: Use the Trained Model for Prediction ---

# Put the model in evaluation mode
model.eval()

# Create two new test mice
# Mouse 1: Should be ALIVE (low stress, low hunger, low density)
mouse_healthy = torch.tensor([[50.0, 50.0, 0.1, 0.1, 5.0]], dtype=torch.float32)

# Mouse 2: Should be DEAD (high stress, high hunger)
mouse_stressed = torch.tensor([[25.0, 30.0, 0.9, 0.9, 15.0]], dtype=torch.float32)

with torch.no_grad(): # We don't need to calculate gradients for prediction
    # Get the raw output (logit) from the model
    healthy_logit = model(mouse_healthy)
    stressed_logit = model(mouse_stressed)

    # Convert the logit to a probability (0 to 1) using the sigmoid function
    healthy_prob = torch.sigmoid(healthy_logit)
    stressed_prob = torch.sigmoid(stressed_logit)

    # Make a final decision
    healthy_prediction = "Alive" if healthy_prob.item() < 0.5 else "Dead"
    stressed_prediction = "Alive" if stressed_prob.item() < 0.5 else "Dead"

print("--- Making Predictions on New Data ---")
print(f"Healthy Mouse -> Survival Probability: {healthy_prob.item():.4f}, Prediction: {healthy_prediction}")
print(f"Stressed Mouse -> Survival Probability: {stressed_prob.item():.4f}, Prediction: {stressed_prediction}")

--- Making Predictions on New Data ---
Healthy Mouse -> Survival Probability: 0.0492, Prediction: Alive
Stressed Mouse -> Survival Probability: 0.7331, Prediction: Dead
