<a href="https://colab.research.google.com/github/engelberger/ACPI/blob/main/rf_aa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/engelberger/all_atom_binder_diffusion/blob/dev/rf_aa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**RF2 aa**
RF2 aa is a method for structure prediction. It can perform a whole range of protein design challenges as we have outlined in the RFdiffusion [manuscript](https://www.science.org/doi/10.1126/science.adl2528).

**<font color="red">NOTE:</font>** This notebook is in development, we are still working on adding all the options from the manuscript above.

For **instructions**, see end of Notebook.



This is a modified version of Sergey's notebook by Felipe Engelberger, see [original version](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/rf/examples/diffusion_ori.ipynb) of this notebook (from 31Mar2023).



In [None]:
# @title COLAB ONLY setup **RosettaFold2 All Atom** (~5m)
%%time

import os
import subprocess
import time
import sys

# Function to detect if running on Google Colab
def is_colab():
    return "COLAB_GPU" in os.environ


def run_command(command, progress_message, wait=True):
    """
    Run a system command with a progress message.
    If wait is False, the command is executed in the background.
    """
    print(f"Starting: {progress_message}")
    process = subprocess.Popen(
        command,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        universal_newlines=True,
    )
    if wait:
        stdout, stderr = process.communicate()
        # If the return code is an error, print the stderr but do not raise an exception
        if process.returncode != 0:
            print(f"Error during {progress_message}: {stderr}")
            raise subprocess.CalledProcessError(process.returncode, command)

        print(f"Completed: {progress_message}")
    return process

def setup_environment_colab():
    # Install aria2 if not already installed (for faster downloads)
    run_command("apt-get install -y aria2", "Installing aria2 for faster downloads")

    # If parameters are already downloaded, skip the download process
    if not os.path.isfile(os.path.join(PARAMS_DIR, "done.txt")):
        print("Downloading parameters and models...")

        # Start downloading parameters and models in the background
        download_process = run_command(
            f"cd {PARAMS_DIR} && aria2c -q -x 16 https://files.ipd.uw.edu/krypton/schedules.zip && \
            aria2c -q -x 16 http://files.ipd.uw.edu/pub/RF-All-Atom/weights/RFDiffusionAA_paper_weights.pt && \
            aria2c -q -x 16 http://files.ipd.uw.edu/pub/RF-All-Atom/weights/RFAA_paper_weights.pt && \
            touch done.txt",
            "Downloading and extracting parameters", wait=False)
#            aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar && \
#            tar -xf alphafold_params_2022-12-06.tar && \


    # Install Open Babel if not already installed
    if not os.path.isfile("/usr/bin/obabel"):
        run_command(
            "apt-get install -y openbabel libopenbabel-dev && ln -s /usr/include/openbabel3 /usr/local/include/openbabel3",
            "Installing Open Babel and its development files",
        )

    # Install SWIG
    if not os.path.isfile("/usr/bin/swig"):
        run_command(
            "apt-get remove -y swig && apt-get install -y swig3.0 && ln -sf /usr/bin/swig3.0 /usr/bin/swig",
            "Installing SWIG",
        )

    # Install Python dependencies
    run_command(
        "pip install jedi omegaconf hydra-core icecream pyrsistent assertpy deepdiff fire git+https://github.com/sokrypton/ColabDesign.git@gamma py3Dmol openbabel",
        "Installing Python dependencies",
    )
    # Download ColabFold Utils
    # wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py
    run_command(
        "wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py",
        "Download ColabFold Utils",
    )

    # Clone RFdiffusion repository
    if not os.path.isdir(RF_DIFFUSION_DIR):
        run_command(
            f"git clone --branch max https://github.com/engelberger/RFdiffusion.git {RF_DIFFUSION_DIR}",
            "Cloning RFdiffusion repository",
        )

    # Clone RosettaFold all atom repository
    if not os.path.isdir(RF2_ALL_ATOM_DIR):
        run_command(
            f"git clone --recurse-submodules --branch colab_march_2024 https://github.com/engelberger/RoseTTAFold-All-Atom.git {RF2_ALL_ATOM_DIR}",
            "Cloning RFdiffusion all atom repository",
        )

    # Clone RFdiffusion all atom repository
    if not os.path.isdir(RF_DIFFUSION_ALL_ATOM_DIR):
        run_command(
            f"git clone --recurse-submodules --branch colab_march_2024 https://github.com/engelberger/rf_diffusion_all_atom.git {RF_DIFFUSION_ALL_ATOM_DIR}",
            "Cloning RFdiffusion all atom repository",
        )

    # Install DGL
    run_command(
        "pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html",
        "Installing DGL",
    )

    # Install SE3 Transformer
    run_command(
        f"cd {os.path.join(RF_DIFFUSION_DIR, 'env/SE3Transformer')} && pip install -q --no-cache-dir -r requirements.txt && pip install -q .",
        "Installing SE3 Transformer",
    )

    # Download and set execute permissions for 'ananas'
    run_command(
        f"wget -qnc https://files.ipd.uw.edu/krypton/ananas -P {BASE_DIR} && chmod +x {os.path.join(BASE_DIR, 'ananas')}",
        "Downloading and setting up 'ananas'",
    )


    predictor = SetupColabDesign(setup_dict["unified_memory"],setup_dict["parentPath"],setup_dict["setupPath"])
    predictor.setup()

    # If parameters are already downloaded, skip the download process
    if not os.path.isfile(os.path.join(PARAMS_DIR, "done.txt")):
        # Wait for the download process to complete
        download_process.communicate()
    print("Environment setup complete.")


class SetupColabDesign:
    def __init__(
        self,
        unified_memory,
        parentPath,
        setupPath,
        python_colab="/usr/bin/python3.10",
        colabdesign_path="/usr/local/lib/python3.10/dist-packages/colabdesign",
    ):
        self.unified_memory = unified_memory
        self.parentPath = parentPath
        self.setupPath = setupPath
        self.python_colab = python_colab
        self.colabdesign_path = colabdesign_path
        self.ENV = (
            {"TF_FORCE_UNIFIED_MEMORY": "1", "XLA_PYTHON_CLIENT_MEM_FRACTION": "4.0"}
            if unified_memory
            else {}
        )

    def setup(self):
        for k, v in self.ENV.items():
            os.environ[k] = v

        os.makedirs(self.setupPath, exist_ok=True)

        if os.path.isdir(os.path.join(self.setupPath, "params")):
            print("Setup path is present.")
        else:
            print("Setup path is not present. Installing ColabDesign...")
            self.install_colab_design()

        if os.path.isdir(os.path.join(self.setupPath, "hhsuite")):
            print("HHsuite is present.")
        else:
            print("HHsuite is not present. Installing HHsuite...")
            self.install_hhsuite()

        print("mmseqs2 is imported.")

    def install_colab_design(self):
        def run_command(command):
            process = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
            output, error = process.communicate()

            if error:
                raise Exception(f"Error occurred while executing command: {error}")

            return output

        params_path = os.path.join(self.setupPath, "params")
        # Make params_path absolute
        params_path = os.path.abspath(params_path)

        if not os.path.exists(params_path):
            os.makedirs(params_path)
            print(f"Created params directory at {params_path}")

        print("Installing ColabDesign...")

        commands = [
            f"apt-get install aria2 -qq",
            f"cd {self.setupPath} && aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar",
            f"tar -xf {self.setupPath}/alphafold_params_2022-12-06.tar -C {params_path}",
            f"touch {os.path.join(params_path, 'done.txt')}",
            # f"cd {self.setupPath} && rm alphafold_params_2022-12-06.tar",
        ]

        for command in commands:
            run_command(command)

        print("Installing Python dependencies...")

        run_command(
            f"{self.python_colab} -m pip -q install git+https://github.com/sokrypton/ColabDesign.git@gamma"
        )
        run_command(
            f"ln -s {self.colabdesign_path} {os.path.join(self.setupPath, 'colabdesign')}"
        )
        run_command(
            f"wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O {os.path.join(self.setupPath, 'colabfold_utils.py')}"
        )

    def install_hhsuite(self):
        os.makedirs(os.path.join(self.setupPath, "hhsuite"), exist_ok=True)
        os.system(
            f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C {os.path.join(self.setupPath, 'hhsuite')}"
        )

        if "hhsuite" not in os.environ["PATH"]:
            os.environ[
                "PATH"
            ] += f":{os.path.join(self.setupPath, 'hhsuite/bin')}: {os.path.join(self.setupPath, 'hhsuite/scripts')}"


# Base directory setup
if is_colab():
    BASE_DIR = "/content"
else:
    # For local setup, adjust this path as per your local environment
    # By default we will assume you are in the devcontainer path
    BASE_DIR = "/workspaces/all_atom_binder_diffusion"

    # For setups outside the devcontainer, you may need to adjust this path
    # BASE_DIR = os.path.expanduser("~")

# Adjust paths based on the environment
PARAMS_DIR = os.path.join(BASE_DIR, "params")
RF_DIFFUSION_DIR = os.path.join(BASE_DIR, "RFdiffusion")
RF_DIFFUSION_ALL_ATOM_DIR = os.path.join(BASE_DIR, "rf_diffusion_all_atom")
RF2_ALL_ATOM_DIR = os.path.join(BASE_DIR, "RoseTTAFold-All-Atom")

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')

if 'RoseTTAFold-All-Atom' not in sys.path:
  sys.path.append(RF2_ALL_ATOM_DIR)


# Ensure the params directory exists
os.makedirs(PARAMS_DIR, exist_ok=True)

setup_dict = {
            "unified_memory": False,
            "parentPath": "/content/output",
            "setupPath": "/content/",
        }


# Call the setup function
setup_environment_colab()

# Standard library imports
import argparse
import gc
import os
import re
import subprocess
import sys
import tempfile
import time

# Third-party imports
from IPython.display import HTML
from IPython.core import ultratb
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import yaml

# Local application/library specific imports
from colabdesign import mk_af_model, clear_mem
from colabdesign.af.contrib import predict
from colabdesign.af.contrib.cyclic import add_cyclic_offset
from colabdesign.shared.plot import plot_pseudo_3D, pymol_cmap
from colabdesign.shared.protein import _np_rmsd, _np_kabsch
from typing import List, Optional, Tuple

# Exception hook setup
sys.excepthook = ultratb.FormattedTB(color_scheme="Linux", call_pdb=False)

#pio.renderers.default = "vscode"
# Import the colabdesign utils to get the MSA
class ColabDesignUtils:
    def __init__(self, setupPath):
        self.setupPath = setupPath

    def run_hhalign(
        self, query_sequence, target_sequence, query_a3m=None, target_a3m=None
    ):
        with tempfile.NamedTemporaryFile() as tmp_query, tempfile.NamedTemporaryFile() as tmp_target, tempfile.NamedTemporaryFile() as tmp_alignment:
            if query_a3m is None:
                tmp_query.write(f">Q\n{query_sequence}\n".encode())
                tmp_query.flush()
                query_a3m = tmp_query.name
            if target_a3m is None:
                tmp_target.write(f">T\n{target_sequence}\n".encode())
                tmp_target.flush()
                target_a3m = tmp_target.name
            os.system(
                f"hhalign -hide_cons -i {query_a3m} -t {target_a3m} -o {tmp_alignment.name}"
            )
            X, start_indices = predict.parse_hhalign_output(tmp_alignment.name)
        return X, start_indices

    def run_do_not_align(self, query_sequence, target_sequence, **arg):
        return [query_sequence, target_sequence], [0, 0]

    def run_hhfilter(self, input, output, id=90, qid=10):
        if "hhsuite" not in os.environ["PATH"]:
            os.environ[
                "PATH"
            ] += f":{os.path.join(self.setupPath, 'hhsuite/bin')}: {os.path.join(self.setupPath, 'hhsuite/scripts')}"

        os.system(f"hhfilter -id {id} -qid {qid} -i {input} -o {output}")

    @jax.jit
    def get_coevolution(self, X):
        """given one-hot encoded MSA, return contacts"""
        Y = jax.nn.one_hot(X, 22)
        N, L, A = Y.shape
        Y_flat = Y.reshape(N, -1)
        c = jnp.cov(Y_flat.T)
        shrink = 4.5 / jnp.sqrt(N) * jnp.eye(c.shape[0])
        ic = jnp.linalg.inv(c + shrink)
        ic_diag = jnp.diag(ic)
        pcc = ic / jnp.sqrt(ic_diag[:, None] * ic_diag[None, :])
        raw = jnp.sqrt(jnp.square(pcc.reshape(L, A, L, A)[:, :20, :, :20]).sum((1, 3)))
        i = jnp.arange(L)
        raw = raw.at[i, i].set(0)
        ap = raw.sum(0, keepdims=True) * raw.sum(1, keepdims=True) / raw.sum()
        return (raw - ap).at[i, i].set(0)

    def plot_3D(aux, Ls, file_name, show=False):
        plt.figure(figsize=(10, 5))
        xyz = aux["atom_positions"][:, 1]
        xyz = xyz @ _np_kabsch(xyz, xyz, return_v=True, use_jax=False)
        ax = plt.subplot(1, 2, 1)
        if len(Ls) > 1:
            plt.title("chain")
            c = np.concatenate([[n] * L for n, L in enumerate(Ls)])
            plot_pseudo_3D(xyz=xyz, c=c, cmap=pymol_cmap, cmin=0, cmax=39, Ls=Ls, ax=ax)
        else:
            plt.title("length")
            plot_pseudo_3D(xyz=xyz, Ls=Ls, ax=ax)
        plt.axis(False)
        ax = plt.subplot(1, 2, 2)
        plt.title("plddt")
        plot_pseudo_3D(xyz=xyz, c=aux["plddt"], cmin=0.5, cmax=0.9, Ls=Ls, ax=ax)
        plt.axis(False)
        plt.savefig(file_name, dpi=200, bbox_inches="tight")
        plt.show() if show else plt.close()

class PrepInputs:
    def __init__(
        self,
        sequence,
        jobname,
        copies,
        msa_method,
        custom_a3m_path,
        pair_mode,
        cov,
        id,
        qid,
        do_not_filter,
        template_mode,
        pdb,
        chain,
        rm_template_seq,
        propagate_to_copies,
        do_not_align,
        setupPath,
        parentPath,
        overwrite,
    ):
        self.sequence = sequence
        self.jobname = jobname
        self.copies = copies
        self.msa_method = msa_method
        self.custom_a3m_path = custom_a3m_path
        self.pair_mode = pair_mode
        self.cov = cov
        self.id = id
        self.qid = qid
        self.do_not_filter = do_not_filter
        self.template_mode = template_mode
        self.pdb = pdb
        self.chain = chain
        self.rm_template_seq = rm_template_seq
        self.propagate_to_copies = propagate_to_copies
        self.do_not_align = do_not_align
        self.rm_sidechain = rm_template_seq
        self.rm_sequence = rm_template_seq
        self.setupPath = setupPath
        self.parentPath = parentPath
        self.overwrite = overwrite

    def filter_options(self):
        self.sequence = self.sequence.upper()
        self.sequence = re.sub("[^A-Z:/()]", "", self.sequence.upper())
        self.sequence = re.sub("\(", ":(", self.sequence)
        self.sequence = re.sub("\)", "):", self.sequence)
        self.sequence = re.sub(":+", ":", self.sequence)
        self.sequence = re.sub("/+", "/", self.sequence)
        self.sequence = re.sub("^[:/]+", "", self.sequence)
        self.sequence = re.sub("[:/]+$", "", self.sequence)
        self.jobname = re.sub(r"\W+", "", self.jobname)

    def process_sequence(self):
        sequences = self.sequence.split(":")
        self.u_sequences = predict.get_unique_sequences(sequences)
        self.u_cyclic = [x.startswith("(") for x in self.u_sequences]
        self.u_sub_lengths = [[len(y) for y in x.split("/")] for x in self.u_sequences]
        self.u_sequences = [
            x.replace("(", "").replace(")", "").replace("/", "")
            for x in self.u_sequences
        ]
        if len(sequences) > len(self.u_sequences):
            print("WARNING: use copies to define homooligomers")
        self.u_lengths = [len(x) for x in self.u_sequences]
        sub_seq = "".join(self.u_sequences)
        seq = sub_seq * self.copies

        self.jobname = f"{self.jobname}_{predict.get_hash(seq)[:5]}"

        def check(folder):
            return os.path.exists(f"{self.parentPath}/{folder}")

        if check(self.jobname):
            n = 0
            while check(f"{self.jobname}_{n}"):
                n += 1
            # If the jobname already exists, print a warning, if overwrite is True do not change the jobname, if overwrite is False change the jobname
            if self.overwrite:
                print(
                    f"WARNING: {self.jobname} already exists. Using the same jobname. If you want to run the job with a different jobname, set overwrite to False. If you did not change other parameters your files will be overwritten."
                )
            else:
                print(
                    f"WARNING: {self.jobname} already exists. Changing jobname to {self.jobname}_{n}"
                )
                self.jobname = f"{self.jobname}_{n}"

        print("jobname", self.jobname)
        print(f"length={self.u_lengths} copies={self.copies}")

        self.input_opts = {
            "sequence": self.u_sequences,
            "copies": self.copies,
            "msa_method": self.msa_method,
            "pair_mode": self.pair_mode,
            "do_not_filter": self.do_not_filter,
            "cov": self.cov,
            "id": self.id,
            "template_mode": self.template_mode,
            "propagate_to_copies": self.propagate_to_copies,
        }

    def get_msa(self):
        def run_mmseqs2_wrapper(*args, **kwargs):
            kwargs["user_agent"] = "colabdesign/gamma"
            return run_mmseqs2(*args, **kwargs)

        os.makedirs(f"{self.parentPath}/{self.jobname}", exist_ok=True)

        utils = ColabDesignUtils(self.setupPath)
        # Import colabfold_utils.py from the setupPath
        sys.path.append(self.setupPath)
        from colabfold_utils import run_mmseqs2

        # Create "in" folder path variable
        input_path = f"{self.parentPath}/{self.jobname}/in"
        os.makedirs(input_path, exist_ok=True)
        self.Ls = [len(x) for x in self.u_sequences]
        if self.msa_method == "mmseqs2":
            # If the msa.a3m file is already present in the input
            # folder then skip this step and print a message to the user
            # Assuming 'input_folder' is the directory where you want to check the file
            if os.path.isfile(os.path.join(input_path, "msa.a3m")):
                print("msa.a3m file is already present, loading from file.")
                print(
                    "please check that this a3m file contains the sequencens you expect"
                )
                print(
                    "this behaviour was implmented for HPC Slurm cluster usage and may not be the best for most users"
                )
                self.msa, self.deletion_matrix = predict.parse_a3m(
                    f"{self.parentPath}/{self.jobname}/in/msa.a3m"
                )
                self.msa_path = f"{self.parentPath}/{self.jobname}/in/msa.a3m"
                self.job_path = f"{self.parentPath}/{self.jobname}"
            else:
                # The steps to be performed if the file is not present goes here
                self.msa, self.deletion_matrix = predict.get_msa(
                    self.u_sequences,
                    input_path,  # Here
                    mode=self.pair_mode,
                    cov=self.cov,
                    id=self.id,
                    qid=self.qid,
                    max_msa=4096,
                    do_not_filter=self.do_not_filter,
                    mmseqs2_fn=run_mmseqs2_wrapper,
                    hhfilter_fn=utils.run_hhfilter,
                )
                print(f"{self.parentPath}/{self.jobname}/in/msa.a3m")
                self.msa_path = f"{self.parentPath}/{self.jobname}/in/msa.a3m"
                self.job_path = f"{self.parentPath}/{self.jobname}"
        # Else if the MSA method is single sequence
        elif self.msa_method == "single_sequence":
            with open(f"{self.parentPath}/{self.jobname}/in/msa.a3m", "w") as a3m:
                a3m.write(f">{self.jobname}\n{self.sub_seq}\n")
            self.msa, self.deletion_matrix = predict.parse_a3m(
                f"{self.parentPath}/{self.jobname}/in/msa.a3m"
            )
        # If the MSA method is custom_X
        else:
            msa_format = self.msa_method.split("_")[1]
            print(f"MSA mode: {self.msa_method}")
            # TODO : Add support for Google Colab!
            google_colab = False
            local_run = True
            # If google colab is used, the file is stored in the google colab environment
            if google_colab:
                msa_format = self.msa_method.split("_")[1]
                print(f"upload {self.msa_method}")
                from google.colab import files

                msa_dict = files.upload()
                lines = []
                for k, v in msa_dict.items():
                    lines += v.decode().splitlines()
            # Here we handle the custom MSA case
            if local_run:
                print(f"Reading MSA from {self.custom_a3m_path}")
                # Check if the path is valid, in other words, if the file exists
                if not os.path.isfile(self.custom_a3m_path):
                    raise ValueError(
                        f"Invalid path: {self.custom_a3m_path}. The file does not exist."
                    )
                with open(self.custom_a3m_path, "r") as file:
                    lines = file.read().splitlines()
            input_lines = []
            for line in lines:
                line = line.replace("\x00", "")
                if len(line) > 0 and not line.startswith("#"):
                    input_lines.append(line)
            # The following is to avoid errors when running parallel jobs in the cluster
            # Only write the file if it does not exist
            if not os.path.isfile(
                f"{self.parentPath}/{self.jobname}/in/msa.{msa_format}"
            ):
                with open(
                    f"{self.parentPath}/{self.jobname}/in/msa.{msa_format}", "w"
                ) as msa:
                    msa.write("\n".join(input_lines))
            if msa_format != "a3m":
                os.system(
                    f"perl hhsuite/scripts/reformat.pl {msa_format} a3m {self.parentPath}/{self.jobname}/in/msa.{msa_format} {self.parentPath}/{self.jobname}/in/msa.a3m"
                )
            # If the user prefers to skip filtering
            if self.do_not_filter:
                # Print that we are not filtering the MSA and the relevant parameters
                print(f"WARNING: not filtering MSA. Using 0 cov, 0 qid and 100 id")
                print(f"{self.parentPath}/{self.jobname}/in/msa.a3m")
                if "hhsuite" not in os.environ["PATH"]:
                    os.environ[
                        "PATH"
                    ] += f":{os.path.join(self.setupPath, 'hhsuite/bin')}: {os.path.join(self.setupPath, 'hhsuite/scripts')}"
                # Only run hhfilter if the MSA is not already present in the in folder
                # to avoid unnecesarilly running hhfilter when running multiple jobs in the cluster
                if not os.path.isfile(
                    f"{self.parentPath}/{self.jobname}/in/msa.filt.a3m"
                ):
                    os.system(
                        f"hhfilter -qid 0 -id 100 -cov 0 -i {self.parentPath}/{self.jobname}/in/msa.a3m -o {self.parentPath}/{self.jobname}/in/msa.filt.a3m"
                    )
            # Else we proceed to filter the MSA with default HHFilter
            else:
                if "hhsuite" not in os.environ["PATH"]:
                    os.environ[
                        "PATH"
                    ] += f":{os.path.join(self.setupPath, 'hhsuite/bin')}: {os.path.join(self.setupPath, 'hhsuite/scripts')}"
                # Only run hhfilter if the MSA is not already present in the in folder
                # to avoid unnecesarilly running hhfilter when running multiple jobs in the cluster
                if not os.path.isfile(
                    f"{self.parentPath}/{self.jobname}/in/msa.filt.a3m"
                ):
                    os.system(
                        f"hhfilter -qid {self.qid} -id {self.id} -cov {self.cov} -i {self.parentPath}/{self.jobname}/in/msa.a3m -o {self.parentPath}/{self.jobname}/in/msa.filt.a3m"
                    )
            self.msa, self.deletion_matrix = predict.parse_a3m(
                f"{self.parentPath}/{self.jobname}/in/msa.filt.a3m"
            )

        if len(self.msa) > 1:
            predict.plot_msa(self.msa, self.Ls)
            plt.savefig(
                f"{self.parentPath}/{self.jobname}/in/msa_feats.png",
                dpi=200,
                bbox_inches="tight",
            )
            # If this is being run in a juptyer notebook, show the image
            if "ipykernel" in sys.modules:
                plt.show()

    def use_templates(self):
        self.use_templates = self.template_mode in [
            "mmseqs2",
            "custom",
        ]  # Here we define the template mode which is either mmseqs2 or custom
        if self.use_templates:
            print("aligning template")
            template_msa = f"{self.parentPath}/{self.jobname}/in/msa.a3m"
            if self.template_mode == "mmseqs2":
                predict.get_msa(
                    self.u_sequences,
                    self.jobname,
                    mode="unpaired",
                    mmseqs2_fn=lambda *x: run_mmseqs2(
                        *x, user_agent="colabdesign/gamma"
                    ),
                    do_not_filter=True,
                    do_not_return=True,
                    output_a3m=f"{self.parentPath}/{self.jobname}/in/msa_tmp.a3m",
                )
                template_msa = f"{self.parentPath}/{self.jobname}/in/msa_tmp.a3m"
                if not self.propagate_to_copies and self.copies > 1:
                    new_msa = []
                    with open(template_msa, "r") as handle:
                        for line in handle:
                            if not line.startswith(">"):
                                new_msa.append(line.rstrip())
                    with open(template_msa, "w") as handle:
                        for n, seq in enumerate(new_msa):
                            handle.write(f">{n}\n{seq*self.copies}\n")

                templates = {}
                print("ID\tpdb\tcid\tevalue")
                for line in open(
                    f"{self.parentPath}/{self.jobname}/in/msa/_env/pdb70.m8", "r"
                ):
                    p = line.rstrip().split()
                    M, target_id, qid, e_value = p[0], p[1], p[2], p[10]
                    M = int(M)
                    if M not in templates:
                        templates[M] = []
                    if len(templates[M]) < 4:
                        print(f"{int(M)}\t{target_id}\t{qid}\t{e_value}")
                        templates[M].append(target_id)
                if len(templates) == 0:
                    use_templates = False
                    print("ERROR: no templates found...")
                else:
                    Ms = sorted(list(templates.keys()))
                    pdbs, chains = [], []
                    for M in Ms:
                        for n, target_id in enumerate(templates[M]):
                            pdb_id, chain_id = target_id.split("_")
                            if len(pdbs) < n + 1:
                                pdbs.append([])
                                chains.append([])
                            pdbs[n].append(pdb_id)
                            chains[n].append(chain_id)
                    print(pdbs)
            else:
                pdbs, chains = [self.pdb], [self.chain]

        if self.use_templates:
            self.input_opts.update({"pdbs": pdbs, "chains": chains})
            self.batches = []
            for pdb, chain in zip(pdbs, chains):
                query_seq = "".join(self.u_sequences)
                batch = predict.get_template_feats(
                    pdb,
                    chain,
                    query_seq=query_seq,
                    query_a3m=template_msa,
                    copies=self.copies,
                    propagate_to_copies=self.propagate_to_copies,
                    use_seq=not self.rm_sequence,
                    get_pdb_fn=self.get_pdb,
                    align_fn=(
                        self.run_do_not_align if self.do_not_align else self.run_hhalign
                    ),
                )
                self.batches.append(batch)

            plt.figure(figsize=(3 * len(self.batches), 3))
            for n, batch in enumerate(self.batches):
                plt.subplot(1, len(self.batches), n + 1)
                plt.title(f"template features {n+1}")
                dgram = batch["dgram"].argmax(-1).astype(float)
                dgram[batch["dgram"].sum(-1) == 0] = np.nan
                Ln = dgram.shape[0]
                plt.imshow(dgram, extent=(0, Ln, Ln, 0))
                predict.plot_ticks(self.Ls * self.copies)
            plt.savefig(
                f"{self.parentPath}/{self.jobname}/in/template_feats.png",
                dpi=200,
                bbox_inches="tight",
            )
            plt.show()
        else:
            self.batches = [None]

        print("GC", gc.collect())

class PrepModel:
    def __init__(
        self,
        model_type,
        rank_by,
        debug,
        use_initial_guess,
        num_msa,
        num_extra_msa,
        use_cluster_profile,
        u_lengths,
        copies,
        use_templates,
        batches,
        msa,
        deletion_matrix,
        u_sub_lengths,
        u_cyclic,
        setupPath,
    ):
        self.model_type = model_type
        self.rank_by = rank_by
        self.debug = debug
        self.use_initial_guess = use_initial_guess
        self.num_msa = num_msa
        self.num_extra_msa = num_extra_msa
        self.use_cluster_profile = use_cluster_profile
        self.u_lengths = u_lengths
        self.copies = copies
        self.use_templates = use_templates
        self.batches = batches
        self.msa = msa
        self.deletion_matrix = deletion_matrix
        self.u_sub_lengths = u_sub_lengths
        self.u_cyclic = u_cyclic
        self.setupPath = setupPath

    def model_options(self):
        if self.model_type == "monomer (ptm)":
            use_multimer = False
            pseudo_multimer = False
        elif self.model_type == "multimer (v3)":
            use_multimer = True
            pseudo_multimer = False
        elif self.model_type == "pseudo_multimer (v3)":
            use_multimer = True
            pseudo_multimer = True
        elif len(self.u_lengths) > 1 or self.copies > 1:
            use_multimer = True
            pseudo_multimer = False
        else:
            use_multimer = False
            pseudo_multimer = False

        if self.rank_by == "auto":
            self.rank_by = (
                "multi" if (len(self.u_lengths) > 1 or self.copies > 1) else "plddt"
            )

        self.model_opts = {
            "num_msa": self.num_msa,
            "num_extra_msa": self.num_extra_msa,
            "num_templates": len(self.batches),
            "use_cluster_profile": self.use_cluster_profile,
            "use_multimer": use_multimer,
            "pseudo_multimer": pseudo_multimer,
            "use_templates": self.use_templates,
            "use_batch_as_template": False,
            "use_dgram": True,
            "protocol": "hallucination",
            "best_metric": self.rank_by,
            "optimize_seq": False,
            "debug": self.debug,
            "clear_prev": False,
        }

    def initialize_model(self):
        if "af" in dir():
            if self.model_opts != model_opts_:
                if (
                    self.model_opts["use_multimer"] == self.af._args["use_multimer"]
                    and self.model_opts["use_templates"]
                    == self.af._args["use_templates"]
                ):
                    old_params = dict(zip(self.af._model_names, self.af._model_params))
                else:
                    print("loading alphafold params")
                    old_params = {}
                    clear_mem()
                self.af = mk_af_model(
                    old_params=old_params, use_mlm=True, **self.model_opts
                )
                model_opts_ = predict.copy_dict(self.model_opts)
        else:
            print("loading alphafold params 1")
            self.af = mk_af_model(
                use_mlm=True, data_dir=f"{self.setupPath}", **self.model_opts
            )
            model_opts_ = predict.copy_dict(self.model_opts)

    def prep_inputs(self):
        self.af.prep_inputs(self.u_lengths, copies=self.copies, seed=0)
        self.print_key = ["plddt", "ptm"]
        if len(self.af._lengths) > 1:
            self.print_key += ["i_ptm", "multi"]
        self.af.set_opt("con", cutoff=8.0)

    def set_templates(self):
        if self.use_templates:
            self.af.set_opt(use_initial_guess=self.use_initial_guess)
            for n, batch in enumerate(self.batches):
                self.af.set_template(batch=batch, n=n)
            self.af.set_opt(
                "template",
                rm_sc=self.rm_sidechain,
                rm_seq=self.rm_sequence,
                rm_ic=self.rm_interchain,
            )

    def set_msa(self):
        self.af.set_msa(self.msa, self.deletion_matrix)

    def set_chainbreaks(self):
        L_prev = 0
        for n, l in enumerate(self.u_sub_lengths * self.copies):
            for L_i in l[:-1]:
                self.af._inputs["residue_index"][L_prev + L_i :] += 32
                L_prev += L_i
            L_prev += l[-1]

    def set_cyclic_constraints(self):
        i_cyclic = [n for n, c in enumerate(self.u_cyclic * self.copies) if c]
        if len(i_cyclic) > 0:
            add_cyclic_offset(self.af, i_cyclic)


class RunAlphaFold:
    def __init__(
        self,
        jobname,
        model,
        num_recycles,
        recycle_early_stop_tolerance,
        select_best_across_recycles,
        use_mlm,
        use_dropout,
        seed,
        num_seeds,
        show_images,
        use_initial_guess,
        af,
        copies,
        print_key,
        rank_by,
        Ls,
        parentPath,
        masking_mode,
        mask_msa,
        mask_deletion_matrix,
        cols,
        cols_range,
        mask_identity,
    ):
        self.jobname = jobname
        self.model = model
        self.num_recycles = num_recycles
        self.recycle_early_stop_tolerance = recycle_early_stop_tolerance
        self.select_best_across_recycles = select_best_across_recycles
        self.use_mlm = use_mlm
        self.use_dropout = use_dropout
        self.seed = seed
        self.num_seeds = num_seeds
        self.show_images = show_images
        self.use_initial_guess = use_initial_guess
        self.af = af
        self.copies = copies
        self.print_key = print_key
        self.rank_by = rank_by
        self.Ls = Ls
        self.parentPath = parentPath
        self.masking_mode = masking_mode
        self.mask_msa = mask_msa
        self.mask_deletion_matrix = mask_deletion_matrix
        self.cols = cols
        self.cols_range = cols_range
        self.mask_identity = mask_identity

    def run(self):
        run_opts = {
            "seed": self.seed,
            "use_mlm": self.use_mlm,
            "use_dropout": self.use_dropout,
            "num_recycles": self.num_recycles,
            "model": self.model,
            "use_initial_guess": self.use_initial_guess,
            "select_best_across_recycles": self.select_best_across_recycles,
            "recycle_early_stop_tolerance": self.recycle_early_stop_tolerance,
        }

        # decide which models to use
        if self.model == "all":
            models = self.af._model_names
        else:
            models = [self.af._model_names[int(self.model) - 1]]

        # set options
        self.af.set_opt("mlm", replace_fraction=0.15 if self.use_mlm else 0.0)

        pdb_path = f"{self.parentPath}/{self.jobname}/out"
        os.makedirs(pdb_path, exist_ok=True)
        # Make  a figs, pdbs, and npz folder
        os.makedirs(f"{pdb_path}/figs", exist_ok=True)
        os.makedirs(f"{pdb_path}/pdbs", exist_ok=True)
        os.makedirs(f"{pdb_path}/npz", exist_ok=True)

        # keep track of results
        info = []
        self.af._tmp = {
            "traj": {"seq": [], "xyz": [], "plddt": [], "pae": []},
            "log": [],
            "best": {},
        }

        # run
        print("running prediction")
        with open(f"{self.parentPath}/{self.jobname}/log.txt", "w") as handle:
            # go through all seeds
            seeds = list(range(self.seed, self.seed + self.num_seeds))
            for seed in seeds:
                self.af.set_seed(seed)
                # go through all models
                for model in models:
                    recycle = 0
                    self.af._inputs.pop("prev", None)
                    stop_recycle = False
                    prev_pos = None
                    # go through all recycles
                    while recycle < self.num_recycles + 1:
                        print_str = (
                            f"seed={str(seed).zfill(3)} model={model} recycle={recycle}"
                        )
                        self.af.predict(
                            dropout=self.use_dropout, models=[model], verbose=False
                        )

                        # set previous inputs
                        self.af._inputs["prev"] = self.af.aux["prev"]

                        # save results
                        if len(self.af._lengths) > 1:
                            self.af.aux["log"]["multi"] = (
                                0.8 * self.af.aux["log"]["i_ptm"]
                                + 0.2 * self.af.aux["log"]["ptm"]
                            )
                        # If mask_msa is True, then add it to the pdb file name
                        if self.mask_msa:
                            # If masking_mode is "list", then add the list of columns to the pdb file name
                            if self.masking_mode == "list":
                                # Convert the list of columns to a string, considering that the list can be empty
                                # I want to include an edge case that the filename is too long

                                cols_str = (
                                    str(self.cols[0])
                                    if len(self.cols) == 1
                                    else (
                                        f"{self.cols[0]}-{self.cols[-1]}"
                                        if self.cols
                                        else "false"
                                    )
                                )
                                self.af.save_current_pdb(
                                    f"{pdb_path}/pdbs/{self.jobname}_{model}_r{recycle}_seed_{str(seed).zfill(3)}_mask_{cols_str}_id_{self.mask_identity}.pdb"
                                )
                            # If masking_mode is "range", then add the range of columns to the pdb file name
                            elif self.masking_mode == "range":
                                # Raise a not implemented error
                                raise NotImplementedError(
                                    "The masking_mode 'range' is not fully implemented yet."
                                )
                        else:
                            self.af.save_current_pdb(
                                f"{pdb_path}/pdbs/{self.jobname}_{model}_r{recycle}_seed_{str(seed).zfill(3)}.pdb"
                            )

                        # print metrics
                        for k in self.print_key:
                            print_str += f" {k}={self.af.aux['log'][k]:.3f}"

                        # early stop check
                        current_pos = self.af.aux["atom_positions"][:, 1]
                        if recycle > 0:
                            rmsd_tol = _np_rmsd(prev_pos, current_pos, use_jax=False)
                            if rmsd_tol < self.recycle_early_stop_tolerance:
                                stop_recycle = True
                            print_str += f" rmsd_tol={rmsd_tol:.3f}"
                        prev_pos = current_pos
                        # print metrics
                        # print(print_str)
                        handle.write(f"{print_str}\n")

                        tag = f"{model}_r{recycle}_seed_{str(seed).zfill(3)}"
                        if self.select_best_across_recycles:
                            info.append(
                                [tag, print_str, self.af.aux["log"][self.rank_by]]
                            )
                            self.af._save_results(
                                save_best=True,
                                best_metric=self.rank_by,
                                metric_higher_better=True,
                                verbose=False,
                            )
                            self.af._k += 1

                        recycle += 1
                        if stop_recycle:
                            break

                    # Check if the 'select_best_across_recycles' attribute is set to False
                    if not self.select_best_across_recycles:
                        # If it is False, append the tag, print string, and the log of the rank_by attribute to the info list
                        info.append([tag, print_str, self.af.aux["log"][self.rank_by]])

                        # Call the '_save_results' method of the 'af' object to save the best results
                        # 'save_best' is set to True to indicate that the best results should be saved
                        # 'best_metric' is set to the 'rank_by' attribute to specify the metric to rank the results by
                        # 'metric_higher_better' is set to True to indicate that higher values of the metric are better
                        # 'verbose' is set to False to prevent the method from printing additional information
                        self.af._save_results(
                            save_best=True,
                            best_metric=self.rank_by,
                            metric_higher_better=True,
                            verbose=False,
                        )

                        # Increment the '_k' attribute of the 'af' object by 1
                        # after finishing all recycles for the current model
                        self.af._k += 1

                    # save current results for each model(after n recycles)
                    ColabDesignUtils.plot_3D(
                        aux=self.af.aux,
                        Ls=self.Ls * self.copies,
                        file_name=f"{pdb_path}/figs/{self.jobname}_{model}_seed_{str(seed).zfill(3)}_mask_{cols_str}_id_{self.mask_identity}.pdf",
                        show=self.show_images,
                    )
                    predict.plot_confidence(
                        self.af.aux["plddt"] * 100,
                        self.af.aux["pae"],
                        self.Ls * self.copies,
                    )
                    plt.savefig(
                        f"{pdb_path}/figs/{self.jobname}_{model}_seed_{str(seed).zfill(3)}_mask_{cols_str}_id_{self.mask_identity}.png",
                        dpi=200,
                        bbox_inches="tight",
                    )
                    plt.close()

        # save best results
        rank = np.argsort([x[2] for x in info])[::-1][:5]
        print(f"best_tag={info[rank[0]][0]} {info[rank[0]][1]}")

        aux_best = self.af._tmp["best"]["aux"]
        # Save the best pdb file
        # If mask_msa is True, then add it to the pdb file name
        if self.mask_msa:
            # If masking_mode is "list", then add the list of columns to the pdb file name
            if self.masking_mode == "list":
                # Convert the list of columns to a string, considering that the list can be empty
                cols_str = (
                    str(self.cols[0])
                    if len(self.cols) == 1
                    else f"{self.cols[0]}-{self.cols[-1]}" if self.cols else "false"
                )
                self.af.save_pdb(
                    f"{pdb_path}/pdbs/{self.jobname}_best_{info[rank[0]][0]}_mask_{cols_str}_id_{self.mask_identity}.pdb"
                )
                # Save npz file
                np.savez_compressed(
                    f"{pdb_path}/npz/{self.jobname}_best_{info[rank[0]][0]}_mask_{cols_str}_id_{self.mask_identity}.npz",
                    plddt=aux_best["plddt"].astype(np.float16),
                    pae=aux_best["pae"].astype(np.float16),
                    tag=np.array(info[rank[0]][0]),
                    metrics=np.array(info[rank[0]][1]),
                )
                # Save the all npz file
                np.savez_compressed(
                    f"{pdb_path}/npz/{self.jobname}_all_{info[rank[0]][0]}_mask_{cols_str}_id_{self.mask_identity}.npz",
                    plddt=np.array(self.af._tmp["traj"]["plddt"], dtype=np.float16),
                    pae=np.array(self.af._tmp["traj"]["pae"], dtype=np.float16),
                    tag=np.array([x[0] for x in info]),
                    metrics=np.array([x[1] for x in info]),
                )
            # If masking_mode is "range", then add the range of columns to the pdb file name
            elif self.masking_mode == "range":
                # Raise a not implemented error
                raise NotImplementedError(
                    "The masking_mode 'range' is not fully implemented yet."
                )
        else:
            self.af.save_pdb(
                f"{pdb_path}/pdbs/{self.jobname}_best_{info[rank[0]][0]}.pdb"
            )
            np.savez_compressed(
                f"{pdb_path}/npz/{self.jobname}_best_{info[rank[0]][0]}.npz",
                plddt=aux_best["plddt"].astype(np.float16),
                pae=aux_best["pae"].astype(np.float16),
                tag=np.array(info[rank[0]][0]),
                metrics=np.array(info[rank[0]][1]),
            )
            np.savez_compressed(
                f"{pdb_path}/npz/{self.jobname}_all_{info[rank[0]][0]}.npz",
                plddt=np.array(self.af._tmp["traj"]["plddt"], dtype=np.float16),
                pae=np.array(self.af._tmp["traj"]["pae"], dtype=np.float16),
                tag=np.array([x[0] for x in info]),
                metrics=np.array([x[1] for x in info]),
            )

        # If this is being run in a juptyer notebook, show the image
        if "ipykernel" in sys.modules:
            ColabDesignUtils.plot_3D(
                aux_best, self.Ls * self.copies, f"{pdb_path}/figs/best.pdf", show=False
            )
            predict.plot_confidence(
                aux_best["plddt"] * 100, aux_best["pae"], self.Ls * self.copies
            )
            plt.savefig(f"{pdb_path}/figs/best.png", dpi=200, bbox_inches="tight")
            plt.close()

        # garbage collection
        print("GC", gc.collect())


class DefaultPipeline:
    """
    This class is used to validate and manage parameters for the DefaultPipeline.

    Parameters
    ----------
    SetupAlphaFoldColabDesign : SetupAlphaFoldColabDesign

        unified_memory : bool
            If True, use unified memory.

        parentPath : str
            The path to the parent directory.

    PrepInputs : PrepInputs

        sequence : str
            The sequence to be processed.

        jobname : str
            The name of the job.

        copies : int
            The number of copies.

        ## MSA retrieval options

        ### MMseqs2 options
        msa_method : str
            The method used for multiple sequence alignment. Options are "mmseqs2","single_sequence", "custom_fas", "custom_a3m", "custom_sto".

        pair_mode : str
            The pairing mode. Options are "unpaired_paired","paired","unpaired".


        ### HHfilter options
            https://sarata.com/manpages/hhfilter.1.html

        cov : int
            [0,100]  minimum coverage with query (%) (def=0)

        id : int
            [0,100]  maximum pairwise sequence identity (%) (def=90)

        qid : int
            [0,100]  minimum sequence identity with query (%) (def=0)

        do_not_filter : bool
            If True, do not filter.

        ## Template options

        template_mode : str
            The template mode. Options are "none", "mmseqs2", "custom".

        pdb : str
            The pdb.

        chain : str
            The chain.

        rm_template_seq : bool
            If True, remove template sequence.

        propagate_to_copies : bool
            If True, propagate to copies.

        do_not_align : bool
            If True, do not align.

    PrepModel : PrepModel
    ## AF2 Model preparation options

        model_type : str
            The model type. Options are "monomer (ptm)", "pseudo_multimer (v3)", "multimer (v3)", "auto".

        rank_by : str
            The rank by. Options are "auto", "plddt", "ptm".

        debug : bool
            If True, debug.

        use_initial_guess : bool
            If True, use initial guess.
    ## AF2 MSA options

        num_msa : str
            The number of msa. Options are "1","2","4","8","16","32", "64", "128", "256", "512".

        num_extra_msa : str
            The number of extra msa. Options are "1","2","4","8","16","32", "64", "128", "256", "512", "1024","2048","4096".

        use_cluster_profile : bool
            If True, use cluster profile.

    RunAlphaFold : RunAlphaFold
    ## AF2 Model run options
        model : str
            The model. Options are "1", "2", "3", "4", "5", "all".

        num_recycles : int
            The number of recycles. Options are "0", "1", "2", "3", "6", "12", "24".

        recycle_early_stop_tolerance : float
            The recycle early stop tolerance. Options are "0.0", "0.5", "1.0".

        select_best_across_recycles : bool
            If True, select best across recycles.

    ## AF2 sthochastic options
        use_mlm : bool
            If True, use mlm.

        use_dropout : bool
            If True, use dropout.

        seed : int
            The seed.

        num_seeds : int
            The number of seeds. Options are "1", "2", "4", "8", "16", "32", "64", "128".
    ## Plotting options
        show_images : bool
            If True, show images.
    """

    def __init__(self, params=None, yaml_file=None):
        if yaml_file is not None:
            with open(yaml_file, "r") as f:
                params = yaml.safe_load(f)

        self.params = params or {}
        self.required_parameters = [
            "unified_memory",
            "parentPath",
            "setupPath",
            "sequence",
            "jobname",
            "copies",
            "msa_method",
            "custom_a3m_path",
            "pair_mode",
            "cov",
            "id",
            "qid",
            "do_not_filter",
            "template_mode",
            "pdb",
            "chain",
            "rm_template_seq",
            "propagate_to_copies",
            "do_not_align",
            "model_type",
            "rank_by",
            "debug",
            "use_initial_guess",
            "num_msa",
            "num_extra_msa",
            "use_cluster_profile",
            "model",
            "num_recycles",
            "recycle_early_stop_tolerance",
            "select_best_across_recycles",
            "use_mlm",
            "use_dropout",
            "seed",
            "num_seeds",
            "show_images",
            "masking_mode",
            "mask_msa",
            "mask_deletion_matrix",
            "cols",
            "cols_range",
            # Target identity of the masked columns
            # X is used by default
            # We should consider comparing between - and X
            "mask_identity",
        ]
        self.param_types = {
            "unified_memory": bool,
            "parentPath": str,
            "setupPath": str,
            "sequence": str,
            "jobname": str,
            "copies": int,
            "msa_method": str,
            "custom_a3m_path": str,
            "pair_mode": str,
            "cov": int,
            "id": int,
            "qid": int,
            "do_not_filter": bool,
            "template_mode": str,
            "pdb": str,
            "chain": str,
            "rm_template_seq": bool,
            "propagate_to_copies": bool,
            "do_not_align": bool,
            "model_type": str,
            "rank_by": str,
            "debug": bool,
            "use_initial_guess": bool,
            "num_msa": int,
            "num_extra_msa": int,
            "use_cluster_profile": bool,
            "model": str,
            "num_recycles": int,
            "recycle_early_stop_tolerance": float,
            "select_best_across_recycles": bool,
            "use_mlm": bool,
            "use_dropout": bool,
            "seed": int,
            "num_seeds": int,
            "show_images": bool,
            # Masking parameters
            "masking_mode": str,
            "mask_msa": bool,
            "mask_deletion_matrix": bool,
            # List of integers
            "cols": list,
            # List of tuples(ranges of columns)
            "cols_range": list,
            # Target identity of the masked columns
            # X is used by default
            # We should consider comparing between - and X
            "mask_identity": str,
        }

        self.param_ranges = {
            "copies": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            "msa_method": [
                "mmseqs2",
                "single_sequence",
                "custom_fas",
                "custom_a3m",
                "custom_sto",
            ],
            "pair_mode": ["unpaired_paired", "paired", "unpaired"],
            "cov": [0, 25, 50, 75, 90, 99],
            "id": [90, 100],
            "qid": [0, 10, 15, 20, 30],
            "template_mode": ["none", "mmseqs2", "custom"],
            "model_type": [
                "monomer (ptm)",
                "pseudo_multimer (v3)",
                "multimer (v3)",
                "auto",
            ],
            "rank_by": ["auto", "plddt", "ptm"],
            "num_msa": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
            "num_extra_msa": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
            "model": ["1", "2", "3", "4", "5", "all"],
            "recycle_early_stop_tolerance": [0.0, 0.5, 1.0],
            "color": ["pLDDT", "chain", "rainbow"],
        }
        self._validate_parameters()

    def _validate_parameters(self):
        for param in self.required_parameters:
            if param not in self.params:
                raise ValueError(f"{param} is a required parameter")

            # Check parameter type
            if not isinstance(self.params[param], self.param_types[param]):
                raise ValueError(
                    f"Parameter {param} should be of type {self.param_types[param]}"
                )

            # Check parameter range if applicable
            if (
                param in self.param_ranges
                and self.params[param] not in self.param_ranges[param]
            ):
                raise ValueError(
                    f"Parameter {param} should be one of {self.param_ranges[param]}"
                )

    def __getattr__(self, attr):
        if attr in self.params:
            return self.params[attr]
        raise AttributeError(f"Attribute {attr} not found")

    def _save_config(self):
        config = {
            "unified_memory": self.unified_memory,
            "parentPath": self.parentPath,
            "setupPath": self.setupPath,
            "sequence": self.sequence,
            "jobname": self.jobname,
            "copies": self.copies,
            "msa_method": self.msa_method,
            "custom_a3m_path": self.custom_a3m_path,
            "pair_mode": self.pair_mode,
            "cov": self.cov,
            "id": self.id,
            "qid": self.qid,
            "do_not_filter": self.do_not_filter,
            "template_mode": self.template_mode,
            "pdb": self.pdb,
            "chain": self.chain,
            "rm_template_seq": self.rm_template_seq,
            "propagate_to_copies": self.propagate_to_copies,
            "do_not_align": self.do_not_align,
            "model_type": self.model_type,
            "rank_by": self.rank_by,
            "debug": self.debug,
            "use_initial_guess": self.use_initial_guess,
            "num_msa": self.num_msa,
            "num_extra_msa": self.num_extra_msa,
            "use_cluster_profile": self.use_cluster_profile,
            "model": self.model,
            "num_recycles": self.num_recycles,
            "recycle_early_stop_tolerance": self.recycle_early_stop_tolerance,
            "select_best_across_recycles": self.select_best_across_recycles,
            "use_mlm": self.use_mlm,
            "use_dropout": self.use_dropout,
            "seed": self.seed,
            "num_seeds": self.num_seeds,
            "show_images": self.show_images,
        }

    def run(self):
        predictor = SetupAlphaFoldColabDesign(self.unified_memory, self.parentPath)
        predictor.setup()
        prep_inputs = PrepInputs(
            self.sequence,
            self.jobname,
            self.copies,
            self.msa_method,
            self.custom_a3m_path,
            self.pair_mode,
            self.cov,
            self.id,
            self.qid,
            self.do_not_filter,
            self.template_mode,
            self.pdb,
            self.chain,
            self.rm_template_seq,
            self.propagate_to_copies,
            self.do_not_align,
            self.setupPath,
            self.parentPath,
        )
        prep_inputs.filter_options()
        prep_inputs.process_sequence()
        prep_inputs.get_msa()
        prep_inputs.use_templates()
        self._save_config()
        prep_model = PrepModel(
            self.model_type,
            self.rank_by,
            self.debug,
            self.use_initial_guess,
            self.num_msa,
            self.num_extra_msa,
            self.use_cluster_profile,
            prep_inputs.u_lengths,
            self.copies,
            prep_inputs.use_templates,
            prep_inputs.batches,
            prep_inputs.msa,
            prep_inputs.deletion_matrix,
            prep_inputs.u_sub_lengths,
            prep_inputs.u_cyclic,
            self.setupPath,
        )
        prep_model.model_options()
        prep_model.initialize_model()
        prep_model.prep_inputs()
        prep_model.set_templates()
        prep_model.set_msa()
        prep_model.set_chainbreaks()
        prep_model.set_cyclic_constraints()

        run_alphafold = RunAlphaFold(
            prep_inputs.jobname,
            self.model,
            self.num_recycles,
            self.recycle_early_stop_tolerance,
            self.select_best_across_recycles,
            self.use_mlm,
            self.use_dropout,
            self.seed,
            self.num_seeds,
            self.show_images,
            self.use_initial_guess,
            prep_model.af,
            copies=prep_inputs.input_opts["copies"],
            print_key=prep_model.print_key,
            rank_by=prep_model.rank_by,
            Ls=prep_inputs.Ls,
            parentPath=self.parentPath,
        )
        run_alphafold.run()

In [None]:
import os
import requests
import random
import string
import yaml
import subprocess

import shutil
import glob
import sys

from google.colab import files


# Function to detect if running on Google Colab
def is_colab():
    return "COLAB_GPU" in os.environ


# Base directory setup
BASE_DIR = "/content" if is_colab() else os.path.expanduser("~")
INPUT_DIR = os.path.join(BASE_DIR, "input")
OUTPUT_DIR = os.path.join(BASE_DIR, "output")

# Ensure the input and output directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


def handle_small_molecule_input(
    small_molecule_chain_ID, small_molecule_chain_path, small_molecule_chain_type
):
    if small_molecule_chain_path == "":
        print(f"Please upload the {small_molecule_chain_ID} small molecule file:")
        # Assuming running in Colab
        # Prompt the user with an upload dialog and then use the uploaded file path as small_molecule_chain_path
        uploaded = files.upload()
        for fn in uploaded.keys():
            small_molecule_chain_path = fn
            print(f'User uploaded file "{fn}" with length {len(uploaded[fn])} bytes')
    return small_molecule_chain_path


def handle_protein_input(protein_chain_ID, protein_chain_path, protein_chain_type):
    if protein_chain_path == "":
        print(f"Please upload the {protein_chain_ID} protein file:")
        # Assuming running in Colab
        # Prompt the user with an upload dialog and then use the uploaded file path as protein_chain_path
        uploaded = files.upload()
        for fn in uploaded.keys():
            protein_chain_path = fn
            print(f'User uploaded file "{fn}" with length {len(uploaded[fn])} bytes')
    return protein_chain_path


def handle_msa_input(msa_path):
    if msa_path == "":
        print("Please upload the MSA file:")
        # Assuming running in Colab
        # Prompt the user with an upload dialog and then use the uploaded file path as msa_path
        uploaded = files.upload()
        for fn in uploaded.keys():
            msa_path = fn
            print(f'User uploaded file "{fn}" with length {len(uploaded[fn])} bytes')
    return msa_path


def run_rf_all_atom(
    config,
    output_subfolder=None,
    output_prefix=None,
    show_last_n_lines=5,
    save_stdout=True,
    overwrite=None,
    msa_path=None,
):
    """
    Wrapper function to run rf all atom with specified options, using a YAML configuration file.
    The configuration is passed as a dictionary.
    """
    # Generate a base output directory name without duplicating parts of the path
    base_output_path = os.path.join(OUTPUT_DIR, output_subfolder)
    print(f"Output base directory: {base_output_path}")

    if overwrite:
        # Initialize counter to generate a unique output directory
        counter = 0
        unique_output_path = f"{base_output_path}/{output_prefix}_{counter}"
        while os.path.exists(unique_output_path):
            counter += 1
            unique_output_path = f"{base_output_path}/{output_prefix}_{counter}"
        final_output_path = unique_output_path
        print(f"Final output path: {final_output_path}")
    else:
        final_output_path = base_output_path
        print(f"Final output path: {final_output_path}")

    # Ensure the final output directory exists
    os.makedirs(final_output_path, exist_ok=True)

    # Update the output_prefix in the config with the actual output path
    # config["inference"]["output_prefix"] = os.path.join(final_output_path, output_prefix)

    # Write the configuration to a YAML file inside the correct output directory
    config_filename = "config.yaml"  # Configuration file name
    config_file_path = os.path.join(
        final_output_path, config_filename
    )  # Full path to the configuration file
    with open(config_file_path, "w") as file:
        yaml.dump(config, file)

    # Correct the command to run the inference script with the YAML config file
    cmd = [
        "python",
        f"{BASE_DIR}/RoseTTAFold-All-Atom/rf2aa/run_inference.py",
        f"--config-name={config_filename[:-5]}",  # Remove the '.yaml' extension
        f"--config-dir={final_output_path}",
        f"+protein_inputs.A.msa_file={msa_path}",
        f"output_path={final_output_path}",
    ]

    # Print the command to the console
    print(f"Running command: {' '.join(cmd)}")

    # Initialize a list to keep track of the output lines
    output_lines = []

    # Use subprocess.Popen to run the command and capture stdout in real-time
    process = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True
    )

    # Correct the path for the log file to ensure it's saved in the final_output_path
    if save_stdout:
        log_file_path = os.path.join(
            final_output_path, "run_inference.log"
        )  # Corrected path
        log_file = open(log_file_path, "w")

    # Periodically check for new output
    while True:
        output = process.stdout.readline()
        if output == "" and process.poll() is not None:
            break
        if output:
            output_lines.append(output.strip())
            # Save to log file if required
            if save_stdout:
                log_file.write(output)

            # Display the last N lines if required
            if show_last_n_lines > 0:
                display_lines = output_lines[-show_last_n_lines:]
                print("\n".join(display_lines))

        # time.sleep(1)  # Adjust the sleep time as needed

    # Ensure the process has finished and close the log file if it was opened
    process.poll()
    if save_stdout:
        log_file.close()

    return final_output_path


