In [None]:
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [None]:
sakhawat18_asteroid_dataset_path = kagglehub.dataset_download('sakhawat18/asteroid-dataset')

In [None]:
sakhawat18_asteroid_dataset_path

'/kaggle/input/asteroid-dataset'

In [None]:
!pip install torch_geometric -q

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)


import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging

from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, roc_curve, roc_auc_score
from sklearn.neighbors import kneighbors_graph

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        F_loss = alpha_t * (1 - pt) ** self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

class AsteroidGraphDataset:
    def __init__(self, filepath):
        logger.info(f"Loading dataset from {filepath}")
        self.df = pd.read_csv(filepath, low_memory=False)
        self.df = self.df.dropna(subset=['pha'])
        logger.info(f"After dropping rows with missing 'pha', dataset size: {len(self.df)}")

        features = ['e', 'a', 'i', 'q', 'tp', 'H', 'diameter', 'albedo', 'moid_ld']
        X = self.df[features].astype(float).values
        y = (self.df['pha'] == 'Y').astype(int).values

        indices = np.arange(len(self.df))
        train_val_indices, test_indices = train_test_split(
            indices, test_size=0.2, stratify=y, random_state=42
        )
        train_indices, val_indices = train_test_split(
            train_val_indices, test_size=0.2, stratify=y[train_val_indices], random_state=42
        )

        imputer = SimpleImputer(strategy='median')
        imputer.fit(X[train_indices])
        X_imputed = imputer.transform(X)

        scaler = StandardScaler()
        scaler.fit(X_imputed[train_indices])
        X_scaled = scaler.transform(X_imputed)

        train_mask = torch.zeros(len(self.df), dtype=torch.bool)
        val_mask = torch.zeros(len(self.df), dtype=torch.bool)
        test_mask = torch.zeros(len(self.df), dtype=torch.bool)
        train_mask[train_indices] = True
        val_mask[val_indices] = True
        test_mask[test_indices] = True

        connectivity = kneighbors_graph(X_scaled, n_neighbors=5, mode='connectivity')
        rows, cols = connectivity.nonzero()
        edge_index_np = np.vstack((rows, cols))
        edge_index = torch.from_numpy(edge_index_np).to(torch.long)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        self.data = Data(x=torch.tensor(X_scaled, dtype=torch.float),
                         edge_index=edge_index,
                         y=torch.tensor(y, dtype=torch.float),
                         train_mask=train_mask,
                         val_mask=val_mask,
                         test_mask=test_mask)

        logger.info(f"Graph constructed with {self.data.num_nodes} nodes and {self.data.num_edges} edges")

class AsteroidGNN(nn.Module):
    def __init__(self, num_features):
        super(AsteroidGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 128)
        self.conv3 = GCNConv(128, 128)
        self.attn = GATConv(128, 64, heads=2)
        self.fc = nn.Linear(64 * 2, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.attn(x, edge_index))
        x = self.fc(x)
        return x.squeeze(-1)

class AsteroidHazardClassifier:
    def __init__(self, dataset_path):
        logger.info("Initializing Asteroid Hazard Classifier")
        self.dataset = AsteroidGraphDataset(dataset_path)
        self.model = AsteroidGNN(num_features=len(['e', 'a', 'i', 'q', 'tp', 'H', 'diameter', 'albedo', 'moid_ld']))
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=1e-4)
        self.criterion = FocalLoss(alpha=0.75, gamma=2)

    def train(self, epochs=100, patience=10):
        logger.info(f"Starting training for {epochs} epochs")
        best_val_loss = float('inf')
        patience_counter = 0
        best_model_state = None

        epoch_progress = tqdm(range(epochs), desc="Training Epochs", position=0)

        for epoch in epoch_progress:
            self.model.train()
            self.optimizer.zero_grad()
            output = self.model(self.dataset.data)
            loss = self.criterion(output[self.dataset.data.train_mask], self.dataset.data.y[self.dataset.data.train_mask])
            loss.backward()
            self.optimizer.step()

            self.model.eval()
            with torch.no_grad():
                val_output = self.model(self.dataset.data)
                val_loss = self.criterion(val_output[self.dataset.data.val_mask], self.dataset.data.y[self.dataset.data.val_mask])

            epoch_progress.set_postfix({'Loss': f'{loss.item():.4f}', 'Val Loss': f'{val_loss.item():.4f}'})

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logger.info(f"Early stopping at epoch {epoch}")
                    break

        self.model.load_state_dict(best_model_state)
        logger.info("Training completed")

    def evaluate(self):
        logger.info("Starting model evaluation")
        self.model.eval()
        with torch.no_grad():
            output = self.model(self.dataset.data)
            val_output = output[self.dataset.data.val_mask].cpu().numpy()
            val_y = self.dataset.data.y[self.dataset.data.val_mask].cpu().numpy()

            precision, recall, thresholds = precision_recall_curve(val_y, val_output)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            optimal_idx = np.argmax(f1_scores)
            optimal_threshold = thresholds[optimal_idx]
            logger.info(f"Optimal threshold: {optimal_threshold}")

            test_output = output[self.dataset.data.test_mask].cpu().numpy()
            test_y = self.dataset.data.y[self.dataset.data.test_mask].cpu().numpy()
            test_pred = (test_output > optimal_threshold).astype(int)

            logger.info("Classification Report:")
            print(classification_report(test_y, test_pred))

            cm = confusion_matrix(test_y, test_pred)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
            plt.title('Confusion Matrix', fontsize=10)
            plt.xlabel('Predicted', fontsize=10)
            plt.ylabel('Actual', fontsize=10)
            plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
            plt.tight_layout()
            plt.savefig('confusion_matrix.png')
            plt.close()
            logger.info("Confusion matrix saved to confusion_matrix.png")

            fpr, tpr, _ = roc_curve(test_y, test_output)
            roc_auc = roc_auc_score(test_y, test_output)
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, color='dodgerblue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], color='crimson', lw=2, linestyle='--')
            plt.xlabel('False Positive Rate', fontsize=10)
            plt.ylabel('True Positive Rate', fontsize=10)
            plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=10)
            plt.legend(loc="lower right", fontsize=10)
            plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
            plt.tight_layout()
            plt.savefig('roc_curve.png')
            plt.close()
            logger.info("ROC curve saved to roc_curve.png")

