In [5]:
BASE_PATH = '/Users/morgunov/batista/Summer/pipeline/'
PRETRAINING_PATH = BASE_PATH + '1. Pretraining/'
GENERATION_PATH = BASE_PATH + '2. Generation/'
SAMPLING_PATH = BASE_PATH + '3. Sampling/'
DIFFDOCK_PATH = BASE_PATH + '4. DiffDock/'
SCORING_PATH = BASE_PATH + '5. Scoring/'
AL_PATH = BASE_PATH + '6. ActiveLearning/'
MODE = 'Active Learning'

In [None]:
import numpy as np
from sklearn.manifold import TSNE
import umap.umap_ as umap
import matplotlib.pyplot as plt

# Assume data is a numpy array of shape (n, d)
# data = np.random.rand(1000, 110)  # Uncomment this line to test with random data

def run_tsne(data, n_components=2, perplexity=30):
    tsne = TSNE(n_components=n_components, perplexity=perplexity)
    tsne_results = tsne.fit_transform(data)
    return tsne_results

def run_umap(data, n_components=2, n_neighbors=15):
    reducer = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors)
    umap_results = reducer.fit_transform(data)
    return umap_results

def plot_results(results, title):
    plt.figure(figsize=(8, 6))
    plt.scatter(results[:, 0], results[:, 1])
    plt.title(title)
    plt.grid(True)

# Run t-SNE and UMAP, and plot the results
tsne_results = run_tsne(data)
plot_results(tsne_results, 't-SNE results')

umap_results = run_umap(data)
plot_results(umap_results, 'UMAP results')

plt.show()


# PCA Exploration

In [114]:
import pickle

all_descriptors = pickle.load(open(f"{SAMPLING_PATH}descriptors/descriptors_moses+bindingdb.pkl", 'rb'))
valid_columns = pickle.load(open(f"{SAMPLING_PATH}descriptors/descriptors_moses+bindingdb_columnlist.pkl", 'rb'))
export_columns_to_yaml(valid_columns, "valid_columns")
valid_descriptors = all_descriptors[valid_columns]
no_fr, no_counts, no_fr_counts = process_columns(valid_columns)

Removed 85 functional group column, 110 left
Removed 19 count columns, 176 left
Removed 104 count and functional group columns, 91 left


In [110]:
import yaml
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from graph import Graph
import numpy as np
import plotly.graph_objects as go
from tqdm import tqdm

def export_columns_to_yaml(columns, fname):
    with open(f"pca_study/columns/{fname}.yaml", 'w') as f:
        yaml.dump(columns, f)

# Load from yaml file
def load_columns_from_yaml(fname):
    with open(f"pca_study/columns/{fname}.yaml", 'r') as f:
        columns = yaml.load(f, Loader=yaml.FullLoader)
    return columns
def process_columns(columns):
    # Remove all the functional group counts
    no_fr = [col for col in columns if not col.startswith('fr_')]
    print(f"Removed {len(columns) - len(no_fr)} functional group column, {len(no_fr)} left")
    # Remove all the counts descriptors
    no_counts = [col for col in columns if "count" not in col.lower() and "num" not in col.lower()]
    print(f"Removed {len(columns) - len(no_counts)} count columns, {len(no_counts)} left")
    # do the dumps
    export_columns_to_yaml(no_fr, 'no_fr')
    export_columns_to_yaml(no_counts, 'no_counts')
    no_fr_counts = sorted(list(set(no_fr) & set(no_counts)))
    print(f"Removed {len(columns) - len(no_fr_counts)} count and functional group columns, {len(no_fr_counts)} left")
    export_columns_to_yaml(no_fr_counts, 'no_fr_counts')
    return no_fr, no_counts, no_fr_counts

def fit_pca(data, n_comps=2, whiten=False):
    scaler = StandardScaler()
    scaled_data = scaler.fit_transform(data)
    pca = PCA(n_components=n_comps, whiten=whiten)
    pca.fit(scaled_data)
    return scaler, pca

def plot_explained_variance(pca, title, suffix=None):
    explained_variance = pca.explained_variance_ratio_.cumsum()
    x = np.arange(1, len(explained_variance) + 1)

    graph = Graph()
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=explained_variance, mode='markers', name='Explained variance'))
    fig.add_shape(type='line', line=dict(dash='dash'), x0=0, x1=100, y0=1, y1=1)

    # Calculate the number of principal components explaining certain percentages of the variance
    percentages = [50, 75, 90, 95, 99, 99.99]
    n_components = [np.ceil(1+np.argmax(explained_variance >= p / 100.0)) for p in percentages]
    text_block = '<br>'.join(f'{p}% of variance is explained by {n} components' for p, n in zip(percentages, n_components))

    # Add text block
    fig.add_annotation(x=0.8, y=0.1, xref='paper', yref='paper', text=text_block, showarrow=False)
    graph.update_parameters(dict(width=800, height=400, title=title, xaxis_title='Number of components', yaxis_title='Explained variance'))
    graph.style_figure(fig)
    if suffix is None: suffix = ''
    fig.write_html(f"pca_study/plots/explained_variance{suffix}.html")
    return fig

