In [None]:
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole, rdMolDraw2D
from rdkit.Chem import AllChem, DataStructs, Descriptors, PandasTools

import os
from os.path import join, getsize

from useful_rdkit_utils import mol2numpy_fp

import pandas as pd
import numpy as np

from sklearn.manifold import TSNE

from umap import UMAP

from itertools import product

import concurrent.futures

import seaborn as sns
import matplotlib.pyplot as plt

from chemplot import Plotter

import numba

In [None]:
current_dir = os.path.dirname(os.path.abspath('plots.ipynb'))

# reset the working directory to the notebook's location
print(current_dir)

In [None]:
# create a relative path to the CSV file
oracle_csv_path = os.path.join(current_dir, 'cs_49k.csv') # precomputed cnnaffinity as oracle, for 49k mols
oracle_df = pd.read_csv(oracle_csv_path)

Functions for data analysis:

In [None]:
def create_test_df(directory):
    """
    Concatenate all selection.csv files from each cycle directory within the specified directory.

    Parameters:
    - directory: str, the path to the directory containing cycle folders.

    Returns:
    - all_selections_df: DataFrame, containing all the data with an additional 'cycle' column.
    """
    # Initialize an empty DataFrame
    all_selections_df = pd.DataFrame()

    # List all cycle directories in the specified directory
    cycle_dirs = [d for d in os.listdir(directory) if
                  os.path.isdir(os.path.join(directory, d)) and d.startswith('cycle_')]

    # Iterate over each directory and append the data to the DataFrame
    for cycle_dir in cycle_dirs:
        cycle_path = os.path.join(directory, cycle_dir, 'selection.csv')
        if os.path.exists(cycle_path):
            # Read the selection.csv file
            cycle_df = pd.read_csv(cycle_path)
            print(cycle_df.columns)
            new_name = 'cnnaffinity'
            if 'combo1' in cycle_df.columns:
                cycle_df.rename(columns={'combo1': new_name}, inplace=True)
            
            if 'plip' in cycle_df.columns:
                cycle_df.rename(columns={'plip': new_name}, inplace=True)
            # Add a 'cycle' column to keep track of the source
            cycle_number = int(cycle_dir.split('_')[1])  # Extract cycle number from the directory name
            cycle_df['cycle'] = cycle_number
            cycle_df['cnnaffinity'] = cycle_df['cnnaffinity'].abs()
            # cycle_df['cnnaffinity'].astype('float').dtypes

            # Append the DataFrame to the main DataFrame
            all_selections_df = pd.concat([all_selections_df, cycle_df], ignore_index=True)
    all_selections_df['expt'] = directory

    print(all_selections_df.columns)
    return all_selections_df


def calculate_performance_metrics(data_df, oracle_df, top_percentage):
    """
    Calculate performance metrics for a given dataframe.

    Parameters:
    - cycle_data: DataFrame for a specific cycle containing 'is_active' column.
    - oracle_df: DataFrame of the oracle data containing 'cnnaffinity' column.
    - threshold: Float, the threshold for determining active compounds.

    Returns:
    - A dictionary containing TP, FP, TN, FN, recall, precision, and F1 score for the cycle.
    """
    # calculate true positives (TP), false positives (FP), and false negatives (FN)
    threshold = oracle_df['cnnaffinity'].quantile(1 - top_percentage)
    TP = data_df['is_active'].sum()  #!! just sum of all the actives, correctly identified as positive via oracle
    FN = oracle_df[oracle_df['cnnaffinity'].astype('float') >= float(threshold)].shape[0] - TP  # all the actives from the oracle, minus the ones we picked i.e. were active but we didnt pick them
    FP = data_df.shape[0] - TP  #!! num mols we picked when we shouldn't
    TN = oracle_df.shape[0] - TP - FP - FN  # only 4 options TP/TN/FP/FN; so total - (TP-FP-FN) = TN

    # ensure the counts add up to the total number of samples
    assert TP + FN + FP + TN == len(oracle_df), "Count mismatch"

    recall = TP / (TP + FN) if TP + FN > 0 else 0
    precision = TP / (TP + FP) if TP + FP > 0 else 0

    f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 #check

    print(f"True Positive (TP): {TP}")
    print(f"False Negative (FN): {FN}")
    print(f"False Positive (FP): {FP}")
    print(f"True Negative (TN): {TN}")

    # Check for Count Mismatch
    print(f"Total Samples: {len(oracle_df)}")
    print(f"Sum of Counts (TP+FN+FP+TN): {TP + FN + FP + TN}")

    print(f"Recall: {recall}")
    print(f"Precision: {precision}")
    print(f"F1 Score: {f1_score}")

    return {
        'expt': [data_df['expt'].iloc[0].replace('_generated', '')], #[data_df.cycle.iloc[0]], #
        'TP': [TP],
        'FP': [FP],
        'TN': [TN],
        'FN': [FN],
        'recall': [recall],
        'precision': [precision],
        'f1_score': [f1_score]
    }