def main():
    # Set the path to the dataset directory
    #sakhawat18_asteroid_dataset_path = '/path/to/dataset'  # Replace with actual path
    dataset_path = f"{sakhawat18_asteroid_dataset_path}/dataset.csv"
    classifier = AsteroidHazardClassifier(dataset_path)
    classifier.train(epochs=100, patience=10)
    classifier.evaluate()

if __name__ == '__main__':
    main()

Training Epochs: 100%|██████████| 100/100 [44:18<00:00, 26.59s/it, Loss=0.0010, Val Loss=0.0010]


              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00    187308
         1.0       0.25      0.54      0.34       413

    accuracy                           1.00    187721
   macro avg       0.62      0.77      0.67    187721
weighted avg       1.00      1.00      1.00    187721



In [None]:
!pip install tabulate -q

In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging
from tabulate import tabulate
from sklearn.ensemble import IsolationForest
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, roc_curve, roc_auc_score
from sklearn.neighbors import kneighbors_graph

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Configuration
config = {
    'epochs': 100,
    'patience': 10,
    'save_dpi': 300,
    'display_dpi': 120
}

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        F_loss = alpha_t * (1 - pt) ** self.gamma * BCE_loss
        return torch.mean(F_loss) if self.reduction == 'mean' else F_loss

class AsteroidGraphDataset:
    def __init__(self, filepath):
        logger.info(f"Loading dataset from {filepath}")
        self.df = pd.read_csv(filepath, low_memory=False)
        self.df = self.df.dropna(subset=['pha'])
        logger.info(f"After dropping rows with missing 'pha', dataset size: {len(self.df)}")

        features = ['e', 'a', 'i', 'q', 'tp', 'H', 'diameter', 'albedo', 'moid_ld']
        X = self.df[features].astype(float).values
        y = (self.df['pha'] == 'Y').astype(int).values

        indices = np.arange(len(self.df))
        train_val_indices, test_indices = train_test_split(
            indices, test_size=0.2, stratify=y, random_state=42
        )
        train_indices, val_indices = train_test_split(
            train_val_indices, test_size=0.2, stratify=y[train_val_indices], random_state=42
        )

        imputer = SimpleImputer(strategy='median')
        imputer.fit(X[train_indices])
        X_imputed = imputer.transform(X)

        scaler = StandardScaler()
        scaler.fit(X_imputed[train_indices])
        X_scaled = scaler.transform(X_imputed)

        train_mask = torch.zeros(len(self.df), dtype=torch.bool)
        val_mask = torch.zeros(len(self.df), dtype=torch.bool)
        test_mask = torch.zeros(len(self.df), dtype=torch.bool)
        train_mask[train_indices] = True
        val_mask[val_indices] = True
        test_mask[test_indices] = True

        connectivity = kneighbors_graph(X_scaled, n_neighbors=5, mode='connectivity')
        rows, cols = connectivity.nonzero()
        edge_index_np = np.vstack((rows, cols))
        edge_index = torch.from_numpy(edge_index_np).to(torch.long)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        self.data = Data(x=torch.tensor(X_scaled, dtype=torch.float),
                         edge_index=edge_index,
                         y=torch.tensor(y, dtype=torch.float),
                         train_mask=train_mask,
                         val_mask=val_mask,
                         test_mask=test_mask)

        logger.info(f"Graph constructed with {self.data.num_nodes} nodes and {self.data.num_edges} edges")

class AsteroidGNN(nn.Module):
    def __init__(self, num_features):
        super(AsteroidGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 128)
        self.conv3 = GCNConv(128, 128)
        self.attn = GATConv(128, 64, heads=2)
        self.fc = nn.Linear(64 * 2, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.attn(x, edge_index))
        x = self.fc(x)
        return x.squeeze(-1)