# @title ### RF all atom

# Interface for specifying PDB input
msa_input_type = "mmseqs2"  # @param ["mmseqs2", "custom_a3m"]
jobname = "7u7w"  # @param {type:"string"}
custom_path = ""  # @param {type:"string"}

sequence = "TATGDEWWAKCKQVDVLDSEMSYYDSDPGKHKNTVIFLHGNPTSSYLWRNVIPHVEPLARCLAPDLIGMGKSGKLPNHSYRFVDHYRYLSAWFDSVNLPEKVTIVCHDWGSGLGFHWCNEHRDRVKGIVHMESVVDVIESWDEWPDIEEDIALIKSEAGEEMVLKKNFFIERLLPSSIMRKLSEEEMDAYREPFVEPGESRRPTLTWPREIPIKGDGPEDVIEIVKSYNKWLSTSKDIPKLFINADPGFFSNAIKKVTKNWPNQKTVTVKGLHFLQEDSPEEIGEAIADFLNELT"  # @param {type:"string"}
protein_chain_1_ID = "A"  # @param ["A", "B", "C", "D", "E", "F"]
protein_chain_2_ID = ""  # @param ["","A", "B", "C", "D", "E", "F"]
small_molecule_chain_1_ID = "B"  # @param {type:"string"}
small_molecule_chain_1_path = "/content/RoseTTAFold-All-Atom/examples/small_molecule/NSW_ideal.sdf"  # @param {type:"string"}
small_molecule_chain_1_type = "sdf"  # @param ["sdf"]

