In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
from tqdm import tqdm
import numpy as np
import torch
import h5py
import pandas as pd
import rdkit.Chem as Chem
from rdkit.Chem import Draw
import io
from PIL import Image
import hickle as hkl

In [None]:
DELQSAR_ROOT = os.getcwd() + '/../../'
sys.path += [DELQSAR_ROOT + '/../']

from del_qsar import models, featurizers

if not os.path.isdir('single_substructure_analysis'):
    os.mkdir('single_substructure_analysis')
def pathify(fname):
    return os.path.join('single_substructure_analysis', fname)

def save_png(data, out_path):
    bio = io.BytesIO(data)
    img = Image.open(bio)
    img.save(out_path, 'PNG')

In [None]:
DD1S_FINGERPRINTS_FILENAME = 'x_DD1S_CAIX_2048_bits_all_fps.h5' # should be in experiments folder
triazine_FINGERPRINTS_FILENAME = 'x_triazine_2048_bits_all_fps.h5' # should be in experiments folder
DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments',
                                                  'models', 'DD1S_CAIX', 'FP-FFNN', 'random_seed_0.torch')
DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_1_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments',
                                                  'models', 'DD1S_CAIX', 'FP-FFNN', 'random_seed_1.torch')
DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_2_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments',
                                                  'models', 'DD1S_CAIX', 'FP-FFNN', 'random_seed_2.torch')
triazine_sEH_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments',
                                                  'models', 'triazine_sEH', 'FP-FFNN', 'random_seed_0.torch')
triazine_SIRT2_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH = os.path.join(DELQSAR_ROOT, 'experiments',
                                                  'models', 'triazine_SIRT2', 'FP-FFNN', 'random_seed_0.torch')

In [None]:
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rc('font', family='sans-serif') 
matplotlib.rc('font', serif='Arial') 
matplotlib.rc('text', usetex='false') 
matplotlib.rcParams.update({'font.size': 8})

In [None]:
def charge(a):
    """Returns a SMARTS substring describing the atomic charge."""
    if a.GetFormalCharge() >= 0:
        return f'+{a.GetFormalCharge()}'
    return f'-{abs(a.GetFormalCharge())}'
def getMorganFingerprintAtomSymbols(mol):
    """Generate custom atomSymbols based on the specificity of an atom
    definition used by Morgan Fingerprints. Namely,
        - atom ID
        - degree
        - number of Hs
        - ring membership
        - charge
    These are based on getConnectivityInvariants from FingerprintUtil.cpp at
    https://github.com/rdkit/rdkit/blob/75f03412ef151a4dc14dfee986e29c3690a4c071/Code/GraphMol/Fingerprints/FingerprintUtil.cpp#L254
    """
    atomSymbols = []
    for a in mol.GetAtoms():
        atomSymbols.append(
            f'[#{a.GetAtomicNum()};D{a.GetDegree()};H{a.GetTotalNumHs()};R{mol.GetRingInfo().NumAtomRings(a.GetIdx())};{charge(a)}]')
    return atomSymbols

In [None]:
def GetWeightsForBit(bit_id): 
    """Returns 1) the differences in model predictions when the specified bit is masked for each fingerprint,
    2) the average of these differences.
    Returns -inf for the average if the specified bit is not set by any of the molecules.
    """
    indices = np.squeeze(np.where(x[:,bit_id]==1))
    x_hasbit = np.squeeze(x[indices, :])
    x_hasbit_masked = x_hasbit.copy()
    if x_hasbit.ndim == 1:
        x_hasbit_masked[bit_id] = 0
    else:
        x_hasbit_masked[:, bit_id] = 0
    basePreds = basePreds_all[indices]
    newPreds = np.array(model.predict_on_x(np.array(x_hasbit_masked), device=DEVICE), dtype='float64')
    weights = np.subtract(basePreds, newPreds)
    if x_hasbit.shape[0] != 0:
        avg_weight = np.sum(weights) / x_hasbit.shape[0] 
        return weights, avg_weight
    else: 
        return weights, -np.inf

In [None]:
def GetWeightsForSubstructure(weights, indices):
    """Returns 1) the molecule-level bit weights when the substructure of interest sets the bit,
    2) the average of these weights. 
    
    Parameters:
        - weights: molecule-level bit weights for a model and specific bit 
        - indices: indices in ::weights:: for the molecules with the substructure
    """
    weights_for_substructure = weights[indices]
    if len(indices) != 0:
        avg_weight = np.sum(weights_for_substructure) / len(indices)
    else:
        raise ValueError('No indices specified')
    return weights_for_substructure, avg_weight

In [None]:
def getFragmentForMolBit(smi, mol, mol_idx, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw):
    """Returns updated dictionaries and examples after searching for substructure(s) in the 
    specified molecule that set the specified bit.
    """
    molAdded = False
    examples = []
    for j, example in enumerate(info_all[mol_idx][bit]):
        atom = example[0]
        radius = example[1]
        env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, atom)
        atoms = set()
        for bidx in env:
            atoms.add(mol.GetBondWithIdx(bidx).GetBeginAtomIdx())
            atoms.add(mol.GetBondWithIdx(bidx).GetEndAtomIdx())
        if atoms:
            submol_smi = Chem.MolFragmentToSmiles(mol, atomsToUse=list(atoms), bondsToUse=env, 
                                            rootedAtAtom=atom, isomericSmiles=True, allBondsExplicit=True)
            submol_sm = Chem.MolFragmentToSmiles(mol, atomsToUse=list(atoms), atomSymbols=atomSymbols, 
                                            bondsToUse=env, isomericSmiles=True, allBondsExplicit=True)
            if submol_sm not in submol_freq_distrib[bit]:
                print(f'cpd_id: {cpd_id}')
                print(f'SMILES string: {smi}')
                print(f'bit ID: {bit}')
                print(f'(atom, radius): {(atom, radius)}')
                print(f'molecular fragment (SMILES): {submol_smi}')
                print(f'molecular fragment (SMARTS): {submol_sm}')
                print()
                submol_freq_distrib[bit][submol_sm] = 1 
                if submol_sm not in smarts_to_smis:
                    smarts_to_smis[submol_sm] = [submol_smi]
                else:
                    smarts_to_smis[submol_sm].append(submol_smi)   
                if submol_sm not in submol_to_bit:
                    submol_to_bit[submol_sm] = [bit]
                else:
                    submol_to_bit[submol_sm].append(bit)       
                examples.append(j)
                submol_to_cpd_indices[submol_sm] = [mol_idx]
                if not molAdded:
                    bits_to_draw.append((cpd_id, mol, bit, info_all[mol_idx]))
                    molAdded = True
            else:
                submol_freq_distrib[bit][submol_sm] += 1
                submol_to_cpd_indices[submol_sm].append(mol_idx)
        else:
            atom_smi = mol.GetAtomWithIdx(atom).GetSmarts()
            atom_sm = Chem.MolFragmentToSmiles(mol, 
                        atomsToUse=atom,
                        atomSymbols=atomSymbols,
                        isomericSmiles=True, 
                        allBondsExplicit=True)
            if atom_sm not in submol_freq_distrib[bit]:
                print(f'cpd_id: {cpd_id}')
                print(f'SMILES string: {smi}')
                print(f'bit ID: {bit}')
                print(f'(atom, radius): {(atom, radius)}')
                print(f'atom: {atom_sm}')
                print()
                submol_freq_distrib[bit][atom_sm] = 1
                if atom_sm not in smarts_to_smis:
                    smarts_to_smis[atom_sm] = [atom_smi]
                else:
                    smarts_to_smis[atom_sm].append(atom_smi)
                if atom_sm not in submol_to_bit:
                    submol_to_bit[atom_sm] = [bit]
                else:
                    submol_to_bit[atom_sm].append(bit)
                examples.append(j)
                submol_to_cpd_indices[atom_sm] = [mol_idx]
                if not molAdded:
                    bits_to_draw.append((cpd_id, mol, bit, info_all[mol_idx]))
                    molAdded = True
            else:
                submol_freq_distrib[bit][atom_sm] += 1
                submol_to_cpd_indices[atom_sm].append(mol_idx)
    return submol_freq_distrib, smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw, examples

In [None]:
DEVICE = None
if torch.cuda.is_available():
    DEVICE = 'cuda:0'

# DD1S CAIX (random split, seed 0)

## Load data

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', DD1S_FINGERPRINTS_FILENAME))
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]   
hf.close()

## Load model

In [None]:
model = models.MLP(INPUT_SIZE, [64, 64, 64],
            dropout=0.1)
model.load_state_dict(torch.load(DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH))
if DEVICE:
    model = model.to(DEVICE)

## Bit analysis

### Calculate bit weights

In [None]:
set_bit_ids = [bit_id for bit_id in tqdm(range(2048)) if any(x[:,bit_id]==1)]
print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
# check if any bit is set by exactly one molecule
# for bit_id in tqdm(range(2048)):
#     indices = np.squeeze(np.where(x[:,bit_id]==1))
#     if indices.shape == ():
#         print(f'Bit {bit_id} is set by only one molecule')
#         break

