In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
import sys


In [None]:


from utils.sequence import uniprot2sequence, encode_sequences
from utils.chem import *
from utils.parallel import *
from utils.sequence import encode_sequences
from utils.chem import get_mols, get_fp  

In [None]:
import pandas as pd
import numpy as np
import pickle
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score
from rdkit import Chem
from xgboost import XGBClassifier
from base_model import BaseModel
from preprocessor import Preprocessor
from barlow_twins import BarlowTwins
from rdkit.Chem import AllChem
from sklearn.svm import SVC
import seaborn as sns
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display
from PIL import Image as PILImage
import io, os, math, random

### Model Run

In [None]:
test_path = "all_train_val_tb_papyrus.csv"  
barlow_model_path = "Papyrus_tb"  

bt_model = BarlowTwins()
bt_model.load_model(barlow_model_path)


test_df = pd.read_csv(test_path)

# ECFPs e embeddings 
test_mols = [Chem.MolFromSmiles(smi) for smi in test_df["smiles"]]
test_ecfp = [AllChem.GetMorganFingerprintAsBitVect(m, 2, nBits=1024) for m in test_mols]
test_ecfp = np.array(test_ecfp)

test_emb = encode_sequences(test_df["sequence"].tolist(), encoder="prost_t5")
test_emb = np.array([np.array(x) for x in test_emb])

# Conct embeddings
test_vectors = bt_model.zero_shot(test_ecfp, test_emb)
true_labels = test_df["label"].values

#model = XGBClassifier()
#model.load_model("/home/resperanca/Tuberculosis_Tese/Source/Models/BarlowDTI/model/xxl_stash/Papyrus_tb_barlowdti_xxl_model_tb_pap.json")




## SHAP (top 20 bit)

In [None]:

import shap

explainer = shap.Explainer(model, test_vectors)
shap_values = explainer(test_vectors)

ecfp_bits_only = test_vectors[:, :1024]
shap_ecfp_values = shap_values.values[:, :1024]
mean_abs_shap = np.abs(shap_ecfp_values).mean(axis=0)
top_bits_idx = np.argsort(mean_abs_shap)[::-1][:20]
for i in top_bits_idx:
    print(f"Bit {i}: SHAP value média = {mean_abs_shap[i]:.4f}")


Top 50 molecules whit more bits

In [None]:
important_bits = top_bits_idx 

bit_presence_matrix = np.array([
    [int(fp[i]) for i in important_bits] for fp in test_ecfp
])

bit_activation_score = bit_presence_matrix.sum(axis=1)
top_mol_indices = np.argsort(bit_activation_score)[::-1][:50]

for idx in top_mol_indices:
    print(f"idx {idx}, SMILES: {test_df.iloc[idx]['smiles']}, activate: {test_df.iloc[idx]['label']}, Score: {bit_activation_score[idx]}")


for idx in top_mol_indices:
    print(f"{test_df.iloc[idx]['smiles']}")


SHAP VS results fo filter (BITS presence)

In [None]:
smiles_list_14 = ['CNC(=O)C(Cc1ccccc1)NC(=O)C(CC(C)C)C1(C(=O)NO)CCCC1',
'CNC(=O)C(Cc1ccccc1)NC(=O)C(CC(C)C)C1(C(=O)NO)CCCC1',
'COc1cc(C#N)cc(-c2csc(C34CN(c5ncc(F)cn5)CC3C(=O)N(C)C(N)=N4)c2)c1',
'Cc1c(CC(=O)O)cc2ccc(F)cc2c1-c1ccc(S(=O)(=O)N2CCCC2)cc1',
'O=C(NCC1CN(c2ccc(N3CCCC(O)C3)c(F)c2)C(=O)O1)c1ccc(Cl)s1',
'CCc1coc(NC(=O)c2cc3c(nc2OC)nc(C(O)(c2ccccc2)c2ccccc2)n3CC)n1',
'CNC(=O)C1(c2cccc(OCc3cc(C)nc4ccccc34)c2)CC1C(=O)NO',
'Cn1cnc(S(=O)(=O)Nc2cccc(-c3ccc(C(=O)c4cc(F)c(F)c(O)c4F)s3)c2)c1',
'Nc1cncc(C2=NN(C(=O)c3ccc(-c4ccccn4)s3)C(c3ccccc3O)C2)c1',
'O=C(NCC1CN(c2ccc(N3CCC(O)C3)c(F)c2)C(=O)O1)c1ccc(Cl)s1',
'O=C(NCC1CN(c2ccc(N3CCC(O)C3)c(F)c2)C(=O)O1)c1ccc(Cl)s1',
'O=C(NCC1CN(c2ccc(N3CCC(O)C3)c(F)c2)C(=O)O1)c1ccc(Cl)s1',
'NC(=O)c1cc(-c2ccc(Cl)c(Cl)c2)cc2c1[nH]c1ccc(C(=O)N3CCOC(CO)C3)cc12',
'N#Cc1ccc(NC(=O)c2ccc(N3CCCCC3=O)cc2)c(C(=O)Nc2ccc(Cl)cn2)c1']

