# 📌 <span style="font-size:18px; color:#007acc;"><b>Introduction</b></span>

This is the second notebook in the series. The first notebook focused on data preprocessing and cleaning, including handling missing values, detecting and removing outliers, and performing exploratory data analysis (EDA). In this notebook, we continue from that point by importing necessary modules and defining custom functions and classes for model training and testing. A custom batch sampler is implemented for efficient data loading. For interpretability, SHAP analysis is performed to explain feature importance. The trained model is then saved for future use, and a FastAPI server is created to serve the model for local inference. Finally, predictions can be made by sending requests to the API. Note: This work was conducted on MS Azure, so some settings may need adjustments.



# <span style="font-size:18px; color:#007acc;"><b> Table of Contents</b></span>
1. [Install libraries](#Install-Libraries)
2. [Import modules](#Import-Modules) 
3. [Custom functions and classes for model training](#Custom-Functions-and-Classes-for-Model-Training)  
4. [Execution code for model training](#Execution-Code-for-Model-Training)  
5. [SHAP analysis](#SHAP-Analysis)
6. [Creating FAST API](#Create-FAST-API)
7. [Inference from API](#Inference-from-API)

## <span style="font-size:18px; color:#007acc;"><b> 1. Install libraries <a id="Install libraries"></a> ##

In [None]:
# Required libraries to be installed
!pip install numpy 
!pip install pandas
!pip install seaborn
!pip install matplotlib
!pip install azureml-core  # AzureML Core for interacting with Azure Machine Learning services
!pip install azure-storage-blob  # Azure Storage Blob for managing files in Azure Blob Storage
!pip install azureml-dataset-runtime  # AzureML Dataset Runtime for working with datasets in Azure ML
!pip install torch torchvision torchmetrics  # PyTorch for deep learning, torchvision for computer vision, torchmetrics for metrics
!pip install scikit-learn pandas numpy matplotlib seaborn
!pip install pyyaml
!pip install mlflow  # MLflow for managing machine learning workflows
!pip install fastapi uvicorn  # FastAPI for building web APIs
!pip install fsspec  # fsspec for working with file systems

## <span style="font-size:18px; color:#007acc;"><b> 2. Import modules <a id="Import modules"></a> ##

In [None]:
# Standard Library
import sys
import math
import logging
import datetime
import threading
import requests
import yaml

# PyTorch and Related
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler, random_split, TensorDataset
from torchvision import datasets, transforms, models
from torchmetrics import Precision, Recall, F1Score

# Sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve

# Plotting and Visualization
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

# Data Handling
import numpy as np
import pandas as pd

# API
from fastapi import FastAPI
from pydantic import BaseModel
from uvicorn import run

# Remote File Access
import fsspec


## <span style="font-size:18px; color:#007acc;"><b> 3. Custom functions and classes for model training <a id="Custom functions and classes for model training"></a> ##

In [None]:

"""
This code defines a custom PyTorch dataset, neural network model, and several utility functions for training, validation, and evaluation of a binary classification model.

1. `LoadDataset`: A custom dataset class to load tabular data from feature and label files, with support for YAML-based configuration.
2. `NNModel`: A neural network model with customizable layers, weight initialization, and dropout.
3. `ProportionalBatchSampler`: A custom batch sampler to ensure proportional sampling based on state labels.
4. Functions for training, validation, accuracy computation, ROC curve computation, and precision-recall curve computation are also defined.

Each class and function is documented with arguments, return types, and functionality explained.

PEP 8 recommendations have been followed, with the exception of line length exceeding 79 characters in some instances. 
"""


class LoadDataset(Dataset):
    """
    Custom PyTorch Dataset for loading tabular data from feature and label files,
    optionally using a YAML config for cleaner initialization.
    """

    def __init__(self, feature_file=None, label_file=None, sep='\t', skiprows=1, config_path=None):
        """
        Initializes the dataset from file paths or a YAML config.

        Args:
            feature_file (str): Path to the feature file (CSV/TSV).
            label_file (str): Path to the label file (CSV/TSV).
            sep (str): Column separator in the file. Default is tab ('\t').
            skiprows (int): Number of rows to skip (usually header). Default is 1.
            config_path (str): Optional path to a YAML config file containing keys:
                               'feature_file', 'label_file', 'sep', and 'skiprows'.
        """
        if config_path:
            with open(config_path, 'r') as config_file:
                config = yaml.load(config_file, Loader=yaml.FullLoader)
            feature_file = config.get('feature_file')
            label_file = config.get('label_file')
            sep = config.get('sep', sep)
            skiprows = config.get('skiprows', skiprows)

        # Validate file paths
        if not feature_file or not label_file:
            raise ValueError("Both feature_file and label_file must be provided either directly or through the config file.")

        # Load data
        feature_data = pd.read_csv(feature_file, sep=sep, skiprows=skiprows)
        label_data = pd.read_csv(label_file, sep=sep, skiprows=skiprows)

        # Convert to PyTorch tensors
        self.X = torch.tensor(feature_data.values, dtype=torch.float32)
        self.Y = torch.tensor(label_data.values, dtype=torch.float32)
        self.n_samples = self.X.shape[0]

    def __getitem__(self, index):
        """
        Returns a single (feature, label) pair at the given index.
        """
        return self.X[index], self.Y[index]

    def __len__(self):
        """
        Returns the total number of samples.
        """
        return self.n_samples


class NNModel(nn.Module):
    """
    Neural network model with customizable hidden layers and initialization.
    """

    def __init__(self, input_size, hidden_sizes, output_size, initialization, dropout):
        """
        Initialize the neural network model.

        Args:
            input_size (int): Number of input features.
            hidden_sizes (list): List of hidden layer sizes.
            output_size (int): Number of output units.
            initialization (str): Type of weight initialization.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size

        self.fc1 = nn.Linear(input_size, hidden_sizes[0])
        self.initialize_weights(self.fc1, initialization)

        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_sizes) - 1):
            layer = nn.Linear(hidden_sizes[i], hidden_sizes[i + 1])
            self.hidden_layers.append(layer)
            self.initialize_weights(layer, initialization)

        self.fc_out = nn.Linear(hidden_sizes[-1], output_size)
        self.initialize_weights(self.fc_out, initialization)

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        for hidden_layer in self.hidden_layers:
            x = F.relu(hidden_layer(x))
        x = self.fc_out(x)
        return x

    def num_flat_features(self, x):
        """
        Compute number of features after flattening.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            int: Number of flattened features.
        """
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def initialize_weights(self, layer, initialization):
        """
        Initialize weights of a given layer.

        Args:
            layer (nn.Module): Layer to initialize.
            initialization (str): Initialization type.
        """
        if initialization == 'uniform':
            nn.init.uniform_(layer.weight, -0.1, 0.1)
            nn.init.constant_(layer.bias, 0)
        elif initialization == 'normal':
            nn.init.normal_(layer.weight, mean=0, std=0.01)
            nn.init.constant_(layer.bias, 0)
        elif initialization == 'xavier':
            nn.init.xavier_uniform_(layer.weight)
            nn.init.constant_(layer.bias, 0)
        elif initialization == 'he':
            nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
            nn.init.constant_(layer.bias, 0)
        else:
            raise ValueError("Invalid initialization type. Choose from 'uniform', 'normal', 'xavier', or 'he'.")

    def print_output_layer_weights(self, epoch):
        """
        Print weights of the output layer at a given epoch.

        Args:
            epoch (int): Current training epoch.
        """
        print(f"Epoch {epoch}: fc_out")
        print(self.fc_out.weight.data)


class ProportionalBatchSampler(Sampler):
    """
    Custom batch sampler that ensures proportional sampling based on state labels.
    """

    def __init__(self, data_source, batch_size, state_start_idx, state_end_idx):
        """
        Initialize the batch sampler.

        Args:
            data_source (torch.utils.data.TensorDataset): Dataset to sample from.
            batch_size (int): Size of each batch.
            state_start_idx (int): Start index for state slice.
            state_end_idx (int): End index for state slice.
        """
        self.data_source = data_source
        self.batch_size = batch_size
        self.state_start_idx = state_start_idx
        self.state_end_idx = state_end_idx
        self.num_samples = len(data_source)
        self._prepare_indices()

    def _prepare_indices(self):
        """Prepare internal indices and proportions for sampling."""
        input_tensor = self.data_source.tensors[0]
        states_tensor = input_tensor[:, self.state_start_idx:self.state_end_idx]
        state_indices = torch.argmax(states_tensor, dim=1).numpy()

        self.state_to_indices = {}
        for idx, state in enumerate(state_indices):
            self.state_to_indices.setdefault(state, []).append(idx)

        for indices in self.state_to_indices.values():
            np.random.shuffle(indices)

        self.state_proportions = {
            state: len(indices) / self.num_samples
            for state, indices in self.state_to_indices.items()
        }

        self.state_pointers = {state: 0 for state in self.state_to_indices}
        self.num_batches = math.ceil(self.num_samples / self.batch_size)

    def __iter__(self):
        """Yield a batch of indices at each iteration."""
        for _ in range(self.num_batches):
            batch_indices = []

            for state, proportion in self.state_proportions.items():
                n = int(round(proportion * self.batch_size))
                start = self.state_pointers[state]
                end = min(start + n, len(self.state_to_indices[state]))
                batch_indices.extend(self.state_to_indices[state][start:end])
                self.state_pointers[state] = end

            if len(batch_indices) < self.batch_size:
                remaining = self.batch_size - len(batch_indices)
                all_indices = sum(self.state_to_indices.values(), [])
                np.random.shuffle(all_indices)
                batch_indices.extend(all_indices[:remaining])

            np.random.shuffle(batch_indices)
            yield batch_indices

    def __len__(self):
        """Return the number of batches."""
        return self.num_batches


def train_model(model, device, train_loader, optimizer, pos_weight):
    """
    Train the neural network on training data.

    Args:
        model (nn.Module): Neural network model.
        device (torch.device): Computation device.
        train_loader (DataLoader): Training data loader.
        optimizer (torch.optim.Optimizer): Optimizer.
        pos_weight (torch.Tensor): Weight for positive class.

    Returns:
        float: Average training loss per sample.
    """
    model = model.double()
    model.train()
    train_loss = 0.0
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    for data, target in train_loader:
        data, target = data.to(device).double(), target.to(device).double()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    return train_loss / len(train_loader.dataset)


def valid_model(model, device, valid_loader, pos_weight):
    """
    Evaluate the model on validation data.

    Args:
        model (nn.Module): Neural network model.
        device (torch.device): Computation device.
        valid_loader (DataLoader): Validation data loader.
        pos_weight (torch.Tensor): Weight for positive class.

    Returns:
        float: Average validation loss per sample.
    """
    model.eval()
    model = model.to(device).double()
    valid_loss = 0.0
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device).double()

    with torch.no_grad():
        for data, target in valid_loader:
            data, target = data.to(device).double(), target.to(device).double()
            output = model(data).to(device).double()
            loss = criterion(output, target).to(device).double()
            valid_loss += loss.item()

    return valid_loss / len(valid_loader.dataset)


def compute_accuracy(model, data_loader, device):
    """
    Compute accuracy and classification metrics.

    Args:
        model (nn.Module): Neural network model.
        data_loader (DataLoader): Data loader.
        device (torch.device): Computation device.

    Returns:
        tuple: accuracy, precision, recall, f1, FP, FN, TP, TN, confusion matrix
    """
    model = model.to(device).float()
    model.eval()

    CM = torch.zeros(2, 2, dtype=torch.int32)
    correct = 0
    total = 0
    all_predictions = torch.tensor([], dtype=torch.float32, device=device)
    all_targets = torch.tensor([], dtype=torch.float32, device=device)

    precision = Precision(average='macro', num_classes=1, task='binary').to(device)
    recall = Recall(average='macro', num_classes=1, task='binary').to(device)
    f1_score = F1Score(average='macro', num_classes=1, task='binary').to(device)

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device).float(), target.to(device).float()
            output = model(data)
            predicted = (torch.sigmoid(output) >= 0.5).float()

            correct += (predicted == target).sum().item()
            total += target.size(0)

            all_predictions = torch.cat((all_predictions, predicted), dim=0)
            all_targets = torch.cat((all_targets, target), dim=0)

            CM += torch.tensor(confusion_matrix(target.cpu(), predicted.cpu(), labels=[0, 1]))

    accuracy = correct / total * 100
    precision.update(all_predictions, all_targets)
    recall.update(all_predictions, all_targets)
    f1_score.update(all_predictions, all_targets)

    return (
        accuracy,
        precision.compute().item(),
        recall.compute().item(),
        f1_score.compute().item(),
        CM[0][1].item(),  # False Positive
        CM[1][0].item(),  # False Negative
        CM[1][1].item(),  # True Positive
        CM[0][0].item(),  # True Negative
        CM
    )


def compute_roc(model, data_loader, device):
    """
    Compute ROC curve and AUC score.

    Args:
        model (nn.Module): Neural network model.
        data_loader (DataLoader): Data loader.
        device (torch.device): Computation device.

    Returns:
        tuple: fpr, tpr, thresholds, auc_score
    """
    model = model.to(device).float()
    model.eval()

    all_predictions = torch.tensor([], dtype=torch.float32, device=device)
    all_targets = torch.tensor([], dtype=torch.float32, device=device)

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device).float(), target.to(device).float()
            output = model(data)
            predicted = torch.sigmoid(output)

            all_predictions = torch.cat((all_predictions, predicted), dim=0)
            all_targets = torch.cat((all_targets, target), dim=0)

    fpr, tpr, thresholds = roc_curve(all_targets.cpu().numpy(), all_predictions.cpu().numpy())
    roc_auc = auc(fpr, tpr)

    print(f'Number of thresholds: {len(thresholds)}')
    return fpr, tpr, thresholds, roc_auc


def compute_prc(model, data_loader, device):
    model = model.to(device).float()
    model.eval()  # Set the model to evaluation mode
    
    all_predictions = torch.tensor([], dtype=torch.float32, device=device)
    all_targets = torch.tensor([], dtype=torch.float32, device=device)

    with torch.no_grad():  # No need to compute gradients during inference
        for data, target in data_loader:
            data, target = data.to(device).float(), target.to(device).float()
            output = model(data)
            predicted = torch.sigmoid(output)

            all_predictions = torch.cat((all_predictions, predicted), dim=0)
            all_targets = torch.cat((all_targets, target), dim=0)

    # Convert predictions and targets to CPU and numpy arrays
    precision, recall, thresholds = precision_recall_curve(all_targets.cpu().numpy(), all_predictions.cpu().numpy())
    prc_auc = auc(recall, precision)
    
    return precision, recall, prc_auc, thresholds



## <span style="font-size:18px; color:#007acc;"><b> 4. Execution code for model training   <a id="Execution Code for Model Training"></a> ## 

In [None]:
"""
This script performs the following tasks:
1. Loads configuration settings from YAML files.
2. Instantiates the dataset and preprocesses it (scaling and splitting).
3. Defines and trains a neural network model with specified configurations.
4. Calculates evaluation metrics (accuracy, loss, false positives/negatives, ROC/PR curves).
5. Saves the trained model's weights.
"""

import torch
import numpy as np
import yaml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import matplotlib.pyplot as plt

# Load config from YAML
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Set random seed for reproducibility
np.random.seed(config['random_seed'])
torch.manual_seed(config['random_seed'])

# Load dataset
dataset = LoadDataset(config_path='config.yaml')

# Split the dataset into features (X) and labels (Y)
X, Y = dataset.X, dataset.Y

# Perform initial split of the data into training and temporary sets, stratifying by Y
X_train, X_temp, Y_train, Y_temp = train_test_split(
    X, Y,
    test_size=config['test_size_initial'],
    random_state=config['random_seed'],
    stratify=Y
)

# Further split the temporary set into validation and test sets, stratifying by Y_temp
X_val, X_test, Y_val, Y_test = train_test_split(
    X_temp, Y_temp,
    test_size=config['test_size_final'],
    random_state=config['random_seed'],
    stratify=Y_temp
)

# Instantiate scaler (support only MinMaxScaler here, but can be extended)
if config['scaling'] == 'MinMaxScaler':
    sc = MinMaxScaler()
else:
    raise ValueError(f"Scaler '{config['scaling']}' is not supported.")

# Fit and transform the training data
X_train_normalized = sc.fit_transform(X_train)

# Transform the validation and test data using the same scaler
X_val_normalized = sc.transform(X_val)
X_test_normalized = sc.transform(X_test)

# Convert to PyTorch tensors
train_input_tensor = torch.from_numpy(X_train_normalized).float()
train_output_tensor = torch.from_numpy(Y_train).float()
valid_input_tensor = torch.from_numpy(X_val_normalized).float()
valid_output_tensor = torch.from_numpy(Y_val).float()
test_input_tensor = torch.from_numpy(X_test_normalized).float()
test_output_tensor = torch.from_numpy(Y_test).float()

# PyTorch train, validation, and test sets
train = TensorDataset(train_input_tensor, train_output_tensor)
valid = TensorDataset(valid_input_tensor, valid_output_tensor)
test = TensorDataset(test_input_tensor, test_output_tensor)

# Class weights calculation for handling class imbalance
train_num_positives = torch.sum(train_output_tensor == 1)
train_num_negatives = torch.sum(train_output_tensor == 0)
print("train_num_positives", train_num_positives)
print("train_num_negatives", train_num_negatives)

valid_num_positives = torch.sum(valid_output_tensor == 1)
valid_num_negatives = torch.sum(valid_output_tensor == 0)
print("valid_num_positives", valid_num_positives)
print("valid_num_negatives", valid_num_negatives)

# Set device
if config['device_preference'] == 'auto':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    device = torch.device(config['device_preference'])
print(device)

##### Load training config from YAML #####
with open('train_config.yaml', 'r') as file:
    train_config = yaml.safe_load(file)['training']

# Extract config values
num_epochs = train_config['num_epochs']
input_size = train_config['input_size']
output_size = train_config['output_size']
hidden_sizes = train_config['hidden_sizes']
initialization = train_config['initialization']
dropout = train_config['dropout']
learning_rate = train_config['learning_rate']
weight_decay = train_config['weight_decay']
pos_weights = train_config['pos_weights']
batch_percentage = train_config['batch_percentage']
scheduler_milestones = train_config['scheduler']['milestones']
scheduler_gamma = train_config['scheduler']['gamma']
state_start_idx = train_config['state_start_index']
state_end_idx = train_config['state_end_index']

# Calculate batch size
total_train_samples = len(train)
batch_size = int(total_train_samples * batch_percentage)
batch_size = max(1, batch_size)

# Initialize dictionaries to store values for plotting later
all_training_losses = {}
all_valid_losses = {}
all_train_accuracies = {}
all_valid_accuracies = {}
all_train_false_positives = {}
all_train_false_negatives = {}
all_valid_false_positives = {}
all_valid_false_negatives = {}
all_train_fpr = {}
all_train_tpr = {}
all_train_roc_auc = {}
all_valid_fpr = {}
all_valid_tpr = {}
all_valid_roc_auc = {}
all_train_precision = {}
all_train_recall = {}
all_valid_precision = {}
all_valid_recall = {}
all_train_prc_auc = {}
all_valid_prc_auc = {}
all_train_thresholds = {}
all_valid_thresholds = {}

# DataLoaders
train_sampler = ProportionalBatchSampler(train, batch_size=batch_size, state_start_idx, state_end_idx)
train_loader = DataLoader(train, batch_sampler=train_sampler)
valid_loader = DataLoader(valid, batch_size=batch_size, shuffle=False)

# Iterate through positive weights
for pos_weight_mul in pos_weights:
    pos_count = torch.sum(train_output_tensor == 1).item()
    neg_count = torch.sum(train_output_tensor == 0).item()

    if pos_count == 0:
        raise ValueError("No positive samples in the dataset, cannot compute pos_weight.")

    pos_weight = (neg_count / pos_count) * pos_weight_mul
    pos_weight = torch.tensor([pos_weight], device=device)

    # Initialize lists for the current run
    Epoch_ind = []
    training_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []
    train_false_positives = []
    train_false_negatives = []
    valid_false_positives = []
    valid_false_negatives = []

    # Initialize model
    model = NNModel(input_size, hidden_sizes, output_size, initialization, dropout).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_milestones, gamma=scheduler_gamma)

    for epoch in range(1, num_epochs + 1):
        train_loss = train_model(model, device, train_loader, optimizer, pos_weight)
        training_losses.append(train_loss)

        train_accuracy, _, _, _, train_fp, train_fn, _, _, _ = compute_accuracy(model, train_loader, device)
        train_accuracies.append(train_accuracy)
        train_false_positives.append(train_fp)
        train_false_negatives.append(train_fn)

        valid_loss = valid_model(model, device, valid_loader, pos_weight)
        valid_losses.append(valid_loss)

        valid_accuracy, _, _, _, valid_fp, valid_fn, _, _, _ = compute_accuracy(model, valid_loader, device)
        valid_accuracies.append(valid_accuracy)
        valid_false_positives.append(valid_fp)
        valid_false_negatives.append(valid_fn)

        Epoch_ind.append(epoch)
        scheduler.step()
        print(f"epoch {epoch}")


    # Compute ROC curve
    train_fpr, train_tpr, train_roc_auc, train_thresholds = compute_roc(model, train_loader, device)
    valid_fpr, valid_tpr, valid_roc_auc, valid_thresholds = compute_roc(model, valid_loader, device)
    
    # Compute PRC
    train_precision, train_recall, train_prc_auc, train_thresholds = compute_prc(model, train_loader, device)
    valid_precision, valid_recall, valid_prc_auc, valid_thresholds = compute_prc(model, valid_loader, device)
    
     # Store results for the current positive weight
    all_training_losses[pos_weight_mul] = training_losses
    all_valid_losses[pos_weight_mul] = valid_losses
    all_train_accuracies[pos_weight_mul] = train_accuracies
    all_valid_accuracies[pos_weight_mul] = valid_accuracies
    all_train_false_positives[pos_weight_mul] = train_false_positives
    all_train_false_negatives[pos_weight_mul] = train_false_negatives
    all_valid_false_positives[pos_weight_mul] = valid_false_positives
    all_valid_false_negatives[pos_weight_mul] = valid_false_negatives
    
    all_train_fpr[pos_weight_mul] = train_fpr
    all_train_tpr[pos_weight_mul] = train_tpr
    all_train_roc_auc[pos_weight_mul] = train_roc_auc
    all_valid_fpr[pos_weight_mul] = valid_fpr
    all_valid_tpr[pos_weight_mul] = valid_tpr
    all_valid_roc_auc[pos_weight_mul] = valid_roc_auc

    all_train_precision[pos_weight_mul] = train_precision
    all_train_recall[pos_weight_mul] = train_recall
    all_valid_precision[pos_weight_mul] = valid_precision
    all_valid_recall[pos_weight_mul] = valid_recall
    
    all_train_prc_auc[pos_weight_mul] = train_prc_auc
    all_valid_prc_auc[pos_weight_mul] = valid_prc_auc

    all_train_thresholds[pos_weight_mul] = train_thresholds
    all_valid_thresholds[pos_weight_mul] = valid_thresholds

# Plotting results for all positive weights
plt.figure(figsize=(15, 25))

# Plot Train False Positives vs Epochs
plt.subplot(5, 2, 1)
for pos_weight_mul in pos_weights:
    plt.plot(Epoch_ind, all_train_false_positives[pos_weight_mul], label=f"Pos Weight {pos_weight_mul}")
plt.xlabel('Epochs')
plt.ylabel('Count')
plt.title('Train False Positives')
plt.legend()

# Plot Train False Negatives vs Epochs
plt.subplot(5, 2, 2)
for pos_weight_mul in pos_weights:
    plt.plot(Epoch_ind, all_train_false_negatives[pos_weight_mul], label=f"Pos Weight {pos_weight_mul}")
plt.xlabel('Epochs')
plt.ylabel('Count')
plt.title('Train False Negatives')
plt.legend()

# Plot Valid False Positives vs Epochs
plt.subplot(5, 2, 3)
for pos_weight_mul in pos_weights:
    plt.plot(Epoch_ind, all_valid_false_positives[pos_weight_mul], label=f"Pos Weight {pos_weight_mul}")
plt.xlabel('Epochs')
plt.ylabel('Count')
plt.title('Valid False Positives')
plt.legend()

# Plot Valid False Negatives vs Epochs
plt.subplot(5, 2, 4)
for pos_weight_mul in pos_weights:
    plt.plot(Epoch_ind, all_valid_false_negatives[pos_weight_mul], label=f"Pos Weight {pos_weight_mul}")
plt.xlabel('Epochs')
plt.ylabel('Count')
plt.title('Valid False Negatives')
plt.legend()

# Plot Training and Validation Loss vs Epochs
plt.subplot(5, 2, 5)
for pos_weight_mul in pos_weights:
    plt.plot(Epoch_ind, all_training_losses[pos_weight_mul], label=f"Train Loss (Pos Weight {pos_weight_mul})")
    plt.plot(Epoch_ind, all_training_losses[pos_weight_mul], label=f"Validation Loss (Pos Weight {pos_weight_mul})")
plt.xlabel('Epochs')
plt.ylabel('Count')
plt.title('Loss')
plt.legend()

## <span style="font-size:18px; color:#007acc;"><b> 5. Shap analysis   <a id="Shap analysis"></a> ## 

In [None]:

"""
Performs SHAP (Shapley Additive Explanations) analysis on a machine learning model to 
interpret feature contributions to predictions. 

1. Computes SHAP values using SHAP's DeepExplainer.
2. Visualizes feature importance with a summary plot and a bar plot.

Requires a trained model and input data (`valid_input_tensor` and `train_input_tensor`).
"""

def shap_analysis(model, input_tensor):
    """Performs SHAP analysis using the DeepExplainer."""
    # Ensure model is in evaluation mode and on CPU
    model.eval()
    model = model.cpu()

    # Create SHAP explainer
    explainer = shap.DeepExplainer(model, input_tensor)

    # Compute SHAP values
    shap_values = explainer.shap_values(input_tensor, check_additivity=False)
    
    return shap_values


# Perform SHAP analysis
subset_tensor = valid_input_tensor[:10000]
model.eval()
model = model.cpu()
model_predictions = model(subset_tensor).detach().numpy()[:, 0]

shap_values = shap_analysis(model, subset_tensor)
shap_values = np.squeeze(shap_values)  # This will remove all dimensions of size 1

shap_values = np.array(shap_values)

# Sum over the feature dimension (axis=1)
instance_sum = np.sum(shap_values, axis=1)

# Generate the summary plot
shap.summary_plot(
    shap_values, 
    features=subset_tensor.cpu().numpy(), 
    feature_names=["Feature_" + str(i) for i in range(train_input_tensor.size(1))]
)

# Generate the bar plot
shap.plots.bar(shap_exp)


## <span style="font-size:18px; color:#007acc;"><b> 6. Creating FAST API <a id="Create FAST API"></a> ##    

In [None]:
"""
This script serves a trained machine learning model as an API using FastAPI. The model is loaded from a configuration 
file, and the weights are loaded from a pre-trained model file. It accepts input data through HTTP POST requests, 
makes predictions using the loaded model, and returns the results.

Functions:
- load_model_config: Loads the model configuration (training parameters) from a YAML configuration file.
- run_app: Runs the FastAPI server in a separate thread to serve predictions.
- predict: A POST endpoint that takes input features, processes them, and returns the model's prediction.

Classes:
- InputData: A Pydantic model for validating the structure of the incoming data.

The FastAPI server listens on port 8000 and provides a `/predict/` endpoint for model inference. The server is run in 
a separate thread to allow asynchronous execution of requests.

Requirements:
- FastAPI for serving the model as an API.
- PyTorch for loading the model and making predictions.
- YAML for loading configuration from the 'config' file.
"""

# Load model configuration from the config file
def load_model_config():
    with open('config', 'r') as file:  # Load 'config' file
        config = yaml.safe_load(file)
    return config['training']

# Initialize FastAPI app
app = FastAPI()

# Define a Pydantic model for input data
class InputData(BaseModel):
    features: list

# Load model parameters from the config file
model_config = load_model_config()

input_size = model_config['input_size']
output_size = model_config['output_size']
hidden_sizes = model_config['hidden_sizes']
initialization = model_config['initialization']

# Create the model instance
model = NNModel(input_size, hidden_sizes, output_size, initialization)
model.load_state_dict(torch.load('model_weights.pth'))  # Load pre-trained weights
model.eval()  # Set the model to evaluation mode

@app.post("/predict/")
async def predict(data: InputData):
    # Convert input data to tensor and ensure it has the correct shape
    input_tensor = torch.tensor(data.features, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        prediction = model(input_tensor)
    return {"prediction": prediction.tolist()}

# Function to run the app in a separate thread
def run_app():
    run(app, host="0.0.0.0", port=8000, log_level="info")

# Start the FastAPI server in a separate thread
thread = threading.Thread(target=run_app)
thread.start()


## <span style="font-size:18px; color:#007acc;"><b> 7. Inference from API <a id="Inference from API"></a> ##    

In [None]:
"""
This script extracts a single example from the test dataset, sends it to a FastAPI model server for prediction, and 
prints both the predicted and true labels. The steps include:

1. Extract the first example (feature set) from the input tensor and the corresponding true label.
2. Send a POST request to the FastAPI endpoint with the feature data.
3. Process the response, which is assumed to contain logits from the model, apply the sigmoid activation function 
   to get probabilities, and convert them into binary predictions (0 or 1).
4. Print the status code and response from the server, as well as the predicted labels and true label for comparison.

Error handling is included for failed requests and invalid JSON responses.

Dependencies:
- requests for making HTTP requests.
- The sigmoid function to convert logits into probabilities.
"""


# Extract the first example (row) from test_input_tensor and convert it to a list
features = test_input_tensor[:1].tolist()[0]  # Extracts the first example and converts it to a list

# Extract the true label corresponding to the first example
true_label = test_output_tensor[0].item()  # Converts the tensor value to a Python scalar

# Define the endpoint
url = "http://localhost:8000/predict/"

data = {"features": features}  # Use the extracted features

try:
    # Make the POST request
    response = requests.post(url, json=data)

    # Print status code
    print(f"Status Code: {response.status_code}")

    # Print response text for debugging
    print(f"Response Text: {response.text}")

    # Assuming the response is a JSON array of logits
    logits = response.json()

    # Apply the sigmoid function and threshold to each logit
    prediction = sigmoid(logits)
    prediction = [1 if logit >= 0.5 else 0 for logit in logits]

    # Print the resulting predictions (0 or 1)
    print("Predicted Labels:", prediction)

    # Print the true label
    print("True Label:", true_label)

except requests.exceptions.RequestException as e:
    print(f"Request failed: {e}")
except ValueError as e:
    print(f"JSON decode error: {e}")