<a href="https://colab.research.google.com/github/patrickbryant1/MoLPC/blob/master/MoLPC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MoLPC
**M**odelling **o**f **L**arge **P**rotein **C**omplexes

This directory contains a pipeline for predicting very large protein complexes using the
[FoldDock pipeline](https://gitlab.com/ElofssonLab/FoldDock) based on [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2).
\
AlphaFold2 is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) and so is FoldDock, which is a derivative thereof.  \
The AlphaFold2 parameters are made available under the terms of the [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/legalcode) and have not been modified.
\
MolPC is licensed under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0).
\
\
**You may not use these files except in compliance with the licenses.**

### MoLPC is available for local installation here: https://github.com/patrickbryant1/MoLPC

Please see the publication in *Nature Communications*: [Predicting the structure of large protein complexes using AlphaFold and Monte Carlo tree search](https://www.nature.com/articles/s41467-022-33729-4) for more information.

**DEBUGGING INFO**
If you are experiencing problems running MoLPC.
1. Try removing all files stored at your Google drive related to MoLPC after connecting.
2. Ensure the naming of the MSAs and chains are correct. Read the naming instructions carefully. If these are not accurate, MoLPC does not know what files to use.
3. Open a github issue at https://github.com/patrickbryant1/MoLPC.

If you like MoLPC, please star the [github repo](https://github.com/patrickbryant1/MoLPC).
\
If you use MoLPC in your research, please cite the publication in *Nature Communications*: [Predicting the structure of large protein complexes using AlphaFold and Monte Carlo tree search](https://www.nature.com/articles/s41467-022-33729-4).

In [22]:
#@title Install dependencies
#@markdown Make sure your runtime is GPU.
#@markdown In the menu above do: Runtime --> Change runtime type --> Hardware accelerator (set to GPU)

#@markdown **Press play.**

#@markdown Simply press play on each cell below and follow the instructions.

#@markdown You will have to restart the runtime after this finishes to include the new packages.
#@markdown In the menu above do: Runtime --> Restart runtime

#@markdown Don't worry about all the errors that pip give below, these are resolved in the end.
!pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
!pip install ml-collections==0.1.1
!pip install dm-haiku==0.0.11
!pip install pandas==1.3.5
!pip install biopython==1.81
!pip install chex==0.1.5
!pip install dm-tree==0.1.8
!pip install immutabledict==2.0.0
!pip install scipy==1.7.3
!pip install tensorflow==2.11.0
!pip install py3Dmol
!pip install numpy --upgrade

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jaxlib==0.4.26+cuda12.cudnn89 (from jax[cuda12_pip])
  Downloading https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26%2Bcuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl (144.2 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.2/144.2 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:02[0m
[?25hCollecting nvidia-cublas-cu12>=12.1.3.1 (from jax[cuda12_pip])
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12>=12.1.105 (from jax[cuda12_pip])
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-nvcc-cu12>=12.1.105 (from jax[cuda12_pip])
  Using cached nvidia_cuda_nvcc_cu12-12.4.131-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12>=12.1.

In [2]:
#@title Check that the GPU is accessible. The response from this cell should be "gpu".
import jax
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [None]:
#@title Clone MoLPC git
import shutil
try:
  shutil.rmtree('/content/MoLPC', ignore_errors=True)
except:
  print('')

!git clone https://github.com/patrickbryant1/MoLPC.git

In [26]:
#@title #Follow all steps outlined below to run the assembly pipeline
#@markdown To try the **test case** 1A8R, click the box "test_case". Then press the play button to the left.
\
#@markdown If you don't want to run the test case, **leave the box blank**.

#@markdown #Parameters
#@markdown - *ID* - name of output \
#@markdown - *SUBSIZE* - the size of the subcomponents to use for the assembly (2 or 3) \
##@markdown - *GET_ALL* - get all possible subcomponents or only ones according to specified interactions below \
#@markdown - *USEQs* and *STs* - Unique sequence in the complex and the stoichiometry of this. \
#@markdown Up to 5 unique sequences are allowed with a total of up to 50 chains. \
#@markdown (If more are required, please install the local version: https://github.com/patrickbryant1/MoLPC) \
##@markdown - *INTERACTIONS* - Interactions between chains (if known) \
#@markdown - **MSAs** - currently no MSA search is available directly in the browser, therefore you have to provide your own MSAs in a3m format and upload them here. \
#@markdown There are two ways of doing this: \
#@markdown 1. Search uniclust_30 locally with HHblits \
#@markdown 2. Go to https://toolkit.tuebingen.mpg.de/tools/hhblits \
#@markdown Paste each unique chain sequence in the search field in fasta format --> Submit. \
#@markdown When the search is finished, go to the tab "Query Template MSA" and "Download Full A3M" \
#@markdown - Upload the MSAs here: \
#@markdown Click the folder icon (Files) to the left and select the upload file icon. Upload your files.
#@markdown - **Make sure to name your MSAs according to the convention** ID_1, ..., ID_N (both here and for the uploaded files)
import sys, os
#from google.colab import files
import pandas as pd
import numpy as np
import glob
#sys.path.insert(0,'/content/MoLPC/src')
test_case = True #@param {type:"boolean"}
ID = "1A8R" #@param {type:"string"}
SUBSIZE = 3 #@param ["2", "3"] {type:"raw"}
GET_ALL = True
#It is 1=True, 0=False
if GET_ALL==True:
  GET_ALL=1
else:
  GET_ALL=0
INTERACTIONS = ""
#Check that get all is true if INTERACTIONS are empty
if INTERACTIONS=="":
  GET_ALL=1
USEQ1 = "PSLSKEAALVHEALVARGLETPLRPPVHEMDNETRKSLIAGHMTEIMQLLNLDLADDSLMETPHRIAKMYVDEIFSGLDYANFPKITLIENKMKVDEMVTVRDITLTSTCESHFVTIDGKATVAYIPKDSVIGLSKINRIVQFFAQRPQVQERLTQQILIALQTLLGTNNVAVSIDAVHYCVKARGIRDATSATTTTSLGGLFKSSQNTRHEFLRAVRHHN" #@param {type:"string"}
ST1 = 10 #@param ["0", 1","2", "3","4","5","6","7","8","9","10"] {type:"raw"}
MSA1 = "1A8R_1.a3m" #@param {type:"string"}
USEQ2 = "" #@param {type:"string"}
ST2 = 0 #@param ["0", 1","2", "3","4","5","6","7","8","9","10"] {type:"raw"}
MSA2 = "" #@param {type:"string"}
USEQ3 = "" #@param {type:"string"}
ST3 = 0 #@param ["0", 1","2", "3","4","5","6","7","8","9","10"] {type:"raw"}
MSA3 = "" #@param {type:"string"}
USEQ4 = "" #@param {type:"string"}
ST4 = 0 #@param ["0", 1","2", "3","4","5","6","7","8","9","10"] {type:"raw"}
MSA4 = "" #@param {type:"string"}
USEQ5 = "" #@param {type:"string"}
ST5 = 0 #@param ["0", 1","2", "3","4","5","6","7","8","9","10"] {type:"raw"}
MSA5 = "" #@param {type:"string"}

#Create DFs
USEQS, CHAINS = pd.DataFrame(), pd.DataFrame()
ALPHABET='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
USEQS['Sequence']=[x for x in [USEQ1, USEQ2, USEQ3, USEQ4, USEQ5] if len(x)>0]
USEQS['SeqID']=np.arange(1,len(USEQS)+1)
USEQS['Stoichiometry']=[ST1, ST2, ST3, ST4, ST5][:len(USEQS)]
chain_useqs = []
for ind,row in USEQS.iterrows():
  chain_useqs.extend([row.SeqID]*row.Stoichiometry)
CHAINS['Useq']=chain_useqs
CHAINS['Chain']=[x for x in ALPHABET[:len(CHAINS)]]

#OUTDIR="/content/"


In [5]:
#@title Step 1: MSA PIPELINE
#@markdown Now, a default MSA is read in - no search is performed here
sys.path.insert(0,'/content/MoLPC/src/')
#Get MSA
if test_case==True:
  MSADIR='/content/MoLPC/data/test/'
else:
  MSADIR='/content/'
  msas = glob.glob(MSADIR+'*.a3m')
  msa_ids = [x.split('/')[-1] for x in msas]
  #See if all the MSAs are provided
  for msa in [MSA1, MSA2, MSA3, MSA4, MSA5][:len(USEQS)]:
    if msa not in msa_ids:
      print(msa,'is missing.')
    else:
      print(msa, 'is uploaded')


#@markdown Write the Paired and Block Diagonalized MSAs to predict sub-components
from preprocess import preprocess_colab
preprocess_colab.create_folder_structure(MSADIR, ID, OUTDIR, USEQS, INTERACTIONS, CHAINS, GET_ALL, SUBSIZE)

Creating all interactions of size 3 ...


In [6]:
#@title Step 2: FOLDING PIPELINE
#Create structure dir
STRUCTURE_DIR=OUTDIR+"AF/"
if not os.path.exists(STRUCTURE_DIR):
  os.mkdir(STRUCTURE_DIR)
#Get the sub_ids and lengths
import glob
files = glob.glob(OUTDIR+'*.fasta')
sub_ids = {}
for filename in files:
  with open(filename, 'r') as file:
    for line in file:
      line = line.rstrip()[1:].split('|')
      sub_ids[line[0]]=line[-1].split('-')[:-1]
      break

#@markdown Get the AF2 params
import shutil
PARAMS=STRUCTURE_DIR+'params/'
if not os.path.exists(PARAMS):
  os.mkdir(PARAMS)
  !wget https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar
  shutil.move('/content/alphafold_params_2021-07-14.tar', PARAMS)
  #Extract
  !tar -xvf /content/AF/params/alphafold_params_2021-07-14.tar -C /content/AF/params/

In [None]:
#@markdown Predict the subcomponents
sys.path.insert(0,'/content/MoLPC/src/AF2')
from AF2 import run_alphafold_colab
import collections
collections.Iterable = collections.abc.Iterable

##### AF2 CONFIGURATION #####
PARAM=STRUCTURE_DIR
PRESET='full_dbs' #Choose preset model configuration - no ensembling (full_dbs) and (reduced_dbs) or 8 model ensemblings (casp14).
MAX_RECYCLES=10 #max_recycles (default=3)
MODEL_NAME='model_1' #model_1_ptm

#Go through all subcomponents and predict their structure
for sub_id in sub_ids:
  #Check if predictions already exist
  if len(glob.glob(OUTDIR+sub_id+'/*.pdb'))>0:
    print('Prediction for',sub_id,'exists')
    continue
  else:
    print('Predicting subcomponent', sub_id)
  ####Get fasta file####
  FASTAFILE=OUTDIR+sub_id+'.fasta'
  ####Get chain break#### Note! This is now set for trimer subcomponents
  CB=np.cumsum([int(x) for x in sub_ids[sub_id]])
  CB = [str(x) for x in CB]
  ####Get MSAs####
  #HHblits paired
  PAIREDMSA=OUTDIR+sub_id+'_paired.a3m'
  ##HHblits block diagonalized
  BLOCKEDMSA=OUTDIR+sub_id+'_blocked.a3m'
  MSAS=[PAIREDMSA,BLOCKEDMSA] #Comma separated list of msa paths
  run_alphafold_colab.main([MODEL_NAME], 1, MAX_RECYCLES, STRUCTURE_DIR, FASTAFILE, sub_id, MSAS, CB, OUTDIR)

In [None]:
#@title Step 3: ASSEMBLY PIPELINE
#@markdown Prepare the assembly
COMPLEXDIR=OUTDIR+'/assembly/complex/' #Where all the output for the complex assembly will be directed
PAIRDIR=OUTDIR+'/assembly/pairs/'
META=OUTDIR+'/assembly/meta.csv' #where to write all interactions
from complex_assembly import prepare_assembly_colab
#Make complex directory
if not os.path.exists(COMPLEXDIR):
  os.mkdir(OUTDIR+'/assembly')
  os.mkdir(COMPLEXDIR)
#Rewrite the FoldDock preds to have separate chains according to the fasta file seqlens
files = glob.glob(OUTDIR+ID+'*/*1.pdb')
if len(files)>0:
    for pdbname in files:
        chains = prepare_assembly_colab.read_all_chains_coords(pdbname)
        if len(chains.keys())>1:
            continue
        subid = pdbname.split('/')[-2]
        print(subid)
        #Rewrite the files
        prepare_assembly_colab.write_pdb(chains, pdbname.split('.')[0]+'_rw'+'.pdb')

#Copy the predictions to reflect all chains
prepare_assembly_colab.copy_uints(ID, OUTDIR, OUTDIR+'/assembly/', USEQS,INTERACTIONS, CHAINS, GET_ALL, SUBSIZE)
##Rewrite AF predicted complexes to have proper numbering and chain labels
files = glob.glob(OUTDIR+'/assembly/'+ID+'*/*.pdb')
if len(files)>0:
    for pdbname in files:
        chains = prepare_assembly_colab.read_all_chains_coords(pdbname)
        subid = pdbname.split('/')[-2]
        chain_names = subid.split('_')[-1]
        #Rewrite the files
        prepare_assembly_colab.write_pdb_chain_labels(chains, chain_names, OUTDIR+'/assembly/'+subid+'.pdb')
#Write all pairs
#It is necessary that the first unique chain is named A-..N for and the second N-... and so on
if not os.path.exists(PAIRDIR):
  os.mkdir(PAIRDIR)

prepare_assembly_colab.get_all_pairs(OUTDIR+'/assembly/', PAIRDIR, INTERACTIONS, GET_ALL, META)
#Cleanup
for filename in glob.glob(OUTDIR+'/assembly/'+ID+'_*.pdb'):
  os.remove(filename)
for dir in glob.glob(OUTDIR+'/assembly/'+ID+'_*'):
  if os.path.isdir(dir)==True:
    shutil.rmtree(dir)


In [None]:
#@markdown Assemble: find the best non-overlapping path that connect all nodes using Monte Carlo tree search
META_DF=pd.read_csv(META)
CHAIN_SEQS=pd.read_csv(OUTDIR+'/assembly/'+ID+'_chains.csv')
from complex_assembly import mcts_colab
mcts_colab.assemble(META_DF, PAIRDIR, OUTDIR+'/assembly/plddt/', USEQS, CHAIN_SEQS, COMPLEXDIR)


In [None]:
#@title Score the assembly and download the result
COMPLEXDIR=OUTDIR+'/assembly/complex/'
from google.colab import files
from complex_assembly import score_entire_complex_colab
score_entire_complex_colab.main(ID, COMPLEXDIR+'best_complex.pdb', COMPLEXDIR+'optimal_path.csv', USEQS, CHAINS, COMPLEXDIR+'scores.csv')

#Clean up files used in the assembly and scoring
#Pair dir
try:
  shutil.rmtree(PAIRDIR)
  shutil.rmtree(OUTDIR+'/assembly/plddt/')
  for subcomponent_dir in glob.glob(OUTDIR+ID+'*_*'):
    if os.path.isdir(subcomponent_dir)==True:
      shutil.rmtree(subcomponent_dir)

except:
  print('No dirs to remove')

#Download
files.download(COMPLEXDIR+'best_complex.pdb')


mpDockQ: 0.7677392355732878


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#@title Display 3D structure {run: "auto"}
#From: https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb
import py3Dmol
import glob
import matplotlib.pyplot as plt

view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)
view.addModel(open(COMPLEXDIR+'best_complex.pdb','r').read(),'pdb')


for n,chain,color in zip(range(len(CHAINS)),list("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
                     ["lime","cyan","magenta","yellow","salmon","white","blue","orange",
                     "grey","brown","lime","cyan","magenta","yellow","salmon","white","blue","orange",
                     "grey","brown","lime","cyan","magenta","yellow","salmon","white","blue","orange",
                     "grey","brown"]):
      view.setStyle({'chain':chain},{'cartoon': {'color':color}})
#view.setStyle({'cartoon': {'color':'spectrum'}})
view.zoomTo()
view.show()

In [27]:

OUTDIR="./data/test/"
#OUTDIR="."
ID = "1A8R" #@param {type:"string"}
COMPLEXDIR=OUTDIR+'/assembly/complex_colab/' #Where all the output for the complex assembly will be directed
PAIRDIR=OUTDIR+'/assembly/pairs/'
META=OUTDIR+'/assembly/meta.csv' #where to write all interactions
#import pandas as pd


In [14]:
#@markdown Assemble: find the best non-overlapping path that connect all nodes using Monte Carlo tree search
META_DF=pd.read_csv(META)
CHAIN_SEQS=pd.read_csv(OUTDIR+'/assembly/'+ID+'_chains.csv')
#from complex_assembly import mcts_colab
#mcts_colab.assemble(META_DF, PAIRDIR, OUTDIR+'/assembly/plddt/', USEQS, CHAIN_SEQS, COMPLEXDIR)


In [15]:
import os

current_directory = os.getcwd()
print(current_directory)

/mnt/nfs_accretion/BioInfo/MCTS/MoLPC


In [23]:
import sys

sys.path.insert(0,'./src')
from complex_assembly import mcts_colab


In [28]:
mcts_colab.assemble(META_DF, PAIRDIR, OUTDIR+'/assembly/plddt/', USEQS, CHAIN_SEQS, COMPLEXDIR)

['A', 'D']
['A', 'J']
['A', 'G']
['A', 'E']
['A', 'J']
['A', 'F']
['A', 'C']
['A', 'J']
['A', 'F']
['A', 'J']
['A', 'C']
['A', 'F']
['A', 'E']
['A', 'F']
['A', 'E']
['A', 'C']
['A', 'D']
['A', 'F']
['A', 'G']
['A', 'I']
['A', 'B']
['A', 'F']
['A', 'B']
['A', 'D']
['A', 'H']
['A', 'C']
['A', 'G']
['A', 'E']
['A', 'I']
['A', 'B']
['A', 'F']
['A', 'H']
['A', 'B']
['A', 'F']
['A', 'C']
['A', 'E']
['A', 'E']
['A', 'D']
['A', 'J']
['A', 'G']
['A', 'E']
['A', 'H']
['A', 'C']
['A', 'I']
['A', 'E']
['A', 'B']
['A', 'B']
['A', 'J']
['A', 'F']
['A', 'C']
['A', 'B']
['A', 'H']
['A', 'I']
['A', 'E']
['A', 'D']
['A', 'C']
['A', 'J']
['A', 'G']
['A', 'H']
['A', 'D']
['A', 'G']
['A', 'D']
['A', 'F']
['A', 'G']
['A', 'C']
['A', 'G']
['A', 'J']
['A', 'C']
['A', 'G']
['A', 'H']
['A', 'D']
['A', 'G']
['A', 'E']
['A', 'B']
['A', 'G']
['A', 'I']
['A', 'F']
['A', 'D']
['A', 'H']
['A', 'B']
['A', 'B']
['A', 'G']
['A', 'I']
['A', 'C']
['A', 'B']
['A', 'E']
['A', 'C']
['A', 'G']
['A', 'B']
['A', 'J']
['A', 'D']

ImportError: cannot import name 'FilePathOrBuffer' from 'pandas._typing' (/home/morris/.conda/envs/molpc/lib/python3.12/site-packages/pandas/_typing.py)

In [33]:
    #def assemble(network, pairdir, plddt_dir, useqs, chain_seqs, outdir):
    network = META_DF
    pairdir = PAIRDIR
    plddt_dir = OUTDIR+'/assembly/plddt/'
    useqs = USEQS
    chain_seqs = CHAIN_SEQS
    outdir = COMPLEXDIR

    #Get all edges
    edges = np.array(network[['Chain1', 'Chain2']])
    sources = np.array(network['Source'])

    #Get all chain lengths
    useqs['Chain_length'] = [len(x) for x in useqs.Sequence]
    useqs = useqs[['SeqID', 'Chain_length']]
    chain_lens = pd.merge(chain_seqs, useqs, left_on='Useq', right_on='SeqID', how='left')
    chain_lens = dict(zip(chain_lens.Chain.values, chain_lens.Chain_length.values))
    #Find paths and assemble
    #best_path = find_paths(edges, sources, pairdir, plddt_dir, chain_lens, outdir)
    #Write PDB files of all complete paths
    #write_pdb(best_path, outdir)
    #Create and save path df
    #create_path_df(best_path, outdir)


    #def find_paths(edges, sources, pairdir, plddt_dir, chain_lens, outdir):
    '''Find all paths that visits all nodes fulfilling the criteria:
    No overlapping chains (50% of shortest chain's CAs within 5 Å btw two chains)
    '''

    #Get all nodes
    nodes = np.unique(edges)
    num_nodes = len(nodes)
    #Run Monte Carlo Tree Search
    #Read source - start at chain A
    sps = edges[np.argwhere(edges=='A')[:,0]][0]
    ssr = sources[np.argwhere(edges=='A')[:,0]][0]
    pdb_chains, chain_coords, chain_CA_inds, chain_CB_inds = mcts_colab.read_pdb(pairdir+ssr+'_'+sps[0]+'-'+ssr+'_'+sps[1]+'.pdb')
    #plDDT
    source_plDDT =  np.load(plddt_dir+ssr+'.npy', allow_pickle=True)
    si = 0
    for p_chain in ssr.split('_')[-1]:
        if p_chain=='A':
            chain_plddt=source_plDDT[si:si+chain_lens['A']]
        else:
            si += chain_lens[p_chain]

    root = mcts_colab.MonteCarloTreeSearchNode('A', '', np.array(chain_coords['A']), np.array(chain_CA_inds['A']),
            np.array(chain_CB_inds['A']), np.array(pdb_chains['A']), chain_plddt,
            edges, sources, pairdir, plddt_dir, chain_lens, outdir,
            source=None, complex_scores=[0], parent=None, parent_path=[], total_chains=num_nodes)

    best_path = root.best_path()
    #return best_path


['A', 'D']
['A', 'J']
['A', 'G']
['A', 'E']
['A', 'J']
['A', 'F']
['A', 'C']
['A', 'J']
['A', 'F']
['A', 'J']
['A', 'C']
['A', 'F']
['A', 'E']
['A', 'F']
['A', 'E']
['A', 'C']
['A', 'D']
['A', 'F']
['A', 'G']
['A', 'I']
['A', 'B']
['A', 'F']
['A', 'B']
['A', 'D']
['A', 'H']
['A', 'C']
['A', 'G']
['A', 'E']
['A', 'I']
['A', 'B']
['A', 'F']
['A', 'H']
['A', 'B']
['A', 'F']
['A', 'C']
['A', 'E']
['A', 'E']
['A', 'D']
['A', 'J']
['A', 'G']
['A', 'E']
['A', 'H']
['A', 'C']
['A', 'I']
['A', 'E']
['A', 'B']
['A', 'B']
['A', 'J']
['A', 'F']
['A', 'C']
['A', 'B']
['A', 'H']
['A', 'I']
['A', 'E']
['A', 'D']
['A', 'C']
['A', 'J']
['A', 'G']
['A', 'H']
['A', 'D']
['A', 'G']
['A', 'D']
['A', 'F']
['A', 'G']
['A', 'C']
['A', 'G']
['A', 'J']
['A', 'C']
['A', 'G']
['A', 'H']
['A', 'D']
['A', 'G']
['A', 'E']
['A', 'B']
['A', 'G']
['A', 'I']
['A', 'F']
['A', 'D']
['A', 'H']
['A', 'B']
['A', 'B']
['A', 'G']
['A', 'I']
['A', 'C']
['A', 'B']
['A', 'E']
['A', 'C']
['A', 'G']
['A', 'B']
['A', 'J']
['A', 'D']

In [36]:
len(root.edges)

2160

In [37]:
root.edges[2111]

array(['E', 'H'], dtype=object)

In [41]:
len(root.children)

432

In [56]:
root.children[421].complex_scores

[336.34782022063274, 1386.6372151212547]

In [46]:
root.children[4].children

[<complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f59169edd00>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f5916b055e0>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f59169ef350>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f5917981e20>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f590f062f60>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f58f58868a0>,
 <complex_assembly.mcts_colab.MonteCarloTreeSearchNode at 0x7f590d7475c0>]

In [54]:
root.children[4].children[2].complex_scores

[848.5487587246585, 2235.0463038415496]

In [65]:
root.children[4].children[2].chain

'F'

In [60]:
max(root.children[4].complex_scores)

3742.4841584615024

In [58]:
len(root.complex_scores)

4097

In [70]:
def nodename(node):
    """
    Constructs a name by concatenating the 'chain' attribute of the node with its ancestors up to the root.
    The concatenation order is from the current node up to the root.
    
    Parameters:
    - node: The starting node from which to begin the traversal up to the root.
    
    Returns:
    - A string representing the concatenated names from node to root.
    """
    if node.parent is None:
        return node.chain  # If this is the root node, return its chain
    else:
        return nodename(node.parent) + '-' + node.chain  # Recursively build the name from root to this node

# Example of calling this function:
# Suppose `some_node` is a node within your tree structure
# result_name = nodename(some_node)
# print(result_name)


In [77]:
import csv

def dump_tree_to_csv(root_node):
    with open('nodes.csv', 'w', newline='') as nodes_file, \
         open('edges.csv', 'w', newline='') as edges_file:
        node_writer = csv.writer(nodes_file)
        edge_writer = csv.writer(edges_file)

        node_writer.writerow(['NodeChains', 'EdgeChain', 'Source', 'TotalChains'])
        edge_writer.writerow(['From', 'To', 'Confidence'])

        def traverse(node):
            # Write node details
            node_writer.writerow([nodename(node), node.edge_chain, node.source, max(node.complex_scores)])
            for child in node.children:  #  children is a list of child nodes
                # Write edge details
                confidence = max(child.complex_scores) # child.plddt  # Assuming plddt can serve as confidence, modify as necessary
                edge_writer.writerow([nodename(node), nodename(node) + '-' + child.chain, confidence])
                traverse(child)

        traverse(root_node)

# This function should be called after the tree is fully constructed


In [79]:
import csv

def dump_tree_to_csv(root_node):
    with open('nodes1.csv', 'w', newline='') as nodes_file, \
         open('edges1.csv', 'w', newline='') as edges_file:
        node_writer = csv.writer(nodes_file)
        edge_writer = csv.writer(edges_file)

        node_writer.writerow(['NodeID', 'Chain', 'EdgeChain', 'Source', 'TotalChains'])
        edge_writer.writerow(['From', 'To', 'Confidence'])

        node_id = 0  # Initialize node ID counter
        nodes_dict = {}  # This dictionary will map nodes to their IDs

        def traverse(node):
            nonlocal node_id  # Allow access to the non-local variable 'node_id'
            current_node_id = node_id  # Assign the current node ID
            nodes_dict[node] = current_node_id  # Map current node to its ID
            node_id += 1  # Increment ID for the next node

            # Write node details
            node_writer.writerow([current_node_id, nodename(node), node.edge_chain, node.source, node.total_chains])
            for child in node.children:  # Assuming children is a list of child nodes
                # Use node IDs in the edge file
                child_node_id = nodes_dict.get(child, node_id)
                if child not in nodes_dict:  # If child node ID has not been assigned, assign it
                    nodes_dict[child] = child_node_id
                    node_id += 1
                confidence = getattr(child, 'plddt', None)  # Assuming 'plddt' can serve as confidence, modify as necessary
                edge_writer.writerow([current_node_id, child_node_id, confidence])
                traverse(child)

        traverse(root_node)

# This function should be called after the tree is fully constructed


In [84]:
dump_tree_to_csv(root)

In [90]:
import pandas as pd

def dump_tree_to_dataframe(root_node):
    nodes_data = []
    edges_data = []

    def traverse(node):
        # Collect node details using nodename function for node chains
        node_name = nodename(node)
        nodes_data.append([node_name, node.edge_chain, node.source, max(node.complex_scores)])
        #print(node_name)
        
        for child in node.children:  # Assuming children is a list of child nodes
            child_name = nodename(child)
            # Collect edge details using nodename function for from and to names
            confidence = max(child.complex_scores)  # Assuming max of complex_scores can serve as confidence
            edges_data.append([node_name, child_name, confidence])
            traverse(child)

    traverse(root_node)

    # Convert lists to Pandas DataFrames
    nodes_df = pd.DataFrame(nodes_data, columns=['NodeChains', 'EdgeChain', 'Source', 'TotalChains'])
    edges_df = pd.DataFrame(edges_data, columns=['From', 'To', 'Confidence'])

    # Process nodes_df to deduplicate and sort
    nodes_df = nodes_df.groupby('NodeChains', as_index=False).agg({
        'EdgeChain': ' '.join,
        'Source': ' '.join,
        'TotalChains': 'mean'
    }).sort_values('NodeChains')

    # Add a label column starting from 1
    nodes_df['Label'] = range(1, len(nodes_df) + 1)

    # Create a mapping from NodeChains to Label
    label_map = dict(zip(nodes_df['NodeChains'], nodes_df['Label']))

    # Update edges_df to use new Label values
    edges_df['From'] = edges_df['From'].map(label_map)
    edges_df['To'] = edges_df['To'].map(label_map)

    # Write to CSV files
    nodes_df.to_csv('nodes_final.csv', index=False)
    edges_df.to_csv('edges_final.csv', index=False)

# This function should be called after the tree is fully constructed


In [127]:
import pandas as pd

#def dump_tree_to_dataframe(root_node):
nodes_data = []
edges_data = []

def traverse(node):
    # Use the 'nodename' function to generate names based on the node's position in the tree
    node_name = nodename(node)
    nodes_data.append([node_name, node.edge_chain, node.source, max(node.complex_scores)])
    
    for child in node.children:  # Assuming children is a list of child nodes
        child_name = nodename(child)
        confidence = max(child.complex_scores)  # Assume max complex_scores as confidence
        edges_data.append([node_name, child_name, confidence])
        traverse(child)




    # Write to CSV files (optional, depending on whether you want to write or use the DataFrame)
    # nodes_df.to_csv('nodes_final.csv', index=False)
    # edges_df.to_csv('edges_final.csv', index=False)

    #return nodes_df, edges_df

# Function call and node definition (examples) should be adapted to your specific use case


In [128]:
traverse(root)

In [159]:
# Convert lists to Pandas DataFrames
nodes_df = pd.DataFrame(nodes_data, columns=['NodeChains', 'EdgeChain', 'Source', 'Score'])
edges_df = pd.DataFrame(edges_data, columns=['From', 'To', 'Confidence'])


In [160]:
nodes_df

Unnamed: 0,NodeChains,EdgeChain,Source,Score
0,A,,,6234.591301
1,A-D,A,1A8R_JAD,1725.790994
2,A-J,A,1A8R_JAD,1729.696006
3,A-G,A,1A8R_EGA,2178.169190
4,A-E,A,1A8R_EGA,146.852392
...,...,...,...,...
4092,A-H,A,1A8R_HAF,2183.825096
4093,A-H,A,1A8R_DAH,398.457573
4094,A-D,A,1A8R_DAH,881.657785
4095,A-I,A,1A8R_EIA,897.113263


In [161]:
unodes_df = nodes_df.sort_values(by='NodeChains').groupby('NodeChains', as_index=False).agg(
    Score_Max=pd.NamedAgg(column="Score", aggfunc="max")
)

In [163]:
# Correctly use groupby and aggregate functions
# unodes_df = nodes_df.groupby('NodeChains', as_index=False).agg({
#     'EdgeChain': ' '.join,      # Joining text with spaces
#     'Source': ' '.join,         # Joining text with spaces
#     'TotalChains': 'mean'       # Calculating the mean of numeric values
# }).sort_values(by='NodeChains')  # Sorting by 'NodeChains'


# Add a label column starting from 1
unodes_df['Label'] = range(1, len(unodes_df) + 1)


In [168]:
unodes_df = unodes_df[['Label', 'NodeChains', 'Score_Max']]

# Now df has columns in the order [Label, NodeChains, Score_Max]
#print(df)

In [169]:
unodes_df

Unnamed: 0,Label,NodeChains,Score_Max
0,1,A,6234.591301
1,2,A-B,3163.698013
2,3,A-B-I,881.310664
3,4,A-C,2662.009592
4,5,A-D,4018.246139
...,...,...,...
141,142,A-J-D,3073.641614
142,143,A-J-E,1364.857997
143,144,A-J-F,2235.046304
144,145,A-J-G,881.657785


In [165]:

# Create a mapping from NodeChains to Label
label_map = dict(zip(unodes_df['NodeChains'], unodes_df['Label']))



In [166]:
# dump_tree_to_dataframe(root)

# Update edges_df to use new Label values from nodes_df
edges_df['From'] = edges_df['From'].map(label_map)
edges_df['To'] = edges_df['To'].map(label_map)


In [167]:
edges_df

Unnamed: 0,From,To,Confidence
0,1,5,1725.790994
1,1,140,1729.696006
2,1,130,2178.169190
3,1,9,146.852392
4,1,140,3742.484158
...,...,...,...
4091,1,132,2183.825096
4092,1,132,398.457573
4093,1,5,881.657785
4094,1,134,897.113263


In [171]:
topo_df = pd.DataFrame()
topo_df['Label'] = unodes_df['Label']
topo_df['Level'] = np.floor(unodes_df['NodeChains'].str.len() / 2).astype(int)

In [172]:
topo_df

Unnamed: 0,Label,Level
0,1,0
1,2,1
2,3,2
3,4,1
4,5,1
...,...,...
141,142,2
142,143,2
143,144,2
144,145,2


In [173]:
# Write topo_df to 'mctstopo.csv'
topo_df.to_csv('mctstopo.csv', index=False)

# Write unodes_df to 'LookupTable_MCTS.csv' with semicolon as the separator
unodes_df.to_csv('LookupTable_MCTS.csv', sep=';', index=False)

# Write edges_df to 'mctsnet.csv'
edges_df.to_csv('mctsnet.csv', index=False)

ImportError: cannot import name 'FilePathOrBuffer' from 'pandas._typing' (/home/morris/.conda/envs/molpc/lib/python3.12/site-packages/pandas/_typing.py)

In [178]:
import csv

# Ensure that 'From' and 'To' columns are integers
edges_df['From'] = edges_df['From'].astype(int)
edges_df['To'] = edges_df['To'].astype(int)

# For nodes_df (similar logic can be applied to other dataframes)
with open('LookupTable_MCTS.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    # Write the header
    writer.writerow(['Label', 'NodeChains', 'Score_Max'])
    # Write the data rows
    for index, row in unodes_df.iterrows():
        writer.writerow([row['Label'], row['NodeChains'], row['Score_Max']])

# For edges_df
with open('mctsnet.csv', 'w', newline='') as file:
    writer = csv.writer(file)
#    writer.writerow(['From', 'To', 'Confidence'])
    for index, row in edges_df.iterrows():
        #writer.writerow([row['From'], row['To'], row['Confidence']])
        writer.writerow([str(int(row['From'])-1), str(int(row['To'])-1), row['Confidence']])

# For topo_df
with open('mctstopo.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Label', 'Level'])
    for index, row in topo_df.iterrows():
        writer.writerow([row['Label'], row['Level']])
