# M3CAD Identification Inference Demo

This notebook demonstrates the inference process for the M3CAD identification model. It covers:
1.  Loading pre-trained models (Antimicrobial, Toxin, etc.).
2.  Loading sample data (Sequences and PDB voxels).
3.  Performing predictions.
4.  Visualizing results.

In [1]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from dataset import GDataset, ADataset
from network import MMPeptide, SEQPeptide

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")





Using device: cuda


In [2]:
# Load Data
# We will use a sample dataset. Ensure you have 'metadata/data_processed.csv' or similar.
# Here we use GDataset which reads from a CSV path.

data_path = 'gendata/v2_filter_r3.csv'  # Example path, adjust if needed
if not os.path.exists(data_path):
    print(f"Data file {data_path} not found. Using a placeholder dataset for demonstration.")
    # You might want to point to an existing CSV or create a dummy one

# Assuming we use GDataset for generated data or ADataset for training data
try:
    test_set = GDataset(path=data_path)
    print(f"Loaded dataset with {len(test_set)} samples.")
except Exception as e:
    print(f"Could not load dataset: {e}")
    test_set = []

if len(test_set) > 0:
    valid_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    print("DataLoader created.")

  0%|                                                                                    | 0/3000 [00:00<?, ?it/s]

  seq = str(row[1])
 14%|██████████▏                                                             | 422/3000 [00:00<00:00, 4218.42it/s]

 36%|█████████████████████████▋                                             | 1087/3000 [00:00<00:00, 5646.68it/s]

 80%|████████████████████████████████████████████████████████▋              | 2396/3000 [00:00<00:00, 9043.60it/s]

100%|███████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 8265.79it/s]

Loaded dataset with 3000 samples.
DataLoader created.





In [3]:
# Load Model
# Adjust number of classes based on task (e.g., 6 for anti, 1 for mic/toxin)
num_classes = 6 
model = MMPeptide(classes=num_classes).to(device)

# Path to weights
weight_path = 'run/anti-mm-mlce1280.00250/model_1.pth'

if os.path.exists(weight_path):
    model.load_state_dict(torch.load(weight_path, map_location=device))
    print(f"Loaded weights from {weight_path}")
else:
    print(f"Warning: Weight file not found at {weight_path}. Using random weights.")

model.eval()

Loaded weights from run/anti-mm-mlce1280.00250/model_1.pth


