# M3CAD Identification Inference Demo

This notebook demonstrates the inference process for the M3CAD identification model. It covers:
1.  Loading pre-trained models (Antimicrobial, Toxin, etc.).
2.  Loading sample data (Sequences and PDB voxels).
3.  Performing predictions.
4.  Visualizing results.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from dataset import GDataset, ADataset
from network import MMPeptide

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load Data
# We will use a sample dataset. Ensure you have 'metadata/data_processed.csv' or similar.
# Here we use GDataset which reads from a CSV path.

data_path = 'gendata/v2_filter_r3.csv'  # Example path, adjust if needed
if not os.path.exists(data_path):
    print(f"Data file {data_path} not found. Using a placeholder dataset for demonstration.")
    # You might want to point to an existing CSV or create a dummy one

# Assuming we use GDataset for generated data or ADataset for training data
try:
    test_set = GDataset(path=data_path)
    print(f"Loaded dataset with {len(test_set)} samples.")
except Exception as e:
    print(f"Could not load dataset: {e}")
    test_set = []

if len(test_set) > 0:
    valid_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    print("DataLoader created.")

In [None]:
# Load Model
# Adjust number of classes based on task (e.g., 6 for anti, 1 for mic/toxin)
num_classes = 6 
model = MMPeptide(classes=num_classes).to(device)

# Path to weights
weight_path = 'run/anti-mm-mlce1280.00250/model_1.pth'

if os.path.exists(weight_path):
    model.load_state_dict(torch.load(weight_path, map_location=device))
    print(f"Loaded weights from {weight_path}")
else:
    print(f"Warning: Weight file not found at {weight_path}. Using random weights.")

model.eval()

In [None]:
# Inference Loop
results = []

if len(test_set) > 0:
    print("Starting inference...")
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            if i >= 10: break # Limit to 10 samples for demo
            
            voxel, seq, exist_info, index = data
            voxel = voxel.to(device)
            seq = seq.to(device)
            
            # Forward pass
            pred, feature = model((voxel, seq))
            
            # Process result
            pred_vals = pred.cpu().numpy().flatten()
            results.append({
                'Index': index.item() if isinstance(index, torch.Tensor) else index,
                'Info': exist_info[0] if isinstance(exist_info, list) else exist_info,
                'Predictions': pred_vals
            })
    
    print("Inference complete.")
else:
    print("No data to process.")

In [None]:
# Display Results
if results:
    df_res = pd.DataFrame(results)
    print(df_res.head())
    
    # Expand predictions if multi-class
    preds = np.stack(df_res['Predictions'].values)
    for c in range(preds.shape[1]):
        df_res[f'Class_{c}'] = preds[:, c]
    
    print("Expanded results:")
    print(df_res.drop(columns=['Predictions']).head())


# Regression Inference

This section demonstrates how to use the model for regression tasks, such as predicting Minimum Inhibitory Concentration (MIC) or toxicity levels (if treated as continuous values).

In [None]:
# Initialize Regression Model
# For regression tasks, we typically set classes=1 to get a single continuous output value.
reg_model = MMPeptide(classes=1).to(device)

# Path to regression weights (Example: MIC prediction model)
# Adjust this path to your trained regression model checkpoint
reg_weight_path = 'run/mic-mm-mse1280.00249/model_1.pth'

if os.path.exists(reg_weight_path):
    reg_model.load_state_dict(torch.load(reg_weight_path, map_location=device))
    print(f"Loaded regression weights from {reg_weight_path}")
else:
    print(f"Warning: Regression weight file not found at {reg_weight_path}. Using random weights for demonstration.")

reg_model.eval()

In [None]:
# Run Regression Inference
reg_results = []

if len(test_set) > 0:
    print("Starting regression inference...")
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            if i >= 10: break # Limit to 10 samples for demo
            
            voxel, seq, exist_info, index = data
            voxel = voxel.to(device)
            seq = seq.to(device)
            
            # Forward pass
            # For regression, the model returns a scalar value per sample
            # Some implementations might return (pred, feature), let's handle that
            output = reg_model((voxel, seq))
            if isinstance(output, tuple):
                pred_reg = output[0]
            else:
                pred_reg = output
            
            # Extract value
            val = pred_reg.item() if pred_reg.numel() == 1 else pred_reg.cpu().numpy().flatten()
            if isinstance(val, np.ndarray) and val.size == 1:
                val = val.item()
                
            reg_results.append({
                'Index': index.item() if isinstance(index, torch.Tensor) else index,
                'Info': exist_info[0] if isinstance(exist_info, list) else exist_info,
                'Predicted Value': val
            })
    
    print("Regression inference complete.")
else:
    print("No data to process.")

In [None]:
# Display Regression Results
if reg_results:
    df_reg = pd.DataFrame(reg_results)
    print(df_reg.head())