def find_active_mols(df, top_percentage, oracle_df):
    """
    Assign activity to chosen molecules from oracle

    Parameters:
    - df: DataFrame containing 'cnnaffinity'
    - top_percentage: float, the top percentage of 'cnnaffinity' values to consider as 'active'.

    Returns:
    - metrics_df: DataFrame containing 'is_active' column, where active is defined as above top_percentage score in the oracle
    """
    threshold = oracle_df['cnnaffinity'].quantile(1 - top_percentage)
    df['is_active'] = df['cnnaffinity'] >= threshold

    print(f'Threshold for {top_percentage} is {threshold}.')
    print(f'Number of active mols for {df.expt.iloc[0]} is {df.is_active.sum()}.')


    return df

def gather_data(data_dir):
    dir_list = []
    for root, dirs, files in os.walk('.'):
        dir_list.append(dirs) 
    return dir_list[0]


def get_params(exp_list):
    search_strs = [text.split('_') for text in exp_list]
    flat_list = [item for sublist in search_strs for item in sublist]
    param_list = list(set(flat_list))
    return param_list

def gen_rep_data(data_dir, percent=0.02, trunc=True):
    exp_list = gather_data(data_dir)
    df_list = [create_test_df(exp) for exp in exp_list]

    
    if trunc:
        df_sorted = [df.sort_values(by=['cycle'])[:2500] for df in df_list]
    else:
        
        df_sorted = [df.sort_values(by=['cycle']) for df in df_list]
    [len(df) for df in df_sorted]
    df_list = df_sorted
    
    dat = {percent : [find_active_mols(df, float(percent), oracle_df) for df in df_list]}
    
    
    metrics = [calculate_performance_metrics(df, oracle_df, float(percent)) for df in dat[percent]]
    
    metrics = [pd.DataFrame(dict) for dict in metrics]
    metric_df = pd.concat(metrics)
    return df_list, metric_df


def visualize_data(search_string, final_df):
    # Filter DataFrame using Search String
    search_df = final_df[final_df['expt'].str.contains(search_string)]
    
    # Melt the DataFrame for Visualisation
    metrics_melted = pd.melt(search_df, id_vars='expt', value_vars=['f1_score', 'recall'], var_name='metric', value_name='score')
    
    # Plot Data
    fig, ax1 = plt.subplots()
    custom_palette = sns.color_palette("colorblind", n_colors=len(metrics_melted['metric'].unique()))
    g = sns.barplot(x='expt', y='score', hue='metric', data=metrics_melted, ax=ax1, palette=custom_palette, hatch='/')
    
    # Customize Plot
    plt.xticks(rotation=90)
    plt.title(f'{search_string} - Top {percent}')  # Note: 'percent' is not defined in the provided code snippet.
    ax1.set_ylabel('Score')
    ax1.set_xlabel('Cycle size')
    plt.legend(loc=(1.04, 0))

    sns.set_style("darkgrid", {"axes.facecolor": ".9"})
    plt.savefig(f'{search_string}_{percent}.png', bbox_inches="tight")
    plt.show()