small_molecule_chain_2_ID = ""  # @param ["","A", "B", "C", "D", "E", "F"]
small_molecule_chain_2_path = ""  # @param {type:"string"}
small_molecule_chain_2_type = "sdf"  # @param ["sdf"]


if msa_input_type == "mmseqs2":
    yaml_dict = {
        "unified_memory": False,
        "parentPath": "/content/output",
        "setupPath": "/content/",
        "sequence": sequence,  # Here
        "jobname": jobname,  # Here
        "copies": 1,
        "msa_method": msa_input_type,
        "custom_a3m_path": "",
        "pair_mode": "unpaired_paired",
        "cov": 75,
        "id": 90,
        "qid": 0,
        "do_not_filter": False,
        "template_mode": "none",
        "pdb": "",
        "chain": "A",
        "rm_template_seq": False,
        "propagate_to_copies": True,
        "do_not_align": False,
        "model_type": "monomer (ptm)",
        "rank_by": "auto",
        "debug": True,
        "use_initial_guess": False,
        "num_msa": 512,
        "num_extra_msa": 1024,
        "use_cluster_profile": True,
        "model": "all",
        "num_recycles": 1,  # Here
        "recycle_early_stop_tolerance": 0.0,
        "select_best_across_recycles": False,
        "use_mlm": False,
        "use_dropout": False,
        "seed": 0,
        "num_seeds": 1,  # Here
        "show_images": False,
        "overwrite": True,
    }

    prep_inputs = PrepInputs(
        yaml_dict["sequence"],
        yaml_dict["jobname"],
        yaml_dict["copies"],
        yaml_dict["msa_method"],
        yaml_dict["custom_a3m_path"],
        yaml_dict["pair_mode"],
        yaml_dict["cov"],
        yaml_dict["id"],
        yaml_dict["qid"],
        yaml_dict["do_not_filter"],
        yaml_dict["template_mode"],
        yaml_dict["pdb"],
        yaml_dict["chain"],
        yaml_dict["rm_template_seq"],
        yaml_dict["propagate_to_copies"],
        yaml_dict["do_not_align"],
        yaml_dict["setupPath"],
        yaml_dict["parentPath"],
        yaml_dict["overwrite"],
    )
    prep_inputs.filter_options()
    prep_inputs.process_sequence()
    prep_inputs.get_msa()