In [None]:
basePreds_all = np.array(model.predict_on_x(x, device=DEVICE), dtype='float64')
bit_to_weights = {bit_id: GetWeightsForBit(bit_id) for bit_id in tqdm(set_bit_ids)}
hkl.dump(bit_to_weights, 'bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# bit_to_weights_hkl = hkl.load('bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# bit_to_weights = {int(bit): weights for bit, weights in bit_to_weights_hkl.items()}

### Plot distribution of average bit weights

In [None]:
avg_bit_weights = [bit_to_weights[b][1] for b in bit_to_weights.keys()]
print(f'Lowest average bit weight: {min(avg_bit_weights)}')
print(f'Highest average bit weight: {max(avg_bit_weights)}')

In [None]:
# histogram of average bit weights (only including bits set by at least one molecule)
def make_hist_avg_bit_weights(zoomIn=False):
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(-0.08, 0.27, 0.003)
    else:
        bins = np.arange(-0.08, 0.27, 0.0062) 
    _, bins, patches = plt.hist(
        np.clip(avg_bit_weights, -0.08, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.set_xlabel('Average weight')
    ax.set_ylabel('Number of bits')
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'bit_weight_histogram_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'bit_weight_histogram_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
make_hist_avg_bit_weights()

In [None]:
make_hist_avg_bit_weights(zoomIn=True)

### Get bits of interest (based on average bit weight)

In [None]:
bits_sorted_by_avg_weight = sorted(set_bit_ids, key = lambda b: bit_to_weights[b][1])
top_bits = bits_sorted_by_avg_weight[-5:]
top_bits.reverse()
bottom_bits = bits_sorted_by_avg_weight[:3]
bottom_bits.reverse()
print(f'Top bits: {top_bits}')
print(f'Bottom bits: {bottom_bits}')

In [None]:
bits_of_interest = top_bits + bottom_bits
for b in bits_of_interest:
    print(f'Bit ID: {b}')
    print(f'Average weight: {bit_to_weights[b][1]}')
    print(f'Number of molecules with the bit: {len(np.squeeze(np.where(x[:,b]==1)))}')
    print()

In [None]:
bit_to_cpd_row_indices = {bit: list([idx for idx in np.squeeze(np.where(x[:,bit]==1))]) for bit in bits_of_interest}
for item in bit_to_cpd_row_indices.items():
    print(f'Bit ID: {item[0]}')
    print(f'Number of molecules with the bit: {len(item[1])}')
    print()

### Plot distributions of molecule-level weights

In [None]:
# histogram of molecule-level bit weights
def make_hist_mol_level_bit_weights(bit_id, x_lb, x_ub, stepsize_noZoom, stepsize_zoom, zoomIn=False):
    weights = bit_to_weights[bit_id][0]
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(x_lb, x_ub, stepsize_zoom)
    else:
        bins = np.arange(x_lb, x_ub, stepsize_noZoom) 
    _, bins, patches = plt.hist(
        np.clip(weights, x_lb, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.tick_params(labelsize=8)
    ax.set_xlabel('Weight', fontsize=8)
    ax.set_ylabel('Number of molecules', fontsize=8)
    ax.set_title(f'Bit {bit_id}', fontsize=8)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
bit = 1489
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.02, 0.62, 0.01, None)

In [None]:
bit = 833
print('Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0, 0.65, 0.01, None)

In [None]:
bit = 1785
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.01, 0.62, 0.01, None)

In [None]:
bit = 997
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.01, 0.63, 0.01, None)

In [None]:
bit = 1197
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.01, 0.59, 0.01, None)

In [None]:
bit = 1148
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.12, -0.01, 0.00265, None)

In [None]:
bit = 365
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.20, -0.01, 0.0035, None)

In [None]:
bit = 1165
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.20, 0, 0.004, None)

### Get and visualize substructures

In [None]:
bits_to_draw = []
submol_freq_distrib = {bit: {} for bit in bits_of_interest} # store frequency distribution of substructures that 
                                                            # set each bit
smarts_to_smis = {}
submol_to_cpd_indices = {} # mapping to indices in df_data_hasbit
submol_to_bit = {}
examples_all = {bit: {} for bit in bits_of_interest} # check if there's more than one distinct bit-setting 
                                                     # substructure in the same molecule

for bit in tqdm(bits_of_interest):
    df_data_hasbit = df_data.iloc[bit_to_cpd_row_indices[bit]]
    smis = df_data_hasbit['smiles']
    featurizer = featurizers.FingerprintFeaturizer()
    _, info_all = featurizer.prepare_x(df_data_hasbit, bitInfo=True)
    for i, smi in enumerate(smis):
        mol = Chem.MolFromSmiles(smi)
        atomSymbols = getMorganFingerprintAtomSymbols(mol)
        cpd_id = int(df_data[df_data['smiles']==smi]['cpd_id'].to_numpy()[0])
        if bit not in info_all[i]:
            continue
        fragment_logs = getFragmentForMolBit(smi, mol, i, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw)
        submol_freq_distrib = fragment_logs[0]
        smarts_to_smis = fragment_logs[1]
        submol_to_cpd_indices = fragment_logs[2]
        submol_to_bit = fragment_logs[3]
        bits_to_draw = fragment_logs[4]
        examples = fragment_logs[5]
        if examples:
            examples_all[bit][i] = examples

