In [3]:
BASE_PATH = '/Users/morgunov/batista/Summer/pipeline/'
PRETRAINING_PATH = BASE_PATH + '1. Pretraining/'
GENERATION_PATH = BASE_PATH + '2. Generation/'
SAMPLING_PATH = BASE_PATH + '3. Sampling/'
DIFFDOCK_PATH = BASE_PATH + '4. DiffDock/'
SCORING_PATH = BASE_PATH + '5. Scoring/'
AL_PATH = BASE_PATH + '6. ActiveLearning/'
MODE = 'Active Learning'

# Generation

In [4]:
# @title Run this cell and check all parameters!
CONTEXT = "!" # @param {type:"string"}
TEMPERATURE = 1.0 #@param {type:"slider", min:0, max:2, step:0.1}
VAL_FNAME = "moses_and_binding_no_rare_tokens_test.csv.gz"
LOAD_CKPT_NAME = "model1_softsub_al2.pt" #@param ["GPT_pretrain_07_14_23:39_1end_ignore_moses+bindingdb.pt", "model1_softsub_al1.pt", "model1_softsub_al2.pt"] {allow-input: true}
NUM_TO_GENERATE = 100_000 #@param
# @markdown Please use the following naming scheme: "model1_al{round of AL}"
# GENERATION_FNAME = "model1_al1" #@param {type:"string"}
#
if MODE == 'Pretraining':
    LOAD_CKPT_PATH = f"{PRETRAINING_PATH}model_weights/{LOAD_CKPT_NAME}"
elif MODE == 'Active Learning':
    LOAD_CKPT_PATH = f"{AL_PATH}model_weights/{LOAD_CKPT_NAME}"
else:
    raise KeyError(f'requested {MODE} but only Pretraining and Active Learning are supported')

inference_parameters = {
    "batch_size": 64,
    "gen_size": NUM_TO_GENERATE,
    "generation_context": CONTEXT,
    "load_ckpt_path": LOAD_CKPT_PATH,
}
CONFIG_DICT.update({
    "generation_path": f"{GENERATION_PATH}smiles/{CURRENT_CYCLE_PREFIX}",
    "inference_temp": TEMPERATURE,})
print("Generation will use the following dataset descriptors\n", '/'.join(CONFIG_DICT['desc_path'].split('/')[6:]))
print("... and following model weights\n", '/'.join(inference_parameters['load_ckpt_path'].split('/')[6:]))
print("... and molecules will be saved to\n", '/'.join(CONFIG_DICT['generation_path'].split('/')[6:])+ f"_temp{TEMPERATURE}_processed" +'.csv')

NameError: name 'CONFIG_DICT' is not defined

