### Resistance Prediction Notebook  

This notebook contains all the code related to the **Inference** step of the pipeline. Each cell includes a detailed description of its function.  

#### Key Function:
**Prediction of Resistance Probabilities:**  
- Computes resistance probabilities for each **"Klebsiella test sample" - "vendor compound"** pair.  
- Saves a **.csv file per sample** in the `klebsiella_resistance_predictions/{input_type}` folder, containing all predicted resistance probabilities for vendor compounds.  
 

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import os
import numpy as np
import json
import re
from gv_experiments.models.classifier import Residual_AMR_Classifier
import polars as pl

In [None]:
## DATASET CLASSES FOR SAMPLE SPECTRA AND VENDOR COMPOUND EMBEDDINGS ## 

class SampleEmbDataset(Dataset):
    def __init__(self, long_table, spectra_matrix):
        """
        Dataset class for storing sample spectra embeddings.

        Args:
            long_table (pd.DataFrame): DataFrame containing sample information.
            spectra_matrix (np.array): Spectra embeddings for each sample.
        """
        # Convert spectra matrix to a PyTorch tensor
        self.spectra_tensor = torch.tensor(spectra_matrix).float()

        # Create mappings between sample IDs and dataset indices
        sorted_samples = sorted(long_table["sample_id"].unique())
        self.idx2sample = {i: smp for i, smp in enumerate(sorted_samples)}
        self.sample2idx = {smp: i for i, smp in self.idx2sample.items()}

        # Remove unnecessary columns and duplicate entries from the long table
        long_table = long_table.drop(["drug", "response"], axis=1)
        long_table = long_table.drop_duplicates()
        self.long_table = long_table

    def __len__(self):
        # Return the number of unique samples
        return len(self.long_table)

    def __getitem__(self, idx):
        # Retrieve sample ID and corresponding spectra embedding
        sample_id = self.long_table.iloc[idx]["sample_id"]
        spectrum_embedding = self.spectra_tensor[self.sample2idx[sample_id], :]
        return sample_id, spectrum_embedding


class CompoundDataset_Pandas(Dataset):
    def __init__(self, compound_df, drug_emb_type="morgan_1024"):
        """
        Dataset class for storing vendor compound embeddings.

        Args:
            compound_df (pd.DataFrame): DataFrame containing the Morgan fingerprints and Molformer embeddings of the vendor compounds.
            drug_emb_type (str): Type of drug embeddings to use ("morgan_1024" or "molformer").
        """
        self.compound_df = compound_df
        self.drug_emb_type = drug_emb_type

    def __len__(self):
        # Return the number of compounds in the dataset
        return len(self.compound_df)

    def __getitem__(self, idx):
        # Load compound embeddings based on the specified type
        if self.drug_emb_type == "morgan_1024":
            row = self.compound_df.iloc[idx]
            smiles = row['SMILES']
            drug_emb = torch.tensor(json.loads(row['Morgan_Fingerprint']), dtype=torch.float32)

        elif self.drug_emb_type == "molformer":
            row = self.compound_df.iloc[idx]
            smiles = row['SMILES']
            drug_emb = torch.tensor(json.loads(row['Molformer_Predictions']), dtype=torch.float32)

        return smiles, drug_emb

class CompoundDataset(Dataset):
    def __init__(self, compound_df, drug_emb_type="morgan_1024"):
        """
        Dataset class for storing vendor compound embeddings.

        Args:
            compound_df (pl.DataFrame): Polars DataFrame containing the Morgan fingerprints and Molformer embeddings of the vendor compounds.
            drug_emb_type (str): Type of drug embeddings to use ("morgan_1024" or "molformer").
        """
        self.compound_df = compound_df
        self.drug_emb_type = drug_emb_type

    def __len__(self):
        return self.compound_df.height  # Use .height instead of len()

    def __getitem__(self, idx):
        if self.drug_emb_type == "morgan_1024":
            row = self.compound_df.row(idx, named=True)  # Get row as a dictionary
            smiles = row["SMILES"]
            drug_emb = torch.tensor(json.loads(row["Morgan_Fingerprint"]), dtype=torch.float32)

        elif self.drug_emb_type == "molformer":
            row = self.compound_df.row(idx, named=True)
            smiles = row["SMILES"]
            drug_emb = torch.tensor(json.loads(row["Molformer_Predictions"]), dtype=torch.float32)

        return smiles, drug_emb


## HELPER FUNCTIONS TO PROCESS SAMPLE ID FILES ##