def plot_correlation_circle(pca, features, title, suffix):
    pcs = pca.components_
    graph = Graph()

    # Create a trace for the variable vectors
    vectors = go.Scatter(
        x=pcs[0, :],
        y=pcs[1, :],
        mode='lines+markers+text',
        text=features,
        textposition='top center',
        line=dict(color='red'),
        marker=dict(size=10, color='blue'),
        textfont=dict(size=8)
    )

    # Create a trace for the unit circle
    circle = go.Scatter(
        x=np.cos(np.linspace(0, 2*np.pi, 100)),
        y=np.sin(np.linspace(0, 2*np.pi, 100)),
        mode='lines',
        line=dict(color='blue', width=1),
        showlegend=False
    )

    fig = go.Figure(data=[vectors, circle])
    graph.update_parameters(dict(width=600, height=600, title=title,
                                 xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]*100:.2f}%)",
                                 yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]*100:.2f}%)",
                                 showlegend=False,))
    graph.style_figure(fig)
    fig.write_html(f"pca_study/plots/corr_circle{suffix}.html")
    return fig



In [112]:
configs = [
    dict(cols=valid_columns, title=f'Explained variance by PCA on df with {len(valid_columns)} valid columns', suffix='_valid_columns'),
    dict(cols=cols[0], title=f'Explained variance by PCA on df with {len(cols[0])} columns (no functional groups)', suffix='_no_fr'),
    dict(cols=cols[1], title=f'Explained variance by PCA on df with {len(cols[1])} columns (no counts)', suffix='_no_counts'),
    dict(cols=cols[2], title=f'Explained variance by PCA on df with {len(cols[2])} columns (no counts and functional groups)', suffix='_no_fr_counts'),
]
pbar = tqdm(configs, total=len(configs))
for config in pbar:
    pbar.set_description(f"Processing {config['suffix']}")
    # n_comps = min(100, len(config['cols']))
    # scaler, pca = fit_pca(valid_descriptors[config['cols']], n_comps=n_comps, whiten=False)
    scaler, pca = pickle.load(open(f"pca_study/checkpoints/scaler_pca{config['suffix']}.pkl", 'rb'))
    plot_explained_variance(pca, title=config['title'], suffix=config['suffix'])
    plot_correlation_circle(pca, config['cols'], config['title'].replace('Explained variance by PCA', 'Correlation circle for PCA'), suffix=config['suffix'])

Removed 85 functional group column, 110 left
Removed 19 count columns, 176 left
Removed 104 count and functional group columns, 91 left


Processing _no_fr_counts: 100%|██████████| 4/4 [00:00<00:00, 25.11it/s]


In [145]:
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import wandb
torch.manual_seed(42)  # or any other seed you prefer

<torch._C.Generator at 0x2adb63b50>

In [146]:
scaler = StandardScaler()
scaled = scaler.fit_transform(all_descriptors[no_fr]) #.sample(n=200_000, random_state=42)
scaled_tensor = torch.tensor(scaled).float().to('mps')
print(scaled_tensor.shape)
train_size = int(0.8*scaled_tensor.shape[0])
test_size = int(0.1*scaled_tensor.shape[0])
val_size = scaled_tensor.shape[0] - train_size - test_size
print(train_size, val_size, test_size)
train_split, val_split, test_split = torch.utils.data.random_split(scaled_tensor, [train_size, val_size, test_size])

torch.Size([2894910, 110])
2315928 289491 289491


In [147]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, hidden_units, latent_dim, do_norm_layer=False, activation=nn.ReLU):
        super().__init__()
        units = [input_dim, *hidden_units, latent_dim]
        rev_units = units[::-1]
        encode_layers, decode_layers = [], []
        for layer_i, layer_ip1 in zip(units, units[1:]):
            encode_layers.append(nn.Linear(layer_i, layer_ip1))
            if layer_ip1 != latent_dim:
                if do_norm_layer:
                    encode_layers.append(nn.LayerNorm(layer_ip1))
                encode_layers.append(activation())

        for layer_i, layer_ip1 in zip(rev_units, rev_units[1:]):
            decode_layers.append(nn.Linear(layer_i, layer_ip1))
            if layer_ip1 != input_dim:
                if do_norm_layer:
                    decode_layers.append(nn.LayerNorm(layer_ip1))
                decode_layers.append(activation())

        self.encoder = nn.Sequential(*encode_layers)
        self.decoder = nn.Sequential(*decode_layers)

    def forward(self, x):
        return self.decoder(self.encoder(x))


def run_one_epoch(model, mode, data_loader, criterion, optimizer, wandb, nbatches_step_loss):
    assert mode in {"train", "val"}
    is_train = mode == "train"
    if is_train: model.train()
    else: model.eval()

    epoch_losses = []
    step_losses = []
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for batch_ind, batch in pbar:
        batch.to('mps')
        output = model(batch)
        loss = criterion(output, batch)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        current_loss = loss.item()
        epoch_losses.append(current_loss)
        step_losses.append(current_loss)

        if batch_ind != 0 and batch_ind % nbatches_step_loss == 0:
            average_step_loss = np.mean(step_losses)
            wandb.log({f"{mode}_loss": average_step_loss})
            step_losses = []  # reset the list after logging

    average_epoch_loss = np.mean(epoch_losses)
    return average_epoch_loss, epoch_losses