def create_violin_plot(df_combined):
    """
    Load data from a df and create a violin plot.

    Parameters:
    - directory_path: Path to the directory containing SDF files.
    - plot_title: Title for the violin plot.

    Returns:
    None.
    """
    
    
    df_combined['cycle'] = df_combined['cycle'].astype('int')
    df_combined['cnnaffinity'] = df_combined['cnnaffinity'].astype('float')
    plot_title = df_combined['expt'].iloc[0]

    # Create a violin plot of sf1 vs cycle
    sns.violinplot(x='cycle', y='cnnaffinity', data=df_combined)
    #plt.title(plot_title)
    plt.xlabel('Cycle')
    plt.ylabel('Predicted pK')
    plt.xticks(range(min(df_combined['cycle']), max(df_combined['cycle']) + 1, 2))
    plt.savefig(f'/home/cree/code/gal/cs50k/{plot_title}')
    plt.show()
    return df_combined

def plot_metric_over_cycles(df, metric, oracle_df, percent):
    # Check if the metric is valid
    if metric not in ['recall', 'precision', 'accuracy', 'f1_score']:  # Add other valid metrics if needed
        raise ValueError("Invalid metric specified.")
    expt = df['expt'].iloc[0]
    # Split df into a list of DataFrames based on the 'cycle' column
    max_cycle = df['cycle'].max()
    df_list = [df[df['cycle'] <= i] for i in range(1, max_cycle + 1)]

    # Calculate performance metrics for each DataFrame
    data = [calculate_performance_metrics(df, oracle_df, percent) for df in df_list]

    # Extracting the specified metric values from the data
    metric_values = [entry[metric][0] for entry in data]  # Adjust indexing if necessary

    # Plotting
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, len(metric_values) + 1), metric_values, marker='o')
    plt.title(f'{metric.capitalize()} Values Over Different Cycles')
    plt.xlabel('Cycle')
    plt.ylabel(metric.capitalize())
    plt.grid(True)
    plt.xticks(range(1, len(metric_values) + 1))
    
    # Save the plot
    plt.savefig(f'{expt}_{metric}_{percent}.png')
    plt.show()


def visualize_avg_data(search_terms, df_list):
    # Combine all DataFrames in the list
    combined_df = pd.concat(df_list)
    
    # Convert search_terms to a list if it's a single string
    if isinstance(search_terms, str):
        search_terms = [search_terms]
    
    # Apply AND logic in filtering: all terms must be present
    mask = combined_df['expt'].apply(lambda x: all(term in x for term in search_terms))
    search_df = combined_df[mask]
        # Calculate Mean and Standard Deviation
    mean_df = search_df.groupby('expt')[['f1_score', 'recall']].mean().reset_index()
    std_df = search_df.groupby('expt')[['f1_score', 'recall']].sem().reset_index()
    
    # Melt the DataFrames for Visualization
    mean_melted = pd.melt(mean_df, id_vars='expt', value_vars=['f1_score', 'recall'], var_name='metric', value_name='mean_score')
    std_melted = pd.melt(std_df, id_vars='expt', value_vars=['f1_score', 'recall'], var_name='metric', value_name='std_dev')
    
    # Merge Mean and Standard Deviation Data
    merged_df = pd.merge(mean_melted, std_melted, on=['expt', 'metric'])
    print(merged_df)
    # Plot Data
    fig, ax1 = plt.subplots()
    custom_palette = sns.color_palette("colorblind", n_colors=len(merged_df['metric'].unique()))
    g = sns.barplot(x='expt', y='mean_score', hue='metric', data=merged_df, ax=ax1, palette=custom_palette, hatch='/', )
    
    # Add Error Bars
    for i, bar in enumerate(g.patches):
        bar.set_width(0.3)
        hue_index = i % len(merged_df['metric'].unique())
        x = bar.get_x() #+ bar.get_width() #* hue_index
        y = bar.get_height()
        #error = merged_df.iloc[i]['std_dev']
        #ax1.errorbar(x + bar.get_width()/2, y, yerr=error, fmt='none', color='black', capsize=3)

    # Customize Plot
    plt.xticks(rotation=90)
    ax1.set_xticklabels(ax1.get_xticklabels(), fontsize=10)
    plt.title(f'{search_terms[0]} Metrics')
    ax1.set_ylabel('Score')
    ax1.set_xlabel('Experiment')
    plt.legend(loc=(1.04, 0))
    plt.xticks(rotation=90)
    #custom_xtick_labels = ['false', 'true',]
    #ax1.set_xticklabels(custom_xtick_labels, fontsize=10, rotation=90)
    plt.title(f'{search_terms[0]} Metrics')
    sns.set_style("darkgrid", {"axes.facecolor": ".9"})
    plt.show()