class MLP(nn.Module):
    def __init__(self, input_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x.squeeze(-1)

class GATModel(nn.Module):
    def __init__(self, num_features):
        super(GATModel, self).__init__()
        self.gat1 = GATConv(num_features, 64, heads=2)
        self.gat2 = GATConv(64 * 2, 64, heads=2)
        self.fc = nn.Linear(64 * 2, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.gat1(x, edge_index))
        x = F.relu(self.gat2(x, edge_index))
        x = self.fc(x)
        return x.squeeze(-1)

class AsteroidHazardClassifier:
    def __init__(self, dataset_path, model_type='GNN'):
        logger.info(f"Initializing {model_type} Classifier")
        self.dataset = AsteroidGraphDataset(dataset_path)
        self.model_type = model_type
        num_features = self.dataset.data.x.shape[1]
        self.model = AsteroidGNN(num_features) if model_type == 'GNN' else GATModel(num_features)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=1e-4)
        self.criterion = FocalLoss(alpha=0.75, gamma=2)

    def train(self, epochs=config['epochs'], patience=config['patience']):
        logger.info(f"Starting {self.model_type} training for {epochs} epochs")
        best_val_loss = float('inf')
        patience_counter = 0
        best_model_state = None
        history = []

        epoch_progress = tqdm(range(1, epochs + 1), desc=f"{self.model_type} Training Epochs", position=0)

        for epoch in epoch_progress:
            self.model.train()
            self.optimizer.zero_grad()
            output = self.model(self.dataset.data)
            loss = self.criterion(output[self.dataset.data.train_mask], self.dataset.data.y[self.dataset.data.train_mask])
            loss.backward()
            self.optimizer.step()

            self.model.eval()
            with torch.no_grad():
                val_output = self.model(self.dataset.data)
                val_loss = self.criterion(val_output[self.dataset.data.val_mask], self.dataset.data.y[self.dataset.data.val_mask])

            history.append({'epoch': epoch, 'loss': loss.item(), 'val_loss': val_loss.item()})
            epoch_progress.set_postfix({'Loss': f'{loss.item():.4f}', 'Val Loss': f'{val_loss.item():.4f}'})

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logger.info(f"Early stopping at epoch {epoch}")
                    break

        self.model.load_state_dict(best_model_state)
        history_df = pd.DataFrame(history)
        history_df.to_csv(f'training_history_{self.model_type.lower()}.csv', index=False)
        logger.info(f"{self.model_type} Training completed")

    def evaluate(self):
        logger.info(f"Starting {self.model_type} model evaluation")
        self.model.eval()
        with torch.no_grad():
            output = self.model(self.dataset.data)
            val_output = output[self.dataset.data.val_mask].cpu().numpy()
            val_y = self.dataset.data.y[self.dataset.data.val_mask].cpu().numpy()

            precision, recall, thresholds = precision_recall_curve(val_y, val_output)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            optimal_idx = np.argmax(f1_scores)
            optimal_threshold = thresholds[optimal_idx]
            logger.info(f"Optimal threshold for {self.model_type}: {optimal_threshold}")

            test_output = output[self.dataset.data.test_mask].cpu().numpy()
            test_y = self.dataset.data.y[self.dataset.data.test_mask].cpu().numpy()
            test_pred = (test_output > optimal_threshold).astype(int)

            logger.info(f"{self.model_type} Classification Report:")
            report = classification_report(test_y, test_pred, output_dict=True)
            print(classification_report(test_y, test_pred))

            cm = confusion_matrix(test_y, test_pred)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
            plt.title(f'{self.model_type} Confusion Matrix', fontsize=10)
            plt.xlabel('Predicted', fontsize=10)
            plt.ylabel('Actual', fontsize=10)
            plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
            plt.tight_layout()
            plt.savefig(f'confusion_matrix_{self.model_type.lower()}.png', dpi=config['save_dpi'])
            plt.close()
            logger.info(f"{self.model_type} Confusion matrix saved")

            fpr, tpr, _ = roc_curve(test_y, test_output)
            roc_auc = roc_auc_score(test_y, test_output)
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, color='dodgerblue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], color='crimson', lw=2, linestyle='--')
            plt.xlabel('False Positive Rate', fontsize=10)
            plt.ylabel('True Positive Rate', fontsize=10)
            plt.title(f'{self.model_type} ROC Curve', fontsize=10)
            plt.legend(loc="lower right", fontsize=10)
            plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
            plt.tight_layout()
            plt.savefig(f'roc_curve_{self.model_type.lower()}.png', dpi=config['save_dpi'])
            plt.close()
            logger.info(f"{self.model_type} ROC curve saved")

            return report, test_output

def train_mlp(X_train, y_train, X_val, y_val, epochs=config['epochs'], patience=config['patience']):
    logger.info("Starting MLP training")
    input_dim = X_train.shape[1]
    model = MLP(input_dim)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    criterion = FocalLoss(alpha=0.75, gamma=2)

    history = []
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    epoch_progress = tqdm(range(1, epochs + 1), desc="MLP Training Epochs", position=0)

    for epoch in epoch_progress:
        model.train()
        optimizer.zero_grad()
        output = model(X_train)
        loss = criterion(output, y_train)
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            val_output = model(X_val)
            val_loss = criterion(val_output, y_val)

        history.append({'epoch': epoch, 'loss': loss.item(), 'val_loss': val_loss.item()})
        epoch_progress.set_postfix({'Loss': f'{loss.item():.4f}', 'Val Loss': f'{val_loss.item():.4f}'})

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch}")
                break

    model.load_state_dict(best_model_state)
    history_df = pd.DataFrame(history)
    history_df.to_csv('training_history_mlp.csv', index=False)
    logger.info("MLP Training completed")
    return model