In [None]:
def generate_SMILES(config_dict, inference_parameters):
    regex = re.compile(REGEX_PATTERN)
    dataset = SMILESDataset()
    dataset.load_desc_attributes(config_dict["desc_path"])

    mconf = GPTConfig(dataset.vocab_size, dataset.block_size, **config_dict)
    model = GPT(mconf).to(config_dict["device"])
    model.load_state_dict(torch.load(inference_parameters["load_ckpt_path"], map_location=torch.device(config_dict["device"])))
    model.to(config_dict["device"])
    torch.compile(model)

    # load parameters into the model
    block_size = model.get_block_size()
    assert (block_size == dataset.block_size), "Warning: model block size and dataset block size are different"
    gen_iter = int(np.ceil(inference_parameters["gen_size"] / inference_parameters["batch_size"]))
    stoi = dataset.stoi  # define dictionary to map strings to integers
    itos = dataset.itos  # define dictionary to map integers to strings
    molecules = []
    completions = []
    for i in tqdm(range(gen_iter)):
        # create an input tensor by converting 'context' to a tensor of token indices, repeat this batch times along the batch dimension
        x = (torch.tensor([stoi[s] for s in regex.findall(inference_parameters["generation_context"])], dtype=torch.long,)[None, ...]
            .repeat(inference_parameters["batch_size"], 1).to(config_dict["device"]))
        y = sample(model, x, block_size, temperature=config_dict["inference_temp"])
        for gen_mol in y:
            completion = "".join([itos[int(i)] for i in gen_mol])  # convert generated molecule from list of integers to list of strings and concatenate to one string
            completions.append(completion)
            if "~" not in completion: continue
            mol_string = completion[1 : completion.index("~")]
            mol = get_mol(mol_string)  # convert the string representation of the molecule to an rdkit Mol object
            if mol is not None:
                molecules.append(mol)
    completions_df = pd.DataFrame({"smiles": completions})
    completions_df.to_csv(config_dict["generation_path"]+ f"_temp{config_dict['inference_temp']}_completions.csv")
    molecules_df = pd.DataFrame([{"smiles": Chem.MolToSmiles(i)} for i in molecules])

    # canon_smiles = [canonic_smiles(s) for s in molecules_df["smiles"]]
    unique_smiles = list(set(molecules_df['smiles']))
    data = pd.read_csv(config_dict["train_path"])  # load training data
    novel_ratio = check_novelty(unique_smiles, set(data[config_dict["smiles_key"]]))  # calculate novelty ratio from generated SMILES and training SMILES

    molecules_df["validity"] = np.round(len(molecules_df) / len(completions), 3)
    molecules_df["unique"] = np.round(len(unique_smiles) / len(molecules_df), 3)
    molecules_df["novelty"] = np.round(novel_ratio / 100, 3)
    molecules_df.to_csv(config_dict["generation_path"]+ f"_temp{config_dict['inference_temp']}_processed.csv")
    # print all evaluation metrics using function from moses package
    # print(moses.get_all_metrics(list(results['smiles'].values), device=config_dict['device']))



# Process GPT Predictions

In [None]:
import rdkit.Chem
import rdkit.Chem.Descriptors
from tqdm import tqdm
import pandas as pd

def descriptors_for_gpt_predictions(path_to_predicted, path_to_save):
    gpt_mols = pd.read_csv(path_to_predicted)
    keySet = None
    keyToData = {}
    pbar = tqdm(gpt_mols.iterrows(), total=len(gpt_mols))
    for index, row in pbar:
        smile = row['smiles']
        mol = rdkit.Chem.MolFromSmiles(smile)
        if not mol: continue
        mol_data = rdkit.Chem.Descriptors.CalcMolDescriptors(mol)
        if keySet is None:
            keySet = set(mol_data.keys())
        for key in keySet:
            keyToData.setdefault(key, []).append(mol_data[key])
        keyToData.setdefault('smiles', []).append(smile)
    gpt_df = pd.DataFrame(keyToData)
    gpt_df.to_pickle(path_to_save)
    return gpt_df

BASE = '/Users/morgunov/batista/Summer/'
CHECKPOINTS = BASE + 'bindingDB/checkpoints/'
GPT_DATA = BASE + 'data/'
gpt_df = descriptors_for_gpt_predictions(path_to_predicted=GPT_DATA + 'molgpt_generated_nocond_06_10_fintetune2.csv', path_to_save=CHECKPOINTS+f'gptMols_ft2.pickle')
gpt_df = descriptors_for_gpt_predictions(path_to_predicted=GPT_DATA + 'molgpt_generated_nocond_06_10.csv', path_to_save=CHECKPOINTS+f'gptMols.pickle')

# PCA-Transform

In [2]:
import pickle
import pandas as pd

def project_into_pca_space(path_to_pca, path_to_mols):
    scaler, pca = pickle.load(open(path_to_pca, 'rb'))
    gptMols = pd.read_pickle(path_to_mols)#.sample(n=10)
    return gptMols['smiles'], pca.transform(scaler.transform(gptMols[scaler.get_feature_names_out()]))

gpt_smiles, pca_transformed = project_into_pca_space(path_to_pca=PICKLES + 'scaler_pca_moses+bindingdb.pkl', path_to_mols=INFERENCES + 'GPT_pretrain_inference_07_14_23_39_1end_ignore_moses+bindingdb_temp1.0_descriptors.pkl')
pca_transformed.shape, gpt_smiles.shape

((99095, 100), (99095,))

# Exploring KMeans clustering 