In [None]:
def smi2svg(smi):
    mol = Chem.MolFromSmiles(smi)
    d2d = rdMolDraw2D.MolDraw2DSVG(200, 100)
    d2d.DrawMolecule(mol)
    d2d.FinishDrawing()
    return d2d.GetDrawingText()
    

def umap(df, nbits=2048):
    """
    Compute UMAP projections for molecular data.

    Parameters:
    - df: Dataframe containing a 'ROMol' column with molecular data.

    Returns:
    - res: UMAP reduced dimensionality output.
    """
    # Compute Morgan Fingerprints
    df['fp'] = df['ROMol'].apply(lambda x: AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=nbits))
    df['svg'] = df['Smiles'].apply(lambda x: smi2svg(x))
    # Tanimoto Distance function
    def tanimoto_dist(a, b):
        dotprod = np.dot(a, b)
        tc = dotprod / (np.sum(a) + np.sum(b) - dotprod)
        return 1.0 - tc

    # UMAP dimensionality reduction
    fps = df['fp'].apply(lambda fp: np.array(fp)).tolist()
    from umap import UMAP
    reducer = UMAP(metric=tanimoto_dist)
    res = reducer.fit_transform(fps)
    
    return res


def cluster_data(df, res, min_samples, min_cluster_size):
    # Apply HDBSCAN clustering on the UMAP results
    clusterer = hdbscan.HDBSCAN(min_samples=min_samples, min_cluster_size=min_cluster_size, cluster_selection_method='leaf')
    cluster_labels = clusterer.fit_predict(res)

    # Add the cluster labels to the original DataFrame
    df['cluster'] = cluster_labels
    return df

### Reproduce plots for analysis of the chemical space of the oracle dataset:

In [None]:
from rdkit.Chem.Descriptors import ExactMolWt


# MW distribution of the oracle
oracle_df['sf1'] = oracle_df['cnnaffinity']
oracle_df['ROMol'] = oracle_df['Smiles'].apply(lambda x: Chem.MolFromSmiles(x))
oracle_df['MW'] = oracle_df['ROMol'].apply(lambda x: ExactMolWt(x))


# plotting the histogram with adjustments
plt.hist(oracle_df['MW'], color='lightgrey', edgecolor='black', bins=50)

# removing gridlines
plt.grid(False)

# setting x-axis label to 'MW'
plt.xlabel('MW')
plt.savefig('oracle_mw')
# simple histogram of the oracle
plt.show()

In [None]:
oracle_csv_path = os.path.join(current_dir, 'cs_49k.csv')

oracle_df = pd.read_csv(oracle_csv_path)
# Identify duplicates based on 'Smiles' and 'cnnaffinity'
duplicates = oracle_df[oracle_df.duplicated(subset=['cnnaffinity'], keep=False)]
duplicates

In [None]:
# Add additional mpro inhibitors from prospective search (just to build umap, not to plot):
sdf = 'onebyone_it14_over6cnnaffinity.sdf'
enamine_df = PandasTools.LoadSDF(sdf)
enamine_df = enamine_df.rename(columns={'filename': 'Smiles',})
enamine_df 