def evaluate_mlp(model, X_val, y_val, X_test, y_test):
    logger.info("Starting MLP model evaluation")
    model.eval()
    with torch.no_grad():
        val_output = model(X_val).cpu().numpy()
        test_output = model(X_test).cpu().numpy()

    precision, recall, thresholds = precision_recall_curve(y_val.cpu().numpy(), val_output)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    logger.info(f"Optimal threshold for MLP: {optimal_threshold}")

    test_pred = (test_output > optimal_threshold).astype(int)
    logger.info("MLP Classification Report:")
    report = classification_report(y_test.cpu().numpy(), test_pred, output_dict=True)
    print(classification_report(y_test.cpu().numpy(), test_pred))

    cm = confusion_matrix(y_test.cpu().numpy(), test_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title('MLP Confusion Matrix', fontsize=10)
    plt.xlabel('Predicted', fontsize=10)
    plt.ylabel('Actual', fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig('confusion_matrix_mlp.png', dpi=config['save_dpi'])
    plt.close()
    logger.info("MLP Confusion matrix saved")

    fpr, tpr, _ = roc_curve(y_test.cpu().numpy(), test_output)
    roc_auc = roc_auc_score(y_test.cpu().numpy(), test_output)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='dodgerblue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='crimson', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate', fontsize=10)
    plt.ylabel('True Positive Rate', fontsize=10)
    plt.title('MLP ROC Curve', fontsize=10)
    plt.legend(loc="lower right", fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig('roc_curve_mlp.png', dpi=config['save_dpi'])
    plt.close()
    logger.info("MLP ROC curve saved")

    return report, test_output

def train_iforest(X_train, contamination=0.0022):
    logger.info("Starting iForest training")
    model = IsolationForest(contamination=contamination, random_state=42)
    model.fit(X_train)
    logger.info("iForest training completed")
    return model

def evaluate_iforest(model, X_val, y_val, X_test, y_test):
    logger.info("Starting iForest model evaluation")
    val_scores = -model.decision_function(X_val)
    precision, recall, thresholds = precision_recall_curve(y_val, val_scores)
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    optimal_idx = np.argmax(f1_scores)
    optimal_threshold = thresholds[optimal_idx]
    logger.info(f"Optimal threshold for iForest: {optimal_threshold}")

    test_scores = -model.decision_function(X_test)
    test_pred = (test_scores > optimal_threshold).astype(int)
    logger.info("iForest Classification Report:")
    report = classification_report(y_test, test_pred, output_dict=True)
    print(classification_report(y_test, test_pred))

    cm = confusion_matrix(y_test, test_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title('iForest Confusion Matrix', fontsize=10)
    plt.xlabel('Predicted', fontsize=10)
    plt.ylabel('Actual', fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig('confusion_matrix_iforest.png', dpi=config['save_dpi'])
    plt.close()
    logger.info("iForest Confusion matrix saved")

    fpr, tpr, _ = roc_curve(y_test, test_scores)
    roc_auc = roc_auc_score(y_test, test_scores)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='dodgerblue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='crimson', lw=2, linestyle='--')
    plt.xlabel('False Positive Rate', fontsize=10)
    plt.ylabel('True Positive Rate', fontsize=10)
    plt.title('iForest ROC Curve', fontsize=10)
    plt.legend(loc="lower right", fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig('roc_curve_iforest.png', dpi=config['save_dpi'])
    plt.close()
    logger.info("iForest ROC curve saved")

    return report, test_scores

def plot_training_history(history_path, model_name, save_dpi=config['save_dpi']):
    history = pd.read_csv(history_path)
    plt.figure(figsize=(10, 6))
    plt.plot(history['epoch'], history['loss'], label='Training Loss', color='dodgerblue')
    plt.plot(history['epoch'], history['val_loss'], label='Validation Loss', color='crimson')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel('Loss', fontsize=10)
    plt.title(f'{model_name} Training and Validation Loss', fontsize=10)
    plt.legend(fontsize=10)

    final_loss = history['loss'].iloc[-1]
    final_val_loss = history['val_loss'].iloc[-1]
    plt.annotate(f'Final Train Loss: {final_loss:.4f}', xy=(history['epoch'].iloc[-1], final_loss), xytext=(10, 10), textcoords='offset points', arrowprops=dict(arrowstyle='->'), fontsize=8)
    plt.annotate(f'Final Val Loss: {final_val_loss:.4f}', xy=(history['epoch'].iloc[-1], final_val_loss), xytext=(10, -10), textcoords='offset points', arrowprops=dict(arrowstyle='->'), fontsize=8)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig(f'training_history_{model_name.lower()}.png', dpi=save_dpi)
    plt.close()

def plot_comparative_curves(y_test, outputs, curve_type='roc'):
    plt.figure(figsize=(10, 8))
    colors = ['dodgerblue', 'crimson', 'orange', 'green']
    for (model_name, test_output), color in zip(outputs.items(), colors):
        if curve_type == 'roc':
            fpr, tpr, _ = roc_curve(y_test, test_output)
            auc = roc_auc_score(y_test, test_output)
            plt.plot(fpr, tpr, label=f'{model_name} (AUC = {auc:.2f})', color=color)
        else:
            precision, recall, _ = precision_recall_curve(y_test, test_output)
            plt.plot(recall, precision, label=model_name, color=color)

    if curve_type == 'roc':
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlabel('False Positive Rate', fontsize=10)
        plt.ylabel('True Positive Rate', fontsize=10)
        plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=10)
    else:
        plt.xlabel('Recall', fontsize=10)
        plt.ylabel('Precision', fontsize=10)
        plt.title('Precision-Recall Curves', fontsize=10)

    plt.legend(loc="lower right" if curve_type == 'roc' else "lower left", fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig(f'{curve_type}_curves.png', dpi=config['save_dpi'])
    plt.close()

def main():
    dataset_path = f"{sakhawat18_asteroid_dataset_path}/dataset.csv"
    dataset = AsteroidGraphDataset(dataset_path)

    X = dataset.data.x
    y = dataset.data.y
    train_mask = dataset.data.train_mask
    val_mask = dataset.data.val_mask
    test_mask = dataset.data.test_mask

    X_train = X[train_mask]
    y_train = y[train_mask]
    X_val = X[val_mask]
    y_val = y[val_mask]
    X_test = X[test_mask]
    y_test = y[test_mask]

    # Train and evaluate GNN
    gnn_classifier = AsteroidHazardClassifier(dataset_path, model_type='GNN')
    gnn_classifier.train(epochs=config['epochs'], patience=config['patience'])
    plot_training_history('training_history_gnn.csv', 'GNN')
    gnn_report, gnn_test_output = gnn_classifier.evaluate()

    # Train and evaluate MLP
    mlp_model = train_mlp(X_train, y_train, X_val, y_val)
    plot_training_history('training_history_mlp.csv', 'MLP')
    mlp_report, mlp_test_output = evaluate_mlp(mlp_model, X_val, y_val, X_test, y_test)

    # Train and evaluate GAT
    gat_classifier = AsteroidHazardClassifier(dataset_path, model_type='GAT')
    gat_classifier.train(epochs=config['epochs'], patience=config['patience'])
    plot_training_history('training_history_gat.csv', 'GAT')
    gat_report, gat_test_output = gat_classifier.evaluate()

    # Train and evaluate iForest
    iforest_model = train_iforest(X_train.cpu().numpy(), contamination=0.0022)
    iforest_report, iforest_test_output = evaluate_iforest(iforest_model, X_val.cpu().numpy(), y_val.cpu().numpy(), X_test.cpu().numpy(), y_test.cpu().numpy())

    # Compare models
    reports = {
        'GNN': gnn_report,
        'MLP': mlp_report,
        'GAT': gat_report,
        'iForest': iforest_report
    }
    outputs = {
        'GNN': gnn_test_output,
        'MLP': mlp_test_output,
        'GAT': gat_test_output,
        'iForest': iforest_test_output
    }

    table = [['Model', 'Precision', 'Recall', 'F1-Score', 'Support']]
    for model_name, report in reports.items():
        metrics = report['1.0']
        table.append([model_name, f"{metrics['precision']:.2f}", f"{metrics['recall']:.2f}", f"{metrics['f1-score']:.2f}", int(metrics['support'])])

    logger.info("Model Performance Comparison:")
    print(tabulate(table, headers='firstrow', tablefmt='grid'))

    # Plot comparative curves
    plot_comparative_curves(y_test.cpu().numpy())

In [None]:
if __name__ == "__main__":
    main()

GNN Training Epochs: 100%|██████████| 100/100 [44:20<00:00, 26.61s/it, Loss=0.0012, Val Loss=0.0012]


              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00    187308
         1.0       0.28      0.37      0.32       413

    accuracy                           1.00    187721
   macro avg       0.64      0.68      0.66    187721
weighted avg       1.00      1.00      1.00    187721



MLP Training Epochs: 100%|██████████| 100/100 [00:40<00:00,  2.47it/s, Loss=0.0015, Val Loss=0.0016]


              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00    187308
         1.0       0.20      0.41      0.27       413

    accuracy                           1.00    187721
   macro avg       0.60      0.71      0.64    187721
weighted avg       1.00      1.00      1.00    187721



GAT Training Epochs: 100%|██████████| 100/100 [29:36<00:00, 17.76s/it, Loss=0.0022, Val Loss=0.0023]


              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00    187308
         1.0       0.11      0.17      0.14       413

    accuracy                           1.00    187721
   macro avg       0.56      0.58      0.57    187721
weighted avg       1.00      1.00      1.00    187721

              precision    recall  f1-score   support

         0.0       1.00      0.97      0.98    187308
         1.0       0.04      0.57      0.08       413

    accuracy                           0.97    187721
   macro avg       0.52      0.77      0.53    187721
weighted avg       1.00      0.97      0.98    187721

+---------+-------------+----------+------------+-----------+
| Model   |   Precision |   Recall |   F1-Score |   Support |
| GNN     |        0.28 |     0.37 |       0.32 |       413 |
+---------+-------------+----------+------------+-----------+
| MLP     |        0.2  |     0.41 |       0.27 |       413 |
+---------+-------------+----------+

TypeError: plot_comparative_curves() missing 1 required positional argument: 'outputs'

In [None]:
!pip install imblearn -q

In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging
from tabulate import tabulate
from imblearn.over_sampling import SMOTE
from sklearn.neighbors import NearestNeighbors, kneighbors_graph
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, roc_curve, roc_auc_score
import os

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Configuration
config = {
    'epochs': 100,
    'patience': 10,
    'save_dpi': 300,
    'display_dpi': 120
}

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets, weights=None):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        F_loss = alpha_t * (1 - pt) ** self.gamma * BCE_loss
        if weights is not None:
            F_loss = F_loss * weights
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

class AsteroidGraphDataset:
    def __init__(self, filepath):
        logger.info(f"Loading dataset from {filepath}")
        self.df = pd.read_csv(filepath, low_memory=False)
        self.df = self.df.dropna(subset=['pha'])
        logger.info(f"After dropping rows with missing 'pha', dataset size: {len(self.df)}")

        features = ['e', 'a', 'i', 'q', 'tp', 'H', 'diameter', 'albedo', 'moid_ld']
        X = self.df[features].astype(float).values
        y = (self.df['pha'] == 'Y').astype(int).values

        indices = np.arange(len(self.df))
        train_val_indices, test_indices = train_test_split(
            indices, test_size=0.2, stratify=y, random_state=42
        )
        train_indices, val_indices = train_test_split(
            train_val_indices, test_size=0.2, stratify=y[train_val_indices], random_state=42
        )

        imputer = SimpleImputer(strategy='median')
        imputer.fit(X[train_indices])
        X_imputed = imputer.transform(X)

        scaler = StandardScaler()
        scaler.fit(X_imputed[train_indices])
        X_scaled = scaler.transform(X_imputed)

        train_mask = torch.zeros(len(self.df), dtype=torch.bool)
        val_mask = torch.zeros(len(self.df), dtype=torch.bool)
        test_mask = torch.zeros(len(self.df), dtype=torch.bool)
        train_mask[train_indices] = True
        val_mask[val_indices] = True
        test_mask[test_indices] = True

        connectivity = kneighbors_graph(X_scaled, n_neighbors=5, mode='connectivity')
        rows, cols = connectivity.nonzero()
        edge_index_np = np.vstack((rows, cols))
        edge_index = torch.from_numpy(edge_index_np).to(torch.long)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        self.data = Data(x=torch.tensor(X_scaled, dtype=torch.float),
                         edge_index=edge_index,
                         y=torch.tensor(y, dtype=torch.float),
                         train_mask=train_mask,
                         val_mask=val_mask,
                         test_mask=test_mask)

        logger.info(f"Graph constructed with {self.data.num_nodes} nodes and {self.data.num_edges} edges")

class AsteroidGNN(nn.Module):
    def __init__(self, num_features):
        super(AsteroidGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 128)
        self.conv2 = GCNConv(128, 128)
        self.conv3 = GCNConv(128, 128)
        self.attn = GATConv(128, 64, heads=2)
        self.fc = nn.Linear(64 * 2, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.attn(x, edge_index))
        x = self.fc(x)
        return x.squeeze(-1)

def augment_with_smote(data, sampling_strategy=0.1, k_neighbors=5):
    X_train = data.x[data.train_mask].cpu().numpy()
    y_train = data.y[data.train_mask].cpu().numpy()
    smote = SMOTE(sampling_strategy=sampling_strategy, k_neighbors=k_neighbors, random_state=42)
    X_synthetic, y_synthetic = smote.fit_resample(X_train, y_train)
    X_synthetic = X_synthetic[len(X_train):]
    y_synthetic = y_synthetic[len(y_train):]
    X_synthetic = torch.tensor(X_synthetic, dtype=torch.float)
    y_synthetic = torch.tensor(y_synthetic, dtype=torch.float)
    num_synthetic = X_synthetic.shape[0]
    original_num_nodes = data.x.shape[0]
    data.x = torch.cat([data.x, X_synthetic], dim=0)
    data.y = torch.cat([data.y, y_synthetic], dim=0)
    data.train_mask = torch.cat([data.train_mask, torch.ones(num_synthetic, dtype=torch.bool)], dim=0)
    data.val_mask = torch.cat([data.val_mask, torch.zeros(num_synthetic, dtype=torch.bool)], dim=0)
    data.test_mask = torch.cat([data.test_mask, torch.zeros(num_synthetic, dtype=torch.bool)], dim=0)
    nn = NearestNeighbors(n_neighbors=k_neighbors)
    nn.fit(data.x[:original_num_nodes].cpu().numpy())
    distances, indices = nn.kneighbors(X_synthetic.cpu().numpy())
    edge_list = []
    for i, orig_indices in enumerate(indices):
        syn_idx = original_num_nodes + i
        for orig_idx in orig_indices:
            edge_list.append([syn_idx, orig_idx])
            edge_list.append([orig_idx, syn_idx])
    edge_list = np.array(edge_list).T
    edge_list = torch.tensor(edge_list, dtype=torch.long)
    data.edge_index = torch.cat([data.edge_index, edge_list], dim=1)
    return data

class AsteroidHazardClassifier:
    def __init__(self, dataset_path, imbalance_method='original', smote_strategy=0.1, focal_alpha=0.75, focal_gamma=2, models_dir='/content/output/models', output_dir='/content/output', plots_dir='/content/output/plots'):
        logger.info(f"Initializing GNN Classifier with imbalance method {imbalance_method}")
        self.dataset = AsteroidGraphDataset(dataset_path)
        self.imbalance_method = imbalance_method
        self.smote_strategy = smote_strategy
        self.models_dir = models_dir
        self.output_dir = output_dir
        self.plots_dir = plots_dir
        num_features = len(['e', 'a', 'i', 'q', 'tp', 'H', 'diameter', 'albedo', 'moid_ld'])
        self.model = AsteroidGNN(num_features)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=1e-4)
        self.criterion = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)

        if imbalance_method == 'smote':
            self.data = augment_with_smote(self.dataset.data.clone(), sampling_strategy=smote_strategy)
        else:
            self.data = self.dataset.data

        if imbalance_method == 'weighted':
            y_train = self.data.y[self.data.train_mask]
            num_pos = (y_train == 1).sum().item()
            num_neg = (y_train == 0).sum().item()
            total = len(y_train)
            self.class_weight = {
                0: total / (2 * num_neg),
                1: total / (2 * num_pos)
            }
        else:
            self.class_weight = None

    def train(self, epochs=config['epochs'], patience=config['patience']):
        logger.info(f"Starting GNN training for {epochs} epochs")
        best_val_loss = float('inf')
        patience_counter = 0
        best_model_state = None
        history = []

        epoch_progress = tqdm(range(1, epochs + 1), desc=f"GNN {self.imbalance_method} Training Epochs", position=0)

        for epoch in epoch_progress:
            self.model.train()
            self.optimizer.zero_grad()
            output = self.model(self.data)
            if self.imbalance_method == 'weighted':
                weights = torch.tensor([self.class_weight[int(y)] for y in self.data.y[self.data.train_mask]], device=self.data.x.device)
                loss = self.criterion(output[self.data.train_mask], self.data.y[self.data.train_mask], weights=weights)
            else:
                loss = self.criterion(output[self.data.train_mask], self.data.y[self.data.train_mask])
            loss.backward()
            self.optimizer.step()

            self.model.eval()
            with torch.no_grad():
                val_output = self.model(self.data)
                logger.info(f"val_output shape: {val_output.shape}, val_mask shape: {self.data.val_mask.shape}")
                val_loss = self.criterion(val_output[self.data.val_mask], self.data.y[self.data.val_mask])

            history.append({'epoch': epoch, 'loss': loss.item(), 'val_loss': val_loss.item()})
            epoch_progress.set_postfix({'Loss': f'{loss.item():.4f}', 'Val Loss': f'{val_loss.item():.4f}'})

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logger.info(f"Early stopping at epoch {epoch}")
                    break

        self.model.load_state_dict(best_model_state)
        torch.save(best_model_state, os.path.join(self.models_dir, f'gnn_{self.imbalance_method}.pth'))
        history_df = pd.DataFrame(history)
        history_df.to_csv(os.path.join(self.output_dir, f'training_history_gnn_{self.imbalance_method}.csv'), index=False)
        logger.info(f"GNN Training completed")

    def evaluate(self):
        logger.info(f"Starting GNN model evaluation")
        self.model.eval()
        with torch.no_grad():
            output = self.model(self.data)
            val_output = output[self.data.val_mask].cpu().numpy()
            val_y = self.data.y[self.data.val_mask].cpu().numpy()

            precision, recall, thresholds = precision_recall_curve(val_y, val_output)
            f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
            optimal_idx = np.argmax(f1_scores)
            optimal_threshold = thresholds[optimal_idx]
            logger.info(f"Optimal threshold for GNN: {optimal_threshold}")

            test_output = output[self.data.test_mask].cpu().numpy()
            test_y = self.data.y[self.data.test_mask].cpu().numpy()
            test_pred = (test_output > optimal_threshold).astype(int)

            logger.info(f"GNN Classification Report:")
            report = classification_report(test_y, test_pred, output_dict=True)
            print(classification_report(test_y, test_pred))

            cm = confusion_matrix(test_y, test_pred)
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
            plt.title(f'GNN {self.imbalance_method} Confusion Matrix', fontsize=10)
            plt.xlabel('Predicted', fontsize=10)
            plt.ylabel('Actual', fontsize=10)
            plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, f'confusion_matrix_gnn_{self.imbalance_method}.png'), dpi=config['save_dpi'])
            plt.close()
            logger.info(f"GNN Confusion matrix saved")

            fpr, tpr, _ = roc_curve(test_y, test_output)
            roc_auc = roc_auc_score(test_y, test_output)
            plt.figure(figsize=(8, 6))
            plt.plot(fpr, tpr, color='dodgerblue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
            plt.plot([0, 1], [0, 1], color='crimson', lw=2, linestyle='--')
            plt.xlabel('False Positive Rate', fontsize=10)
            plt.ylabel('True Positive Rate', fontsize=10)
            plt.title(f'GNN {self.imbalance_method} ROC Curve', fontsize=10)
            plt.legend(loc="lower right", fontsize=10)
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, f'roc_curve_gnn_{self.imbalance_method}.png'), dpi=config['save_dpi'])
            plt.close()
            logger.info(f"GNN ROC curve saved")

            return report, test_output

