## Fluorine Model

In [2]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import numpy, sys
import wandb
from pathlib import Path
import torch.optim as optim

from massspecgym.data import MassSpecDataset, MassSpecDataModule
from massspecgym.data.transforms import SpecTokenizer, MolFingerprinter
from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import MassSpecGymModel
from sklearn.metrics import precision_score, recall_score
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities import grad_norm

from torch import nn
import torch.nn.functional as F
from massspecgym.models.base import Stage
from dreams.api import PreTrainedModel
from dreams.models.dreams.dreams import DreaMS as DreaMSModel
from torchmetrics.classification import BinaryPrecision, BinaryRecall, BinaryAccuracy

import numpy as np

numpy.set_printoptions(threshold=sys.maxsize)
torch.set_float32_matmul_precision('high')


In [None]:
import numpy as np
from rdkit import Chem
from massspecgym.data.transforms import MolToHalogensVector, MolToPFASVector

# Example usage
checker = MolToHalogensVector() # creating an object of type MolToHalogensVector
smiles_string = "CC(F)(F)F"
halogen_vector = checker.from_smiles(smiles_string)
print(halogen_vector)
# Example usage
smiles_string = "CCBr"
halogen_vector = checker.from_smiles(smiles_string)
print(halogen_vector)

checker = MolToPFASVector()
smiles_string = "CC(F)(F)F"
halogen_vector = checker.from_smiles(smiles_string)
print(halogen_vector)

# Example usage
smiles_string = "CCBr"
halogen_vector = checker.from_smiles(smiles_string)
print(halogen_vector)

In [None]:
pl.seed_everything(0)

DEBUG = False

if DEBUG:
    mgf_pth = Path("/teamspace/studios/this_studio/MassSpecGym/data/debug/example_5_spectra.mgf")
    split_pth = Path("/teamspace/studios/this_studio/MassSpecGym/data/debug/example_5_spectra_split.tsv")
else:
    mgf_pth = None
    split_pth = None

# Check if MPS is available, otherwise use CUDA
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
else:
    mps_device = None

In [14]:
# final model containing the network definition