In [None]:
hkl.dump(submol_freq_distrib, 'submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(smarts_to_smis, 'smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_cpd_indices, 'submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_bit, 'submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(bits_to_draw, 'bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(examples_all, 'examples_all_DD1S_CAIX_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# submol_freq_distrib_hkl = hkl.load('submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# submol_freq_distrib = {int(bit): d for bit, d in submol_freq_distrib_hkl.items()}
# smarts_to_smis = hkl.load('smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# submol_to_cpd_indices = hkl.load('submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# submol_to_bit = hkl.load('submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# examples_all_hkl = hkl.load('examples_all_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# examples_all = {int(bit): {int(mol_idx): [int(ex_num) for ex_num in ex_nums] for mol_idx, ex_nums in exs.items()} for bit, exs in examples_all_hkl.items()}
# bits_to_draw_hkl = hkl.load('bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_0.hkl')
# bits_to_draw = [(int(item[0]), Chem.MolFromSmiles(df_data.iloc[int(item[0])-1]['smiles']), int(item[2]), 
#         {int(bit): tuple([(int(an[0]), int(an[1])) for an in ans]) for bit, ans in item[3].items()}) for item in bits_to_draw_hkl]

In [None]:
examples_all

In [None]:
for item in smarts_to_smis.items():
    print(f'SMARTS: {item[0]}')
    print(f'SMILES: {np.squeeze(item[1])}')
    print()
print()
print(f'Number of distinct SMARTS: {len(set(smarts_to_smis.keys()))}')
print(f'Number of distinct SMILES: {len(set([j for i in smarts_to_smis.values() for j in i]))}')

In [None]:
for bit in submol_freq_distrib:
    print(f'Bit ID: {bit}')
    for submol in submol_freq_distrib[bit]:
        print(f'Substructure (SMARTS): {submol}')
        print(f'Frequency: {submol_freq_distrib[bit][submol]}')
    print()

In [None]:
_bits_to_draw = [bit[1:] for bit in bits_to_draw]

In [None]:
# visualize each substructure
d = Draw.DrawMorganBits(_bits_to_draw, molsPerRow=4, aromaticColor=None, ringColor=None, 
                        legends=[f'cpd_id {bit[0]}, bit: {bit[2]}' for bit in bits_to_draw], subImgSize=(600, 600))
d.save(pathify(f'bits_visualization_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
d

## Substructure analysis

### Number substructures

In [None]:
# numbering correponds to ordering of bars from left to right on plot of substructure weights (see below)
ctr = 1
submol_to_id = {}
for bit in bits_of_interest:
    submols = sorted(submol_freq_distrib[bit].keys(), key = lambda x: submol_freq_distrib[bit][x], reverse=True)
    for s in submols:
        submol_to_id[s] = ctr
        ctr += 1

In [None]:
for item in submol_to_id.items():
    print(item[1])
    print(item[0])
    print(np.squeeze(smarts_to_smis[item[0]]))
    print()

### Calculate substructure weights

In [None]:
substruct_to_weight = {}
for submol in submol_to_id:
    if submol in submol_to_bit:
        bit = int(np.squeeze(submol_to_bit[submol]))
        substruct_to_weight[submol] = GetWeightsForSubstructure(bit_to_weights[bit][0], list(set(submol_to_cpd_indices[submol])))[1]

In [None]:
for item in sorted(substruct_to_weight.items(), key=lambda x: submol_to_id[x[0]]):
    print(f'Substructure (SMARTS): {item[0]}')
    print()
    print(f'Weight: {item[1]}')
    print()
    print()

### Plot substructure weights

In [None]:
substructs = sorted(substruct_to_weight.keys(), key=lambda s: submol_to_id[s])
weights = [substruct_to_weight[s] for s in substructs]
bars = []
for bit in bits_of_interest:
    _bars = []
    substruct_smarts_for_bit = submol_freq_distrib[bit].keys()
    for s in substructs:
        if s in substruct_smarts_for_bit:
            _bars.append(substruct_to_weight[s])
    bars.append(_bars)

barWidth = 1.5
offset = 1.7
space = 4
pos = len(bars[0])*barWidth / 2

In [None]:
r1 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r2 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r3 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r4 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r5 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r6 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r7 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r8 = [pos]
rs = [r1, r2, r3, r4, r5, r6, r7, r8]

In [None]:
fig = plt.figure(figsize=(7, 4), dpi=300)

rects = []
for bit in range(len(bars)):
    for i in range(len(rs[bit])):
        rects.append(plt.bar(rs[bit][i], bars[bit][i], color='#1f77b4', width=barWidth, zorder=2)[0])
        
for rect in rects[:-3]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
for rect in rects[-3:]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
        
fig.canvas.draw()
ax = plt.gca()
ax.grid(zorder=1)
ax.set_xlabel('Bit ID', fontsize=9)
ax.set_ylabel('Substructure weight', fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xticks([((len(bars[0])*barWidth / 2) + i*(space + 1.5*barWidth + (offset-barWidth))) for i in range(len(bars))])
ax.set_xticklabels([str(bit_id) for bit_id in bits_of_interest], ha='center')
ax.tick_params(axis='x', length=0)
plt.tight_layout()
plt.savefig(pathify(f'substructure_weights_DD1S_CAIX_FP-FFNN_random_seed_0.png'))
plt.show()

# Triazine sEH (random split, seed 0)

## Load data

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'triazine_lib_sEH_SIRT2_QSAR.csv'))
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', triazine_FINGERPRINTS_FILENAME))
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1] 
hf.close()

## Load model

In [None]:
model = models.MLP(INPUT_SIZE, [256, 128, 64],
            dropout=0.4)
model.load_state_dict(torch.load(triazine_sEH_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH))
if DEVICE:
    model = model.to(DEVICE)

## Bit analysis

### Calculate bit weights

In [None]:
set_bit_ids = [bit_id for bit_id in tqdm(range(2048)) if any(x[:,bit_id]==1)]
print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
hf = h5py.File('triazine_set_bit_ids', 'w')
hf.create_dataset('set_bit_ids', data=np.array(set_bit_ids))
hf.close()

In [None]:
# hf = h5py.File('triazine_set_bit_ids', 'r')
# set_bit_ids = np.array(hf['set_bit_ids'])
# hf.close()
# print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
# check if any bit is set by exactly one molecule
# for bit_id in tqdm(range(2048)):
#     indices = np.squeeze(np.where(x[:,bit_id]==1))
#     if indices.shape == ():
#         print(f'Bit {bit_id} is set by only one molecule')
#         break

In [None]:
basePreds_all = np.array(model.predict_on_x(x, device=DEVICE), dtype='float64')
bit_to_weights = {bit_id: GetWeightsForBit(bit_id) for bit_id in tqdm(set_bit_ids)}
hkl.dump(bit_to_weights, 'bit_to_weights_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# bit_to_weights_hkl = hkl.load('bit_to_weights_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# bit_to_weights = {int(bit): weights for bit, weights in bit_to_weights_hkl.items()}

### Plot distribution of average bit weights

In [None]:
avg_bit_weights = [bit_to_weights[b][1] for b in set_bit_ids]
print(f'Lowest average bit weight: {min(avg_bit_weights)}')
print(f'Highest average bit weight: {max(avg_bit_weights)}')

In [None]:
# histogram of average bit weights (only including bits set by at least one molecule)
def make_hist_avg_bit_weights(zoomIn=False):
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    bins = np.arange(-0.23, 3.53, 0.03)
    _, bins, patches = plt.hist(
        np.clip(avg_bit_weights, -0.23, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.set_xlabel('Average weight')
    ax.set_ylabel('Number of bits')
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'bit_weight_histogram_zoomed_in_triazine_sEH_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'bit_weight_histogram_triazine_sEH_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
make_hist_avg_bit_weights()

In [None]:
make_hist_avg_bit_weights(zoomIn=True)

### Get bits of interest (based on average bit weight)

In [None]:
bits_sorted_by_avg_weight = sorted(set_bit_ids, key = lambda b: bit_to_weights[b][1])
top_bits = bits_sorted_by_avg_weight[-5:]
top_bits.reverse()
bottom_bits = bits_sorted_by_avg_weight[:3]
bottom_bits.reverse()
print(f'Top bits: {top_bits}')
print(f'Bottom bits: {bottom_bits}')

In [None]:
bits_of_interest = top_bits + bottom_bits
for b in bits_of_interest:
    print(f'Bit ID: {b}')
    print(f'Average weight: {bit_to_weights[b][1]}')
    print(f'Number of molecules with the bit: {len(np.squeeze(np.where(x[:,b]==1)))}')
    print()

In [None]:
bit_to_cpd_row_indices = {bit: list([idx for idx in np.squeeze(np.where(x[:,bit]==1))]) for bit in bits_of_interest}
for item in bit_to_cpd_row_indices.items():
    print(f'Bit ID: {item[0]}')
    print(f'Number of molecules with the bit: {len(item[1])}')
    print()

### Plot distributions of molecule-level bit weights

In [None]:
# histogram of molecule-level bit weights
def make_hist_mol_level_bit_weights(bit_id, x_lb, x_ub, stepsize_noZoom, stepsize_zoom, zoomIn=False,
                                    zoomIn_y_ub=None, xticks=None):
    weights = bit_to_weights[bit_id][0]
    fig = plt.figure(figsize=(1.85, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(x_lb, x_ub, stepsize_zoom)
    else:
        bins = np.arange(x_lb, x_ub, stepsize_noZoom) 
    _, bins, patches = plt.hist(
        np.clip(weights, x_lb, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, zoomIn_y_ub])
    ax.grid(zorder=1)
    ax.tick_params(labelsize=8)
    if xticks:
        plt.xticks(xticks)
    ax.set_xlabel('Weight', fontsize=8)
    ax.set_ylabel('Number of molecules', fontsize=8)
    ax.set_title(f'Bit {bit_id}', fontsize=8)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_zoomed_in_triazine_sEH_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_triazine_sEH_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
bit = 720
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -18.04, 65.91, 2.1, None)

In [None]:
bit = 720
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -18.04, 65.91, None, 2, zoomIn=True, zoomIn_y_ub=200)

In [None]:
bit = 60
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -17.44, 46.71, 1.75, None)

In [None]:
bit = 60
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -17.44, 46.71, None, 2, zoomIn=True, zoomIn_y_ub=220)

In [None]:
bit = 793
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.16, 79.83, 2, None)

In [None]:
bit = 793
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.16, 79.83, None, 2, zoomIn=True, zoomIn_y_ub=200)

In [None]:
bit = 1767
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -18.26, 51.71, 2.1, None)

In [None]:
bit = 1767
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -18.26, 51.71, None, 2, zoomIn=True, zoomIn_y_ub=220)

In [None]:
bit = 237
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -21.09, 42.49, 1.87, None)

In [None]:
bit = 237
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -21.09, 42.49, None, 2, zoomIn=True, zoomIn_y_ub=250)

In [None]:
bit = 411
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -51.74, 11.70, 1.9, None)

In [None]:
bit = 411
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -51.74, 11.70, None, 2, zoomIn=True, zoomIn_y_ub=100)

In [None]:
bit = 2024
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -52.01, 11.27, 2, None)

In [None]:
bit = 2024
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -52.01, 11.27, None, 2, zoomIn=True, zoomIn_y_ub=100)

In [None]:
bit = 864
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -105.44, 0.57, 3.2, None)

In [None]:
bit = 864
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -105.44, 0.57, None, 3.1, zoomIn=True, zoomIn_y_ub=30)

### Get and visualize substructures

In [None]:
bits_to_draw = []
submol_freq_distrib = {bit: {} for bit in bits_of_interest} # store frequency distribution of substructures that 
                                                            # set each bit
smarts_to_smis = {}
submol_to_cpd_indices = {} # mapping to indices in df_data_hasbit
submol_to_bit = {}
examples_all = {bit: {} for bit in bits_of_interest} # check if there's more than one distinct bit-setting 
                                                     # substructure in the same molecule

for bit in tqdm(bits_of_interest):
    df_data_hasbit = df_data.iloc[bit_to_cpd_row_indices[bit]]
    smis = df_data_hasbit['smiles']
    featurizer = featurizers.FingerprintFeaturizer()
    _, info_all = featurizer.prepare_x(df_data_hasbit, bitInfo=True)
    for i, smi in enumerate(smis):
        mol = Chem.MolFromSmiles(smi)
        atomSymbols = getMorganFingerprintAtomSymbols(mol)
        cpd_id = int(df_data[df_data['smiles']==smi]['cpd_id'].to_numpy()[0])
        if bit not in info_all[i]:
            continue
        fragment_logs = getFragmentForMolBit(smi, mol, i, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw)
        submol_freq_distrib = fragment_logs[0]
        smarts_to_smis = fragment_logs[1]
        submol_to_cpd_indices = fragment_logs[2]
        submol_to_bit = fragment_logs[3]
        bits_to_draw = fragment_logs[4]
        examples = fragment_logs[5]
        if examples:
            examples_all[bit][i] = examples