elif msa_input_type == "upload":
    print("Please upload your PDB file:")
    # Assuming running in Colab
    # Raise error not implemented
    input_pdb = handle_msa_input(msa_input_type)
elif msa_input_type == "custom_a3m":
    # Raise error not implemented
    input_pdb = handle_msa_input(msa_input_type, jobname, custom_path)

if protein_chain_1_ID == "":
    # Raise an error that the protein chain 1 ID is empty
    raise ValueError("Please specify a protein chain 1 ID")

# If the user specifies 2 protein chains, check they are different, check if any is empty string
if protein_chain_2_ID != "" and protein_chain_1_ID == protein_chain_2_ID:
    # Raise a warning that the protein chain 2 ID is the same as protein chain 1 ID
    raise ValueError("Please specify a different protein chain 2 ID")

# If the user specified a small molecule chain 1, but not the path, prompt the user to upload the file
if small_molecule_chain_1_ID != "" and small_molecule_chain_1_path == "":
    small_molecule_chain_1_path = handle_small_molecule_input(
        small_molecule_chain_1_ID,
        small_molecule_chain_1_path,
        small_molecule_chain_1_type,
    )

# If the user specifies 2 small molecule chains, check they are different, check if any is empty string
if (
    small_molecule_chain_2_ID != ""
    and small_molecule_chain_1_ID == small_molecule_chain_2_ID
):
    # Raise a warning that the small molecule chain 2 ID is the same as small molecule chain 1 ID
    print("Please specify a different small molecule chain 2 ID")