MMPeptide(
  (v_encoder): ResNet3D(
    (conv1): Conv3d(3, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (max_pool): MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU()
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       

In [4]:
# Inference Loop
results = []

if len(test_set) > 0:
    print("Starting inference...")
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            if i >= 10: break # Limit to 10 samples for demo
            
            voxel, seq, exist_info, index = data
            voxel = voxel.to(device)
            seq = seq.to(device)
            
            # Forward pass
            pred, feature = model((voxel, seq))
            
            # Process result
            pred_vals = pred.cpu().numpy().flatten()
            results.append({
                'Index': index.item() if isinstance(index, torch.Tensor) else index,
                'Info': exist_info[0] if isinstance(exist_info, list) else exist_info,
                'Predictions': pred_vals
            })
    
    print("Inference complete.")
else:
    print("No data to process.")

Starting inference...


Inference complete.


In [5]:
# Display Results
if results:
    df_res = pd.DataFrame(results)
    print(df_res.head())
    
    # Expand predictions if multi-class
    preds = np.stack(df_res['Predictions'].values)
    for c in range(preds.shape[1]):
        df_res[f'Class_{c}'] = preds[:, c]
    
    print("Expanded results:")
    print(df_res.drop(columns=['Predictions']).head())


# Save intermediate results
df_res.to_csv('classification_results.csv', index=False)
print('Classification results saved to classification_results.csv')

                    Index                                               Info  \
0       (EYCVKSYTKFYWNL,)  (127896,EYCVKSYTKFYWNL,0.99937861,0.798733688,...   
1  (RYNPKWFCNFWTCLVTWFN,)  (131315,RYNPKWFCNFWTCLVTWFN,0.998720965,0.7747...   
2       (KYMLKSYTEYYWQI,)  (124445,KYMLKSYTEYYWQI,0.991360021,0.45319928,...   
3       (KPMIRSYMEFWWQI,)  (140972,KPMIRSYMEFWWQI,0.99246768,0.821931983,...   
4       (KYNPKTFCDWWSML,)  (109446,KYNPKTFCDWWSML,0.998968741,0.864479621...   

                                         Predictions  
0  [-8.739532, 0.5637505, -0.5338131, -6.4296203,...  
1  [-9.240669, 2.4427419, -3.3392594, -0.65752065...  
2  [-8.387549, 0.4254674, 0.04264795, -5.8507676,...  
3  [-9.023673, 2.5661252, 0.14928277, -6.693563, ...  
4  [-7.6231074, 1.6117508, 0.43183252, -4.6166224...  
Expanded results:
                    Index                                               Info  \
0       (EYCVKSYTKFYWNL,)  (127896,EYCVKSYTKFYWNL,0.99937861,0.798733688,...   
1  (RYNPKWF

# Regression Inference

This section demonstrates how to use the model for regression tasks, such as predicting Minimum Inhibitory Concentration (MIC) or toxicity levels (if treated as continuous values).

In [6]:
# Initialize Regression Model
# For regression tasks, we typically set classes=1 to get a single continuous output value.
reg_model = SEQPeptide(classes=1).to(device)

# Path to regression weights (Example: MIC prediction model)
# Adjust this path to your trained regression model checkpoint
reg_weight_path = 'run/regression-seq-mse1280.00220/model_1.pth'

if os.path.exists(reg_weight_path):
    reg_model.load_state_dict(torch.load(reg_weight_path, map_location=device))
    print(f"Loaded regression weights from {reg_weight_path}")
else:
    print(f"Warning: Regression weight file not found at {reg_weight_path}. Using random weights for demonstration.")

reg_model.eval()

Loaded regression weights from run/regression-seq-mse1280.00220/model_1.pth


SEQPeptide(
  (q_encoder): SEQ(
    (rnn): Sequential(
      (0): Linear(in_features=50, out_features=50, bias=False)
      (1): ReLU()
      (2): Linear(in_features=50, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=256, bias=True)
    )
    (rnn_fc): Sequential(
      (0): Linear(in_features=512, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (fusion): Linear(in_features=2304, out_features=1, bias=True)
  (vox_fc): Linear(in_features=2048, out_features=1, bias=True)
  (seq_fc): Linear(in_features=256, out_features=1, bias=True)
)

In [7]:
# Run Regression Inference
reg_results = []

if len(test_set) > 0:
    print("Starting regression inference...")
    with torch.no_grad():
        for i, data in enumerate(valid_loader):
            if i >= 10: break # Limit to 10 samples for demo
            
            voxel, seq, exist_info, index = data
            voxel = voxel.to(device)
            seq = seq.to(device)
            
            # Forward pass
            # For regression, the model returns a scalar value per sample
            # Some implementations might return (pred, feature), let's handle that
            output = reg_model((voxel, seq))
            if isinstance(output, tuple):
                pred_reg = output[0]
            else:
                pred_reg = output
            
            # Extract value
            val = pred_reg.item() if pred_reg.numel() == 1 else pred_reg.cpu().numpy().flatten()
            if isinstance(val, np.ndarray) and val.size == 1:
                val = val.item()
                
            reg_results.append({
                'Index': index.item() if isinstance(index, torch.Tensor) else index,
                'Info': exist_info[0] if isinstance(exist_info, list) else exist_info,
                'Predicted Value': val
            })
    
    print("Regression inference complete.")
else:
    print("No data to process.")

Starting regression inference...


Regression inference complete.


In [8]:
# Display Regression Results
if reg_results:
    df_reg = pd.DataFrame(reg_results)
    print(df_reg.head())

# Save intermediate results
df_reg.to_csv('regression_results.csv', index=False)
print('Regression results saved to regression_results.csv')

                    Index                                               Info  \
0       (EYCVKSYTKFYWNL,)  (127896,EYCVKSYTKFYWNL,0.99937861,0.798733688,...   
1  (RYNPKWFCNFWTCLVTWFN,)  (131315,RYNPKWFCNFWTCLVTWFN,0.998720965,0.7747...   
2       (KYMLKSYTEYYWQI,)  (124445,KYMLKSYTEYYWQI,0.991360021,0.45319928,...   
3       (KPMIRSYMEFWWQI,)  (140972,KPMIRSYMEFWWQI,0.99246768,0.821931983,...   
4       (KYNPKTFCDWWSML,)  (109446,KYNPKTFCDWWSML,0.998968741,0.864479621...   

   Predicted Value  
0         0.968370  
1         0.907830  
2         0.923516  
3         0.855975  
4         0.774706  
Regression results saved to regression_results.csv