In [None]:
hkl.dump(submol_freq_distrib, 'submol_freq_distrib_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(smarts_to_smis, 'smarts_to_smis_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_cpd_indices, 'submol_to_cpd_indices_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_bit, 'submol_to_bit_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(bits_to_draw, 'bits_to_draw_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(examples_all, 'examples_all_triazine_sEH_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# submol_freq_distrib_hkl = hkl.load('submol_freq_distrib_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# submol_freq_distrib = {int(bit): d for bit, d in submol_freq_distrib_hkl.items()}
# smarts_to_smis = hkl.load('smarts_to_smis_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# submol_to_cpd_indices = hkl.load('submol_to_cpd_indices_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# submol_to_bit = hkl.load('submol_to_bit_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# examples_all_hkl = hkl.load('examples_all_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# examples_all = {int(bit): {int(mol_idx): [int(ex_num) for ex_num in ex_nums] for mol_idx, ex_nums in exs.items()} for bit, exs in examples_all_hkl.items()}
# bits_to_draw_hkl = hkl.load('bits_to_draw_triazine_sEH_FP-FFNN_random_seed_0.hkl')
# bits_to_draw = [(int(item[0]), Chem.MolFromSmiles(df_data.iloc[int(item[0])-1]['smiles']), int(item[2]), 
#         {int(bit): tuple([(int(an[0]), int(an[1])) for an in ans]) for bit, ans in item[3].items()}) for item in bits_to_draw_hkl]

In [None]:
examples_all

In [None]:
for item in smarts_to_smis.items():
    print(f'SMARTS: {item[0]}')
    print(f'SMILES: {np.squeeze(item[1])}')
    print()
print()
print(f'Number of distinct SMARTS: {len(set(smarts_to_smis.keys()))}')
print(f'Number of distinct SMILES: {len(set([j for i in smarts_to_smis.values() for j in i]))}')

In [None]:
for bit in submol_freq_distrib:
    print(f'Bit ID: {bit}')
    for submol in submol_freq_distrib[bit]:
        print(f'Substructure (SMARTS): {submol}')
        print(f'Frequency: {submol_freq_distrib[bit][submol]}')
    print()

In [None]:
_bits_to_draw = [bit[1:] for bit in bits_to_draw]

In [None]:
d = Draw.DrawMorganBits(_bits_to_draw, molsPerRow=4, aromaticColor=None, ringColor=None, 
                        legends=[f'cpd_id {bit[0]}, bit: {bit[2]}' for bit in bits_to_draw], subImgSize=(600, 600))
d.save(pathify(f'bits_visualization_triazine_sEH_FP-FFNN_random_seed_0.png'))
d

## Substructure analysis

### Number substructures

In [None]:
# numbering correponds to ordering of bars from left to right on plot of substructure weights (see below)
ctr = 1
submol_to_id = {}
for bit in bits_of_interest:
    submols = sorted(submol_freq_distrib[bit].keys(), key = lambda x: submol_freq_distrib[bit][x], reverse=True)
    for s in submols:
        submol_to_id[s] = ctr
        ctr += 1

In [None]:
for item in submol_to_id.items():
    print(item[1])
    print(item[0])
    print(np.squeeze(smarts_to_smis[item[0]]))
    print()

### Calculate substructure weights

In [None]:
substruct_to_weight = {}
for submol in submol_to_cpd_indices:
    bit = int(np.squeeze(submol_to_bit[submol]))
    substruct_to_weight[submol] = GetWeightsForSubstructure(bit_to_weights[bit][0], list(set(submol_to_cpd_indices[submol])))[1]

In [None]:
for item in sorted(substruct_to_weight.items(), key=lambda x: submol_to_id[x[0]]):
    print(f'Substructure (SMARTS): {item[0]}')
    print()
    print(f'Weight: {item[1]}')
    print()
    print()

### Plot substructure weights

In [None]:
substructs = sorted(substruct_to_weight.keys(), key=lambda s: submol_to_id[s])
weights = [substruct_to_weight[s] for s in substructs]
bars = []
for bit in bits_of_interest:
    _bars = []
    substruct_smarts_for_bit = submol_freq_distrib[bit].keys()
    for s in substructs:
        if s in substruct_smarts_for_bit:
            _bars.append(substruct_to_weight[s])
    bars.append(_bars)
    
barWidth = 1.5
offset = 1.7
space = 4
pos = len(bars[0])*barWidth / 2

In [None]:
r1 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r2 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r3 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r4 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r5 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r6 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r7 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r8 = [pos]
rs = [r1, r2, r3, r4, r5, r6, r7, r8]

In [None]:
fig = plt.figure(figsize=(7, 4), dpi=300)

rects = []
for bit in range(len(bars)):
    for i in range(len(rs[bit])):
        rects.append(plt.bar(rs[bit][i], bars[bit][i], color='#1f77b4', width=barWidth, zorder=2)[0])
    
for rect in rects[:-4]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
rect = rects[-4]
height = rect.get_height()
plt.annotate(-8e-5,
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
for rect in rects[-3:]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)

fig.canvas.draw()
ax = plt.gca()
ax.grid(zorder=1)
ax.set_xlabel('Bit ID', fontsize=9)
ax.set_ylabel('Substructure weight', fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xticks([((len(bars[0])*barWidth / 2) + i*(space + 1.5*barWidth + (offset-barWidth))) for i in range(len(bars))])
ax.set_xticklabels([str(bit_id) for bit_id in bits_of_interest], ha='center')
ax.tick_params(axis='x', length=0)
plt.tight_layout()
plt.savefig(pathify(f'substructure_weights_triazine_sEH_FP-FFNN_random_seed_0.png'))
plt.show()

# Triazine SIRT2 (random split, seed 0)

## Load data

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'triazine_lib_sEH_SIRT2_QSAR.csv'))
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', triazine_FINGERPRINTS_FILENAME))
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1] 
hf.close()

## Load model

In [None]:
model = models.MLP(INPUT_SIZE, [256, 128, 64],
            dropout=0.1)
model.load_state_dict(torch.load(triazine_SIRT2_RANDOM_SPLIT_FP_FFNN_SEED_0_MODEL_PATH))
if DEVICE:
    model = model.to(DEVICE)

## Bit analysis

### Calculate bit weights

In [None]:
set_bit_ids = [bit_id for bit_id in tqdm(range(2048)) if any(x[:,bit_id]==1)]
print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
hf = h5py.File('triazine_set_bit_ids', 'w')
hf.create_dataset('set_bit_ids', data=np.array(set_bit_ids))
hf.close()

In [None]:
# hf = h5py.File('triazine_set_bit_ids', 'r')
# set_bit_ids = np.array(hf['set_bit_ids'])
# hf.close()
# print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
# check if any bit is set by exactly one molecule
# for bit_id in tqdm(range(2048)):
#     indices = np.squeeze(np.where(x[:,bit_id]==1))
#     if indices.shape == ():
#         print(f'Bit {bit_id} is set by only one molecule')
#         break

In [None]:
basePreds_all = np.array(model.predict_on_x(x, device=DEVICE), dtype='float64')
bit_to_weights = {bit_id: GetWeightsForBit(bit_id) for bit_id in tqdm(set_bit_ids)}
hkl.dump(bit_to_weights, 'bit_to_weights_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# bit_to_weights_hkl = hkl.load('bit_to_weights_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# bit_to_weights = {int(bit): weights for bit, weights in bit_to_weights_hkl.items()}

### Plot distribution of average bit weights

In [None]:
avg_bit_weights = [bit_to_weights[b][1] for b in bit_to_weights.keys()]
print(f'Lowest average bit weight: {min(avg_bit_weights)}')
print(f'Highest average bit weight: {max(avg_bit_weights)}')

In [None]:
# histogram of average bit weights (only including bits set by at least one molecule)
def make_hist_avg_bit_weights(zoomIn=False):
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn: 
        bins = np.arange(-0.04, 1.44, 0.012) 
    else:
        bins = np.arange(-0.04, 1.44, 0.014) 
    _, bins, patches = plt.hist(
        np.clip(avg_bit_weights, -0.04, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.set_xlabel('Average weight')
    ax.set_ylabel('Number of bits')
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'bit_weight_histogram_zoomed_in_triazine_SIRT2_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'bit_weight_histogram_triazine_SIRT2_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
make_hist_avg_bit_weights()

In [None]:
make_hist_avg_bit_weights(zoomIn=True)

### Get bits of interest (based on average bit weight)

In [None]:
bits_sorted_by_avg_weight = sorted(bit_to_weights.keys(), key = lambda b: bit_to_weights[b][1])
top_bits = bits_sorted_by_avg_weight[-5:]
top_bits.reverse()
bottom_bits = bits_sorted_by_avg_weight[:3]
bottom_bits.reverse()
print(f'Top bits: {top_bits}')
print(f'Bottom bits: {bottom_bits}')

In [None]:
bits_of_interest = top_bits + bottom_bits
for b in bits_of_interest:
    print(f'Bit ID: {b}')
    print(f'Average weight: {bit_to_weights[b][1]}')
    print(f'Number of molecules with the bit: {len(np.squeeze(np.where(x[:,b]==1)))}')
    print()

In [None]:
bit_to_cpd_row_indices = {bit: list([idx for idx in np.squeeze(np.where(x[:,bit]==1))]) for bit in bits_of_interest}
for item in bit_to_cpd_row_indices.items():
    print(f'Bit ID: {item[0]}')
    print(f'Number of molecules with the bit: {len(item[1])}')
    print()

### Plot distributions of molecule-level bit weights

In [None]:
# histogram of molecule-level bit weights
def make_hist_mol_level_bit_weights(bit_id, x_lb, x_ub, stepsize_noZoom, stepsize_zoom, full=False, 
                                    zoomIn=False, zoomIn_y_ub=None):
    weights = bit_to_weights[bit_id][0]
    if full:
        fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    else:
        fig = plt.figure(figsize=(1.85, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(x_lb, x_ub, stepsize_zoom)
    else:
        bins = np.arange(x_lb, x_ub, stepsize_noZoom) 
    _, bins, patches = plt.hist(
        np.clip(weights, x_lb, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, zoomIn_y_ub])
    ax.grid(zorder=1)
    ax.tick_params(labelsize=8)
    ax.set_xlabel('Weight', fontsize=8)
    ax.set_ylabel('Number of molecules', fontsize=8)
    ax.set_title(f'Bit {bit_id}', fontsize=8)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_zoomed_in_triazine_SIRT2_FP-FFNN_random_seed_0.png'))
    else:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_triazine_SIRT2_FP-FFNN_random_seed_0.png'))
    plt.show()

In [None]:
bit = 348
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.15, 2.34, 0.03, None, full=True)

In [None]:
bit = 330
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.11, 2.27, 0.08, None)

In [None]:
bit = 330
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.11, 2.27, None, 0.08, zoomIn=True, zoomIn_y_ub=5000)

In [None]:
bit = 1643
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.07, 1.97, 0.07, None)

In [None]:
bit = 1643
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.07, 1.97, None, 0.07, zoomIn=True, zoomIn_y_ub=5000)