def plot_training_history(history_path, model_name, save_dpi=config['save_dpi'], plots_dir='/content/output/plots'):
    history = pd.read_csv(history_path)
    plt.figure(figsize=(10, 6))
    plt.plot(history['epoch'], history['loss'], label='Training Loss', color='dodgerblue')
    plt.plot(history['epoch'], history['val_loss'], label='Validation Loss', color='crimson')
    plt.xlabel('Epoch', fontsize=10)
    plt.ylabel('Loss', fontsize=10)
    plt.title(f'{model_name} Training and Validation Loss', fontsize=10)
    plt.legend(fontsize=10)

    final_loss = history['loss'].iloc[-1]
    final_val_loss = history['val_loss'].iloc[-1]
    plt.annotate(f'Final Train Loss: {final_loss:.4f}', xy=(history['epoch'].iloc[-1], final_loss), xytext=(10, 10), textcoords='offset points', arrowprops=dict(arrowstyle='->'), fontsize=8)
    plt.annotate(f'Final Val Loss: {final_val_loss:.4f}', xy=(history['epoch'].iloc[-1], final_val_loss), xytext=(10, -10), textcoords='offset points', arrowprops=dict(arrowstyle='->'), fontsize=8)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f'training_history_{model_name.lower().replace(" ", "_")}.png'), dpi=save_dpi)
    plt.close()