In [None]:
oracle_df = oracle_df.reset_index(drop=True,)
enamine_df = enamine_df.reset_index(drop=True,)
enamine_oracle_df = pd.concat([oracle_df, enamine_df])
enamine_oracle_df = enamine_oracle_df[['Smiles','cnnaffinity', 'enamine_id']]
enamine_oracle_df.fillna(0,inplace=True)
enamine_oracle_df.cnnaffinity = enamine_oracle_df.cnnaffinity.astype('float').abs()
enamine_oracle_df

## Fig 2b

In [None]:
# check distribution of predicted pK for oracle dataset
oracle_df = oracle_df.drop_duplicates(subset=['Smiles', 'cnnaffinity'], keep='first')
oracle_df.hist('cnnaffinity', bins=175, color='grey')

## Fig 2c

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap
import seaborn as sns
import numpy as np

# Custom color map and normalization
my_colors = ['#c7c7c7', 'gold', 'orange', 'red', 'darkred']
my_cmap = ListedColormap(my_colors)
bounds = [3, 4.5, 5, 5.5, 6]
my_norm = BoundaryNorm(bounds, ncolors=len(my_colors))

def plot_umap(df, title, size=100):
    sns.set_style('white')
    plt.figure(figsize=(10, 8))

    # Use custom colormap and norm
    points = plt.scatter(x=df['UMAP-1'], y=df['UMAP-2'], c=df['cnnaffinity'], cmap=my_cmap, norm=my_norm, s=size)
    
    # Creating color bar and legend
    cbar = plt.colorbar(points, spacing='proportional', ticks=bounds, shrink=0.6, aspect=30)
    cbar.ax.tick_params(labelsize=16)
    cbar.set_label('Predicted pK', fontsize=16) 
    #plt.title(f'Chemical space UMAP - {title}', fontsize=16)
    plt.xlabel('UMAP-1', fontsize=16)
    plt.ylabel('UMAP-2', fontsize=16)
    plt.xticks([])
    plt.yticks([])
    
    #plt.legend(title='CNNaffinity', title_fontsize=14, fontsize=14, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.savefig(title, bbox_inches='tight')
    plt.show()

In [None]:
# calculate umap for oracle df
cp = Plotter.from_smiles(enamine_oracle_df["Smiles"],)
res = cp.umap()

In [None]:
# plot oracle data onto umap:
oracle_df = oracle_df.reset_index(drop=True,)
res = res.reset_index(drop=True,)
df = pd.concat([oracle_df, res], axis=1)
df['cnnaffinity'] = df['cnnaffinity'].astype(float)
umap_df = df[df['cnnaffinity'] > 0]

In [None]:
# note that umap is a stochastic algorithm, so different runs may show variations:
plot_umap(umap_df, '50K', size=5)

## Fig 3 

In [None]:
# some examples of calculation of recall and F1 score for different active learning hyperparameters:

os.chdir(current_dir)
base_dir = os.path.dirname(os.path.abspath('plots.ipynb'))
df_list = []
percent=0.02
# create 5 dfs in a loop with different directories for each rep
dfs = {}
for i in [1,2,3,4,5]:
    data_dir = f'{base_dir}/rep_{i}'  # change directory for each rep
    os.chdir(data_dir)
    print(data_dir)
    df_name = f'rep{i}_df'  # naming each df as rep1_df, rep2_df, etc.
    expt_df, dfs[df_name] = gen_rep_data(data_dir, percent, trunc=False)
    df_list.append(expt_df)

os.chdir(current_dir)
# Access dataframes: dfs['rep1_df'], dfs['rep2_df'], ..., dfs['rep5_df']
dfs['rep3_df']

In [None]:
# average f1/recall values over 5 independent runs
# dfs is a dictionary of all reps, accessed via their keys, repX_df
# ignore error
visualize_avg_data(["0.1", "gp",], [dfs[key] for key in ['rep1_df', 'rep2_df', 'rep3_df', 'rep4_df', 'rep5_df']])