def train(model, train_data, valid_data, wandb, num_epochs=5, learning_rate=0.001, 
          do_validation=True, nbatches_step_loss=50, ckpt_fname='train_ckpt', do_warmup=False, do_decay=False, warmup_epochs=0, decay_epochs=0):
    assert warmup_epochs + decay_epochs <= num_epochs, "The sum of warmup and decay epochs should not exceed total epochs"

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    best_loss = np.inf

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        if do_warmup and epoch < warmup_epochs:  # learning rate warmup phase
            lr = learning_rate * (epoch + 1) / warmup_epochs
        elif do_decay and epoch < warmup_epochs + decay_epochs:  # learning rate decay phase
            lr = learning_rate * 0.5 * (1 + np.cos(np.pi * (epoch - warmup_epochs) / decay_epochs))
        else:  # constant learning rate phase
            lr = learning_rate

        # apply new learning rate to the optimizer
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
            
        epoch_train_loss, _ = run_one_epoch(model, "train", train_data, criterion, optimizer, wandb, nbatches_step_loss)
        wandb.log({"epoch_train_loss": epoch_train_loss, "learning_rate": lr, "epoch": epoch})

        if do_validation:
            epoch_validation_loss, _ = run_one_epoch(model, "val", valid_data, criterion, optimizer, wandb, nbatches_step_loss)
            wandb.log({"epoch_validation_loss": epoch_validation_loss})
            
            if epoch_validation_loss < best_loss:
                best_loss = epoch_validation_loss
                torch.save(model.state_dict(), f'autoencoder/checkpoints/{ckpt_fname}.pth')
            
        else:
            if epoch_train_loss < best_loss:
                best_loss = epoch_train_loss
                torch.save(model.state_dict(), f'autoencoder/checkpoints/{ckpt_fname}.pth')

Autoencoder(input_dim=scaled_tensor.shape[1], hidden_units=[48, 24], latent_dim=2)

Autoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=110, out_features=48, bias=True)
    (1): ReLU()
    (2): Linear(in_features=48, out_features=24, bias=True)
    (3): ReLU()
    (4): Linear(in_features=24, out_features=2, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=24, bias=True)
    (1): ReLU()
    (2): Linear(in_features=24, out_features=48, bias=True)
    (3): ReLU()
    (4): Linear(in_features=48, out_features=110, bias=True)
  )
)

In [148]:
train_loader = DataLoader(train_split, batch_size=512, shuffle=True)
val_loader = DataLoader(val_split, batch_size=512, shuffle=True)
test_loader = DataLoader(test_split, batch_size=512, shuffle=True)

In [149]:
CONFIGS = [
    dict(name="lat_110_hid_none", hidden_units=[], latent_dim=110),
]

In [None]:
for config in CONFIGS:
    model = Autoencoder(input_dim=scaled_tensor.shape[1], hidden_units=config['hidden_units'], latent_dim=config['latent_dim'])
    model.to('mps')
    torch.compile(model)
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    wandb.init(project='autoencoder', entity='generative_ml', name=config['name'], config={'num_params': num_params})
    wandb.watch(model, log="all")
    print(f'Number of parameters: {num_params}')
    out = train(model, train_loader, val_loader, wandb, num_epochs=10, learning_rate=0.01, nbatches_step_loss=250, ckpt_fname=config['name'],
                do_warmup=False, do_decay=False, warmup_epochs=0, decay_epochs=0)
    wandb.finish()

# wandb API 5be14d5930441de4707f6a58e4f7c2e229dab1d1

# Process GPT Predictions

In [None]:
import rdkit.Chem
import rdkit.Chem.Descriptors
from tqdm import tqdm
import pandas as pd

def descriptors_for_gpt_predictions(path_to_predicted, path_to_save):
    gpt_mols = pd.read_csv(path_to_predicted)
    keySet = None
    keyToData = {}
    pbar = tqdm(gpt_mols.iterrows(), total=len(gpt_mols))
    for index, row in pbar:
        smile = row['smiles']
        mol = rdkit.Chem.MolFromSmiles(smile)
        if not mol: continue
        mol_data = rdkit.Chem.Descriptors.CalcMolDescriptors(mol)
        if keySet is None:
            keySet = set(mol_data.keys())
        for key in keySet:
            keyToData.setdefault(key, []).append(mol_data[key])
        keyToData.setdefault('smiles', []).append(smile)
    gpt_df = pd.DataFrame(keyToData)
    gpt_df.to_pickle(path_to_save)
    return gpt_df

BASE = '/Users/morgunov/batista/Summer/'
CHECKPOINTS = BASE + 'bindingDB/checkpoints/'
GPT_DATA = BASE + 'data/'
gpt_df = descriptors_for_gpt_predictions(path_to_predicted=GPT_DATA + 'molgpt_generated_nocond_06_10_fintetune2.csv', path_to_save=CHECKPOINTS+f'gptMols_ft2.pickle')
gpt_df = descriptors_for_gpt_predictions(path_to_predicted=GPT_DATA + 'molgpt_generated_nocond_06_10.csv', path_to_save=CHECKPOINTS+f'gptMols.pickle')