def plot_comparative_curves(y_test, outputs, curve_type='roc', plots_dir='/content/output/plots'):
    plt.figure(figsize=(10, 8))
    colors = ['dodgerblue', 'crimson', 'orange']
    for (model_name, test_output), color in zip(outputs.items(), colors):
        if curve_type == 'roc':
            fpr, tpr, _ = roc_curve(y_test, test_output)
            auc = roc_auc_score(y_test, test_output)
            plt.plot(fpr, tpr, label=f'{model_name} (AUC = {auc:.2f})', color=color)
        elif curve_type == 'pr':
            precision, recall, _ = precision_recall_curve(y_test, test_output)
            plt.plot(recall, precision, label=model_name, color=color)

    if curve_type == 'roc':
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlabel('False Positive Rate', fontsize=10)
        plt.ylabel('True Positive Rate', fontsize=10)
        plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=10)
    else:
        plt.xlabel('Recall', fontsize=10)
        plt.ylabel('Precision', fontsize=10)
        plt.title('Precision-Recall Curves', fontsize=10)

    plt.legend(loc="lower right" if curve_type == 'roc' else "lower left", fontsize=10)
    plt.text(0.5, -0.1, 'Data Source: NASA JPL Small-Body Database', ha='center', va='center', transform=plt.gca().transAxes, fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, f'{curve_type}_curves.png'), dpi=config['save_dpi'])
    plt.close()