## Fig 5

In [None]:
def extract_cycles(df):
    cyc_last = df['cycle'].max()
    first_df = df[df['cycle'] == 1]
    last_df = df[df['cycle'] == cyc_last]
    return {'first' : first_df, 'last' : last_df}
    
def create_test_df(directory, feature_column):
    """
    Concatenate all selection.csv files from each cycle directory within the specified directory.

    Parameters:
    - directory: str, the path to the directory containing cycle folders.

    Returns:
    - all_selections_df: DataFrame, containing all the data with an additional 'cycle' column.
    """
    # Initialize an empty DataFrame
    all_selections_df = pd.DataFrame()

    # List all cycle directories in the specified directory
    cycle_dirs = [d for d in os.listdir(directory) if
                  os.path.isdir(os.path.join(directory, d)) and d.startswith('cycle_')]

    # Iterate over each directory and append the data to the DataFrame
    for cycle_dir in cycle_dirs:
        cycle_path = os.path.join(directory, cycle_dir, 'selection.csv')
        if os.path.exists(cycle_path):
            # Read the selection.csv file
            cycle_df = pd.read_csv(cycle_path)

            # Add a 'cycle' column to keep track of the source
            cycle_number = int(cycle_dir.split('_')[1])  # Extract cycle number from the directory name
            cycle_df['cycle'] = cycle_number
            cycle_df[feature_column] = cycle_df[feature_column].abs()
            # cycle_df['cnnaffinity'].astype('float').dtypes

            # Append the DataFrame to the main DataFrame
            all_selections_df = pd.concat([all_selections_df, cycle_df], ignore_index=True)
    all_selections_df['expt'] = directory
    return all_selections_df

import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, PandasTools
import itertools

dfs = [] # list to store each DataFrame

# generate and store dfs
for i in range(1, 6):
    path = f'{base_dir}/rep_{i}/gp_200_UCB_True_10'
    df = create_test_df(path, 'cnnaffinity')
    # optionally, add source identifier before appending
    df['source'] = f'rep_{i}'
    dfs.append(df)

# concatenate all dfs
concatenated_df = pd.concat(dfs, ignore_index=True)
concatenated_df

In [None]:
def extract_cycles(df):
    cyc_last = df['cycle'].max()
    first_df = df[df['cycle'] == 1]
    last_df = df[df['cycle'] == cyc_last]
    return {'first' : first_df, 'last' : last_df}



cycle_dict = extract_cycles(concatenated_df)

def fid_filter(mask_df):
    mask = umap_df['fid'].isin(mask_df['fid'])
    filtered_df = umap_df[mask]
    return filtered_df

import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap
import seaborn as sns
import numpy as np

# Custom color map and normalization
my_colors = ['#c7c7c7', 'gold', 'orange', 'red', 'darkred']
my_cmap = ListedColormap(my_colors)
bounds = [3, 4.5, 5, 5.5, 6]
my_norm = BoundaryNorm(bounds, ncolors=len(my_colors))




def plot_umap(df, title, size=100):
    sns.set_style('white')
    plt.figure(figsize=(10, 8))

    # Use custom colormap and norm
    points = plt.scatter(x=df['UMAP-1'], y=df['UMAP-2'], c=df['cnnaffinity'], cmap=my_cmap, norm=my_norm, s=size)
    
    # Creating color bar and legend
    cbar = plt.colorbar(points, spacing='proportional', ticks=bounds, shrink=0.6, aspect=30)
    cbar.ax.tick_params(labelsize=16)
    cbar.set_label('Predicted pK', fontsize=16) 
    #plt.title(f'Chemical space UMAP - {title}', fontsize=16)
    plt.xlabel('UMAP-1', fontsize=16)
    plt.ylabel('UMAP-2', fontsize=16)
    plt.xticks([])
    plt.yticks([])

    #plt.legend(title='CNNaffinity', title_fontsize=14, fontsize=14, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.savefig(title, bbox_inches='tight')
    plt.show()
    