# If the user specified a small molecule chain 2, but not the path, prompt the user to upload the file
if small_molecule_chain_2_ID != "" and small_molecule_chain_2_path == "":
    small_molecule_chain_2_path = handle_small_molecule_input(
        small_molecule_chain_2_ID,
        small_molecule_chain_2_path,
        small_molecule_chain_2_type,
    )


# Define the configuration dictionary based on the user inputs
config = {
    "protein_inputs": {
        f"{protein_chain_1_ID}": {
            "fasta_file": "/content/RoseTTAFold-All-Atom/examples/protein/7u7w_A.fasta"
        }
    },
    "job_name": f"{prep_inputs.jobname}",
    "checkpoint_path": "/content/params/RFAA_paper_weights.pt",
    "defaults": ["base"],
}

# If the protein chain 2 is specified is not an empty string add the protein_inputs section to the config
if protein_chain_2_ID != "":
    config["protein_inputs"][f"{protein_chain_2_ID}"] = {
        "fasta_file": f"/content/RoseTTAFold-All-Atom/examples/protein/{jobname}_{protein_chain_2_ID}.fasta"
    }

# If the small molecule chain 1 is specified is not an empty string add the sm_inputs section to the config
if small_molecule_chain_1_ID != "":
    config["sm_inputs"] = {
        f"{small_molecule_chain_1_ID}": {
            "input": small_molecule_chain_1_path,
            "input_type": small_molecule_chain_1_type,
        }
    }