In [None]:
bit = 991
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.09, 1.09, 0.017, None, full=True)

In [None]:
bit = 1272
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.13, 2.31, 0.09, None)

In [None]:
bit = 1272
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.13, 2.31, None, 0.09, zoomIn=True, zoomIn_y_ub=5500)

In [None]:
bit = 1334
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.64, 0.07, 0.027, None)

In [None]:
bit = 1334
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.3, 0.07, None, 0.014, zoomIn=True, zoomIn_y_ub=4500)

In [None]:
bit = 873
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.78, 0.02, 0.019, None)

In [None]:
bit = 873
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.2, 0.02, None, 0.007, zoomIn=True, zoomIn_y_ub=10000)

In [None]:
bit = 430
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.63, 0.07, 0.025, None)

In [None]:
bit = 430
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.3, 0.07, None, 0.0145, zoomIn=True, zoomIn_y_ub=3500)

### Get and visualize substructures

In [None]:
bits_to_draw = []
submol_freq_distrib = {bit: {} for bit in bits_of_interest} # store frequency distribution of substructures that 
                                                            # set each bit
smarts_to_smis = {}
submol_to_cpd_indices = {} # mapping to indices in df_data_hasbit
submol_to_bit = {}
examples_all = {bit: {} for bit in bits_of_interest} # check if there's more than one distinct bit-setting 
                                                     # substructure in the same molecule

for bit in tqdm(bits_of_interest):
    df_data_hasbit = df_data.iloc[bit_to_cpd_row_indices[bit]]
    smis = df_data_hasbit['smiles']
    featurizer = featurizers.FingerprintFeaturizer()
    _, info_all = featurizer.prepare_x(df_data_hasbit, bitInfo=True)
    for i, smi in enumerate(smis):
        mol = Chem.MolFromSmiles(smi)
        atomSymbols = getMorganFingerprintAtomSymbols(mol)
        cpd_id = int(df_data[df_data['smiles']==smi]['cpd_id'].to_numpy()[0])
        if bit not in info_all[i]:
            continue
        fragment_logs = getFragmentForMolBit(smi, mol, i, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw)
        submol_freq_distrib = fragment_logs[0]
        smarts_to_smis = fragment_logs[1]
        submol_to_cpd_indices = fragment_logs[2]
        submol_to_bit = fragment_logs[3]
        bits_to_draw = fragment_logs[4]
        examples = fragment_logs[5]
        if examples:
            examples_all[bit][i] = examples

In [None]:
hkl.dump(submol_freq_distrib, 'submol_freq_distrib_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(smarts_to_smis, 'smarts_to_smis_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_cpd_indices, 'submol_to_cpd_indices_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(submol_to_bit, 'submol_to_bit_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(bits_to_draw, 'bits_to_draw_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')
hkl.dump(examples_all, 'examples_all_triazine_SIRT2_FP-FFNN_random_seed_0.hkl', mode='w')

In [None]:
# submol_freq_distrib_hkl = hkl.load('submol_freq_distrib_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# submol_freq_distrib = {int(bit): d for bit, d in submol_freq_distrib_hkl.items()}
# smarts_to_smis = hkl.load('smarts_to_smis_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# submol_to_cpd_indices = hkl.load('submol_to_cpd_indices_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# submol_to_bit = hkl.load('submol_to_bit_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# examples_all_hkl = hkl.load('examples_all_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# examples_all = {int(bit): {int(mol_idx): [int(ex_num) for ex_num in ex_nums] for mol_idx, ex_nums in exs.items()} for bit, exs in examples_all_hkl.items()}
# bits_to_draw_hkl = hkl.load('bits_to_draw_triazine_SIRT2_FP-FFNN_random_seed_0.hkl')
# bits_to_draw = [(int(item[0]), Chem.MolFromSmiles(df_data.iloc[int(item[0])-1]['smiles']), int(item[2]), 
#         {int(bit): tuple([(int(an[0]), int(an[1])) for an in ans]) for bit, ans in item[3].items()}) for item in bits_to_draw_hkl]

In [None]:
examples_all

In [None]:
for item in smarts_to_smis.items():
    print(f'SMARTS: {item[0]}')
    print(f'SMILES: {np.squeeze(item[1])}')
    print()
print()
print(f'Number of distinct SMARTS: {len(set(smarts_to_smis.keys()))}')
print(f'Number of distinct SMILES: {len(set([j for i in smarts_to_smis.values() for j in i]))}')

In [None]:
for bit in submol_freq_distrib:
    print(f'Bit ID: {bit}')
    for submol in submol_freq_distrib[bit]:
        print(f'Substructure (SMARTS): {submol}')
        print(f'Frequency: {submol_freq_distrib[bit][submol]}')
    print()

In [None]:
# visualize central atom for unkekulizable substructure (cpd_id 74, atom 20, radius 2)
mol = Chem.MolFromSmiles(df_data.iloc[73]['smiles'])
d = Draw.MolDraw2DCairo(600, 600)
Draw.PrepareAndDrawMolecule(d, mol, highlightAtoms=[20])
d.FinishDrawing()
save_png(d.GetDrawingText(), pathify('cpd_id 74_atom 20.png'))

In [None]:
ctrs = {bit: 0 for bit in bits_of_interest}
for b in bits_to_draw:
    if b[0] == 74:
        continue # Unkekulizable; skip
    mol_examples = list(examples_all[b[2]].values())[ctrs[b[2]]]
    ctrs[b[2]] += 1
    for ex_num in mol_examples:
        d = Draw.DrawMorganBit(b[1], b[2], b[3], whichExample=ex_num, aromaticColor=None, ringColor=None)
        d.save(pathify(f'triazine_SIRT2_FP-FFNN_random_seed_0_cpd_id_{b[0]}_bit_{b[2]}_example_{ex_num}.png'), 'PNG')

## Substructure analysis

### Number substructures

In [None]:
# numbering correponds to ordering of bars from left to right on plot of substructure weights (see below)
ctr = 1
submol_to_id = {}
for bit in bits_of_interest:
    submols = sorted(submol_freq_distrib[bit].keys(), key = lambda x: submol_freq_distrib[bit][x], reverse=True)
    for s in submols:
        submol_to_id[s] = ctr
        ctr += 1

In [None]:
for item in submol_to_id.items():
    print(item[1])
    print(item[0])
    print(np.squeeze(smarts_to_smis[item[0]]))
    print()

### Calculate substructure weights

In [None]:
substruct_to_weight = {}
for submol in submol_to_cpd_indices:
    bit = int(np.squeeze(submol_to_bit[submol]))
    substruct_to_weight[submol] = GetWeightsForSubstructure(bit_to_weights[bit][0], list(set(submol_to_cpd_indices[submol])))[1]

In [None]:
for item in substruct_to_weight.items():
    print(f'Substructure (SMARTS): {item[0]}')
    print()
    print(f'Weight: {item[1]}')
    print()
    print()

### Plot substructure weights

In [None]:
substructs = sorted(substruct_to_weight.keys(), key=lambda s: submol_to_id[s])
weights = [substruct_to_weight[s] for s in substructs]
bars = []
for bit in bits_of_interest:
    _bars = []
    substruct_smarts_for_bit = submol_freq_distrib[bit].keys()
    for s in substructs:
        if s in substruct_smarts_for_bit:
            _bars.append(substruct_to_weight[s])
    bars.append(_bars)

barWidth = 2
offset = 2.2
space = 4.5

pos = len(bars[0])*barWidth / 2

In [None]:
r1 = [pos]
pos += space + 2*barWidth + (offset-barWidth)
r2 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 2*barWidth + (offset-barWidth)
r3 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 2*barWidth + (offset-barWidth)
r4 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 2*barWidth + (offset-barWidth)
r5 = [pos - 1.5*barWidth - 1.5*(offset-barWidth), pos - 0.5*barWidth - 0.5*(offset-barWidth),
            pos + 0.5*barWidth + 0.5*(offset-barWidth), pos + 1.5*barWidth + 1.5*(offset-barWidth)]
pos += space + 2*barWidth + (offset-barWidth)
r6 = [pos]
pos += space + 2*barWidth + (offset-barWidth)
r7 = [pos]
pos += space + 2*barWidth + (offset-barWidth)
r8 = [pos]
rs = [r1, r2, r3, r4, r5, r6, r7, r8]

In [None]:
fig = plt.figure(figsize=(7, 4), dpi=300)

rects = []
for bit in range(len(bars)):
    for i in range(len(rs[bit])):
        rects.append(plt.bar(rs[bit][i], bars[bit][i], color='#1f77b4', width=barWidth, zorder=2)[0])

for i in range(13):
    rect = rects[i]
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
for i in [13, 14, 15]:
    rect = rects[i]
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)

