### Set up notebook

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Import necessary packages and modules
import torch
import os
from rdkit import Chem
import shutil
import pandas as pd
from tqdm import tqdm

# Install necessary packages if they are not present
try:
    import biopandas
except:
    !pip install pyg==0.7.1 --quiet
    !pip install pyyaml==6.0 --quiet
    !pip install scipy==1.7.3 --quiet
    !pip install networkx==2.6.3 --quiet
    !pip install biopython==1.79 --quiet
    !pip install rdkit-pypi==2022.03.5 --quiet
    !pip install e3nn==0.5.0 --quiet
    !pip install spyrmsd==0.5.2 --quiet
    !pip install pandas==1.5.3 --quiet
    !pip install biopandas==0.4.1 --quiet


# Install PyTorch geometric dependencies if not already present
try:
    import torch_geometric
except ModuleNotFoundError:
    !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster --y
    !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet
    !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet
    !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html --quiet
    !pip install git+https://github.com/pyg-team/pytorch_geometric.git --quiet


# Clone DiffDock repository if not already present
if not os.path.exists("/content/DiffDock"):
    os.chdir('/content')
    !git clone https://github.com/gcorso/DiffDock.git
    os.chdir('/content/DiffDock')
    # Use a specific version for consistency
    !git checkout a6c5275


# Clone ESM repository inside DiffDock if not already present
if not os.path.exists("/content/DiffDock/esm"):
    os.chdir('/content/DiffDock')
    !git clone https://github.com/facebookresearch/esm
    os.chdir('/content/DiffDock/esm')
    # Use a specific version for consistency
    !git checkout ca8a710
    # Install ESM
    !sudo pip install -e .
    os.chdir('/content/DiffDock')

### Define paths

In [None]:
# Ensure the directory for DiffDock results exists, if not, create it
DIFFDOCK_RESULTS_PATH = "/content/drive/MyDrive/results/"
if not os.path.exists(DIFFDOCK_RESULTS_PATH):
    os.mkdir(DIFFDOCK_RESULTS_PATH)

# Make sure the results directory exists
assert os.path.exists(DIFFDOCK_RESULTS_PATH), f"Directory {DIFFDOCK_RESULTS_PATH} doesn't exist!"

# Define the save path for DiffDock data
DIFFDOCK_SAVE_PATH = "/content/drive/Othercomputers/My Laptop/ChemSpaceAL/ChemSpaceAL_package/run_ChemSpaceAL/3. Sampling/diffdock.pkl"

# Initialize an empty list to store paths for scored results from DiffDock
DIFFDOCK_SCORED_PATH_LIST = []

# Path to the processed protein file
PROTEIN_PATH = "/content/drive/MyDrive/Generative_ML/current_data/4. DiffDock/proteins/HNH_processed.pdb"

### Define necessary functions

In [None]:
def load_scored_mols(scored_path_list):
    """
    Load the scores of molecules from given CSV paths.
    
    Args:
    - scored_path_list (list): List of paths to CSV files containing molecular scores.

    Returns:
    - dict: Dictionary with molecule smiles as keys and their corresponding scores as values.
    """
    scored_mols = {}
    for scored_path in scored_path_list:
        for _, row in pd.read_csv(scored_path).iterrows():
            if row['smiles'] in scored_mols and scored_mols[row['smiles']] != row['score']:
                print(f"{row['smiles']} scored {row['score']} but was already scored {scored_mols[row['smiles']]}")
            scored_mols[row['smiles']] = row['score']
    return scored_mols

def count_new_samples(sampled_path, scored_set):
    """
    Count and display the number of sampled molecules that were previously scored.
    
    Args:
    - sampled_path (str): Path to the CSV file containing sampled molecules.
    - scored_set (set): Set of previously scored molecule smiles.
    """
    sampled_mols = set(pd.read_csv(sampled_path)['smiles'])
    repeated = len(sampled_mols & scored_set)
    print(f"scored directory contains {len(scored_set)} unique molecules")
    print(f"{repeated} out of {len(sampled_mols)} sampled molecules were already scored")

def get_top_poses(ligands_csv, protein_pdb_path, scored_set):
    """
    Process ligands, generate embeddings, and run the DiffDock inference.
    
    Args:
    - ligands_csv (str): Path to the CSV file containing ligands.
    - protein_pdb_path (str): Path to the protein structure file.
    - scored_set (set): Set of previously scored molecule smiles.

    Returns:
    - list: List of paths to ligand files.
    """
    data = pd.read_csv(ligands_csv)
    ligand_files = []

    os.environ['HOME'] = 'esm/model_weights'
    os.environ['PYTHONPATH'] = f'{os.environ.get("PYTHONPATH", "")}:/content/DiffDock/esm'
    
    pbar = tqdm(range(len(data)), total=len(data))
    for i in pbar:
        smiles = data['smiles'][i]
        if smiles in scored_set: 
            continue
        rdkit_mol = Chem.MolFromSmiles(smiles)

        if rdkit_mol is not None:
            # Prepare input data for ESM
            with open('/content/input_protein_ligand.csv', 'w') as out:
                out.write('protein_path,ligand\n')
                out.write(f'{protein_pdb_path},{smiles}\n')

            # Remove previous results
            shutil.rmtree('/content/DiffDock/results', ignore_errors=True)

            # Generate ESM embeddings
            os.chdir('/content/DiffDock')
            !python /content/DiffDock/datasets/esm_embedding_preparation.py --protein_ligand_csv /content/input_protein_ligand.csv --out_file /content/DiffDock/data/prepared_for_esm.fasta
            !python /content/DiffDock/esm/scripts/extract.py esm2_t33_650M_UR50D /content/DiffDock/data/prepared_for_esm.fasta /content/DiffDock/data/esm2_output --repr_layers 33 --include per_tok --truncation_seq_length 30000

            # Run DiffDock inference
            !python /content/DiffDock/inference.py --protein_ligand_csv /content/input_protein_ligand.csv --out_dir /content/DiffDock/results/user_predictions_small --inference_steps 20 --samples_per_complex 10 --batch_size 6

            # Collect results
            for root, _, files in os.walk('/content/DiffDock/results/user_predictions_small'):
                for file in files:
                    if file.startswith('rank1_confidence'):
                        shutil.move(os.path.join(root, file), os.path.join(DIFFDOCK_RESULTS_PATH, f'complex{i}.sdf'))
                        ligand_files.append(f'{DIFFDOCK_RESULTS_PATH}complex{i}.sdf')
    return ligand_files

### Run Diffdock

In [None]:
# Load previously scored molecules
scored_mols = load_scored_mols(DIFFDOCK_SCORED_PATH_LIST)

# Count and print how many of the new samples have already been scored
count_new_samples(DIFFDOCK_SAVE_PATH, scored_mols.keys())

# Retrieve top poses for the molecules
top_diffdock_poses = get_top_poses(DIFFDOCK_SAVE_PATH, PROTEIN_PATH, set())