def load_samples(tracking_file):
    """
    Load processed sample IDs from a tracking file.

    Args:
        tracking_file (str): Path to the file storing processed sample IDs.

    Returns:
        set: Set of processed sample IDs.
    """
    if os.path.exists(tracking_file):
        with open(tracking_file, 'r') as f:
            return set(f.read().splitlines())
    return set()

def save_processed_sample(tracking_file, sample_id):
    """
    Append a processed sample ID to the tracking file.

    Args:
        tracking_file (str): Path to the file storing processed sample IDs.
        sample_id (str): ID of the processed sample to be saved.
    """
    with open(tracking_file, 'a') as f:
        f.write(f"{sample_id}\n")

def load_excluded_samples(file_path):
    """Load excluded sample IDs from a text file."""
    with open(file_path, 'r') as f:
        return {line.strip() for line in f}


In [None]:
## SETS UP THE 2 DATALOADERS NEEDED: ONE FOR THE KLEBSIELLA TEST SAMPLES AND ONE FOR THE VENDOR COMPOUND EMBEDDINGS ##

spectra_type = "MAE" # raw or MAE
fingerprint_type = "molformer" # molformer or morgan_1024

# Load the vendor compound library 
if fingerprint_type == "molformer":
    compound_df = pl.read_csv("compound_lists/Enamine_Hit_Locator_morgan_molformer.csv")
elif fingerprint_type == "morgan_1024":
    compound_df = pl.read_csv("compound_lists/Enamine_Hit_Locator_with_fingerprints.csv")
else: print("The inputed fingerprint_type is invalid!")

# Load the long-table containing Klebsiella pneumoniae test sample-drug-response triplets
long_table = pd.read_csv("klebsiella_data/klebsiella_test_samples_long_table.csv")

# Load the spectra embeddings for Klebsiella pneumoniae test samples
spectra_matrix = np.load(f'klebsiella_data/klebsiella_test_embeddings_filtered_{spectra_type}.npy')

# Initialize dataset for Klebsiella pneumoniae test samples
species_dataset = SampleEmbDataset(long_table, spectra_matrix)

# Initialize dataset for vendor compounds using Morgan 1024-bit fingerprints or 768-dim Molformer embeddings
#compound_dataset = CompoundDataset(compound_df, drug_emb_type=fingerprint_type)
compound_dataset = CompoundDataset(compound_df, drug_emb_type=fingerprint_type)

# Create a DataLoader for test sample spectra (batch size 1, no shuffling for evaluation)
species_loader = DataLoader(species_dataset, batch_size=1, shuffle=False)

# Create a DataLoader for vendor compound embeddings (batch size 128, no shuffling for evaluation)
compound_loader = DataLoader(compound_dataset, batch_size=128, shuffle=False)

print("Data loaders set")


Data loaders set


In [4]:
## LOADS THE ResMLP MODEL FOR INFERENCE ##  

import json

input_type = "mae_molformer"

if input_type == "raw_morgan":
    config_path= "ABCD/ResultsAndCheckpoints/ABCD/raw_fing/new_loader_rawMS_fing_ABCD_DRIAMS-any_specific/config.json"
    checkpoint_path = ".ABCD/ResultsAndCheckpoints/ABCD/raw_fing/new_loader_rawMS_fing_ABCD_DRIAMS-any_specific/0/lightning_logs/version_0/checkpoints/epoch=99-step=394700.ckpt"

elif input_type == "mae_molformer":
    config_path = "ABCD/ResultsAndCheckpoints/ABCD/MAE_Mol/new_loader_MAE_Mol_ABCD_DRIAMS-any_specific/config.json"
    checkpoint_path = "ABCD/ResultsAndCheckpoints/ABCD/MAE_Mol/new_loader_MAE_Mol_ABCD_DRIAMS-any_specific/0/lightning_logs/version_0/checkpoints/epoch=99-step=394700.ckpt"

else: print("Invalid input type")

with open(config_path, 'r') as f:
    config = json.load(f)

checkpoint = torch.load(checkpoint_path, weights_only=True, map_location="mps")

state_dict = {key.replace("model.", ""): value for key, value in checkpoint['state_dict'].items()}

model = Residual_AMR_Classifier(config)
model.load_state_dict(state_dict)
model.eval()

print("Model loaded, initiating inference")

Model loaded, initiating inference


In [None]:
## RUNS THE INFERENCE STEP: SEE BULLETPOINT 2 FROM MARKDOWN ABOVE ## 