def main():
    output_dir = '/content/output'
    models_dir = os.path.join(output_dir, 'models')
    plots_dir = os.path.join(output_dir, 'plots')
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(plots_dir, exist_ok=True)

    dataset_path = f"{sakhawat18_asteroid_dataset_path}/dataset.csv"
    dataset = AsteroidGraphDataset(dataset_path)

    # Train SMOTE GNN
    gnn_smote = AsteroidHazardClassifier(dataset_path, imbalance_method='smote', smote_strategy=0.1, models_dir=models_dir, output_dir=output_dir, plots_dir=plots_dir)
    gnn_smote.train()
    plot_training_history(os.path.join(output_dir, 'training_history_gnn_smote.csv'), 'GNN SMOTE', plots_dir=plots_dir)
    gnn_smote_report, gnn_smote_test_output = gnn_smote.evaluate()

    # Train adjusted Focal Loss GNN
    gnn_adjusted_focal = AsteroidHazardClassifier(dataset_path, imbalance_method='original', focal_alpha=0.9, focal_gamma=3, models_dir=models_dir, output_dir=output_dir, plots_dir=plots_dir)
    gnn_adjusted_focal.train()
    plot_training_history(os.path.join(output_dir, 'training_history_gnn_adjusted_focal.csv'), 'GNN Adjusted Focal', plots_dir=plots_dir)
    gnn_adjusted_focal_report, gnn_adjusted_focal_test_output = gnn_adjusted_focal.evaluate()

    # Train weighted GNN
    gnn_weighted = AsteroidHazardClassifier(dataset_path, imbalance_method='weighted', models_dir=models_dir, output_dir=output_dir, plots_dir=plots_dir)
    gnn_weighted.train()
    plot_training_history(os.path.join(output_dir, 'training_history_gnn_weighted.csv'), 'GNN Weighted', plots_dir=plots_dir)
    gnn_weighted_report, gnn_weighted_test_output = gnn_weighted.evaluate()

    # Collect reports and outputs
    reports = {
        'GNN SMOTE': gnn_smote_report,
        'GNN Adjusted Focal': gnn_adjusted_focal_report,
        'GNN Weighted': gnn_weighted_report
    }
    outputs = {
        'GNN SMOTE': gnn_smote_test_output,
        'GNN Adjusted Focal': gnn_adjusted_focal_test_output,
        'GNN Weighted': gnn_weighted_test_output
    }

    # Print comparison table
    data = []
    for model_name, report in reports.items():
        metrics = report['1.0']
        data.append([model_name, metrics['precision'], metrics['recall'], metrics['f1-score'], metrics['support']])
    df = pd.DataFrame(data, columns=['Model', 'Precision', 'Recall', 'F1-Score', 'Support'])
    logger.info("Model Performance Comparison:")
    print(df)

    # Plot comparative curves
    y_test = dataset.data.y[dataset.data.test_mask].cpu().numpy()
    plot_comparative_curves(y_test, outputs, curve_type='roc', plots_dir=plots_dir)
    plot_comparative_curves(y_test, outputs, curve_type='pr', plots_dir=plots_dir)

if __name__ == '__main__':
    main()

GNN smote Training Epochs: 100%|██████████| 100/100 [48:38<00:00, 29.19s/it, Loss=0.0011, Val Loss=0.0010]


              precision    recall  f1-score   support

         0.0       1.00      0.99      1.00    187308
         1.0       0.24      0.78      0.37       413

    accuracy                           0.99    187721
   macro avg       0.62      0.89      0.68    187721
weighted avg       1.00      0.99      1.00    187721



GNN original Training Epochs: 100%|██████████| 100/100 [45:47<00:00, 27.48s/it, Loss=0.0003, Val Loss=0.0003]


FileNotFoundError: [Errno 2] No such file or directory: '/content/output/training_history_gnn_adjusted_focal.csv'