# PCA-Transform

In [2]:
import pickle
import pandas as pd

def project_into_pca_space(path_to_pca, path_to_mols):
    scaler, pca = pickle.load(open(path_to_pca, 'rb'))
    gptMols = pd.read_pickle(path_to_mols)#.sample(n=10)
    return gptMols['smiles'], pca.transform(scaler.transform(gptMols[scaler.get_feature_names_out()]))

gpt_smiles, pca_transformed = project_into_pca_space(path_to_pca=PICKLES + 'scaler_pca_moses+bindingdb.pkl', path_to_mols=INFERENCES + 'GPT_pretrain_inference_07_14_23_39_1end_ignore_moses+bindingdb_temp1.0_descriptors.pkl')
pca_transformed.shape, gpt_smiles.shape

((99095, 100), (99095,))

# Exploring KMeans clustering 

In [None]:
from sklearn.cluster import KMeans
import numpy as np
from tqdm import tqdm 

def _cluster_mols_experimental_loss(mols, n_clusters, n_iter):
    min_loss, best_kmeans = float('inf'), None
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        if kmeans.inertia_ < min_loss:
            min_loss = kmeans.inertia_
            best_kmeans = kmeans
    return best_kmeans

def _cluster_mols_experimental_variance(mols, n_clusters, n_iter):
    max_variance, best_kmeans = float('-inf'), None
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        counts = np.unique(kmeans.labels_, return_counts=True)[1]
        if (variance:=np.var(counts)) > max_variance:
            max_variance = variance
            best_kmeans = kmeans
    return best_kmeans

def _cluster_mols_experimental_mixed(mols, n_clusters, n_iter, mixed_objective_loss_quantile):
    inertias = []
    variances = []
    km_objs = []
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        inertias.append(kmeans.inertia_)
        counts = np.unique(kmeans.labels_, return_counts=True)[1]
        variances.append(np.var(counts))
        km_objs.append(kmeans)
    loss_var_kmeans_triples = sorted(zip(inertias, variances, km_objs), key=lambda x: x[0])
    lowest_n = loss_var_kmeans_triples[:int(len(loss_var_tuples) * mixed_objective_loss_quantile)]
    sorted_by_variance = sorted(lowest_n, key=lambda x: x[1])
    return sorted_by_variance[0][2]

def _cluster_mols_experimental(mols, n_clusters, save_path, n_iter=1, objective='loss', mixed_objective_loss_quantile=0.1):
    if n_iter == 1:
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
    elif objective == 'loss':
        kmeans = _cluster_mols_experimental_loss(mols, n_clusters, n_iter)
    elif kmeans == 'variance':
        kmeans = _cluster_mols_experimental_variance(mols, n_clusters, n_iter)
    elif objective == 'mixed':
        kmeans = _cluster_mols_experimental_mixed(mols, n_clusters, n_iter, mixed_objective_loss_quantile)
    else:
        raise ValueError(f'Unknown objective {objective}')

    pickle.dump(best_kmeans, open(save_path, 'wb'))
    return kmeans

out = _cluster_mols_experimental(mols=pca_transformed, n_clusters=100, n_iter=1_000)

In [None]:
class Graph:
    def __init__(self):
        self.title_size = 20
        self.axis_title_size = 14
        self.tick_font_size = 12
        self.text_color="#333333"
        self.background = "white"
        self.grid_color = "#e2e2e2"
        self.line_color = "#000000"
        self.font_family = 'Helvetica'
        self.width = 600
        self.height = 400
        self.title = ''
        self.xaxis_title = ''
        self.yaxis_title = ''
    
    def update_parameters(self, params):
        for key, val in params.items():
            setattr(self, key, val)
        

    def style_figure(self, figure):
        figure.update_layout({
            'margin': {'t': 50, 'b': 50, 'l': 50, 'r': 50},
            'plot_bgcolor': self.background,
            'paper_bgcolor': self.background,
            'title': {
                'text': self.title,
                'font': {
                    'size': self.title_size,
                    'color': self.text_color,
                    'family': self.font_family
                },
            },
            'height': self.height,  # Set fixed size ratio 3:4
            'width': self.width, 
            'font': {
                'family': self.font_family,
                'size': self.tick_font_size,
                'color': self.text_color
            },
            'legend': {
                'font': {
                    'family': self.font_family,
                    'size': self.tick_font_size,
                    'color': self.text_color
                },
            },
        })

        # Setting the title size and color and grid for both x and y axes
        figure.update_xaxes(
            title=self.xaxis_title,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=True,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make x axis line visible
            linewidth=2
        )

        figure.update_yaxes(
            title=self.yaxis_title,
            title_standoff=0,
            title_font={'size': self.axis_title_size, 'color': self.text_color, 'family': self.font_family},
            tickfont={'size': self.tick_font_size, 'color': self.text_color, 'family': self.font_family},
            showgrid=True,
            gridwidth=1,
            gridcolor=self.grid_color,
            linecolor=self.line_color,  # make y axis line visible
            linewidth=2
        )
        return fig


In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

loss_to_var = {loss:var for loss, var in zip(out[0], out[1])}
sort_loss, variances = zip(*sorted(loss_to_var.items(), key=lambda x: x[0]))