In [None]:
from sklearn.cluster import KMeans
import numpy as np
from tqdm import tqdm 

def _cluster_mols_experimental_loss(mols, n_clusters, n_iter):
    min_loss, best_kmeans = float('inf'), None
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        if kmeans.inertia_ < min_loss:
            min_loss = kmeans.inertia_
            best_kmeans = kmeans
    return best_kmeans

def _cluster_mols_experimental_variance(mols, n_clusters, n_iter):
    max_variance, best_kmeans = float('-inf'), None
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        counts = np.unique(kmeans.labels_, return_counts=True)[1]
        if (variance:=np.var(counts)) > max_variance:
            max_variance = variance
            best_kmeans = kmeans
    return best_kmeans

def _cluster_mols_experimental_mixed(mols, n_clusters, n_iter, mixed_objective_loss_quantile):
    inertias = []
    variances = []
    km_objs = []
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        inertias.append(kmeans.inertia_)
        counts = np.unique(kmeans.labels_, return_counts=True)[1]
        variances.append(np.var(counts))
        km_objs.append(kmeans)
    loss_var_kmeans_triples = sorted(zip(inertias, variances, km_objs), key=lambda x: x[0])
    lowest_n = loss_var_kmeans_triples[:int(len(loss_var_tuples) * mixed_objective_loss_quantile)]
    sorted_by_variance = sorted(lowest_n, key=lambda x: x[1])
    return sorted_by_variance[0][2]

def _cluster_mols_experimental(mols, n_clusters, save_path, n_iter=1, objective='loss', mixed_objective_loss_quantile=0.1):
    if n_iter == 1:
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
    elif objective == 'loss':
        kmeans = _cluster_mols_experimental_loss(mols, n_clusters, n_iter)
    elif kmeans == 'variance':
        kmeans = _cluster_mols_experimental_variance(mols, n_clusters, n_iter)
    elif objective == 'mixed':
        kmeans = _cluster_mols_experimental_mixed(mols, n_clusters, n_iter, mixed_objective_loss_quantile)
    else:
        raise ValueError(f'Unknown objective {objective}')

    pickle.dump(best_kmeans, open(save_path, 'wb'))
    return kmeans

out = _cluster_mols_experimental(mols=pca_transformed, n_clusters=100, n_iter=1_000)

In [None]:
class Graph:
    def __init__(self):
        self.title_size = 20
        self.axis_title_size = 14
        self.tick_font_size = 12
        self.text_color="#333333"
        self.background = "white"
        self.grid_color = "#e2e2e2"
        self.line_color = "#000000"
        self.font_family = 'Helvetica'
        self.width = 600
        self.height = 400
        self.title = ''
        self.xaxis_title = ''
        self.yaxis_title = ''
    
    def update_parameters(self, params):
        for key, val in params.items():
            setattr(self, key, val)
        

    def style_figure(self, figure):
        figure.update_layout({
            'margin': {'t': 50, 'b': 50, 'l': 50, 'r': 50},
            'plot_bgcolor': self.background,
            'paper_bgcolor': self.background,
            'title': {
                'text': self.title,
                'font': {
                    'size': self.title_size,
                    'color': self.text_color,
                    'family': self.font_family
                },
            },
            'height': self.height,  # Set fixed size ratio 3:4
            'width': self.width, 
            'font': {
                'family': self.font_family,
                'size': self.tick_font_size,
                'color': self.text_color
            },
            'legend': {
                'font': {
                    'family': self.font_family,
                    'size': self.tick_font_size,
                    'color': self.text_color
                },
            },
        })

        # Setting the title size and color and grid for both x and y axes
        figure.update_xaxes(
            title=self.xaxis_title,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=True,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make x axis line visible
            linewidth=2
        )

        figure.update_yaxes(
            title=self.yaxis_title,
            title_standoff=0,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=True,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make y axis line visible
            linewidth=2
        )
        return fig


In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

loss_to_var = {loss:var for loss, var in zip(out[0], out[1])}
sort_loss, variances = zip(*sorted(loss_to_var.items(), key=lambda x: x[0]))