fig.canvas.draw()
ax = plt.gca()
ax.grid(zorder=1)
ax.set_xlabel('Bit ID', fontsize=9)
ax.set_ylabel('Substructure weight', fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xticks([((len(bars[0])*barWidth / 2) + i*(space + 2*barWidth + (offset-barWidth))) for i in range(len(bars))])
ax.set_xticklabels([str(bit_id) for bit_id in bits_of_interest], ha='center')
ax.tick_params(axis='x', length=0)
plt.tight_layout()
plt.savefig(pathify(f'substructure_weights_triazine_SIRT2_FP-FFNN_random_seed_0.png'))
plt.show()

# DD1S CAIX (random split, seed 1)

## Load data

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', DD1S_FINGERPRINTS_FILENAME))
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]   
hf.close()

## Load model

In [None]:
model = models.MLP(INPUT_SIZE, [128, 64, 32],
            dropout=0.4)
model.load_state_dict(torch.load(DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_1_MODEL_PATH))
if DEVICE:
    model = model.to(DEVICE)

## Bit analysis

### Calculate bit weights

In [None]:
set_bit_ids = [bit_id for bit_id in tqdm(range(2048)) if any(x[:,bit_id]==1)]
print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
# check if any bit is set by exactly one molecule
# for bit_id in tqdm(range(2048)):
#     indices = np.squeeze(np.where(x[:,bit_id]==1))
#     if indices.shape == ():
#         print(f'Bit {bit_id} is set by only one molecule')
#         break

In [None]:
basePreds_all = np.array(model.predict_on_x(x, device=DEVICE), dtype='float64')
bit_to_weights = {bit_id: GetWeightsForBit(bit_id) for bit_id in tqdm(set_bit_ids)}
hkl.dump(bit_to_weights, 'bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')

In [None]:
# bit_to_weights_hkl = hkl.load('bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# bit_to_weights = {int(bit): weights for bit, weights in bit_to_weights_hkl.items()}

### Plot distribution of average bit weights

In [None]:
avg_bit_weights = [bit_to_weights[b][1] for b in bit_to_weights.keys()]
print(f'Lowest average bit weight: {min(avg_bit_weights)}')
print(f'Highest average bit weight: {max(avg_bit_weights)}')

In [None]:
# histogram of average bit weights (only including bits set by at least one molecule)
def make_hist_avg_bit_weights(zoomIn=False):
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(-0.07, 0.21, 0.003)
    else:
        bins = np.arange(-0.07, 0.21, 0.005) 
    _, bins, patches = plt.hist(
        np.clip(avg_bit_weights, -0.07, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.set_xlabel('Average weight')
    ax.set_ylabel('Number of bits')
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'bit_weight_histogram_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
    else:
        plt.savefig(pathify(f'bit_weight_histogram_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
    plt.show()

In [None]:
make_hist_avg_bit_weights()

In [None]:
make_hist_avg_bit_weights(zoomIn=True)

### Plot distributions of molecule-level bit weights

In [None]:
# histogram of molecule-level bit weights
def make_hist_mol_level_bit_weights(bit_id, x_lb, x_ub, stepsize_noZoom, stepsize_zoom, zoomIn=False, xticks=None):
    weights = bit_to_weights[bit_id][0]
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(x_lb, x_ub, stepsize_zoom)
    else:
        bins = np.arange(x_lb, x_ub, stepsize_noZoom) 
    _, bins, patches = plt.hist(
        np.clip(weights, x_lb, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.tick_params(labelsize=8)
    if xticks:
        plt.xticks(xticks)
    ax.set_xlabel('Weight', fontsize=8)
    ax.set_ylabel('Number of molecules', fontsize=8)
    ax.set_title(f'Bit {bit_id}', fontsize=8)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
    else:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
    plt.show()

In [None]:
bit = 1489
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.02, 0.26, 0.004, None)

In [None]:
bit = 833
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.02, 0.27, 0.004, None)

In [None]:
bit = 1785
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0, 0.27, 0.004, None)

In [None]:
bit = 997
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.01, 0.27, 0.004, None)

In [None]:
bit = 1197
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0, 0.25, 0.004, None)

In [None]:
bit = 1736
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.19, -0.01, 0.003, None, xticks=[-0.175, -0.125, -0.075, -0.025])

In [None]:
bit = 258
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.1, -0.02, 0.0018, None)

In [None]:
bit = 1165
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.13, -0.02, 0.0025, None)

### Get bits of interest (based on average bit weight)

In [None]:
bits_sorted_by_avg_weight = sorted(set_bit_ids, key = lambda b: bit_to_weights[b][1])
top_bits = bits_sorted_by_avg_weight[-5:]
top_bits.reverse()
bottom_bits = bits_sorted_by_avg_weight[:3]
bottom_bits.reverse()
print(f'Top bits: {top_bits}')
print(f'Bottom bits: {bottom_bits}')

In [None]:
bits_of_interest = top_bits + bottom_bits
for b in bits_of_interest:
    print(f'Bit ID: {b}')
    print(f'Average weight: {bit_to_weights[b][1]}')
    print(f'Number of molecules with the bit: {len(np.squeeze(np.where(x[:,b]==1)))}')
    print()

In [None]:
bit_to_cpd_row_indices = {bit: list([idx for idx in np.squeeze(np.where(x[:,bit]==1))]) for bit in bits_of_interest}
for item in bit_to_cpd_row_indices.items():
    print(f'Bit ID: {item[0]}')
    print(f'Number of molecules with the bit: {len(item[1])}')
    print()

### Get and visualize substructures

In [None]:
bits_to_draw = []
submol_freq_distrib = {bit: {} for bit in bits_of_interest} # store frequency distribution of substructures that 
                                                            # set each bit
smarts_to_smis = {}
submol_to_cpd_indices = {} # mapping to indices in df_data_hasbit
submol_to_bit = {}
examples_all = {bit: {} for bit in bits_of_interest} # check if there's more than one distinct bit-setting 
                                                     # substructure in the same molecule

for bit in tqdm(bits_of_interest):
    df_data_hasbit = df_data.iloc[bit_to_cpd_row_indices[bit]]
    smis = df_data_hasbit['smiles']
    featurizer = featurizers.FingerprintFeaturizer()
    _, info_all = featurizer.prepare_x(df_data_hasbit, bitInfo=True)
    for i, smi in enumerate(smis):
        mol = Chem.MolFromSmiles(smi)
        atomSymbols = getMorganFingerprintAtomSymbols(mol)
        cpd_id = int(df_data[df_data['smiles']==smi]['cpd_id'].to_numpy()[0])
        if bit not in info_all[i]:
            continue
        fragment_logs = getFragmentForMolBit(smi, mol, i, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw)
        submol_freq_distrib = fragment_logs[0]
        smarts_to_smis = fragment_logs[1]
        submol_to_cpd_indices = fragment_logs[2]
        submol_to_bit = fragment_logs[3]
        bits_to_draw = fragment_logs[4]
        examples = fragment_logs[5]
        if examples:
            examples_all[bit][i] = examples

In [None]:
hkl.dump(submol_freq_distrib, 'submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')
hkl.dump(smarts_to_smis, 'smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')
hkl.dump(submol_to_cpd_indices, 'submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')
hkl.dump(submol_to_bit, 'submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')
hkl.dump(bits_to_draw, 'bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')
hkl.dump(examples_all, 'examples_all_DD1S_CAIX_FP-FFNN_random_seed_1.hkl', mode='w')

In [None]:
# submol_freq_distrib_hkl = hkl.load('submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# submol_freq_distrib = {int(bit): d for bit, d in submol_freq_distrib_hkl.items()}
# smarts_to_smis = hkl.load('smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# submol_to_cpd_indices = hkl.load('submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# submol_to_bit = hkl.load('submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# examples_all_hkl = hkl.load('examples_all_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# examples_all = {int(bit): {int(mol_idx): [int(ex_num) for ex_num in ex_nums] for mol_idx, ex_nums in exs.items()} for bit, exs in examples_all_hkl.items()}
# bits_to_draw_hkl = hkl.load('bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_1.hkl')
# bits_to_draw = [(int(item[0]), Chem.MolFromSmiles(df_data.iloc[int(item[0])-1]['smiles']), int(item[2]), 
#         {int(bit): tuple([(int(an[0]), int(an[1])) for an in ans]) for bit, ans in item[3].items()}) for item in bits_to_draw_hkl]