# Define input type and tracking file for processed samples
input_type = "mae_molformer"
tracking_file = f"klebsiella_resistance_predictions/{input_type}/processed_samples.txt"

# Create the necessary folders if they don't exist
if not os.path.exists("klebsiella_resistance_predictions"):
    os.makedirs("klebsiella_resistance_predictions")

if not os.path.exists(f"klebsiella_resistance_predictions/{input_type}"):
    os.makedirs(f"klebsiella_resistance_predictions/{input_type}")


# Load the list of samples that have already been processed
processed_samples = load_samples(tracking_file)

# Load the list of samples that will be excluded 
excluded_samples = load_excluded_samples(f"sample_lists/excluded_samples_{input_type}.txt")

# Iterate through each test sample in the species data loader
for sample_id, spectrum_embedding in species_loader:
    sample_id_str = str(sample_id[0])  # Convert sample ID to string

    # Skip processing if the sample was already processed or belongs to the excluded list
    if sample_id_str in excluded_samples:
        print(f"Sample {sample_id_str} has been excluded during pre-filtering. Skipping.")
        continue

    if sample_id_str in processed_samples:
        print(f"Sample {sample_id_str} already processed. Skipping.")
        continue

    # Prepare the spectrum embedding by removing unnecessary dimensions
    spectrum_embedding = spectrum_embedding.squeeze(0)

    # Initialize a list to store predictions
    predictions = []

    # Iterate through batches of vendor compounds
    for smiles_batch, drug_emb_batch in compound_loader:
        # Get the batch size for the current batch
        batch_size = drug_emb_batch.size(0)

        # Adjust spectrum embedding and species index to match the batch size
        spectrum_batch = spectrum_embedding.expand(batch_size, -1)
        species_idx = torch.full((batch_size,), 172, dtype=torch.float32)  # Assume 172 is the species index

        # Perform model inference to get resistance probability scores
        outputs = model(
            species_idx=species_idx,
            x_spectrum=spectrum_batch,
            dr_tensor=drug_emb_batch
        )

        # Apply the sigmoid function to obtain probabilities
        probabilities = torch.sigmoid(outputs)
        
        # Store predictions
        predictions.extend(probabilities.squeeze().detach().cpu().numpy().tolist())

    # Create a DataFrame with compound SMILES and corresponding predictions
    
    #predictions_df = pd.DataFrame({
    #    'SMILES': compound_df['SMILES'].iloc[:len(predictions)],
    #    'Predictions': predictions
    #})

    predictions_df = pd.DataFrame({
    "SMILES": compound_df["SMILES"].slice(0, len(predictions)),  # Select first len(predictions) rows
    "Predictions": predictions
    })


    # Define the output file path for saving predictions
    output_file = f"klebsiella_resistance_predictions/{input_type}/predictions_sample_{sample_id_str}.csv"

    # Save the predictions to a CSV file
    predictions_df.to_csv(output_file, index=False)

    print(f"Saved predictions for sample {sample_id_str} to {output_file}")

    # Mark the sample as processed by saving it in the tracking file
    save_processed_sample(tracking_file, sample_id_str)

print("Inference completed")


Saved predictions for sample 007893a6-6ff9-4fe9-9b4b-d65fb4f181ca_MALDI2 to klebsiella_resistance_predictions/mae_molformer/predictions_sample_007893a6-6ff9-4fe9-9b4b-d65fb4f181ca_MALDI2.csv
Saved predictions for sample 00b8b1c1-1a90-42df-90cb-63388b4ca1c0_MALDI2 to klebsiella_resistance_predictions/mae_molformer/predictions_sample_00b8b1c1-1a90-42df-90cb-63388b4ca1c0_MALDI2.csv
Sample 02ad184c-5b89-4a88-9713-57f22524868a_3313 has been excluded during pre-filtering. Skipping.
Sample 032d7989-c976-4c24-b809-48d80c4f1571_MALDI2 has been excluded during pre-filtering. Skipping.
Sample 03c5aaad-64e3-4f0b-9bbf-c16f5766cc01_3312 has been excluded during pre-filtering. Skipping.
Saved predictions for sample 04799d35-3076-4565-967c-d770b0f935e3 to klebsiella_resistance_predictions/mae_molformer/predictions_sample_04799d35-3076-4565-967c-d770b0f935e3.csv
Saved predictions for sample 04f11af7-e7a4-40df-9803-f1b246b2076e_3312 to klebsiella_resistance_predictions/mae_molformer/predictions_sample_0