graph = Graph()
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=np.arange(len(sort_loss)), y=sort_loss, mode='markers', name='Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=np.arange(len(sort_loss)), y=variances, mode='markers', name='Variances'), secondary_y=True)
graph.style_figure(fig)
fig.show()
fig.write_html(CHECKPOINTS + 'kmeans_sort_loss_vs_variance.html', include_plotlyjs='cdn')

In [None]:
loss_to_var = {loss:var for loss, var in zip(out[0], out[1])}
loss, sort_variances = zip(*sorted(loss_to_var.items(), key=lambda x: x[1]))

graph = Graph()
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=np.arange(len(loss)), y=loss, mode='markers', name='Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=np.arange(len(loss)), y=sort_variances, mode='markers', name='Variances'), secondary_y=True)
graph.style_figure(fig)
fig.show()
fig.write_html(CHECKPOINTS + 'kmeans_sort_loss_vs_variance.html', include_plotlyjs='cdn')

# K-Means clustering

In [10]:
import numpy as np

def sample_based_on_distance_percentiles(elements, distances, n_samples, n_percentiles):
    assert len(elements) == len(distances), "Elements and distances lists must be of the same length"
    assert n_samples <= len(elements), "Number of samples cannot exceed the total number of elements"
    assert n_percentiles > 0, "Number of percentiles must be a positive integer"
    
    # Sort elements and distances together based on ascending order of distances
    distances, elements = zip(*sorted(zip(distances, elements)))
    
    # Compute the percentiles
    percentile_values = [np.percentile(distances, p * 100 / n_percentiles) for p in range(1, n_percentiles)]
    percentile_values.append(np.inf)  # the highest percentile encompasses all remaining points
    
    # Divide data into percentiles
    elements_by_percentile = []
    start_index = 0
    for percentile_value in percentile_values:
        end_index = start_index
        while end_index < len(distances) and distances[end_index] <= percentile_value:
            end_index += 1
        elements_by_percentile.append(elements[start_index:end_index])
        start_index = end_index
    
    # Sample from each percentile
    samples_per_percentile = n_samples // n_percentiles
    remaining_samples = n_samples % n_percentiles
    samples = []
    for i, percentile_elements in enumerate(elements_by_percentile):
        if len(percentile_elements) <= samples_per_percentile:
            # If we don't have enough elements in this percentile, take them all and
            # add the deficit to remaining_samples so it can be distributed among subsequent percentiles
            samples += percentile_elements
            remaining_samples += samples_per_percentile - len(percentile_elements)
        else:
            # Sample elements from this percentile
            samples += list(np.random.choice(percentile_elements, size=samples_per_percentile, replace=False))
        
        # Distribute remaining_samples among the last n_percentiles
        if i >= n_percentiles - remaining_samples:
            extra_samples = min(len(percentile_elements) - samples_per_percentile, 1)
            samples += list(np.random.choice([el for el in percentile_elements if el not in samples], size=extra_samples, replace=False))
            
    return samples

In [11]:
from sklearn.cluster import KMeans
from pprint import pprint as pp 
import numpy as np
import random

def _cluster_mols(mols, n_clusters, save_path, n_iter=1):
    """
        Performs K-Means clustering on a given list of molecules and saves the model to a specified file.

        This function will apply the K-Means algorithm to the input list of molecules. If n_iter is set to 1 (default), the function will perform the clustering once and return the KMeans object. If n_iter is set to more than 1, the function will perform the clustering n_iter times and return the KMeans object with the lowest inertia (i.e., the sum of squared distances of samples to their closest cluster center). The function will save the KMeans object to a file at the specified save_path using pickle.

        Parameters

            mols : array-like or sparse matrix, shape (n_samples, n_features)
            The input samples where n_samples is the number of samples and n_features is the number of features.

            n_clusters : int
            The number of clusters to form as well as the number of centroids to generate.

            save_path : str
            The path (including file name) where the resulting KMeans object should be saved.

            n_iter : int, optional (default=1)
            The number of times to perform the clustering. If greater than 1, the function will return the KMeans object with the lowest inertia.

        Returns

            kmeans : sklearn.cluster._kmeans.KMeans
            A KMeans instance trained on the input molecules. If n_iter is greater than 1, it's the best performing model (lowest inertia) from all iterations.

    """
    if n_iter == 1:
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        pickle.dump(kmeans, open(save_path, 'wb'))
        return kmeans
    best_kmeans = None
    best_inertia = float('inf')
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        if kmeans.inertia_ < best_inertia:
            best_kmeans = kmeans
            best_inertia = kmeans.inertia_
    pickle.dump(best_kmeans, open(save_path, 'wb'))
    return best_kmeans