smiles_list_1 = ['CCCS(=O)(=O)C1CCN(c2ccn3cc(-c4cc(Cl)c(OC)cc4OC)nc3n2)CC1']

bits_csv = ('shap_top_ecfp_bits.csv') 
topN = 20       



if bits_csv:
    df_bits = pd.read_csv(bits_csv)
    important_bits = df_bits["bit"].astype(int).tolist()[:topN]
    shap_means = df_bits.set_index("bit")["shap_value_mean"]
else:
    shap_means = pd.Series(dtype=float) 

important_bits = list(map(int, important_bits))
important_bits_sorted = sorted(important_bits) 

def ecfp_bits_presence(smi, bits, radius=2, nBits=1024):
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        return None, {}
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits)
    presence = {b: int(fp[b]) for b in bits}
    return mol, presence

rows = []
for smi in smiles_list_1:
    mol, presence = ecfp_bits_presence(smi, important_bits_sorted)
    if mol is None:
        rows.append({
            "smiles": smi,
            "valid_smiles": 0,
            "n_bits_present": np.nan,
            "bits_present": "",
            "weighted_score": np.nan,
            **{f"bit_{b}": np.nan for b in important_bits_sorted},
        })
        continue

   
    n_present = sum(presence.values())
    bits_present_list = [b for b,v in presence.items() if v==1]

    
    if not shap_means.empty:
        weighted = float(sum(shap_means.get(b, 0.0) for b in bits_present_list))
    else:
        weighted = np.nan

    base = {
        "smiles": smi,
        "valid_smiles": 1,
        "n_bits_present": n_present,
        "bits_present": ",".join(map(str, bits_present_list)),
        "weighted_score": weighted,
    }

    base.update({f"bit_{b}": presence[b] for b in important_bits_sorted})
    rows.append(base)

df = pd.DataFrame(rows)


if "weighted_score" in df.columns and df["weighted_score"].notna().any():
    df = df.sort_values(["weighted_score","n_bits_present"], ascending=False, na_position="last")
else:
    df = df.sort_values("n_bits_present", ascending=False)

print(df[["smiles","valid_smiles","n_bits_present","bits_present","weighted_score"]].to_string(index=False))
df.to_csv("bits_presence_report.csv", index=False)


Generate representations of the molecules in question with the present bits highlighted

In [None]:
smiles_list = [
    "NC(=O)c1cc(-c2ccc(Cl)c(Cl)c2)cc2c1[nH]c1ccc(C(=O)N3CCOC(CO)C3)cc12",
    "Cn1cnc(S(=O)(=O)Nc2cccc(-c3ccc(C(=O)c4cc(F)c(F)c(O)c4F)s3)c2)c1",
    "Nc1cncc(C2=NN(C(=O)c3ccc(-c4ccccn4)s3)C(c3ccccc3O)C2)c1",
    "CCCS(=O)(=O)C1CCN(c2ccn3cc(-c4cc(Cl)c(OC)cc4OC)nc3n2)CC1",
]
top_bits = [49,472,178,123,100,0,137,223,39,88,156,91,106,72,71,160,218,61,6,58]

save_dir = "bits_mol_e_fragmento"   
mols_per_row = 3                    
subimg_size = (240, 240)            
radius = 2
nBits = 1024
max_smiles_len = 40

def _short(s, k=40):
    return s if len(s) <= k else s[:k-3] + "..."

def get_bit_envs(mol, bit, radius=2, nBits=1024):
    
    bit_info = {}
    _ = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits, bitInfo=bit_info)
    if bit not in bit_info:
        return []

    envs = []
    for atom_idx, rad in bit_info[bit]:
        bond_ids = Chem.FindAtomEnvironmentOfRadiusN(mol, rad, atom_idx)
        atom_ids = set()
        for b in bond_ids:
            bond = mol.GetBondWithIdx(b)
            atom_ids.add(bond.GetBeginAtomIdx())
            atom_ids.add(bond.GetEndAtomIdx())
        if not bond_ids:  
            atom_ids.add(atom_idx)
        envs.append({
            "atoms": sorted(atom_ids),
            "bonds": list(bond_ids),
            "center": atom_idx,
            "rad": rad,
        })
    return envs