# If the small molecule chain 2 is specified is not an empty string add the sm_inputs section to the config
if small_molecule_chain_2_ID != "":
    config["sm_inputs"] = {
        f"{small_molecule_chain_2_ID}": {
            "input": small_molecule_chain_2_path,
            "input_type": small_molecule_chain_2_type,
        }
    }

# Specify additional options for the run function
show_last_n_lines = 1  # Show only the last 1 lines of stdout to avoid cluttering the notebook, above 1 is not working at the moment
save_stdout = True  # Save the stdout as a logfile in the output folder

# Call the run_rf_all_atom function with the specified configuration and options
final_output_path = run_rf_all_atom(
    config,
    output_subfolder=prep_inputs.job_path,
    output_prefix="out",
    show_last_n_lines=show_last_n_lines,
    save_stdout=save_stdout,
    overwrite=yaml_dict["overwrite"],
    msa_path=prep_inputs.msa_path,
)


In [None]:
import os
import requests
import random
import string
import yaml
import subprocess

import shutil
import glob


# Function to detect if running on Google Colab
def is_colab():
    return "COLAB_GPU" in os.environ

# Base directory setup
BASE_DIR = "/content" if is_colab() else os.path.expanduser("~")
INPUT_DIR = os.path.join(BASE_DIR, "input")
OUTPUT_DIR = os.path.join(BASE_DIR, "output")