graph = Graph()
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=np.arange(len(sort_loss)), y=sort_loss, mode='markers', name='Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=np.arange(len(sort_loss)), y=variances, mode='markers', name='Variances'), secondary_y=True)
graph.style_figure(fig)
fig.show()
fig.write_html(CHECKPOINTS + 'kmeans_sort_loss_vs_variance.html', include_plotlyjs='cdn')

In [None]:
loss_to_var = {loss:var for loss, var in zip(out[0], out[1])}
loss, sort_variances = zip(*sorted(loss_to_var.items(), key=lambda x: x[1]))

graph = Graph()
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(go.Scatter(x=np.arange(len(loss)), y=loss, mode='markers', name='Loss'), secondary_y=False)
fig.add_trace(go.Scatter(x=np.arange(len(loss)), y=sort_variances, mode='markers', name='Variances'), secondary_y=True)
graph.style_figure(fig)
fig.show()
fig.write_html(CHECKPOINTS + 'kmeans_sort_loss_vs_variance.html', include_plotlyjs='cdn')

# K-Means clustering

In [10]:
import numpy as np

def sample_based_on_distance_percentiles(elements, distances, n_samples, n_percentiles):
    assert len(elements) == len(distances), "Elements and distances lists must be of the same length"
    assert n_samples <= len(elements), "Number of samples cannot exceed the total number of elements"
    assert n_percentiles > 0, "Number of percentiles must be a positive integer"
    
    # Sort elements and distances together based on ascending order of distances
    distances, elements = zip(*sorted(zip(distances, elements)))
    
    # Compute the percentiles
    percentile_values = [np.percentile(distances, p * 100 / n_percentiles) for p in range(1, n_percentiles)]
    percentile_values.append(np.inf)  # the highest percentile encompasses all remaining points
    
    # Divide data into percentiles
    elements_by_percentile = []
    start_index = 0
    for percentile_value in percentile_values:
        end_index = start_index
        while end_index < len(distances) and distances[end_index] <= percentile_value:
            end_index += 1
        elements_by_percentile.append(elements[start_index:end_index])
        start_index = end_index
    
    # Sample from each percentile
    samples_per_percentile = n_samples // n_percentiles
    remaining_samples = n_samples % n_percentiles
    samples = []
    for i, percentile_elements in enumerate(elements_by_percentile):
        if len(percentile_elements) <= samples_per_percentile:
            # If we don't have enough elements in this percentile, take them all and
            # add the deficit to remaining_samples so it can be distributed among subsequent percentiles
            samples += percentile_elements
            remaining_samples += samples_per_percentile - len(percentile_elements)
        else:
            # Sample elements from this percentile
            samples += list(np.random.choice(percentile_elements, size=samples_per_percentile, replace=False))
        
        # Distribute remaining_samples among the last n_percentiles
        if i >= n_percentiles - remaining_samples:
            extra_samples = min(len(percentile_elements) - samples_per_percentile, 1)
            samples += list(np.random.choice([el for el in percentile_elements if el not in samples], size=extra_samples, replace=False))
            
    return samples

In [11]:
from sklearn.cluster import KMeans
from pprint import pprint as pp 
import numpy as np
import random

def _cluster_mols(mols, n_clusters, save_path, n_iter=1):
    """
        Performs K-Means clustering on a given list of molecules and saves the model to a specified file.

        This function will apply the K-Means algorithm to the input list of molecules. If n_iter is set to 1 (default), the function will perform the clustering once and return the KMeans object. If n_iter is set to more than 1, the function will perform the clustering n_iter times and return the KMeans object with the lowest inertia (i.e., the sum of squared distances of samples to their closest cluster center). The function will save the KMeans object to a file at the specified save_path using pickle.

        Parameters

            mols : array-like or sparse matrix, shape (n_samples, n_features)
            The input samples where n_samples is the number of samples and n_features is the number of features.

            n_clusters : int
            The number of clusters to form as well as the number of centroids to generate.

            save_path : str
            The path (including file name) where the resulting KMeans object should be saved.

            n_iter : int, optional (default=1)
            The number of times to perform the clustering. If greater than 1, the function will return the KMeans object with the lowest inertia.

        Returns

            kmeans : sklearn.cluster._kmeans.KMeans
            A KMeans instance trained on the input molecules. If n_iter is greater than 1, it's the best performing model (lowest inertia) from all iterations.

    """
    if n_iter == 1:
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        pickle.dump(kmeans, open(save_path, 'wb'))
        return kmeans
    best_kmeans = None
    best_inertia = float('inf')
    for _ in range(n_iter):
        kmeans = KMeans(n_clusters=n_clusters, n_init='auto', init='k-means++').fit(mols)
        if kmeans.inertia_ < best_inertia:
            best_kmeans = kmeans
            best_inertia = kmeans.inertia_
    pickle.dump(best_kmeans, open(save_path, 'wb'))
    return best_kmeans