def cluster_and_sample(mols, mols_smiles, n_clusters, n_samples, kmeans_save_path, clusters_save_path, diffdock_save_path, 
                        ensure_correctness=False, path_to_pca=None, probabilistic_sampling=True, load_kmeans=False,
                        percentile_sampling=True, n_percentiles=1):
    """
        Clusters a given list of molecules, samples from each cluster, and saves the resulting data to specified files.

        This function performs K-Means clustering on the input list of molecules and then samples a specified number of molecules 
        from each cluster. The function ensures that the number of samples requested from each cluster doesn't exceed the total number 
        of available molecules. The clustered data and sampled data are saved to specified file paths using pickle.

        Parameters
        ----------
        mols : array-like or sparse matrix, shape (n_samples, n_features)
            The input samples where n_samples is the number of samples and n_features is the number of features.

        mols_smiles : list of str
            A list of SMILES strings corresponding to the input molecules.

        n_clusters : int
            The number of clusters to form as well as the number of centroids to generate.

        n_samples : int
            The number of samples to draw from each cluster.

        kmeans_save_path : str
            The path (including file name) where the resulting KMeans object should be saved.

        clusters_save_path : str
            The path (including file name) where the resulting clusters should be saved.

        ensure_correctness : bool, optional (default=False)
            If True, performs additional correctness checks, such as comparing SMILES string derived features to features in mols array. 
            This requires 'path_to_pca' to be set.

        path_to_pca : str, optional (default=None)
            If ensure_correctness is True, this should be the path to a PCA model used to transform the molecules' descriptors.

        Returns
        -------
        cluster_to_samples : dict
            A dictionary where the keys are cluster labels and the values are lists of sampled SMILES strings from each cluster.

        Raises
        ------
        AssertionError
            If the number of requested samples exceeds the total number of molecules provided.
            If ensure_correctness is True but path_to_pca is None.
            If the number of labels returned by the KMeans algorithm differs from the number of molecules.
            If features calculated from a smile string differ from features in the mols array.
            If the total number of sampled molecules doesn't equal to n_clusters * n_samples.

    """
    assert n_clusters * n_samples <= len(mols), f"{n_clusters=} * {n_samples=} = {n_clusters*n_samples} requested but only {len(mols)} molecules provided"
    if ensure_correctness:
        assert path_to_pca is not None, "path_to_pca must be provided to ensure correctness"
        scaler, pca = pickle.load(open(path_to_pca, 'rb'))

    if load_kmeans:
        kmeans = pickle.load(open(kmeans_save_path, 'rb'))
    else:
        kmeans = _cluster_mols(mols=mols, n_clusters=n_clusters, save_path=kmeans_save_path)
        assert len(kmeans.labels_) == len(mols_smiles), "Number of labels differs from number of molecules"
    distances = kmeans.transform(mols)

    cluster_to_mols = {}
    cluster_to_distances = {}
    for mol, distance, label, smile in zip(mols, distances, kmeans.labels_, mols_smiles):
        cluster_to_mols.setdefault(label, []).append(smile)
        cluster_to_distances.setdefault(label, []).append(distance.min())
        if ensure_correctness: # recalculate descriptors from a smile string and compare to the descriptors in the array
            smile_features = pca.transform(scaler.transform(pd.DataFrame({k: [v] for k, v in rdkit.Chem.Descriptors.CalcMolDescriptors(rdkit.Chem.MolFromSmiles(smile)).items()})[scaler.get_feature_names_out()]))
            assert np.allclose(smile_features[0], mol), "Features calculated from a smile string differ from features in the array"

    pickle.dump((kmeans.labels_, cluster_to_distances), open(clusters_save_path.split('.')[0]+'_cl_to_d.pickle', 'wb'))
    # What happens below is sampling from each cluster. All the extra code is to ensure that the number of samples requested from each cluster
    # doesn't exceed the total number of available molecules. This is done by calculating the average number of molecules per cluster and then
    # calculating the number of extra molecules that need to be sampled from each cluster. The extra molecules are then distributed among the
    # clusters uniformly. If the number of extra molecules is greater than the number of molecules in a cluster, all
    # molecules from that cluster are sampled.
    avg_len = np.mean([len(v) for v in cluster_to_mols.values()])
    cluster_to_samples = {}
    extra_mols = 0
    left_to_sample = n_clusters*n_samples
    cluster_to_len = {cluster:len(mols) for cluster, mols in cluster_to_mols.items()}
    for i, (cluster, _) in enumerate(sorted(cluster_to_len.items(), key=lambda x: x[1], reverse=False)):
        smiles = cluster_to_mols[cluster]
        if extra_mols > 0:
            cur_extra = int(1+extra_mols/(len(cluster_to_mols) - i) * len(smiles)/avg_len)
            cur_samples = n_samples + cur_extra
            extra_mols -= cur_extra
        else:
            cur_samples = n_samples
        if cur_samples > left_to_sample:
            cur_samples = left_to_sample

        if len(smiles) > cur_samples:
            if probabilistic_sampling:
                cluster_to_samples[cluster] = np.random.choice(smiles, cur_samples, p=cluster_to_distances[cluster]/np.sum(cluster_to_distances[cluster]), replace=False)
            elif percentile_sampling:
                cluster_to_samples[cluster] = sample_based_on_distance_percentiles(smiles, cluster_to_distances[cluster], n_samples=cur_samples, n_percentiles=n_percentiles)
            else:
                cluster_to_samples[cluster] = np.random.choice(smiles, cur_samples, replace=False)
            left_to_sample -= cur_samples
        else:
            cluster_to_samples[cluster] = smiles
            left_to_sample -= len(smiles)
            extra_mols += cur_samples - len(smiles)

    assert (n_sampled:=sum(len(vals) for vals in cluster_to_samples.values())) == n_clusters*n_samples, f"Sampled {n_sampled} but were requested {n_clusters*n_samples}"
    pickle.dump(cluster_to_mols, open(clusters_save_path, 'wb'))
    pickle.dump(cluster_to_samples, open(clusters_save_path.split('.')[0] + '_samples.pickle', 'wb'))
    keyToData = {}
    for cluster, mols in cluster_to_samples.items():
        for mol in mols:
            keyToData.setdefault('smiles', []).append(mol)
            keyToData.setdefault('cluster_id', []).append(cluster)
    pd.DataFrame(keyToData).to_csv(diffdock_save_path)
    return cluster_to_samples