cycles = ['1','13']
for i, df in enumerate(cycle_dict.values()):
    number_of_zeros = (df['cnnaffinity'] == 0).sum()
    mean_value = df['cnnaffinity'].mean()
    #print(f'mean cnnaffinity: {mean_value}, # of 0s: {number_of_zeros}')
    filtered_df = fid_filter(df)
    #print(filtered_df)
    #cycle = str(filtered_df['cycle'].iloc[0])
    #print(cycle)
    filtered_df['cnnaffinity'] = filtered_df['cnnaffinity'].astype(float)
    filtered_df = filtered_df[filtered_df['cnnaffinity'] > 0]
    plot_umap(filtered_df, f'{cycles[i]}',)

## Fig 6

In [None]:
def create_violin_plot(df_combined, title, hline=None):
    """
    Load data from a df and create a violin plot with optional horizontal line.
    
    Parameters:
    - df_combined: DataFrame containing the data.
    - hline: Optional; y-value for a horizontal line across the plot.
    
    Returns:
    None.
    """
    
    df_combined['cycle'] = df_combined['cycle'].astype('int')
    df_combined['cnnaffinity'] = df_combined['cnnaffinity'].astype('float')
    plot_title = df_combined['expt'].iloc[0]

    # create violin plot without black borders and with specified inner
    sns.violinplot(x='cycle', y='cnnaffinity', data=df_combined, linewidth=0, edgecolor='none')
    #ax.set_facecolor('white')
    # add optional horizontal line
    if hline is not None:
        plt.axhline(y=hline, color='black', linewidth=2)  # increase thickness of hline
    
    plt.xlabel('Cycle')
    plt.ylabel('Predicted pK')
    
    plt.xticks(range(min(df_combined['cycle']), max(df_combined['cycle']) + 1, 2))
    plt.savefig(f'{title}')
    plt.show()
    return df_combined

def create_test_df(directory):
    """
    Concatenate all selection.csv files from each cycle directory within the specified directory.

    Parameters:
    - directory: str, the path to the directory containing cycle folders.

    Returns:
    - all_selections_df: DataFrame, containing all the data with an additional 'cycle' column.
    """
    # Initialize an empty DataFrame
    all_selections_df = pd.DataFrame()

    # List all cycle directories in the specified directory
    cycle_dirs = [d for d in os.listdir(directory) if
                  os.path.isdir(os.path.join(directory, d)) and d.startswith('cycle_')]

    # Iterate over each directory and append the data to the DataFrame
    for cycle_dir in cycle_dirs:
        cycle_path = os.path.join(directory, cycle_dir, 'selection.csv')
        if os.path.exists(cycle_path):
            # Read the selection.csv file
            cycle_df = pd.read_csv(cycle_path)
            print(cycle_df.columns)
            new_name = 'cnnaffinity'
            if 'combo1' in cycle_df.columns:
                cycle_df.rename(columns={'combo1': new_name}, inplace=True)
            
            if 'plip' in cycle_df.columns:
                cycle_df.rename(columns={'plip': new_name}, inplace=True)
            # Add a 'cycle' column to keep track of the source
            cycle_number = int(cycle_dir.split('_')[1])  # Extract cycle number from the directory name
            cycle_df['cycle'] = cycle_number
            cycle_df['cnnaffinity'] = cycle_df['cnnaffinity'].abs()
            # cycle_df['cnnaffinity'].astype('float').dtypes

            # Append the DataFrame to the main DataFrame
            all_selections_df = pd.concat([all_selections_df, cycle_df], ignore_index=True)
    all_selections_df['expt'] = directory

    print(all_selections_df.columns)
    return all_selections_df
    
os.chdir(current_dir)
base_dir = os.path.dirname(os.path.abspath('plots.ipynb'))

dirs = ['mpro-al-pK-beta01','mpro-al-pK-beta10', 'mpro-al-plip', 'mpro-al-cs']

