In [None]:
import numpy as np
import pandas as pd

import zipfile

import matplotlib.pyplot as plt
import seaborn as sns

from implementations import *

In [None]:
SAVING_NAME = 'labeled_data.csv'

In [None]:
zip_file_path = 'BindingDB_All_202409_tsv.zip'
file_path = 'BindingDB_All.tsv'

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    with zip_ref.open(file_path) as file:
        data = pd.read_csv(file, sep='\t', 
                           usecols=['Ligand SMILES', 'Target Name', 'Ki (nM)', 
                                    'Target Source Organism According to Curator or DataSource',
                                    'Number of Protein Chains in Target (>1 implies a multichain complex)'], 
                           na_values=['', 'NULL'])
        
hiv_data = data[data['Target Source Organism According to Curator or DataSource'] == 'Human immunodeficiency virus 1'].reset_index(drop=True)
mask_invalid_values = hiv_data['Ki (nM)'].str.contains('<',case=False,na=False) | hiv_data['Ki (nM)'].str.contains('>',case=False,na=False) | hiv_data['Ki (nM)'].isna()
hiv_data = hiv_data[~mask_invalid_values]
hiv_data = hiv_data[hiv_data['Number of Protein Chains in Target (>1 implies a multichain complex)']==1].reset_index(drop=True)
print(f'final size: {hiv_data.shape[0]}')

ligand_embedding = pd.DataFrame(np.load('ligand_embeddings.npy'))
protein_embedding = pd.DataFrame(np.load('protein_embeddings.npy'))
ligand_embedding = pd.concat((hiv_data[['Ligand SMILES', 'Target Name', 'Ki (nM)']], ligand_embedding),axis=1)

In [None]:
ligand_embedding.head()

### <b>labelling</b>

In [None]:
NN = 50
D = 0.05
N_CLUSTER = 3
METRIC = 'cosine'
K_LIGAND = 3
N_EMB_LIGAND = 768
NC_PROT = 1000 # could be the same size as ligand embedding
cluster_model = KMeans
cluster_ligand = {'n_clusters':K_LIGAND, 'random_state':42}

In [None]:
smiles_list = hiv_data['Ligand SMILES'].unique().tolist()

In [None]:
model_name = "seyonec/ChemBERTa-zinc-base-v1"
model, tokenizer = get_BERT_model(model_name)

bert_SMILES = np.zeros((len(smiles_list), N_EMB_LIGAND))

for i in range(len(smiles_list)):
    inputs = tokenizer(smiles_list[i], padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    if i % 500 == 0:
        print(f'iteration {i}')
    bert_SMILES[i] = outputs.last_hidden_state.mean(dim=1).numpy()

In [None]:
model = umap.UMAP(n_components=3, n_neighbors=NN, min_dist=D, metric=METRIC)
umap_ligand = model.fit_transform(bert_SMILES)
kmeans = cluster_model(**cluster_ligand)
kmeans.fit(umap_ligand)

df = pd.DataFrame(umap_ligand, columns=[f'UMAP1', f'UMAP2', f'UMAP3'])
df['Cluster'] = kmeans.labels_

    # Create 3D scatter plot
fig = go.Figure()
scatter = go.Scatter3d(
    x=df[f'UMAP1'],
    y=df[f'UMAP2'],
    z=df[f'UMAP3'],
    mode='markers',
    marker=dict(size=5, color=df['Cluster'], colorscale='Inferno', opacity=0.7)
)
fig.add_trace(scatter)

# Customize layout
fig.update_layout(
    title=f"UMAP Clusters",
        scene=dict(
        xaxis_title=f'UMAP1',
        yaxis_title=f'UMAP2',
        zaxis_title=f'UMAP3',
        xaxis=dict(showgrid=True, gridwidth=2, gridcolor='gray'),
        yaxis=dict(showgrid=True, gridwidth=2, gridcolor='gray'),
        zaxis=dict(showgrid=True, gridwidth=2, gridcolor='gray'),
    ),
    height=800, width=1200,
    template="plotly_white"
)

# Display the figure
fig.show()

In [None]:
smiles_list = pd.concat((pd.Series(smiles_list,name='Ligand SMILES'),pd.Series(kmeans.labels_,name='Labels')),axis=1)

In [None]:
model = umap.UMAP(n_components=NC_PROT, n_neighbors=NN, min_dist=D, metric=METRIC)
protein_embedding = model.fit_transform(protein_embedding)

In [None]:
ligand_embedding = pd.merge(ligand_embedding,smiles_list,on='Ligand SMILES')
labeled_data = pd.concat((ligand_embedding,pd.DataFrame(protein_embedding)),axis=1)
# to avoid multi type in columns
labeled_data.columns = labeled_data.columns.astype(str)

In [None]:
labeled_data.columns

In [None]:
labeled_data.to_csv('labeled_data.csv', index=False)