In [16]:
nclusters = 100
# bkmeans = _cluster_mols(mols=pca_transformed, n_clusters=nclusters, save_path=CHECKPOINTS + 'k100means_07_11.pickle')
c_to_s = cluster_and_sample(mols=pca_transformed, mols_smiles=gpt_smiles, n_clusters=nclusters, n_samples=10, 
                   kmeans_save_path=PICKLES + 'k100means_07_14_23_39_1end_ignore_moses+bindingdb.pickle', ensure_correctness=False, path_to_pca=PICKLES + 'scaler_pca_moses+bindingdb.pickle',
                   clusters_save_path=PICKLES + 'cluster_to_samples_07_16.pickle', probabilistic_sampling=False, percentile_sampling=True, n_percentiles=4,
                   diffdock_save_path=INFERENCES + f'samples_percentiles_4.csv')

In [None]:
labels, cluster_to_distances = pickle.load(open(CHECKPOINTS + 'cluster_to_samples_07_11_cl_to_d.pickle', 'rb'))

In [None]:
all_distances = np.array([distance for distances in cluster_to_distances.values() if len(distances) > 10 for distance in distances])
min_d, max_d = np.min(all_distances), np.max(all_distances)
hist, bin_edges = np.histogram(all_distances, bins=10)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
sine_distances = np.sin(2*np.pi*bin_centers) * max(all_distances)
# sine_distances = np.sin(all_distances*2*np.pi/np.sum(all_distances))
len(all_distances)