def cluster_and_sample(mols, mols_smiles, n_clusters, n_samples, kmeans_save_path, clusters_save_path, diffdock_save_path, 
                        ensure_correctness=False, path_to_pca=None, probabilistic_sampling=True, load_kmeans=False,
                        percentile_sampling=True, n_percentiles=1):
    """
        Clusters a given list of molecules, samples from each cluster, and saves the resulting data to specified files.

        This function performs K-Means clustering on the input list of molecules and then samples a specified number of molecules 
        from each cluster. The function ensures that the number of samples requested from each cluster doesn't exceed the total number 
        of available molecules. The clustered data and sampled data are saved to specified file paths using pickle.

        Parameters
        ----------
        mols : array-like or sparse matrix, shape (n_samples, n_features)
            The input samples where n_samples is the number of samples and n_features is the number of features.

        mols_smiles : list of str
            A list of SMILES strings corresponding to the input molecules.

        n_clusters : int
            The number of clusters to form as well as the number of centroids to generate.

        n_samples : int
            The number of samples to draw from each cluster.

        kmeans_save_path : str
            The path (including file name) where the resulting KMeans object should be saved.

        clusters_save_path : str
            The path (including file name) where the resulting clusters should be saved.

        ensure_correctness : bool, optional (default=False)
            If True, performs additional correctness checks, such as comparing SMILES string derived features to features in mols array. 
            This requires 'path_to_pca' to be set.

        path_to_pca : str, optional (default=None)
            If ensure_correctness is True, this should be the path to a PCA model used to transform the molecules' descriptors.

        Returns
        -------
        cluster_to_samples : dict
            A dictionary where the keys are cluster labels and the values are lists of sampled SMILES strings from each cluster.

        Raises
        ------
        AssertionError
            If the number of requested samples exceeds the total number of molecules provided.
            If ensure_correctness is True but path_to_pca is None.
            If the number of labels returned by the KMeans algorithm differs from the number of molecules.
            If features calculated from a smile string differ from features in the mols array.
            If the total number of sampled molecules doesn't equal to n_clusters * n_samples.

    """
    assert n_clusters * n_samples <= len(mols), f"{n_clusters=} * {n_samples=} = {n_clusters*n_samples} requested but only {len(mols)} molecules provided"
    if ensure_correctness:
        assert path_to_pca is not None, "path_to_pca must be provided to ensure correctness"
        scaler, pca = pickle.load(open(path_to_pca, 'rb'))

    if load_kmeans:
        kmeans = pickle.load(open(kmeans_save_path, 'rb'))
    else:
        kmeans = _cluster_mols(mols=mols, n_clusters=n_clusters, save_path=kmeans_save_path)
        assert len(kmeans.labels_) == len(mols_smiles), "Number of labels differs from number of molecules"
    distances = kmeans.transform(mols)

    cluster_to_mols = {}
    cluster_to_distances = {}
    for mol, distance, label, smile in zip(mols, distances, kmeans.labels_, mols_smiles):
        cluster_to_mols.setdefault(label, []).append(smile)
        cluster_to_distances.setdefault(label, []).append(distance.min())
        if ensure_correctness: # recalculate descriptors from a smile string and compare to the descriptors in the array
            smile_features = pca.transform(scaler.transform(pd.DataFrame({k: [v] for k, v in rdkit.Chem.Descriptors.CalcMolDescriptors(rdkit.Chem.MolFromSmiles(smile)).items()})[scaler.get_feature_names_out()]))
            assert np.allclose(smile_features[0], mol), "Features calculated from a smile string differ from features in the array"

    pickle.dump((kmeans.labels_, cluster_to_distances), open(clusters_save_path.split('.')[0]+'_cl_to_d.pickle', 'wb'))
    # What happens below is sampling from each cluster. All the extra code is to ensure that the number of samples requested from each cluster
    # doesn't exceed the total number of available molecules. This is done by calculating the average number of molecules per cluster and then
    # calculating the number of extra molecules that need to be sampled from each cluster. The extra molecules are then distributed among the
    # clusters uniformly. If the number of extra molecules is greater than the number of molecules in a cluster, all
    # molecules from that cluster are sampled.
    avg_len = np.mean([len(v) for v in cluster_to_mols.values()])
    cluster_to_samples = {}
    extra_mols = 0
    left_to_sample = n_clusters*n_samples
    cluster_to_len = {cluster:len(mols) for cluster, mols in cluster_to_mols.items()}
    for i, (cluster, _) in enumerate(sorted(cluster_to_len.items(), key=lambda x: x[1], reverse=False)):
        smiles = cluster_to_mols[cluster]
        if extra_mols > 0:
            cur_extra = int(1+extra_mols/(len(cluster_to_mols) - i) * len(smiles)/avg_len)
            cur_samples = n_samples + cur_extra
            extra_mols -= cur_extra
        else:
            cur_samples = n_samples
        if cur_samples > left_to_sample:
            cur_samples = left_to_sample

        if len(smiles) > cur_samples:
            if probabilistic_sampling:
                cluster_to_samples[cluster] = np.random.choice(smiles, cur_samples, p=cluster_to_distances[cluster]/np.sum(cluster_to_distances[cluster]), replace=False)
            elif percentile_sampling:
                cluster_to_samples[cluster] = sample_based_on_distance_percentiles(smiles, cluster_to_distances[cluster], n_samples=cur_samples, n_percentiles=n_percentiles)
            else:
                cluster_to_samples[cluster] = np.random.choice(smiles, cur_samples, replace=False)
            left_to_sample -= cur_samples
        else:
            cluster_to_samples[cluster] = smiles
            left_to_sample -= len(smiles)
            extra_mols += cur_samples - len(smiles)

    assert (n_sampled:=sum(len(vals) for vals in cluster_to_samples.values())) == n_clusters*n_samples, f"Sampled {n_sampled} but were requested {n_clusters*n_samples}"
    pickle.dump(cluster_to_mols, open(clusters_save_path, 'wb'))
    pickle.dump(cluster_to_samples, open(clusters_save_path.split('.')[0] + '_samples.pickle', 'wb'))
    keyToData = {}
    for cluster, mols in cluster_to_samples.items():
        for mol in mols:
            keyToData.setdefault('smiles', []).append(mol)
            keyToData.setdefault('cluster_id', []).append(cluster)
    pd.DataFrame(keyToData).to_csv(diffdock_save_path)
    return cluster_to_samples