class HalogenDetectorDreamsTest(MassSpecGymModel):
    def __init__(
        self,
        alpha: float=0.25,
        gamma: float=0.5,
        batch_size: int=64,
        threshold: float=0.5,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        if mps_device is not None:
            self.alpha = torch.tensor([1-alpha, alpha], device=mps_device)
        else:
            self.alpha = torch.tensor([1-alpha, alpha]).cuda()
        self.gamma = gamma
        self.batch_size = batch_size
        self.threshold = threshold
        print(f"Training with threshold: {self.threshold}, alpha: {self.alpha}, gamma: {self.gamma}, batch_size: {self.batch_size}")
        
        # Metrics
        self.train_precision = BinaryPrecision()
        self.train_recall = BinaryRecall()
        self.val_precision = BinaryPrecision()
        self.val_recall = BinaryRecall()
        self.train_accuracy = BinaryAccuracy()
        self.val_accuracy = BinaryAccuracy()

        # loading the DreaMS model weights from the internet
        self.main_model = PreTrainedModel.from_ckpt(
            # ckpt_path should be replaced with the path to the ssl_model.ckpt model downloaded from https://zenodo.org/records/10997887
            ckpt_path="https://zenodo.org/records/10997887/files/ssl_model.ckpt?download=1", ckpt_cls=DreaMSModel, n_highest_peaks=60
        ).model.train()
        self.lin_out = nn.Linear(1024, 1) # for F

    def forward(self, x):
        output_main_model = self.main_model(x)[:, 0, :] # to get the precursor peak token embedding 
        fl_probability = F.sigmoid(self.lin_out(output_main_model))
        return fl_probability

    def step(
        self, batch: dict, stage: Stage
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Implement your custom logic of using predictions for training and inference."""
        # Unpack inputs
        x = batch["spec"]  # shape: [batch_size, num_peaks + 1, 2]
        #print("--batch.keys", batch.keys())

        halogen_vector_true = batch["mol"] # shape: [batch_size, 4]

        # Extract the 1st column --> fluorine predictions
        true_values = halogen_vector_true[:, 0] # shape [batch_size]

        # the forward pass
        predicted_probs = self.forward(x) # shape [batch_size x 1]
        
        if DEBUG:
            predicted_probs = predicted_probs[0] # for testing
        else:
            predicted_probs = predicted_probs.squeeze() # shape [batch_size]

        #print("--predicted_probs", predicted_probs)

        ### Focal Loss: https://amaarora.github.io/posts/2020-06-29-FocalLoss.html ### 
        # Increase loss for minority misclassification (F = 1 but predicted as 0) and 
        # decreases loss for majority class misclassification (F = 0 but predicted as 1)
        # Our MassSpecGym training data is skewed with only 5% of molecules containing Fluorine
       
        bce_loss = nn.BCELoss(reduction='none')
        loss = bce_loss(predicted_probs, true_values)
        targets = true_values.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1))
        pt = torch.exp(-loss)
        F_loss = at * (1 - pt)**self.gamma * loss
        return { 'loss': F_loss.mean() } 

    def on_batch_end(
        self, outputs: [], batch: dict, batch_idx: int, stage: Stage
    ) -> None:
        x = batch["spec"] # shape: [batch_size, num_peaks + 1, 2]
        halogen_vector_true = batch["mol"] # shape [batch_size]
        # updated predictions with the updated weights at the end of the batch
        pred_probs = self.forward(x) # shape [batch_size x 1]

        # thresholding
        halogen_vector_pred_binary = torch.where(pred_probs >= self.threshold, 1, 0)

        # Extract the 1st column --> fluorine predictions
        true_labels = halogen_vector_true[:, 0] # shape [batch_size]
        
        # make shape [batch_size x 1] into shape [batch_size]
        pred_bool_labels = halogen_vector_pred_binary.squeeze() # shape [batch_size]

        if stage.to_pref() == 'train_':
            self.train_precision.update(pred_bool_labels, true_labels)
            self.train_recall.update(pred_bool_labels, true_labels)
            self.train_accuracy.update(pred_bool_labels, true_labels)
        elif stage.to_pref() == 'val_':
            self.val_precision.update(pred_bool_labels, true_labels)
            self.val_recall.update(pred_bool_labels, true_labels)
            self.val_accuracy.update(pred_bool_labels, true_labels)

        self.log_dict({ f"{stage.to_pref()}/loss": outputs['loss'] },
                prog_bar=True,
                on_epoch=True,
                batch_size=self.batch_size
        )

    def _reset_metrics_train(self):
        # Reset states for next epoch
        self.train_precision.reset()
        self.train_recall.reset()
        self.train_accuracy.reset()

    def _reset_metrics_val(self):
        # Reset states for next epoch
        self.val_precision.reset()
        self.val_recall.reset()
        self.val_accuracy.reset()
        self.all_predicted_probs = []  # reset the list of predicted probabilities for validation

    def on_train_epoch_start(self) -> None:
        self._reset_metrics_train()

    def on_validation_epoch_start(self) -> None:
        self._reset_metrics_val()

    def on_train_epoch_end(self) -> None:
        precision = self.train_precision.compute()
        recall = self.train_recall.compute()
        accuracy = self.train_accuracy.compute()
        f1_score = (2*precision*recall)/(precision + recall) if (precision + recall) != 0 else 0
        self.log_dict({
                f"train_/precision": precision,
                f"train_/recall": recall,
                f"train_/accuracy": accuracy,
                f"train_/f1_score": f1_score
            },
            prog_bar=True,
            on_epoch=True,
            on_step=False
        )
        
    def on_validation_epoch_end(self) -> None:
        precision = self.val_precision.compute()
        recall = self.val_recall.compute()
        accuracy = self.val_accuracy.compute()
        f1_score = (2*precision*recall)/(precision + recall) if (precision + recall) != 0 else 0
        self.log_dict({
                f"val_/precision": precision,
                f"val_/recall": recall,
                f"val_/accuracy": accuracy,
                f"val_/f1_score": f1_score
            },
            prog_bar=True,
            on_epoch=True,
            on_step=False
        )

In [15]:
# removed adduct due to a str error
class TestMassSpecDataset(MassSpecDataset):

    def __getitem__(
        self, i: int, transform_spec: bool = True, transform_mol: bool = True
    ) -> dict:
        spec = self.spectra[i]
        metadata = self.metadata.iloc[i]
        mol = metadata["smiles"]

        # Apply all transformations to the spectrum
        item = {}
        if transform_spec and self.spec_transform:
            if isinstance(self.spec_transform, dict):
                for key, transform in self.spec_transform.items():
                    item[key] = transform(spec) if transform is not None else spec
            else:
                item["spec"] = self.spec_transform(spec)
        else:
            item["spec"] = spec

        # Apply all transformations to the molecule
        if transform_mol and self.mol_transform:
            if isinstance(self.mol_transform, dict):
                for key, transform in self.mol_transform.items():
                    item[key] = transform(mol) if transform is not None else mol
            else:
                item["mol"] = self.mol_transform(mol)
        else:
            item["mol"] = mol

        # Add other metadata to the item
        item.update({
            k: metadata[k] for k in ["precursor_mz"] # removed adduct due to a str error
        })

        if self.return_mol_freq:
            item["mol_freq"] = metadata["mol_freq"]

        if self.return_identifier:
            item["identifier"] = metadata["identifier"]

        # TODO: this should be refactored
        for k, v in item.items():
            if not isinstance(v, str):
                item[k] = torch.as_tensor(v, dtype=self.dtype)
        
        return item

## Training Code

In [None]:
from pytorch_lightning.loggers import WandbLogger

torch.set_float32_matmul_precision('high')

# Init hyperparameters
max_epochs = 1
n_peaks = 60
threshold = 0.9
alpha = 0.75 # 0.25, 0.5, 0.75, 1 - found 0.25 as best
gamma = 0.75 # 0.25, 0.5, 0.75, 1 - found 0.75 as best
lr = 1e-5
num_iterations = 1

if DEBUG:
    batch_size = 1
else:
    batch_size = 64

for i in range (0, num_iterations):
    # Load dataset
    dataset = TestMassSpecDataset(
        spec_transform=SpecTokenizer(n_peaks=n_peaks),
        mol_transform = MolToPFASVector(),
        pth='/teamspace/studios/this_studio/files/merged_massspec_nist20_with_fold.tsv'
    )

    # Init data module
    data_module = MassSpecDataModule(
        dataset=dataset,
        batch_size=batch_size,
        split_pth=split_pth,
        num_workers=4
    )

    # Init model
    model = HalogenDetectorDreamsTest(
        threshold=threshold,
        alpha=alpha,
        gamma=gamma,
        batch_size=batch_size,
        lr=lr
    )
    # initialise the wandb logger and name your wandb project
    wandb_logger = WandbLogger(project='PFASDetection-MergedMassSpecNIST20-HyperParam')
    # add your batch size to the wandb config
    wandb_logger.experiment.config["batch_size"] = batch_size
    wandb_logger.experiment.config["n_peaks"] = n_peaks
    wandb_logger.experiment.config["threshold"] = threshold
    wandb_logger.experiment.config["alpha"] = alpha
    wandb_logger.experiment.config["gamma"] = gamma

    trainer = Trainer(accelerator="auto", devices="auto", max_epochs=max_epochs, logger=wandb_logger, val_check_interval=0.2)

    # Validate before training
    data_module.prepare_data()  # Explicit call needed for validate before fit
    data_module.setup()  # Explicit call needed for validate before fit
    trainer.validate(model, datamodule=data_module)

    # # Train
    trainer.fit(model, datamodule=data_module)

    # [optional] finish the wandb run, necessary in notebooks
    wandb.finish()

## Detecting Fluorine

In [16]:
from pathlib import Path
from tqdm import tqdm
from dreams.utils.data import MSData
from dreams.api import dreams_predictions, PreTrainedModel
from dreams.models.heads.heads import BinClassificationHead
from dreams.utils.io import append_to_stem


def find_fluorine():

    # in_pth = 'data/teo/<in_file>.mgf'  # or .mzML
    # out_csv_pth = 'data/teo/<in_file>_f_preds.csv'

    # in_pths = list(Path('/scratch/project/open-26-5/DreaMS/data/andrej/fluorine_dataset').glob('*.mzML'))
    # model_ckpt_111k = '/scratch/project/open-26-5/DreaMS/dreams/HAS_F_1.0/Feb2025_8bs_5e-5lr_bce/epoch=30-step=111000.ckpt'
    # model_ckpt_7k = '/scratch/project/open-26-5/DreaMS/dreams/HAS_F_1.0/Feb2025_8bs_5e-5lr_bce/epoch=1-step=7000.ckpt'

    in_pth = Path('/teamspace/studios/this_studio/20250627_pa_flo_nmr_pos_norm.mzML')
    # threshold = 0.9 model
    model_ckpt = '/teamspace/studios/this_studio/HalogenDetection-FocalLoss-MergedMassSpecNIST20/opi4lx8s/checkpoints/epoch=0-step=8920.ckpt'

    n_highest_peaks = 60

    # Load model
    model = HalogenDetectorDreamsTest.load_from_checkpoint(model_ckpt)
    print(model)

    print(f'Processing {in_pth}...')

    # Load data
    try:
        msdata = MSData.load(in_pth, in_mem=True)
    except ValueError as e:
        print(f'Skipping {in_pth} because of {e}.')
        return

    # Compute fluorine probabilties
    df = msdata.to_pandas()
    
    f_preds = dreams_predictions(
        spectra=msdata,
        model_ckpt=model,
        n_highest_peaks=n_highest_peaks
    )
    df[f'F_preds'] = f_preds

        # Store predictions
    df.to_csv(append_to_stem(in_pth, 'F_preds').with_suffix('.csv'), index=False)


In [None]:
find_fluorine()

## Merging NIST20 and MassSpecGym

In [None]:
import pandas as pd

# Replace with your actual file path
file_path = '/teamspace/studios/this_studio/MassSpecGym/NIST20_MoNA_A_all_with_F_Murcko_split_MCE_test_minimum_cols.pkl'

# Load the pickle file
nist20_df = pd.read_pickle(file_path)

# Check the result
nist20_df.info()

In [None]:
nist20_df.head(3)

In [None]:
# Filter rows where the 'ID' starts with "NIST20"
nist20_df = nist20_df[nist20_df['ID'].str.startswith("NIST20")].copy()

nist20_df.info()

In [77]:
from massspecgym.utils import load_massspecgym
massspec_df = load_massspecgym().reset_index()

In [None]:
massspec_df.head(1)

In [None]:
import pandas as pd

# -----------------------------
# STEP 1: Preprocess nist20_df
# -----------------------------
nist20_df = nist20_df.copy()

# Split 'PARSED PEAKS' into two columns
nist20_df['mzs'] = nist20_df['PARSED PEAKS'].apply(lambda x: x[0])
nist20_df['intensities'] = nist20_df['PARSED PEAKS'].apply(lambda x: x[1])

# Build a MassSpec-compatible DataFrame from NIST20
nist20_converted = pd.DataFrame({
    'identifier': nist20_df['ID'],
    'mzs': nist20_df['mzs'],
    'intensities': nist20_df['intensities'],
    'smiles': nist20_df['SMILES'],
    'inchikey': None,  # Not available in NIST20
    'formula': nist20_df['FORMULA'],
    'precursor_formula': nist20_df['FORMULA'],  # Assume it's the same
    'parent_mass': nist20_df['PRECURSOR M/Z'],  # Approximate
    'precursor_mz': nist20_df['PRECURSOR M/Z'],
    'adduct': '[M+H]+',
    'instrument_type': None,
    'collision_energy': None,
    'fold': nist20_df['fold'],
    'simulation_challenge': False  # NIST20 is real, not simulated
})

# -----------------------------
# STEP 2: Normalize MassSpec df
# -----------------------------
expected_columns = [
    'identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula',
    'parent_mass', 'precursor_mz', 'adduct', 'instrument_type',
    'collision_energy', 'fold', 'simulation_challenge'
]

nist20_converted = nist20_converted[expected_columns]
massspec_gym_df = massspec_df.copy()
massspec_gym_df = massspec_gym_df[expected_columns]

# -----------------------------
# STEP 3: Merge the datasets
# -----------------------------
merged_df = pd.concat([massspec_gym_df, nist20_converted], ignore_index=True)

# -----------------------------
# STEP 4: Save merged dataset
# -----------------------------
# Save as TSV
merged_df.to_pickle('merged_massspec_nist20.pkl')

# Check result
print(f"Merged dataset shape: {merged_df.shape}")


## Murcko Histogram Based Training/Validation split

In [3]:
import pandas as pd

# Replace with your actual file path
file_path = '/teamspace/studios/this_studio/files/merged_massspec_nist20.pkl'

# Load the pickle file
df = pd.read_pickle(file_path)

In [None]:
# Check the result
df.tail(2)

In [2]:
# Load the necessary libraries
from rdkit import Chem
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from dreams.algorithms.murcko_hist import murcko_hist
from dreams.utils.data import MSData, evaluate_split
from dreams.utils.plots import init_plotting
from dreams.definitions import *
tqdm.pandas()
%load_ext autoreload
%autoreload 2

In [None]:
hist = murcko_hist.murcko_hist(Chem.MolFromSmiles('O=C(O)[C@@H]1/N=C(\SC1)c2sc3cc(O)ccc3n2'), show_mol_scaffold=True)
print('Murcko histogram:', hist)

In [None]:
df_us = df.drop_duplicates(subset=[SMILES]).copy()  # Uniquify SMILES

# Compute Murcko histograms
df_us['MurckoHist'] = df_us[SMILES].progress_apply(
    lambda x: murcko_hist.murcko_hist(Chem.MolFromSmiles(x))
)

# Convert dictionaries to strings for easier handling
df_us['MurckoHistStr'] = df_us['MurckoHist'].astype(str)

In [None]:
print('Num. unique SMILES:', df_us[SMILES].nunique(), 'Num. unique Murcko histograms:', df_us['MurckoHistStr'].nunique())
print('Top 20 most common Murcko histograms:')
df_us['MurckoHistStr'].value_counts()[:20]

In [None]:
# Group by MurckoHistStr and aggregate
df_gb = df_us.groupby('MurckoHistStr').agg(
    count=(SMILES, 'count'),
    smiles_list=(SMILES, list)
).reset_index()

# Convert MurckoHistStr to MurckoHist
df_gb['MurckoHist'] = df_gb['MurckoHistStr'].apply(eval)

# Sort by 'n' in descending order and reset index
df_gb = df_gb.sort_values('count', ascending=False).reset_index(drop=True)

df_gb

In [None]:
median_i = len(df_gb) // 2
cum_val_mols = 0
val_mols_frac = 0.15  # Approximately 15% of the molecules go to validation set
val_idx, train_idx = [], []

# Iterate from median to start, assigning molecules to train or val sets
for i in range(median_i, -1, -1):
    current_hist = df_gb.iloc[i]['MurckoHist']
    is_val_subhist = any(
        murcko_hist.are_sub_hists(current_hist, df_gb.iloc[j]['MurckoHist'], k=3, d=4)
        for j in val_idx
    )

    if is_val_subhist:
        train_idx.append(i)
    else:
        if cum_val_mols / len(df_us) <= val_mols_frac:
            cum_val_mols += df_gb.iloc[i]['count']
            val_idx.append(i)
        else:
            train_idx.append(i)

# Add remaining indices to train set
train_idx.extend(range(median_i + 1, len(df_gb)))
assert(len(train_idx) + len(val_idx) == len(df_gb))

# Map SMILES to their assigned fold
smiles_to_fold = {}
for i, row in df_gb.iterrows():
    fold = 'val' if i in val_idx else 'train'
    for smiles in row['smiles_list']:
        smiles_to_fold[smiles] = fold
df[FOLD] = df[SMILES].map(smiles_to_fold)

# Display fold distributions
print('Distribution of spectra:')
display(df[FOLD].value_counts(normalize=True))
print('Distribution of SMILES:')
display(df.drop_duplicates(subset=[SMILES])[FOLD].value_counts(normalize=True))

In [None]:
eval_res = evaluate_split(df, n_workers=4)
init_plotting(figsize=(3, 3))
sns.histplot(eval_res['val'], bins=100)
plt.xlabel('Max Tanimoto similarity to training set')
plt.ylabel('Num. validation set molecules')
plt.show()

In [None]:
df.info()
print('Num. unique inchikey:', df['inchikey'].nunique())

In [None]:
df_t = df.groupby('inchikey').agg(
    count=(SMILES, 'count')
).reset_index()

df_t = df_t.sort_values(by='count', ascending=False).reset_index()
df_t

In [None]:
df.head(5)

In [58]:
def remove_zero_peaks(mzs, intensities):
    # Filter out zero values in either mz or intensity
    filtered = [(mz, inten) for mz, inten in zip(mzs, intensities) if mz != 0 and inten != 0]
    
    if not filtered:
        return [], []
    
    # Sort by mz
    filtered.sort(key=lambda x: x[0])
    
    mzs_clean, intensities_clean = zip(*filtered)
    return list(mzs_clean), list(intensities_clean)

# Apply to entire DataFrame
df[['mzs', 'intensities']] = df.apply(
    lambda row: pd.Series(remove_zero_peaks(row['mzs'], row['intensities'])),
    axis=1
)

In [None]:
df.head(5)

In [60]:
# convert mzs and intensities into a comma separated list for serializing to disk
df['mzs'] = df['mzs'].apply(lambda x: ','.join(map(str, x)))
df['intensities'] = df['intensities'].apply(lambda x: ','.join(map(str, x)))

In [None]:
df.head(1)

In [62]:
#df.to_csv('merged_massspec_nist20_with_fold.tsv', sep='\t')

In [None]:
df_t = pd.read_csv('/teamspace/studios/this_studio/files/merged_massspec_nist20_with_fold.tsv', sep='\t')
df_t.head(3)

## Detect PFAS in dataset

In [5]:
from rdkit import Chem

# SMARTS for –CF3 and –CF2– groups (saturated and fully fluorinated)
cf3_smarts = '[CX4](F)(F)F'       # –CF3
cf2_smarts = '[CX4H0](F)(F)'      # –CF2– (not terminal, excludes CF3)

cf3_pattern = Chem.MolFromSmarts(cf3_smarts)
cf2_pattern = Chem.MolFromSmarts(cf2_smarts)

def is_pfas_oecd(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False

        if mol.HasSubstructMatch(cf3_pattern) or mol.HasSubstructMatch(cf2_pattern):
            return True
        else:
            return False
    except Exception:
        return 'Error'



In [None]:
import pandas as pd
from rdkit import Chem

# Example: Load your dataframe
df = pd.read_csv('/teamspace/studios/this_studio/files/merged_massspec_nist20_with_fold.tsv', sep='\t')

# Apply the function
df['is_PFAS'] = df['smiles'].apply(is_pfas_oecd)

# View how many were identified
print(f"Identified {df['is_PFAS'].sum()} potential PFAS compounds out of {len(df)} total.")

# Optionally: get only the PFAS rows
pfas_df = df[df['is_PFAS']]

#df.to_csv('merged_massspec_nist20_with_pfas_fold.tsv', sep='\t')

In [None]:
from rdkit.Chem import Draw
import random as r
import pandas as pd

# read pfas dataset
pfas_df = pd.read_csv('/teamspace/studios/this_studio/files/pfas_only_records.tsv', sep='\t')
unique_pfas_train = pfas_df[pfas_df['fold'] == 'train']['smiles'].unique()
unique_pfas_val = pfas_df[pfas_df['fold'] == 'val']['smiles'].unique()
print(f"Train PFAS = {len(unique_pfas_train)}, Val PFAS = {len(unique_pfas_val)}")

print(f"Drawing a random molecule from train")
smiles_list = unique_pfas_train.tolist()
m = Chem.MolFromSmiles(r.choice(smiles_list))
img = Draw.MolToImage(m)
img

In [None]:
#pfas_df.to_csv('pfas_only_records.csv', sep='\t')
#only_val_df = pfas_df[pfas_df['fold'] == 'val']
#only_val_df.to_csv('pfas_only_records_val.tsv', sep='\t')

In [None]:
# Load dataset
pfas_dataset = TestMassSpecDataset(
    spec_transform=SpecTokenizer(n_peaks=60),
    mol_transform = MolToHalogensVector(),
    pth='/teamspace/studios/this_studio/pfas_only_records.tsv'
)

print(len(pfas_dataset))

# Init data module
pfas_data_module = MassSpecDataModule(
    dataset=pfas_dataset,
    batch_size=64,
    num_workers=1
)
pfas_data_module.setup()

ckpt_path = '/teamspace/studios/this_studio/HalogenDetection-FocalLoss-MergedMassSpecNIST20/opi4lx8s/checkpoints/epoch=0-step=8920.ckpt'
model = HalogenDetectorDreamsTest.load_from_checkpoint(ckpt_path)


trainer = Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.validate(model=model, datamodule=pfas_data_module)

## Playground

In [6]:
import torch

# Fluorine Model
# threshold - 0.75
## /teamspace/studios/this_studio/HalogenDetection-FocalLoss-MergedMassSpecNIST20/6qcft1pp
# threshold - 0.9
## /teamspace/studios/this_studio/HalogenDetection-FocalLoss-MergedMassSpecNIST20/opi4lx8s/checkpoints/epoch=0-step=8920.ckpt
## opi4lx8s - threshold - 0.9


# PFAS Model
# threshold - 0.9
## /teamspace/studios/this_studio/PFASDetection-FocalLoss-MergedMassSpecNIST20/31hkfun1/checkpoints/epoch=0-step=8920.ckpt
## /teamspace/studios/this_studio/PFASDetection-FocalLoss-MergedMassSpecNIST20/kxfsf9c5/checkpoints/epoch=0-step=8920.ckpt
## /teamspace/studios/this_studio/PFASDetection-FocalLoss-MergedMassSpecNIST20/nrw1m4b9/checkpoints/epoch=0-step=8920.ckpt

# Path to your checkpoint file
ckpt_path = '/teamspace/studios/this_studio/HalogenDetection-FocalLoss-MergedMassSpecNIST20/opi4lx8s/checkpoints/epoch=0-step=8920.ckpt'

# Load the checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')

# Print available metadata keys
print("Checkpoint keys:")
print(checkpoint.keys())

# Optionally, display specific metadata if available
if 'state_dict' in checkpoint:
    print(f"Model state_dict {checkpoint['hyper_parameters']}")


Checkpoint keys:
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
Model state_dict {'lr': 1e-05, 'weight_decay': 0.0, 'log_only_loss_at_stages': (<Stage.TRAIN: 'train'>,), 'no_mces_metrics_at_stages': (<Stage.VAL: 'val'>,), 'bootstrap_metrics': False, 'df_test_path': None, 'alpha': 0.25, 'gamma': 0.75, 'batch_size': 64, 'threshold': 0.9}


In [4]:
# read the file merged_massspec_nist20.pkl
# do EDA on nist20 vs masspec data and look at the number of unique inchikeys and smiles
import pandas as pd
df_unique = pd.read_pickle('/teamspace/studios/this_studio/files/merged_massspec_nist20.pkl')

In [None]:
df_unique.head(2)

In [16]:
df_massspec = df_unique[df_unique["identifier"].str.startswith("MassSpecGym")]
df_nist = df_unique[df_unique["identifier"].str.startswith("NIST20")]

In [None]:
num_inchi_ms = df_massspec["inchikey"].nunique()
num_inchi_nist = df_nist["inchikey"].nunique()
num_smil_ms = df_massspec["smiles"].nunique()
num_smil_nist = df_nist["smiles"].nunique()

print(f"num_inchi_ms = {num_inchi_ms}, num_inchi_nist = {num_inchi_nist}, num_smil_ms = {num_smil_ms}, num_smil_nist = {num_smil_nist}")

In [18]:
import numpy as np
import pandas as pd
import itertools
import urllib
import json
import time
import ase
import rdkit
import base64
from io import BytesIO
from tqdm import tqdm
from rdkit import DataStructs, RDLogger
from rdkit.Chem import AllChem as Chem
from rdkit.Chem import rdchem, Draw, rdMolDescriptors, QED, Crippen, Lipinski
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.MACCSkeys import GenMACCSKeys
from rdkit.Contrib.SA_Score import sascorer
from rdkit.Chem.Descriptors import ExactMolWt
from collections import defaultdict
from typing import List, Optional
from pathlib import Path
import dreams.utils.misc as utils


def show_mols(mols, legends='new_indices', smiles_in=None, svg=False, sort_by_legend=False, max_mols=500,
              legend_float_decimals=4, mols_per_row=6, save_pth: Optional[Path] = None):
    """
    Returns svg image representing a grid of skeletal structures of the given molecules

    :param mols: list of rdkit molecules
    :param legends: list of labels for each molecule, length must be equal to the length of mols. 
                   Can be 'new_indices' for default numbering, 'masses' for molecular weights,
                   or a list of custom labels
    :param smiles_in: True - SMILES inputs, False - RDKit mols, None - determine automatically
    :param svg: True - return svg image, False - return png image
    :param sort_by_legend: True - sort molecules by legend values
    :param max_mols: maximum number of molecules to show
    :param legend_float_decimals: number of decimal places to show for float legends
    :param mols_per_row: number of molecules per row to show
    :param save_pth: path to save the .svg image to
    """
    disable_rdkit_log()

    if smiles_in is None:
        smiles_in = all(isinstance(e, str) for e in mols)

    if smiles_in:
        mols = [Chem.MolFromSmiles(e) for e in mols]

    if isinstance(legends, str):
        if legends == 'new_indices':
            legends = list(range(len(mols)))
        elif legends == 'masses':
            legends = [ExactMolWt(m) for m in mols]
    elif callable(legends):
        legends = [legends(e) for e in mols]
    elif isinstance(legends, (list, np.ndarray, pd.Series)):
        legends = [str(l) for l in legends]
    else:
        raise ValueError(f'Invalid legends type: {type(legends)}. Must be a list, numpy array, pandas series or'
                         '"new_indices" or "masses".')

    if sort_by_legend:
        idx = np.argsort(legends).tolist()
        legends = [legends[i] for i in idx]
        mols = [mols[i] for i in idx]

    legends = [f'{l:.{legend_float_decimals}f}' if isinstance(l, float) else str(l) for l in legends]

    img = Draw.MolsToGridImage(mols, maxMols=max_mols, legends=legends, molsPerRow=min(max_mols, mols_per_row),
                         useSVG=svg, returnPNG=False)

    if save_pth:
        with open(save_pth, 'w') as f:
            f.write(img.data)

    return img


def mol_to_formula(mol, as_dict=False):
    formula = rdMolDescriptors.CalcMolFormula(mol)
    return formula_to_dict(formula) if as_dict else formula


def smiles_to_formula(s, as_dict=False, invalid_mol_smiles=''):
    mol = Chem.MolFromSmiles(s)
    if not mol and invalid_mol_smiles is not None:
        f = invalid_mol_smiles
    else:
        f = rdMolDescriptors.CalcMolFormula(mol)
    if as_dict:
        f = formula_to_dict(f)
    return f


class MolPropertyCalculator:
    def __init__(self):
        # Estimates of min and max values from the training part of MoNA and NIST20 Murcko histograms split
        self.min_maxs = {
            'AtomicLogP': {'min': -13.054800000000025, 'max': 26.849200000000053},
            'NumHAcceptors': {'min': 0.0, 'max': 36.0},
            'NumHDonors': {'min': 0.0, 'max': 20.0},
            'PolarSurfaceArea': {'min': 0.0, 'max': 585.0300000000002},
            'NumRotatableBonds': {'min': 0.0, 'max': 68.0},
            'NumAromaticRings': {'min': 0.0, 'max': 8.0},
            'NumAliphaticRings': {'min': 0.0, 'max': 22.0},
            'FractionCSP3': {'min': 0.0, 'max': 1.0},
            'QED': {'min': 0.0, 'max': 1.0},  # 'QED': {'min': 0.008950206972239864, 'max': 0.9479380820623227},
            'SyntheticAccessibility': {'min': 1.0, 'max': 10.0},  # 'SyntheticAccessibility': {'min': 1.0549172379947862, 'max': 8.043981630210263},
            'BertzComplexity': {'min': 2.7548875021634682, 'max': 3748.669248605835}
        }
        self.prop_names = list(self.min_maxs.keys())

    def mol_to_props(self, mol, min_max_norm=False):
        props = {
            'AtomicLogP': Crippen.MolLogP(mol),
            'NumHAcceptors': Lipinski.NumHAcceptors(mol),
            'NumHDonors': Lipinski.NumHDonors(mol),
            'PolarSurfaceArea': rdMolDescriptors.CalcTPSA(mol),
            'NumRotatableBonds': Lipinski.NumRotatableBonds(mol),
            'NumAromaticRings': Lipinski.NumAromaticRings(mol),
            'NumAliphaticRings': Lipinski.NumAliphaticRings(mol),
            'FractionCSP3': Lipinski.FractionCSP3(mol),
            'QED': QED.qed(mol),
            'SyntheticAccessibility': sascorer.calculateScore(mol),
            'BertzComplexity': rdkit.Chem.GraphDescriptors.BertzCT(mol)
        }
        if min_max_norm:
            props = self.normalize_props(props)
        return props

    def normalize_prop(self, prop, prop_name):
        return (prop - self.min_maxs[prop_name]['min']) / (self.min_maxs[prop_name]['max'] - self.min_maxs[prop_name]['min'])

    def denormalize_prop(self, prop, prop_name, do_not_add_min=False):
        res = prop * (self.min_maxs[prop_name]['max'] - self.min_maxs[prop_name]['min'])
        if not do_not_add_min:
            res = res + self.min_maxs[prop_name]['min']
        return res

    def normalize_props(self, props):
        return {k: self.normalize_prop(v, k) for k, v in props.items()}

    def denormalize_props(self, props):
        return {k: self.denormalize_prop(v, k) for k, v in props.items()}

    def __len__(self):
        return len(self.prop_names)


def formula_to_dict(formula):
    """
    Transforms chemical formula string to dictionary mapping elements to their frequencies
    e.g. 'C15H24' -> {'C': 15, 'H': 24}
    """
    elem_count = defaultdict(int)
    #try:
    formula = formula.replace('+', '').replace('-', '').replace('[', '').replace(']', '')
    formula_counts = ase.formula.Formula(formula)
    formula_counts = formula_counts.count().items()
    for k, v in formula_counts:
        elem_count[k] += v
    #except Exception as e:
    #    print(f'Invalid formula: {formula} ({e.__class__.__name__})')

    return elem_count


def rdkit_fp(mol, fp_size=4096):
    """Default RDKit fingerprint."""
    return Chem.RDKFingerprint(mol, fpSize=fp_size)


def tanimoto_sim(fp1, fp2):
    """Default RDKit Tanimoto distance."""
    return DataStructs.TanimotoSimilarity(fp1, fp2)


def rdkit_mol_sim(m1, m2, fp_size=4096):
    """Default RDKit Tanimoto distance on default RDKit fingerprint."""
    return tanimoto_sim(rdkit_fp(m1, fp_size=fp_size), rdkit_fp(m2, fp_size=fp_size))


def rdkit_smiles_sim(s1, s2, fp_size=4096):
    """Default RDKit Tanimoto distance on default RDKit fingerprint."""
    return rdkit_mol_sim(Chem.MolFromSmiles(s1), Chem.MolFromSmiles(s2), fp_size=fp_size)


def morgan_fp(mol, binary=True, fp_size=4096, radius=2, as_numpy=True):
    if binary:
        fp = Chem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=fp_size)
    else:
        fp = Chem.GetHashedMorganFingerprint(mol, radius=radius, nBits=fp_size)

    if as_numpy:
        return rdkit_fp_to_np(fp)
    return fp


def maccs_fp(mol, as_numpy=True):
    """
    NOTE: Since indexing of MACCS keys starts from 1, when converting to numpy array with `as_numpy`, the first element
          is removed, so the resulting array has 166 elements instead of 167.
    """
    fp = GenMACCSKeys(mol)
    if as_numpy:
        return rdkit_fp_to_np(fp)[1:]
    return fp


def fp_func_from_str(s):
    """
    :param s: E.g. "fp_rdkit_2048", "fp_rdkit_2048" or "fp_maccs_166".
    """
    _, fp_type, n_bits = s.split('_')
    n_bits = int(n_bits)
    if fp_type == 'rdkit':
        return lambda mol: np.array(rdkit_fp(mol, fp_size=n_bits), dtype=float)
    elif fp_type == 'morgan':
        return lambda mol: morgan_fp(mol, fp_size=n_bits).astype(float, copy=False)
    elif fp_type == 'maccs':
        return lambda mol: maccs_fp(mol).astype(float, copy=False)
    else:
        raise ValueError(f'Invalid fingerprint function name: "{s}".')


def morgan_mol_sim(m1, m2, fp_size=4096, radius=2):
    return tanimoto_sim(
        morgan_fp(m1, fp_size=fp_size, radius=radius, as_numpy=False),
        morgan_fp(m2, fp_size=fp_size, radius=radius, as_numpy=False)
    )


def morgan_smiles_sim(s1, s2, fp_size=4096, radius=2):
    return morgan_mol_sim(Chem.MolFromSmiles(s1), Chem.MolFromSmiles(s2), fp_size=fp_size, radius=radius)


def rdkit_fp_to_np(fp):
    fp_np = np.zeros((0,), dtype=np.int32)
    DataStructs.ConvertToNumpyArray(fp, fp_np)
    return fp_np


def np_to_rdkit_fp(fp):
    fp = fp.round().astype(int, copy=False)
    bitstring = ''.join(fp.astype(str))
    return DataStructs.cDataStructs.CreateFromBitString(bitstring)


def mol_to_inchi14(mol: Chem.Mol):
    return Chem.MolToInchiKey(mol).split('-')[0]


def smiles_to_inchi14(s):
    return mol_to_inchi14(Chem.MolFromSmiles(s))


def generate_fragments(mol: Chem.Mol, max_cuts: int = None):
    """
    Generates all possible fragments of a molecule up to a certain number of bond cuts or without the restriction if
    `max_cuts` is not specified.

    :param mol: an RDKit molecule object
    :param max_cuts: the maximum number of bonds to cut
    :return a set of RDKit Mol objects representing all possible fragments
    """

    bonds = mol.GetBonds()
    # bonds = [bond for bond in bonds if bond.GetBondType() in [rdchem.BondType.SINGLE, rdchem.BondType.DOUBLE]]
    fragments = set()
    for i in range(1, len(bonds) + 1):

        if max_cuts and i > max_cuts:
            break

        for combination in itertools.combinations(bonds, i):
            new_mol = rdchem.RWMol(mol)
            for bond in combination:
                new_mol.RemoveBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())

            # Update properties such as ring membership after changing the molecule's structure.
            for fragment in Chem.GetMolFrags(new_mol, asMols=True, sanitizeFrags=False):
                fragments.add(Chem.MolToSmiles(fragment))

    fragments = [Chem.MolFromSmiles(f) for f in fragments]
    return [f for f in fragments if f is not None]


def generate_spectrum(mol: Chem.Mol, prec_mz: float = None, fragments: List = None, max_cuts: int = None):
    """
    Generates an MS/MS spectrum by exhaustively simulating the m/z values of theoretical fragments of a given molecule.
    The algorithm is very simplistic since it considers only subgraph-like fragments, does not consider isotopes, etc.

    :param mol: An RDKit molecule object.
    :param prec_mz: The m/z value of a molecule. If not specified, it is calculated as the sum of the
                    exact molecular weight of the molecule and 1.
    :param fragments: A list of RDKit Mol objects representing pre-generated fragments of the molecule. If not specified,
                     the function will generate the fragments automatically.
    :param max_cuts: The maximum number of bonds to cut when generating fragments. If not specified, all possible
                     fragments will be generated without any restriction on the number of cuts.
    :return: A spectrum represented as a numpy array with two columns: m/z values and their respective intensities.
    """

    # Simulate the m/z of "protonated adduct"
    if not prec_mz:
        prec_mz = ExactMolWt(mol) + 1

    # Fragment molecule
    if not fragments:
        fragments = generate_fragments(mol, max_cuts=max_cuts)

    # Simulate spectrum
    masses = np.round(np.array([prec_mz - ExactMolWt(f) for f in fragments]))
    ins, mzs = np.histogram(masses, bins=np.arange(0, np.ceil(max(masses)), 1))
    spec = np.stack([mzs[1:], ins]).T

    return spec


def closest_mz_frags(query_mz, frags, n=1, mass_shift=1, return_masses=False, print_masses=True):
    masses = [ExactMolWt(f) + mass_shift for f in frags]
    idx = utils.get_closest_values(masses, query_mz, n=n, return_idx=True)
    frags, masses = [frags[i] for i in idx], [masses[i] for i in idx]
    if n == 1:
        frags, masses = frags[0], masses[0]
    if print_masses:
        print(masses)
    if return_masses:
        return frags, masses
    return frags


def disable_rdkit_log():
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)


def np_classify(smiles: List[str], progress_bar=True, sleep_each_n_requests=100):
    np_classes = []
    for i, s in enumerate(tqdm(smiles) if progress_bar else smiles):
        if i % sleep_each_n_requests == 0 and i > 0:
            time.sleep(1)
        print(s)
        with urllib.request.urlopen(f'https://npclassifier.ucsd.edu/classify?smiles={urllib.parse.quote(s)}') as url:
            res = json.load(url)
            for k in list(res.keys()):
                if 'fp' in k:
                    res.pop(k)
            np_classes.append(res)
    return np_classes


def mol_to_img_str(mol, svg_size=200):
    """
    Supposed to be used with `pyvis` for showing molecule images as graph nodes.
    """
    buffered = BytesIO()
    d2d = rdMolDraw2D.MolDraw2DSVG(svg_size, svg_size)
    opts = d2d.drawOptions()
    opts.clearBackground = False
    d2d.DrawMolecule(mol)
    d2d.FinishDrawing()
    img_str = d2d.GetDrawingText()
    buffered.write(str.encode(img_str))
    img_str = base64.b64encode(buffered.getvalue())
    img_str = f"data:image/svg+xml;base64,{repr(img_str)[2:-1]}"
    return img_str


def formula_is_carbohydrate(formula):
    return set(formula.keys()) <= {'C', 'H', 'O'}


def formula_is_halogenated(formula):
    return sum([(formula[e] if e in formula else 0) for e in ['F', 'Cl', 'Br', 'I']]) > 0


def formula_type(f):
    if isinstance(f, str):
        f = formula_to_dict(f)

    if not f:
        return 'No formula'
    elif formula_is_carbohydrate(f):
        return 'Carbohydrate'
    elif set(f.keys()) <= {'C', 'H', 'O', 'N'}:
        return 'Carbohydrate with nitrogen'
    elif set(f.keys()) <= {'C', 'H', 'O', 'N', 'S'} and 'N' in f and 'S' in f:
        return 'Carbohydrate with nitrogen and sulfur'
    elif formula_is_halogenated(f):
        return 'Compound with halogens'
    else:
        return 'Other'


def get_mol_mass(mol):
    return ExactMolWt(mol)

In [None]:
#df_massspec = df_unique[df_unique["identifier"].str.startswith("MassSpecGym")]
#df_nist = df_unique[df_unique["identifier"].str.startswith("NIST20")]
#df_unique = pd.read_pickle('merged_massspec_nist20.pkl')

df_nist['inchikey'] = df_nist['smiles'].apply(smiles_to_inchi14)

In [None]:
num_inchi_ms = df_massspec["inchikey"].nunique()
num_inchi_nist = df_nist["inchikey"].nunique()
print("NIST unique # inchikeys: " + str(num_inchi_nist))
print("MassSpecGym unique # inchikeys: " + str(num_inchi_ms))

In [3]:
# Overlap our PFAS training and PFAS suspect list from data.gov
import pandas as pd

# Load both TSV files
df_records = pd.read_csv("/teamspace/studios/this_studio/files/pfas_only_records.tsv", sep='\t')
df_suspects = pd.read_csv("/teamspace/studios/this_studio/files/PFAS_suspect_list_data_gov.tsv", sep='\t')

# Preview column names
print("Records columns:", df_records.columns.tolist())
print("Suspects columns:", df_suspects.columns.tolist())

# Standardize column names
smiles_records = df_records[df_records['fold'] == 'train']['smiles'].dropna().str.strip().unique()
smiles_suspects = df_suspects['SMILES'].dropna().str.strip().unique()

# Convert to sets for comparison
set_records = set(smiles_records)
set_suspects = set(smiles_suspects)

# Find overlap
overlap = set_suspects.intersection(set_records)

# Report results
print(f"Total in PFAS_Suspect_List: {len(set_suspects)}")
print(f"Total in pfas_only_records: {len(set_records)}")
print(f"Overlapping SMILES: {len(overlap)}")

for smile in sorted(overlap):
        print(smile)

Records columns: ['Unnamed: 0.1', 'Unnamed: 0', 'identifier', 'mzs', 'intensities', 'smiles', 'inchikey', 'formula', 'precursor_formula', 'parent_mass', 'precursor_mz', 'adduct', 'instrument_type', 'collision_energy', 'fold', 'simulation_challenge', 'is_PFAS']
Suspects columns: ['SUSPECTID', 'CHEMICAL_NAME', 'INCHI', 'SMILES', 'INCHIKEY', 'FIXEDINCHI', 'FORMULA', 'FIXEDMASS', 'NETCHARGE', 'LOCAL_POSITIVE', 'LOCAL_NEGATIVE', 'DOI', 'CITATION_TYPE', 'ADDITIONAL', 'INSPECTEDBY']
Total in PFAS_Suspect_List: 4964
Total in pfas_only_records: 1469
Overlapping SMILES: 4
C(COP(=O)(O)O)C(C(C(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F
C(COP(=O)(O)O)C(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F
C(COP(=O)(O)OCCC(C(C(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)C(C(C(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F
C(COP(=O)(O)OCCC(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F)C(C(C(C(C(C(F)(F)F)(F)F)(F)F)(F)F)(F)F)(F)F