In [None]:
import plotly.graph_objects as go


# Plot a histogram based on all distances (which are python lists) in cluster_to_distances
fig = go.Figure()

fig.add_trace(go.Histogram(x=all_distances, xbins=dict(start=min_d, end=max_d, size=0.01*(max_d-min_d)), name='all_distances'))
fig.add_trace(go.Scatter(x=bin_centers, y=sine_distances, name='sine transformed distances'))
fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.75)
fig.show()

In [None]:
import plotly.graph_objects as go


# Plot a histogram based on all distances (which are python lists) in cluster_to_distances
fig = go.Figure()
fig.add_trace(go.Scatter(x=pca_transformed[:, 7], y=pca_transformed[:, 50], marker=dict(color=labels), text=labels, mode='markers', name='gpt generated'))
fig.update_traces(opacity=0.75)
fig.show()

# Prepare training dataset for active learning

In [None]:

def _preprocess_scores_uniformly(scores, remove_positives=False, lowest_score=1):
    """
        Preprocesses a dictionary of scores by negating and normalizing them.

        The function negates all scores and optionally removes positive scores. If the minimum value among the negated scores 
        is less than zero, it shifts all values by subtracting the minimum value and adding 'lowest_score'. The final step is 
        to normalize the scores so that their total sum equals to 1.

        Parameters
        ----------
        scores : dict
            A dictionary of scores where the keys are identifiers and the values are their corresponding scores.

        remove_positives : bool, optional (default=False)
            If True, all positive scores are removed after negation.

        lowest_score : int, optional (default=1)
            This value is added to all scores if the minimum score is less than zero.

        Returns
        -------
        normalized : dict
            The normalized dictionary of scores.

    """
    negated = {k: -v for k, v in scores.items()}
    min_value = min(negated.values())
    if min_value < 0:
        if remove_positives:
            negated = {k: v for k, v in negated.items() if v > 0}
        else:
            negated = {k: v - min_value + lowest_score for k, v in negated.items()}
    total = sum(negated.values())
    normalized = {k: v / total for k, v in negated.items()}
    return normalized

def _preprocess_scores_softmax(scores):
    negated = {k: -v for k, v in scores.items()}
    max_value = max(negated.values())
    exponentiate = {k: np.exp(v - max_value) for k, v in negated.items()}
    total = sum(exponentiate.values())
    softmax = {k: v / total for k, v in exponentiate.items()}
    return softmax

def balance_cluster_to_n(cluster_to_n, cluster_to_len):
    """
        Balances the target number of samples for each cluster to ensure it doesn't exceed the actual size of the cluster.

        The function first calculates the surplus (i.e., the excess of the target number over the actual size) for each cluster. 
        Then, it distributes the total surplus proportionally among the clusters that have a deficit (i.e., the target number is less than the actual size). 
        If after this distribution, there's still a deficit (i.e., the sum of target numbers is less than the sum of actual sizes), the function 
        increases the target number of the largest clusters one by one until the sum of target numbers equals to the sum of actual sizes.

        Parameters
        ----------
        cluster_to_n : dict
            A dictionary mapping cluster identifiers to their target number of samples.

        cluster_to_len : dict
            A dictionary mapping cluster identifiers to the actual size of each cluster.

        Returns
        -------
        balanced : dict
            A dictionary mapping cluster identifiers to their balanced target number of samples.

        Raises
        ------
        AssertionError
            If the sum of target numbers before and after balancing don't match.

    """

    surplus = {key: cluster_to_n[key] - cluster_to_len[key] for key in cluster_to_n if cluster_to_n[key] > cluster_to_len[key]}
    balanced = {k:v for k, v in cluster_to_n.items()}
    n_to_cluster = {v: k for k, v in cluster_to_n.items()}

    for key in surplus:
        balanced[key] = cluster_to_len[key]

    total_surplus = sum(surplus.values())
    initial_n_sum = sum(n for key, n in cluster_to_n.items() if key not in surplus)

    for key in balanced:
        if key in surplus: continue
        surplus_to_add = total_surplus * cluster_to_n[key] / initial_n_sum
        new_n = int(cluster_to_n[key] + surplus_to_add)
        balanced[key] = min(new_n, cluster_to_len[key])

    deficit = sum(cluster_to_n.values()) - sum(balanced.values())

    while deficit > 0:
        for initial_n in sorted(n_to_cluster, reverse=True):
            if (cluster:=n_to_cluster[initial_n]) in surplus: continue
            if balanced[cluster] < cluster_to_len[cluster]:
                balanced[cluster] += 1
                deficit -= 1
    
    assert sum(cluster_to_n.values()) == sum(balanced.values()), f"Before balancing had {sum(cluster_to_n.values())}, post balancing = {sum(balanced.values())}"
    return balanced