In [16]:
nclusters = 100
# bkmeans = _cluster_mols(mols=pca_transformed, n_clusters=nclusters, save_path=CHECKPOINTS + 'k100means_07_11.pickle')
c_to_s = cluster_and_sample(mols=pca_transformed, mols_smiles=gpt_smiles, n_clusters=nclusters, n_samples=10, 
                   kmeans_save_path=PICKLES + 'k100means_07_14_23_39_1end_ignore_moses+bindingdb.pickle', ensure_correctness=False, path_to_pca=PICKLES + 'scaler_pca_moses+bindingdb.pickle',
                   clusters_save_path=PICKLES + 'cluster_to_samples_07_16.pickle', probabilistic_sampling=False, percentile_sampling=True, n_percentiles=4,
                   diffdock_save_path=INFERENCES + f'samples_percentiles_4.csv')

In [None]:
labels, cluster_to_distances = pickle.load(open(CHECKPOINTS + 'cluster_to_samples_07_11_cl_to_d.pickle', 'rb'))

In [None]:
all_distances = np.array([distance for distances in cluster_to_distances.values() if len(distances) > 10 for distance in distances])
min_d, max_d = np.min(all_distances), np.max(all_distances)
hist, bin_edges = np.histogram(all_distances, bins=10)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
sine_distances = np.sin(2*np.pi*bin_centers) * max(all_distances)
# sine_distances = np.sin(all_distances*2*np.pi/np.sum(all_distances))
len(all_distances)

In [None]:
import plotly.graph_objects as go


# Plot a histogram based on all distances (which are python lists) in cluster_to_distances
fig = go.Figure()

fig.add_trace(go.Histogram(x=all_distances, xbins=dict(start=min_d, end=max_d, size=0.01*(max_d-min_d)), name='all_distances'))
fig.add_trace(go.Scatter(x=bin_centers, y=sine_distances, name='sine transformed distances'))
fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.75)
fig.show()

In [None]:
import plotly.graph_objects as go


# Plot a histogram based on all distances (which are python lists) in cluster_to_distances
fig = go.Figure()
fig.add_trace(go.Scatter(x=pca_transformed[:, 7], y=pca_transformed[:, 50], marker=dict(color=labels), text=labels, mode='markers', name='gpt generated'))
fig.update_traces(opacity=0.75)
fig.show()

# Prepare training dataset for active learning

In [None]:

def _preprocess_scores_uniformly(scores, remove_positives=False, lowest_score=1):
    """
        Preprocesses a dictionary of scores by negating and normalizing them.

        The function negates all scores and optionally removes positive scores. If the minimum value among the negated scores 
        is less than zero, it shifts all values by subtracting the minimum value and adding 'lowest_score'. The final step is 
        to normalize the scores so that their total sum equals to 1.

        Parameters
        ----------
        scores : dict
            A dictionary of scores where the keys are identifiers and the values are their corresponding scores.

        remove_positives : bool, optional (default=False)
            If True, all positive scores are removed after negation.

        lowest_score : int, optional (default=1)
            This value is added to all scores if the minimum score is less than zero.

        Returns
        -------
        normalized : dict
            The normalized dictionary of scores.

    """
    negated = {k: -v for k, v in scores.items()}
    min_value = min(negated.values())
    if min_value < 0:
        if remove_positives:
            negated = {k: v for k, v in negated.items() if v > 0}
        else:
            negated = {k: v - min_value + lowest_score for k, v in negated.items()}
    total = sum(negated.values())
    normalized = {k: v / total for k, v in negated.items()}
    return normalized

def _preprocess_scores_softmax(scores):
    negated = {k: -v for k, v in scores.items()}
    max_value = max(negated.values())
    exponentiate = {k: np.exp(v - max_value) for k, v in negated.items()}
    total = sum(exponentiate.values())
    softmax = {k: v / total for k, v in exponentiate.items()}
    return softmax