def submol_from_env(mol, env):
   
    bonds = env["bonds"]
    atoms = env["atoms"]
    if bonds:
        sub = Chem.PathToSubmol(mol, bonds)
    else:
    
        try:
            sub = Chem.PathToSubmol(mol, bondIndices=[], atomIndices=atoms)
        except TypeError:
            
            em = Chem.EditableMol(Chem.Mol())
            map_old_to_new = {}
            for aidx in atoms:
                a = mol.GetAtomWithIdx(aidx)
                na = Chem.Atom(a.GetAtomicNum())
                new_idx = em.AddAtom(na)
                map_old_to_new[aidx] = new_idx
            sub = em.GetMol()
    
    try:
        if sub.GetNumConformers() == 0:
            AllChem.Compute2DCoords(sub)
    except Exception:
        AllChem.Compute2DCoords(sub)
    return sub

def mol_to_image(mol, legend, highlightAtoms=None, highlightBonds=None, size=(240,240), color=(0.98,0.25,0.25)):
    
    highlightAtoms = highlightAtoms or []
    highlightBonds = highlightBonds or []
    try:
        img = Draw.MolToImage(
            mol, size=size, legend=legend,
            highlightAtoms=highlightAtoms,
            highlightBonds=highlightBonds,
            highlightColor=color
        )
    except TypeError:
        img = Draw.MolToImage(
            mol, size=size, legend=legend,
            highlightAtoms=highlightAtoms,
            highlightBonds=highlightBonds
        )
    return img  

def hpair(left_img, right_img, tile_size=(240,240), gap=8, bg=(255,255,255)):
   
    w, h = tile_size
    canvas = PILImage.new("RGB", (w*2 + gap, h), bg)
    if left_img.size != tile_size:
        left_img = left_img.resize(tile_size, resample=PILImage.BICUBIC)
    if right_img.size != tile_size:
        right_img = right_img.resize(tile_size, resample=PILImage.BICUBIC)
    canvas.paste(left_img, (0, 0))
    canvas.paste(right_img, (w + gap, 0))
    return canvas

def make_grid(panels, molsPerRow=3, tile_size=(240,240), gap=8, bg=(255,255,255)):
    
    if not panels:
        raise ValueError("empty.")
    single_w, single_h = tile_size
    pair_w = single_w*2 + gap
    pair_h = single_h
    n = len(panels)
    rows = math.ceil(n / molsPerRow)
    grid = PILImage.new("RGB", (molsPerRow * pair_w, rows * pair_h), bg)
    for idx, im in enumerate(panels):
        if im.size != (pair_w, pair_h):
            im = im.resize((pair_w, pair_h), resample=PILImage.BICUBIC)
        r = idx // molsPerRow
        c = idx % molsPerRow
        grid.paste(im, (c * pair_w, r * pair_h))
    return grid


# MAIN

def show_bits_mol_and_fragment(smiles_list, bits, save_dir=None,
                               molsPerRow=3, subImgSize=(240,240),
                               radius=2, nBits=1024, max_smiles_len=40, pair_gap=8):
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

    rng = random.Random(42)  
    bit_colors = {}
    def color_for(bit):
        if bit not in bit_colors:
            bit_colors[bit] = (0.9*rng.random(), 0.5+0.5*rng.random(), 0.4+0.6*rng.random())
        return bit_colors[bit]

    for i, smi in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            print(f"[Mol {i}] SMILES invalid: {smi}")
            continue

        if mol.GetNumConformers() == 0:
            AllChem.Compute2DCoords(mol)

        pairs = []
        for b in bits:
            envs = get_bit_envs(mol, b, radius=radius, nBits=nBits)
            col = color_for(b)
            for k, env in enumerate(envs, 1):
               
                legend_left = f"Mol {i} | Bit {b} (oc {k})"
                mol_img = mol_to_image(
                    mol, legend_left,
                    highlightAtoms=env["atoms"],
                    highlightBonds=env["bonds"],
                    size=subImgSize,
                    color=col
                )
                
                frag = submol_from_env(mol, env)
                legend_right = f"Fragmento bit {b} (oc {k})"
                try:
                    frag_img = Draw.MolToImage(
                        frag, size=subImgSize, legend=legend_right,
                        highlightAtoms=list(range(frag.GetNumAtoms())),
                        highlightColor=col
                    )
                except TypeError:
                    frag_img = Draw.MolToImage(frag, size=subImgSize, legend=legend_right)

                
                pair = hpair(mol_img, frag_img, tile_size=subImgSize, gap=pair_gap)
                pairs.append(pair)

            if not pairs:
                print(f"[empty.")
                continue

            
            grid = make_grid(pairs, molsPerRow=molsPerRow, tile_size=subImgSize, gap=pair_gap)
            display(grid)

            if save_dir:
                out_path = os.path.join(save_dir, f"mol_{i}_mol_e_fragments.png")
                grid.save(out_path)
                



# RUN

show_bits_mol_and_fragment(
    smiles_list, top_bits,
    save_dir=save_dir,
    molsPerRow=mols_per_row,      
    subImgSize=subimg_size,
    radius=radius, nBits=nBits,
    max_smiles_len=max_smiles_len,
    pair_gap=8
)