def sample_clusters_for_active_learning(cluster_to_scores, n_samples, path_to_clusters, probability_type='softmax', remove_positives=False, lowest_score=1):
    """
        Sample molecules from clusters for active learning purposes, considering previously docked molecules and balancing the sampling among clusters.

        This function uses either softmax or uniform probabilities to determine how many molecules to sample from each cluster. The function then samples 
        the required number of new molecules (i.e., those not present in docked_mols) from each cluster. The sampling is balanced to ensure the target number 
        doesn't exceed the actual size of the cluster.

        Parameters
        ----------
        cluster_to_scores : dict
            A dictionary mapping cluster identifiers to their scores.

        n_samples : int
            The total number of molecules to sample.

        path_to_clusters : str
            The path to a pickle file storing a dictionary that maps each cluster to a list of molecules.

        probability_type : str, optional (default='softmax')
            The type of probability distribution used to determine the number of samples per cluster. 
            Options are 'softmax' and 'uniform'.

        remove_positives : bool, optional (default=False)
            Only used when probability_type is 'uniform'. If True, positive scores are removed after negation.

        lowest_score : int, optional (default=1)
            Only used when probability_type is 'uniform'. This value is added to all scores if the minimum score is less than zero.

        Returns
        -------
        training : list
            A list of randomly sampled molecules for active learning.

        Raises
        ------
        KeyError
            If an unsupported probability_type is provided.
        AssertionError
            If the number of sampled molecules doesn't equal to n_samples.

    """
    if probability_type == 'softmax':
        probability_function = _preprocess_scores_softmax 
    elif probability_type == 'uniform':
        probability_function = lambda x: _preprocess_scores_uniformly(x, remove_positives, lowest_score)
    else:
        raise KeyError("Only uniform and softmax probabilities are supported")
    cluster_to_mols = pickle.load(open(path_to_clusters, 'rb'))
    cluster_to_samples = pickle.load(open(path_to_clusters.split('.')[0] + '_samples.pickle', 'rb'))
    docked_mols = {smile for smiles in cluster_to_samples.values() for smile in smiles}
    cluster_to_new_mols = {k: [smile for smile in v if smile not in docked_mols] for k, v in cluster_to_mols.items()}

    probabilities = probability_function(cluster_to_scores)
    cluster_to_n = {k: int(v * n_samples) for k, v in probabilities.items()}
    max_cluster_id, max_prob = None, 0
    for cluster, prob in probabilities.items():
        if prob > max_prob:
            max_cluster_id, max_prob = cluster, prob
    cluster_to_n[max_cluster_id] += n_samples - sum(cluster_to_n.values())

    cluster_to_len = {k: len(v) for k, v in cluster_to_new_mols.items()}
    balanced = balance_cluster_to_n(cluster_to_n, cluster_to_len)

    training = []
    for i, (cluster, n) in enumerate(balanced.items()):
        training.extend(np.random.choice(cluster_to_new_mols[cluster], n, replace=False))
        
    assert len(training) == n_samples, f"{len(training)=} != {n_samples=}"
    return training

sample_clusters_for_active_learning({k: random.random() for k in range(nclusters)}, n_samples=10, path_to_clusters=CHECKPOINTS + 'cluster_to_samples_07_07.pickle')