def balance_cluster_to_n(cluster_to_n, cluster_to_len):
    """
        Balances the target number of samples for each cluster to ensure it doesn't exceed the actual size of the cluster.

        The function first calculates the surplus (i.e., the excess of the target number over the actual size) for each cluster. 
        Then, it distributes the total surplus proportionally among the clusters that have a deficit (i.e., the target number is less than the actual size). 
        If after this distribution, there's still a deficit (i.e., the sum of target numbers is less than the sum of actual sizes), the function 
        increases the target number of the largest clusters one by one until the sum of target numbers equals to the sum of actual sizes.

        Parameters
        ----------
        cluster_to_n : dict
            A dictionary mapping cluster identifiers to their target number of samples.

        cluster_to_len : dict
            A dictionary mapping cluster identifiers to the actual size of each cluster.

        Returns
        -------
        balanced : dict
            A dictionary mapping cluster identifiers to their balanced target number of samples.

        Raises
        ------
        AssertionError
            If the sum of target numbers before and after balancing don't match.

    """

    surplus = {key: cluster_to_n[key] - cluster_to_len[key] for key in cluster_to_n if cluster_to_n[key] > cluster_to_len[key]}
    balanced = {k:v for k, v in cluster_to_n.items()}
    n_to_cluster = {v: k for k, v in cluster_to_n.items()}

    for key in surplus:
        balanced[key] = cluster_to_len[key]

    total_surplus = sum(surplus.values())
    initial_n_sum = sum(n for key, n in cluster_to_n.items() if key not in surplus)

    for key in balanced:
        if key in surplus: continue
        surplus_to_add = total_surplus * cluster_to_n[key] / initial_n_sum
        new_n = int(cluster_to_n[key] + surplus_to_add)
        balanced[key] = min(new_n, cluster_to_len[key])

    deficit = sum(cluster_to_n.values()) - sum(balanced.values())

    while deficit > 0:
        for initial_n in sorted(n_to_cluster, reverse=True):
            if (cluster:=n_to_cluster[initial_n]) in surplus: continue
            if balanced[cluster] < cluster_to_len[cluster]:
                balanced[cluster] += 1
                deficit -= 1
    
    assert sum(cluster_to_n.values()) == sum(balanced.values()), f"Before balancing had {sum(cluster_to_n.values())}, post balancing = {sum(balanced.values())}"
    return balanced

def sample_clusters_for_active_learning(cluster_to_scores, n_samples, path_to_clusters, probability_type='softmax', remove_positives=False, lowest_score=1):
    """
        Sample molecules from clusters for active learning purposes, considering previously docked molecules and balancing the sampling among clusters.

        This function uses either softmax or uniform probabilities to determine how many molecules to sample from each cluster. The function then samples 
        the required number of new molecules (i.e., those not present in docked_mols) from each cluster. The sampling is balanced to ensure the target number 
        doesn't exceed the actual size of the cluster.

        Parameters
        ----------
        cluster_to_scores : dict
            A dictionary mapping cluster identifiers to their scores.

        n_samples : int
            The total number of molecules to sample.

        path_to_clusters : str
            The path to a pickle file storing a dictionary that maps each cluster to a list of molecules.

        probability_type : str, optional (default='softmax')
            The type of probability distribution used to determine the number of samples per cluster. 
            Options are 'softmax' and 'uniform'.

        remove_positives : bool, optional (default=False)
            Only used when probability_type is 'uniform'. If True, positive scores are removed after negation.

        lowest_score : int, optional (default=1)
            Only used when probability_type is 'uniform'. This value is added to all scores if the minimum score is less than zero.

        Returns
        -------
        training : list
            A list of randomly sampled molecules for active learning.

        Raises
        ------
        KeyError
            If an unsupported probability_type is provided.
        AssertionError
            If the number of sampled molecules doesn't equal to n_samples.

    """
    if probability_type == 'softmax':
        probability_function = _preprocess_scores_softmax 
    elif probability_type == 'uniform':
        probability_function = lambda x: _preprocess_scores_uniformly(x, remove_positives, lowest_score)
    else:
        raise KeyError("Only uniform and softmax probabilities are supported")
    cluster_to_mols = pickle.load(open(path_to_clusters, 'rb'))
    cluster_to_samples = pickle.load(open(path_to_clusters.split('.')[0] + '_samples.pickle', 'rb'))
    docked_mols = {smile for smiles in cluster_to_samples.values() for smile in smiles}
    cluster_to_new_mols = {k: [smile for smile in v if smile not in docked_mols] for k, v in cluster_to_mols.items()}

    probabilities = probability_function(cluster_to_scores)
    cluster_to_n = {k: int(v * n_samples) for k, v in probabilities.items()}
    max_cluster_id, max_prob = None, 0
    for cluster, prob in probabilities.items():
        if prob > max_prob:
            max_cluster_id, max_prob = cluster, prob
    cluster_to_n[max_cluster_id] += n_samples - sum(cluster_to_n.values())

    cluster_to_len = {k: len(v) for k, v in cluster_to_new_mols.items()}
    balanced = balance_cluster_to_n(cluster_to_n, cluster_to_len)

    training = []
    for i, (cluster, n) in enumerate(balanced.items()):
        training.extend(np.random.choice(cluster_to_new_mols[cluster], n, replace=False))
        
    assert len(training) == n_samples, f"{len(training)=} != {n_samples=}"
    return training

sample_clusters_for_active_learning({k: random.random() for k in range(nclusters)}, n_samples=10, path_to_clusters=CHECKPOINTS + 'cluster_to_samples_07_07.pickle')