# Ensure the input and output directories exist
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

def download_pdb(pdb_code, output_dir=INPUT_DIR):
    """
    Download a PDB file given a PDB code.
    """
    url = f"https://files.rcsb.org/download/{pdb_code}.pdb"
    response = requests.get(url)
    if response.status_code == 200:
        pdb_path = os.path.join(output_dir, f"{pdb_code}.pdb")
        with open(pdb_path, 'w') as file:
            file.write(response.text)
        return pdb_path
    else:
        raise ValueError(f"Failed to download PDB file for {pdb_code}")

def handle_pdb_input(pdb_input_type, pdb_code=None, custom_path=None, output_dir=INPUT_DIR):
    """
    Handle PDB input by either uploading a file or downloading it using a PDB code.
    """
    if pdb_input_type == "upload":
        if is_colab():
            from google.colab import files
            uploaded = files.upload()
            pdb_filename = next(iter(uploaded))
            pdb_path = os.path.join(output_dir, pdb_filename)
            with open(pdb_path, 'wb') as file:
                file.write(uploaded[pdb_filename])
            return pdb_path
        else:
            raise EnvironmentError("File upload is only supported on Google Colab.")
    elif pdb_input_type == "pdb_code":
        return download_pdb(pdb_code, output_dir)

    elif pdb_input_type == "custom_path":
        if custom_path is not None:
          return custom_path

    else:
        raise ValueError("Invalid PDB input type")