for d in dirs:
    combined_df = create_test_df(f'{base_dir}/{d}/generated')
    
    combined_df = combined_df.rename(columns={'Cycle': 'cycle'})
    create_violin_plot(combined_df, hline=None, title=f'{base_dir}/{d}.png')

## AL regression vs CNNaffinity plot

In [None]:
# al regression model vs cnnaffinity for a random rep & parameter
os.chdir(current_dir)
base_dir = os.path.dirname(os.path.abspath('plots.ipynb'))
os.chdir(f'{base_dir}/rep_3/gp_300_UCB_True_10/')

import os
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

def create_test_df(directory):
    """
    Concatenate all selection.csv files from each cycle directory within the specified directory.

    Parameters:
    - directory: str, the path to the directory containing cycle folders.

    Returns:
    - all_selections_df: DataFrame, containing all the data with an additional 'cycle' column.
    """
    # Initialize an empty DataFrame
    all_selections_df = []
    # List all cycle directories in the specified directory
    cycle_dirs = [d for d in os.listdir(directory) if
                  os.path.isdir(os.path.join(directory, d)) and d.startswith('cycle_') and not d.endswith('0001')]
    cycle_dirs = sorted(cycle_dirs, key=lambda x: int(x.split('_')[1]))



    print(cycle_dirs)
    # Iterate over each directory and append the data to the DataFrame
    for cycle_dir in cycle_dirs:
        cycle_path = os.path.join(directory, cycle_dir, 'virtual_library_with_predictions.csv')
        #print(cycle_path)
        if os.path.exists(cycle_path):
            # Read the selection.csv file
            cycle_df = pd.read_csv(cycle_path)

            # Add a 'cycle' column to keep track of the source
            cycle_number = int(cycle_dir.split('_')[1])  # Extract cycle number from the directory name
            cycle_df['cycle'] = cycle_number
            cycle_df['regression'] = cycle_df['regression'].abs()
            # cycle_df['cnnaffinity'].astype('float').dtypes

            # Append the DataFrame to the main DataFrame
            all_selections_df.append(cycle_df)
            #all_selections_df['expt] = directory
    return all_selections_df

oracle_csv_path = f'{base_dir}/cs_49k.csv'
oracle_df = pd.read_csv(oracle_csv_path)

directory = '.'
cycle_dirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]
cycle_dirs

a = create_test_df('.')
#len(a)

def calc_rmse(df_list):
    rmse_list = []
    print(f'calculating rmses for {len(df_list)} cycles')
    for i, df in enumerate(df_list):
        # Merge with oracle_df on 'Smiles'
        merged_df = pd.merge(df, oracle_df, how='outer', on='Smiles')
        # Calculate RMSE
        rmse = mean_squared_error(merged_df['cnnaffinity_y'], merged_df['regression'], squared=False)
        rmse_list.append(rmse)
        print(f'rmse for cycle {i} is {rmse}')
        print(f'df:\n {df}')
        #print(f'cycle number: {len(rmse_list)}')
    return rmse_list

#a[0].reset_index(
b = pd.merge(a[-1],oracle_df, how='outer', on='Smiles')
b

In [None]:
merged_df_last = b

# Calculate RMSE
# Prepare the data
y = merged_df_last['regression']
x = merged_df_last['cnnaffinity_y']

# Create a scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(x, y, alpha=0.1)  # alpha for transparency in case of overlapping points

# Adding labels and title
plt.xlabel('CNNaffinity')
plt.ylabel('Regression Model Pred. CNNaffinity')
#plt.title('Scatter Plot of RMSE vs cnnaffinity_y')

# Display the plot
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Prepare the data
x = merged_df_last['cnnaffinity_y']
y = merged_df_last['regression']

# Create a density plot
plt.figure(figsize=(10, 6))
sns.kdeplot(x=x, y=y, cmap="Reds", fill=True, bw_adjust=0.5)

# Adding labels
plt.xlabel('Oracle Predicted pK')
plt.ylabel('Regression Predicted pK')

# Display the plot
plt.show()