In [None]:
examples_all

In [None]:
for item in smarts_to_smis.items():
    print(f'SMARTS: {item[0]}')
    print(f'SMILES: {np.squeeze(item[1])}')
    print()
print()
print(f'Number of distinct SMARTS: {len(set(smarts_to_smis.keys()))}')
print(f'Number of distinct SMILES: {len(set([j for i in smarts_to_smis.values() for j in i]))}')

In [None]:
for bit in submol_freq_distrib:
    print(f'Bit ID: {bit}')
    for submol in submol_freq_distrib[bit]:
        print(f'Substructure (SMARTS): {submol}')
        print(f'Frequency: {submol_freq_distrib[bit][submol]}')
    print()

In [None]:
_bits_to_draw = [bit[1:] for bit in bits_to_draw]

In [None]:
# visualize each substructure
d = Draw.DrawMorganBits(_bits_to_draw, molsPerRow=4, aromaticColor=None, ringColor=None, 
                        legends=[f'cpd_id {bit[0]}, bit: {bit[2]}' for bit in bits_to_draw], subImgSize=(600, 600))
d.save(pathify(f'bits_visualization_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
d

## Substructure analysis

### Number substructures

In [None]:
# numbering correponds to ordering of bars from left to right on plot of substructure weights (see below)
ctr = 1
submol_to_id = {}
for bit in bits_of_interest:
    submols = sorted(submol_freq_distrib[bit].keys(), key = lambda x: submol_freq_distrib[bit][x], reverse=True)
    for s in submols:
        submol_to_id[s] = ctr
        ctr += 1

In [None]:
for item in submol_to_id.items():
    print(item[1])
    print(item[0])
    print(np.squeeze(smarts_to_smis[item[0]]))
    print()

### Calculate substructure weights

In [None]:
substruct_to_weight = {}
for submol in submol_to_cpd_indices:
    bit = int(np.squeeze(submol_to_bit[submol]))
    substruct_to_weight[submol] = GetWeightsForSubstructure(bit_to_weights[bit][0], list(set(submol_to_cpd_indices[submol])))[1]

In [None]:
for item in substruct_to_weight.items():
    print(f'Substructure (SMARTS): {item[0]}')
    print()
    print(f'Weight: {item[1]}')
    print()
    print()

### Plot substructure weights

In [None]:
substructs = sorted(substruct_to_weight.keys(), key=lambda s: submol_to_id[s])
weights = [substruct_to_weight[s] for s in substructs]
bars = []
for bit in bits_of_interest:
    _bars = []
    substruct_smarts_for_bit = submol_freq_distrib[bit].keys()
    for s in substructs:
        if s in substruct_smarts_for_bit:
            _bars.append(substruct_to_weight[s])
    bars.append(_bars)

barWidth = 1.5
offset = 1.7
space = 4
pos = len(bars[0])*barWidth / 2

In [None]:
r1 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r2 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r3 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r4 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r5 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r6 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r7 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r8 = [pos]
rs = [r1, r2, r3, r4, r5, r6, r7, r8]

In [None]:
fig = plt.figure(figsize=(7, 4), dpi=300)

rects = []
for bit in range(len(bars)):
    for i in range(len(rs[bit])):
        rects.append(plt.bar(rs[bit][i], bars[bit][i], color='#1f77b4', width=barWidth, zorder=2)[0])
        
for rect in rects[:-3]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
for rect in rects[-3:]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
        
fig.canvas.draw()
ax = plt.gca()
ax.grid(zorder=1)
ax.set_xlabel('Bit ID', fontsize=9)
ax.set_ylabel('Substructure weight', fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xticks([((len(bars[0])*barWidth / 2) + i*(space + 1.5*barWidth + (offset-barWidth))) for i in range(len(bars))])
ax.set_xticklabels([str(bit_id) for bit_id in bits_of_interest], ha='center')
ax.tick_params(axis='x', length=0)
plt.tight_layout()
plt.savefig(pathify(f'substructure_weights_DD1S_CAIX_FP-FFNN_random_seed_1.png'))
plt.show()

# DD1S CAIX (random split, seed 2)

## Load data

In [None]:
os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE'
df_data = pd.read_csv(os.path.join(DELQSAR_ROOT, 'experiments', 'datasets', 'DD1S_CAIX_QSAR.csv'))
hf = h5py.File(os.path.join(DELQSAR_ROOT, 'experiments', DD1S_FINGERPRINTS_FILENAME))
x = np.array(hf['all_fps'])
INPUT_SIZE = x.shape[1]   
hf.close()

## Load model

In [None]:
model = models.MLP(INPUT_SIZE, [128, 128, 128],
            dropout=0.05)
model.load_state_dict(torch.load(DD1S_CAIX_RANDOM_SPLIT_FP_FFNN_SEED_2_MODEL_PATH))
if DEVICE:
    model = model.to(DEVICE)

## Bit analysis

### Calculate bit weights

In [None]:
set_bit_ids = [bit_id for bit_id in tqdm(range(2048)) if any(x[:,bit_id]==1)]
print(f'Number of bits set by at least one molecule in the data set: {len(set_bit_ids)}')

In [None]:
# check if any bit is set by exactly one molecule
# for bit_id in tqdm(range(2048)):
#     indices = np.squeeze(np.where(x[:,bit_id]==1))
#     if indices.shape == ():
#         print(f'Bit {bit_id} is set by only one molecule')
#         break

In [None]:
basePreds_all = np.array(model.predict_on_x(x, device=DEVICE), dtype='float64')
bit_to_weights = {bit_id: GetWeightsForBit(bit_id) for bit_id in tqdm(set_bit_ids)}
hkl.dump(bit_to_weights, 'bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')

In [None]:
# bit_to_weights_hkl = hkl.load('bit_to_weights_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# bit_to_weights = {int(bit): weights for bit, weights in bit_to_weights_hkl.items()}

### Plot distribution of average bit weights

In [None]:
avg_bit_weights = [bit_to_weights[b][1] for b in bit_to_weights.keys()]
print(f'Lowest average bit weight: {min(avg_bit_weights)}')
print(f'Highest average bit weight: {max(avg_bit_weights)}')

In [None]:
# histogram of average bit weights (only including bits set by at least one molecule)
def make_hist_avg_bit_weights(zoomIn=False):
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(-0.06, 0.32, 0.003)
    else:
        bins = np.arange(-0.06, 0.32, 0.0057) # set range based on range of average bit weights
    _, bins, patches = plt.hist(
        np.clip(avg_bit_weights, -0.06, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.set_xlabel('Average weight')
    ax.set_ylabel('Number of bits')
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'bit_weight_histogram_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
    else:
        plt.savefig(pathify(f'bit_weight_histogram_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
    plt.show()

In [None]:
make_hist_avg_bit_weights()

In [None]:
make_hist_avg_bit_weights(zoomIn=True)

### Get bits of interest (based on average bit weight)

In [None]:
bits_sorted_by_avg_weight = sorted(set_bit_ids, key = lambda b: bit_to_weights[b][1])
top_bits = bits_sorted_by_avg_weight[-5:]
top_bits.reverse()
bottom_bits = bits_sorted_by_avg_weight[:3]
bottom_bits.reverse()
print(f'Top bits: {top_bits}')
print(f'Bottom bits: {bottom_bits}')

In [None]:
bits_of_interest = top_bits + bottom_bits
for b in bits_of_interest:
    print(f'Bit ID: {b}')
    print(f'Average weight: {bit_to_weights[b][1]}')
    print(f'Number of molecules with the bit: {len(np.squeeze(np.where(x[:,b]==1)))}')
    print()

In [None]:
bit_to_cpd_row_indices = {bit: list([idx for idx in np.squeeze(np.where(x[:,bit]==1))]) for bit in bits_of_interest}
for item in bit_to_cpd_row_indices.items():
    print(f'Bit ID: {item[0]}')
    print(f'Number of molecules with the bit: {len(item[1])}')
    print()

### Plot distributions of molecule-level bit weights

In [None]:
# histogram of molecule-level bit weights
def make_hist_mol_level_bit_weights(bit_id, x_lb, x_ub, stepsize_noZoom, stepsize_zoom, zoomIn=False, xticks=None):
    weights = bit_to_weights[bit_id][0]
    fig = plt.figure(figsize=(3.5, 1.6), dpi=300)
    if zoomIn:
        bins = np.arange(x_lb, x_ub, stepsize_zoom)
    else:
        bins = np.arange(x_lb, x_ub, stepsize_noZoom) 
    _, bins, patches = plt.hist(
        np.clip(weights, x_lb, bins[-1]), 
        bins=bins,  
        density=False,
        zorder=2
    )
    fig.canvas.draw() # required to get tick labels
    ax = plt.gca()
    if zoomIn:
        ax.set_ylim([0, 8])
    ax.grid(zorder=1)
    ax.tick_params(labelsize=8)
    if xticks:
        plt.xticks(xticks)
    ax.set_xlabel('Weight', fontsize=8)
    ax.set_ylabel('Number of molecules', fontsize=8)
    ax.set_title(f'Bit {bit_id}', fontsize=8)
    plt.tight_layout()
    if zoomIn:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_zoomed_in_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
    else:
        plt.savefig(pathify(f'mol-level_bit_weight_histogram_bit_{bit_id}_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
    plt.show()

In [None]:
bit = 1489
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.01, 0.84, 0.012, None)

In [None]:
bit = 833
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.02, 0.87, 0.012, None)

In [None]:
bit = 1785
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0, 0.82, 0.011, None)

In [None]:
bit = 997
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0.01, 0.87, 0.012, None)

In [None]:
bit = 1197
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, 0, 0.79, 0.012, None)

In [None]:
bit = 258
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.26, -0.01, 0.004, None)

In [None]:
bit = 1844
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.17, -0.01, 0.00305, None, xticks=[-0.175,-0.125,-0.075,-0.025])

In [None]:
bit = 1165
print(f'Bit {bit}')
print(f'Lowest molecule-level bit weight: {min(bit_to_weights[bit][0])}')
print(f'Highest molecule-level bit weight: {max(bit_to_weights[bit][0])}')

make_hist_mol_level_bit_weights(bit, -0.17, -0.01, 0.00331, None, xticks=[-0.175,-0.125,-0.075,-0.025])

### Get and visualize substructures

In [None]:
bits_to_draw = []
submol_freq_distrib = {bit: {} for bit in bits_of_interest} # store frequency distribution of substructures that 
                                                            # set each bit
smarts_to_smis = {}
submol_to_cpd_indices = {} # mapping to indices in df_data_hasbit
submol_to_bit = {}
examples_all = {bit: {} for bit in bits_of_interest} # check if there's more than one distinct bit-setting 
                                                     # substructure in the same molecule

for bit in tqdm(bits_of_interest):
    df_data_hasbit = df_data.iloc[bit_to_cpd_row_indices[bit]]
    smis = df_data_hasbit['smiles']
    featurizer = featurizers.FingerprintFeaturizer()
    _, info_all = featurizer.prepare_x(df_data_hasbit, bitInfo=True)
    for i, smi in enumerate(smis):
        mol = Chem.MolFromSmiles(smi)
        atomSymbols = getMorganFingerprintAtomSymbols(mol)
        cpd_id = int(df_data[df_data['smiles']==smi]['cpd_id'].to_numpy()[0])
        if bit not in info_all[i]:
            continue
        fragment_logs = getFragmentForMolBit(smi, mol, i, atomSymbols, cpd_id, bit, info_all, submol_freq_distrib, 
                        smarts_to_smis, submol_to_cpd_indices, submol_to_bit, bits_to_draw)
        submol_freq_distrib = fragment_logs[0]
        smarts_to_smis = fragment_logs[1]
        submol_to_cpd_indices = fragment_logs[2]
        submol_to_bit = fragment_logs[3]
        bits_to_draw = fragment_logs[4]
        examples = fragment_logs[5]
        if examples:
            examples_all[bit][i] = examples

In [None]:
hkl.dump(submol_freq_distrib, 'submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')
hkl.dump(smarts_to_smis, 'smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')
hkl.dump(submol_to_cpd_indices, 'submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')
hkl.dump(submol_to_bit, 'submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')
hkl.dump(bits_to_draw, 'bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')
hkl.dump(examples_all, 'examples_all_DD1S_CAIX_FP-FFNN_random_seed_2.hkl', mode='w')

In [None]:
# submol_freq_distrib_hkl = hkl.load('submol_freq_distrib_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# submol_freq_distrib = {int(bit): d for bit, d in submol_freq_distrib_hkl.items()}
# smarts_to_smis = hkl.load('smarts_to_smis_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# submol_to_cpd_indices = hkl.load('submol_to_cpd_indices_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# submol_to_bit = hkl.load('submol_to_bit_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# examples_all_hkl = hkl.load('examples_all_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# examples_all = {int(bit): {int(mol_idx): [int(ex_num) for ex_num in ex_nums] for mol_idx, ex_nums in exs.items()} for bit, exs in examples_all_hkl.items()}
# bits_to_draw_hkl = hkl.load('bits_to_draw_DD1S_CAIX_FP-FFNN_random_seed_2.hkl')
# bits_to_draw = [(int(item[0]), Chem.MolFromSmiles(df_data.iloc[int(item[0])-1]['smiles']), int(item[2]), 
#         {int(bit): tuple([(int(an[0]), int(an[1])) for an in ans]) for bit, ans in item[3].items()}) for item in bits_to_draw_hkl]

In [None]:
examples_all

In [None]:
for item in smarts_to_smis.items():
    print(f'SMARTS: {item[0]}')
    print(f'SMILES: {np.squeeze(item[1])}')
    print()
print()
print(f'Number of distinct SMARTS: {len(set(smarts_to_smis.keys()))}')
print(f'Number of distinct SMILES: {len(set([j for i in smarts_to_smis.values() for j in i]))}')

In [None]:
for bit in submol_freq_distrib:
    print(f'Bit ID: {bit}')
    for submol in submol_freq_distrib[bit]:
        print(f'Substructure (SMARTS): {submol}')
        print(f'Frequency: {submol_freq_distrib[bit][submol]}')
    print()

In [None]:
_bits_to_draw = [bit[1:] for bit in bits_to_draw]

In [None]:
# visualize each substructure
d = Draw.DrawMorganBits(_bits_to_draw, molsPerRow=4, aromaticColor=None, ringColor=None, 
                        legends=[f'cpd_id {bit[0]}, bit: {bit[2]}' for bit in bits_to_draw], subImgSize=(600, 600))
d.save(pathify(f'bits_visualization_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
d

## Substructure analysis

### Number substructures

In [None]:
# numbering correponds to ordering of bars from left to right on plot of substructure weights (see below)
ctr = 1
submol_to_id = {}
for bit in bits_of_interest:
    submols = sorted(submol_freq_distrib[bit].keys(), key = lambda x: submol_freq_distrib[bit][x], reverse=True)
    for s in submols:
        submol_to_id[s] = ctr
        ctr += 1

In [None]:
for item in submol_to_id.items():
    print(item[1])
    print(item[0])
    print(np.squeeze(smarts_to_smis[item[0]]))
    print()

### Calculate substructure weights

In [None]:
substruct_to_weight = {}
for submol in submol_to_cpd_indices:
    bit = int(np.squeeze(submol_to_bit[submol]))
    substruct_to_weight[submol] = GetWeightsForSubstructure(bit_to_weights[bit][0], list(set(submol_to_cpd_indices[submol])))[1]

In [None]:
for item in substruct_to_weight.items():
    print(f'Substructure (SMARTS): {item[0]}')
    print()
    print(f'Weight: {item[1]}')
    print()
    print()

### Plot substructure weights

In [None]:
substructs = sorted(substruct_to_weight.keys(), key=lambda s: submol_to_id[s])
weights = [substruct_to_weight[s] for s in substructs]
bars = []
for bit in bits_of_interest:
    _bars = []
    substruct_smarts_for_bit = submol_freq_distrib[bit].keys()
    for s in substructs:
        if s in substruct_smarts_for_bit:
            _bars.append(substruct_to_weight[s])
    bars.append(_bars)

barWidth = 1.5
offset = 1.7
space = 4
pos = len(bars[0])*barWidth / 2

In [None]:
r1 = [pos - 0.5*barWidth - 0.5*(offset-barWidth), pos + 0.5*barWidth + 0.5*(offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r2 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r3 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r4 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r5 = [pos - barWidth - (offset-barWidth), pos, pos + barWidth + (offset-barWidth)]
pos += space + 1.5*barWidth + (offset-barWidth)
r6 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r7 = [pos]
pos += space + 1.5*barWidth + (offset-barWidth)
r8 = [pos]
rs = [r1, r2, r3, r4, r5, r6, r7, r8]

In [None]:
fig = plt.figure(figsize=(7, 4), dpi=300)

rects = []
for bit in range(len(bars)):
    for i in range(len(rs[bit])):
        rects.append(plt.bar(rs[bit][i], bars[bit][i], color='#1f77b4', width=barWidth, zorder=2)[0])
        
for rect in rects[:-3]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, 3),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
for rect in rects[-3:]:
    height = rect.get_height()
    plt.annotate('{:.3f}'.format(height),
                xy=(rect.get_x() + rect.get_width() / 2, height),
                xytext=(0, -8),  
                textcoords="offset points",
                ha='center', va='bottom',
                fontsize=4.5)
        
fig.canvas.draw()
ax = plt.gca()
ax.grid(zorder=1)
ax.set_xlabel('Bit ID', fontsize=9)
ax.set_ylabel('Substructure weight', fontsize=9)
ax.tick_params(labelsize=9)
ax.set_xticks([((len(bars[0])*barWidth / 2) + i*(space + 1.5*barWidth + (offset-barWidth))) for i in range(len(bars))])
ax.set_xticklabels([str(bit_id) for bit_id in bits_of_interest], ha='center')
ax.tick_params(axis='x', length=0)
plt.tight_layout()
plt.savefig(pathify(f'substructure_weights_DD1S_CAIX_FP-FFNN_random_seed_2.png'))
plt.show()