def run_rfdiffusion_all_atom(config, output_subfolder="ligand_protein_motif", output_prefix="sample", show_last_n_lines=5, save_stdout=True):
    """
    Wrapper function to run rfdiffusion all atom with specified options, using a YAML configuration file.
    The configuration is passed as a dictionary.
    """
    # Generate a base output directory name without duplicating parts of the path
    base_output_path = os.path.join(OUTPUT_DIR, output_subfolder)
    print(f"Output base directory: {base_output_path}")

    # Initialize counter to generate a unique output directory
    counter = 0
    unique_output_path = f"{base_output_path}/{output_prefix}_{counter}"
    while os.path.exists(unique_output_path):
        counter += 1
        unique_output_path = f"{base_output_path}/{output_prefix}_{counter}"

    final_output_path = unique_output_path
    print(f"Final output path: {final_output_path}")
    # Ensure the final output directory exists
    os.makedirs(final_output_path, exist_ok=True)

    # Update the output_prefix in the config with the actual output path
    config["inference"]["output_prefix"] = os.path.join(final_output_path, output_prefix)

    # Write the configuration to a YAML file inside the correct output directory
    config_filename = "config.yaml"  # Configuration file name
    config_file_path = os.path.join(final_output_path, config_filename)  # Full path to the configuration file
    with open(config_file_path, 'w') as file:
        yaml.dump(config, file)

    # Correct the command to run the inference script with the YAML config file
    cmd = [
        "python", "./rf_diffusion_all_atom/run_inference.py",
        f"--config-name={config_filename[:-5]}",  # Remove the '.yaml' extension
        f"--config-dir={final_output_path}",
        f"diffuser.T={config['diffuser']['T']}" # I do not know why this need to be added again if its already in the config.yaml
    ]

    # Print the command to the console
    print(f"Running command: {' '.join(cmd)}")

    # Initialize a list to keep track of the output lines
    output_lines = []

    # Use subprocess.Popen to run the command and capture stdout in real-time
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)

    # Correct the path for the log file to ensure it's saved in the final_output_path
    if save_stdout:
        log_file_path = os.path.join(final_output_path, "run_inference.log")  # Corrected path
        log_file = open(log_file_path, "w")

    # Periodically check for new output
    while True:
        output = process.stdout.readline()
        if output == '' and process.poll() is not None:
            break
        if output:
            output_lines.append(output.strip())
            # Save to log file if required
            if save_stdout:
                log_file.write(output)

            # Display the last N lines if required
            if show_last_n_lines > 0:
                display_lines = output_lines[-show_last_n_lines:]
                print("\n".join(display_lines))

        # time.sleep(1)  # Adjust the sleep time as needed

    # Ensure the process has finished and close the log file if it was opened
    process.poll()
    if save_stdout:
        log_file.close()


    return final_output_path



#@title ### Small molecule binder design with protein motif

# Interface for specifying PDB input
pdb_input_type = "pdb_code" #@param ["upload", "pdb_code", "manual_path"]
pdb_code = "7v11" #@param {type:"string"}
custom_path = "" #@param {type:"string"}

if pdb_input_type == "pdb_code":
    input_pdb = handle_pdb_input(pdb_input_type, pdb_code)
elif pdb_input_type == "upload":
    print("Please upload your PDB file:")
    # Assuming running in Colab
    input_pdb = handle_pdb_input(pdb_input_type)
elif pdb_input_type == "custom_path":
    input_pdb = handle_pdb_input(pdb_input_type, pdb_code, custom_path)


contigs = "100-100" #@param {type:"string"}
contig_length = "" #@param {type:"string"}
ligand = "OQO" #@param {type:"string"}
num_designs = 1 #@param {type:"integer"}
design_startnum = 0 #@param {type:"integer"}
output_prefix = "sample" #@param {type:"string"}
output_subfolder = "ligand_protein_motif" #@param {type:"string"}

deterministic = True #@param {type:"boolean"}
T = 25 #@param {type:"integer"}

# Split contigs string into list
contigs_list = contigs.split(',')

# Convert contigs list to string format for YAML
contigs_yaml = [f"{contig}" for contig in contigs_list]
contig_length = contig_length if contig_length else None

# Define the configuration dictionary based on the user inputs
config = {
    "inference": {
        "deterministic": deterministic,
        "input_pdb": input_pdb,
        "ligand": ligand,
        "num_designs": num_designs,
        "design_startnum": design_startnum,
        "ckpt_path": "./params/RFDiffusionAA_paper_weights.pt",
        "model_runner": "NRBStyleSelfCond"
    },
    "diffuser": {
        "T": T
    },
    "contigmap": {
        "contigs": contigs_yaml,
        "length": contig_length
    },
    "model": {"freeze_track_motif": "True"},
    "defaults": ["aa"]
}

# Specify additional options for the run function
show_last_n_lines = 1  # Show only the last 1 lines of stdout to avoid cluttering the notebook, above 1 is not working at the moment
save_stdout = True  # Save the stdout as a logfile in the output folder

# Call the run_rfdiffusion_all_atom function with the specified configuration and options
final_output_path = run_rfdiffusion_all_atom(config, output_subfolder=output_subfolder, output_prefix=output_prefix, show_last_n_lines=show_last_n_lines, save_stdout=save_stdout)

#@markdown After running the diffusion function, you can zip the last job's output for download:

#@markdown Run the following cell to zip and download the last job's output.

In [None]:
#@title Display 3D structure {run: "auto"}
animate = "interactive" #@param ["none", "movie", "interactive"]
color = "chain" #@param ["rainbow", "chain", "plddt"]
denoise = True
dpi = 100 #@param ["100", "200", "400"] {type:"raw"}

from colabdesign.shared.plot import pymol_color_list
from colabdesign.rf.utils import get_ca, get_Ls, make_animation
from string import ascii_uppercase, ascii_lowercase
import os
import ipywidgets as widgets
from IPython.display import display, HTML
import py3Dmol

alphabet_list = list(ascii_uppercase + ascii_lowercase)


# Construct the base output directory
base_output_dir = os.path.join(OUTPUT_DIR, output_subfolder)

def find_latest_output_dir(base_dir, prefix):
    """Find the latest output directory based on the prefix."""
    dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(prefix)]
    if not dirs:
        raise FileNotFoundError(f"No output directories found with prefix '{prefix}' in '{base_dir}'")
    latest_dir = sorted(dirs, key=lambda x: int(x.split('_')[-1]))[-1]
    return os.path.join(base_dir, latest_dir)

def plot_pdb(num=0):
    # Find the latest output directory
    latest_output_dir = find_latest_output_dir(base_output_dir, output_prefix)

    # Construct the path to the PDB file
    pdb_path = os.path.join(latest_output_dir, f"{output_prefix}_{num}.pdb")

    # Load the PDB file
    pdb_str = open(pdb_path, 'r').read()

    # Initialize the 3Dmol.js viewer
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
    view.addModel(pdb_str, 'pdb')

    # Apply color scheme
    if color == "rainbow":
        view.setStyle({'cartoon': {'color':'spectrum'}})
    elif color == "chain":
        # Example: Apply color by chain
        for n, chain, c in zip(range(len(contigs)), alphabet_list, pymol_color_list):
            view.setStyle({'chain': chain}, {'cartoon': {'color': c}})
            # If chain == B the visualization should be atoms
            if chain == "B":
                view.setStyle({'chain': chain}, {'stick': {}})
    else:
        # Example: Apply a custom color scheme
        view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}})

    # Zoom to fit and display the viewer
    view.zoomTo()
    view.show()



if num_designs > 1:
  output = widgets.Output()
  def on_change(change):
    if change['name'] == 'value':
      with output:
        output.clear_output(wait=True)
        plot_pdb(change['new'])
  dropdown = widgets.Dropdown(
      options=[(f'{k}',k) for k in range(num_designs)],
      value=0, description='design:',
  )
  dropdown.observe(on_change)
  display(widgets.VBox([dropdown, output]))
  with output:
    plot_pdb(dropdown.value)
else:
  plot_pdb()

In [None]:
#@title Package and download results
#@markdown If you are having issues downloading the result archive,
#@markdown try disabling your adblocker and run this cell again.
#@markdown  If that fails click on the little folder icon to the
#@markdown  left, navigate to file: `name.result.zip`,
#@markdown  right-click and select \"Download\"
#@markdown (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).
import shutil

def zip_last_job(final_output_path, counter):
    """
    Zip the last job's output directory for download.
    """
    base_output_path = os.path.join(OUTPUT_DIR, final_output_path)

    if final_output_path:
        output_path = f"{base_output_path}"
        shutil.make_archive(output_path, 'zip', output_path)
        return f"{output_path}.zip"
    else:
        print("No output directory found.")
        return None


# Assuming you're in a Jupyter notebook cell
from google.colab import files

# This cell should be run after the diffusion function to zip and download the output
zip_path = zip_last_job(final_output_path=final_output_path, counter=0)

if zip_path and is_colab():
    from google.colab import files
    files.download(zip_path)
else:
    print("Zip file path:", zip_path)
    print("Note: Automatic download is only supported in Google Colab.")