<a href="https://colab.research.google.com/github/ccccclw/ColabDesign/blob/main/af/examples/alphafolding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This notebook supports
- running iterative predictions with AlphaFold2 (monomer model 1,2) and visualization of structure predictions. For predictions that succesfully find the native state, the structure predictions before native state can possibly resemble protein folding intermediates.


In [1]:
#@title setup {"vertical-output":true,"form-width":"50%","display-mode":"form"}
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/ccccclw/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from Bio.PDB import *
import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_get_cb
from colabdesign.shared.plot import plot_pseudo_3D, make_animation, show_pdb
import pickle
from colabdesign import af
from google.colab import files
import numpy as np
from IPython.display import HTML
import jax.numpy as jnp
import jax
from scipy.special import softmax
import sys
import tqdm.notebook
import argparse
import matplotlib.pyplot as plt
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

##util functions
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"


def get_dgram(positions, num_bins=39, min_bin=3.25, max_bin=50.75):
  atom_idx = residue_constants.atom_order
  atoms = {k:positions[...,atom_idx[k],:] for k in ["N","CA","C"]}
  cb = _np_get_cb(**atoms, use_jax=False)
  dist2 = np.square(cb[None,:] - cb[:,None]).sum(-1,keepdims=True)
  lower_breaks = np.linspace(min_bin, max_bin, num_bins)
  lower_breaks = np.square(lower_breaks)
  upper_breaks = np.concatenate([lower_breaks[1:],np.array([1e8], dtype=jnp.float32)], axis=-1)
  return ((dist2 > lower_breaks) * (dist2 < upper_breaks)).astype(float)

def sample_gumbel(shape, eps=1e-10):                  
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)
 
def sample_uniform(shape, eps=1e-10): 
  """Sample from Uniform(0, 1)"""
  U = np.random.uniform(size=shape)
  return U + eps
 
from colabdesign.af.alphafold.common import residue_constants
def xyz_atom37(pdb_file):
  """
  Convert atom coordinates [num_atom, 3] from xyz read from file such as pdb to atom37 format.
  """
  atom37_order = residue_constants.atom_order
  parser = PDBParser()
  structure = parser.get_structure("A", pdb_file)
  atoms = list(structure.get_atoms())
  length = len(list(structure.get_residues()))
  atom37_coord = np.zeros((length, 37, 3))
  
  for atom in atoms:
    atom37_index = atom37_order[atom.get_name()]
    residue_index = atom.get_parent().id[1]
    atom37_coord[residue_index-1][atom37_index] = atom.get_coord()
  return atom37_coord

def sequence_to_one_hot(sequence):
    """
    Convert a sequence string into a one-hot encoding matrix of shape (N, 20),
    where N is the number of residues, and 20 is the number of amino acids.
    
    Parameters:
    - sequence: str, the input sequence of amino acids (e.g., "ACDE").
    
    Returns:
    - one_hot_matrix: np.ndarray, one-hot encoding matrix of shape (N, 20).
    """
    # Convert the sequence to a list of integers using aa_order dictionary
    aa_dict = residue_constants.restype_order
    seq_indices = [aa_dict.get(aa, -1) for aa in sequence]  # -1 for unknown AA
    
    # Ensure no unknown amino acids (-1) are present in the sequence
    if any(idx == -1 for idx in seq_indices):
        raise ValueError("Sequence contains invalid amino acid(s) not present in aa_order.")
    
    # Create a one-hot encoding matrix
    N = len(sequence)
    one_hot_matrix = np.eye(20)[seq_indices]
    
    return one_hot_matrix



UsageError: Line magic function `%%time` not found.


In [None]:
#@title input preparation {"vertical-output":true,"form-width":"50%","display-mode":"form"}
starting_seq = "" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
##default sequence is PDB:3GB1 if no sequence is provided
starting_seq = "MTYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE" if len(starting_seq) == 0 else starting_seq
length = len(starting_seq)
template = "None" #@param ["custom","None"]
if template == "custom":
  custom_template_path = os.path.join(template,f"template")
  os.makedirs(custom_template_path, exist_ok=True)
  uploaded = files.upload()
  for fn in uploaded.keys():
    os.rename(fn,os.path.join(custom_template_path,fn))
  template_path = os.path.join(custom_template_path,fn)


In [None]:
#@title initialize the model with parameters and run {"vertical-output":true,"form-width":"50%","display-mode":"form"}
clear_mem()
model_name = "model_1_ptm" #@param ["model_1_ptm", "model_2_ptm","both"]
use_multimer = False 
model_name = None if model_name == "both" else [model_name]
af_model = mk_afdesign_model(protocol="hallucination",
                             use_templates=True,
                             debug=True, 
                             model_names=model_name,
                             use_multimer=use_multimer)
af_model.prep_inputs(length=length)

mode = "dgram" #@param ["dgram","dgram_retrain"]
if "dgram" in mode:
  if "retrain" in mode and not use_multimer:
    # update distogram head to return all 39 bins
    af_model._cfg.model.heads.distogram.first_break = 3.25
    af_model._cfg.model.heads.distogram.last_break = 50.75
    af_model._cfg.model.heads.distogram.num_bins = 39
    af_model._model = af_model._get_model(af_model._cfg)
    from colabdesign.af.weights import __file__ as af_path
    template_dgram_head = np.load(os.path.join(os.path.dirname(af_path),'template_dgram_head.npy'))
    for k in range(len(af_model._model_params)):
      params = {"weights":jnp.array(template_dgram_head[k]),"bias":jnp.zeros(39)}
      af_model._model_params[k]["alphafold/alphafold_iteration/distogram_head/half_logits"] = params
  else:
    dgram_map = np.eye(39)[np.repeat(np.append(0,np.arange(15)),4)]
    dgram_map[-1,:] = 0 

iterations = 50 #@param [50, 100, 200] {type:"raw"}
use_dgram_noise = None #@param ["g","u","None"]
use_dropout = False #@param {type:"boolean"}
seqsep_mask =  0 #@param {type:"integer"}
num_recycles = 2 #@param {type:"integer"}

sample_models = True if model_name == "both" else False
dgram_noise_type = use_dgram_noise
use_dgram_noise = False if use_dgram_noise is None else True

L = sum(af_model._lengths)
af_model.restart(mode="gumbel")
af_model._inputs["rm_template_seq"] = False
# gather info about inputs
if "offset" in af_model._inputs:           
  offset = af_model._inputs
else:
  idx = af_model._inputs["residue_index"]
  offset = idx[:,None] - idx[None,:]

# initialize sequence
if len(starting_seq) > 1:
  af_model.set_seq(seq=starting_seq)
af_model._inputs["bias"] = np.zeros((L,20))

# initialize coordinates/dgram
af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                             "all_atom_mask":np.zeros((L,37)),
                             "all_atom_positions":np.zeros((L,37,3)),
                             "dgram":np.zeros((L,L,39))}

if template == "custom":
  xyz = xyz_atom37(pdb_file=template_path)
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  dgram = get_dgram(xyz)
  mask = np.abs(offset) > seqsep_mask
  af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]
  if use_dgram_noise:
    if dgram_noise_type == "g":   
      noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
    elif dgram_noise_type == 'u':  
      noise = sample_uniform(dgram.shape) * (1 - k/iterations)
      dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
plddts = []
print(f"running seq {starting_seq} with model: {'both' if model_name is None else model_name} for {iterations} steps")
for k in range(iterations):
  # noise
  if k > 0:
    dgram_xyz = get_dgram(xyz)
    dgram_prob = softmax(dgram_logits,-1)

    if mode == "xyz":
      dgram = dgram_xyz
    if mode == "dgram":
      dgram = dgram_prob @ dgram_map
      dgram[...,14:] = dgram_xyz[...,14:] * dgram_prob[...,-1:]
    if mode == "dgram_retrain":
      dgram = dgram_prob
    
    if use_dgram_noise:
      if dgram_noise_type == "g":   
        noise = sample_gumbel(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)
      elif dgram_noise_type == 'u':  
        noise = sample_uniform(dgram.shape) * (1 - k/iterations)
        dgram = softmax(np.log(dgram + 1e-8) + noise, -1)

    # add mask to avoid local contacts being fixed (otherwise there is a bias toward helix)
    mask = np.abs(offset) > seqsep_mask
    af_model._inputs["batch"]["dgram"] = dgram * mask[:,:,None]

  # prediction
  aux = af_model.predict(return_aux=True, verbose=False,
                        sample_models=sample_models,
                        dropout=use_dropout, num_recycles=num_recycles)
  plddt = aux["plddt"]
  plddts.append(np.average(plddt))
  seq = aux["seq"]["hard"][0].argmax(-1)   
  xyz = aux["atom_positions"].copy()
  dgram_logits = aux["debug"]["outputs"]["distogram"]["logits"] 
  
  # update inputs    
  af_model._inputs["batch"]["aatype"] = seq
  af_model._inputs["batch"]["all_atom_mask"][:,:4] = np.sqrt(plddt)[:,None]
  af_model._inputs["batch"]["all_atom_positions"] = xyz
  
  # save results
  af_model._save_results(aux)
  af_model._k += 1
  af_model.save_pdb(f"iter_{k}.pdb")

In [None]:
#@title visualization
fig,ax=plt.subplots(1,1,figsize=(7.4,2))
ax.scatter(range(len(plddts)),np.array(plddts)*100,s=12, color='grey', zorder=1)
ax.plot(np.array(plddts)*100,'darkorange',zorder=0)
ax.set_xlabel("Prediction iteration")
ax.set_ylabel("pLDDT")
ax.text(ax.get_xlim()[0]+(ax.get_xlim()[1]-ax.get_xlim()[0])*0.85,\
        ax.get_ylim()[0]+(ax.get_ylim()[1]-ax.get_ylim()[0])*0.05,f"recycle# {num_recycles}")
HTML(af_model.animate(dpi=80, interval=300))

In [None]:
#@title Visualize precalculated iterative structure predictions from PDB {"vertical-output":true,"form-width":"50%","display-mode":"form"}
!pip install plotly
import plotly
import plotly.express as px
#import nglview as nv
import plotly.graph_objects as go
import pandas as pd

In [None]:
#@title helpful functions for download data from zenodo and visualization {"vertical-output":true,"form-width":"50%","display-mode":"form"}
import os
import requests
import zipfile
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import shutil

def plot_embedding(embedding,ss,rmsd,pdb_id,rmsd_cutoff=1,selection=None,selection_min=0, selection_max=1):
    fig = go.Figure()
    # plt.figure(figsize=(20,20))
    embedding_df = pd.DataFrame()
    embedding_df['dim1'] = np.round(embedding[:, 0],2)
    embedding_df['dim2'] = np.round(embedding[:, 1],2)
    embedding_df['%alpha'] = np.round(ss[:,3],2)
    embedding_df['%beta'] = np.round(ss[:,2],2)
    embedding_df['%coil'] = np.round(ss[:,4],2)
    embedding_df['RMSD'] = np.round(rmsd,2)*10
    embedding_df['pdb_id'] = np.array([i[-5:-1].upper() for i in pdb_id])
    if selection:
        for key in embedding_df.keys():
            embedding_df[key] = embedding_df[key][np.where((embedding_df[selection]>selection_min) & (embedding_df[selection]<selection_max))[0]]
    fig = px.scatter(embedding_df,x='dim1',y='dim2',color='%alpha',
                     custom_data=[embedding_df['pdb_id'],
                                 embedding_df['%alpha'],
                                 embedding_df['%beta'],
                                 embedding_df['%coil'],
                                 embedding_df['RMSD']],
                     hover_data={'dim1':False,
                                 'dim2':False,
                                 'pdb_id':True,
                                 '%alpha': True,
                                 '%beta': True,
                                 '%coil': True,
                                 'RMSD': (':.2f')},
                    color_continuous_scale='RdBu')
    fig.update_layout(
    #     margin=dict(l=10, r=10, t=10, b=10),
        width=800,height=800,
    #     paper_bgcolor="LightSteelBlue",
    )
    fig.update_traces(marker=dict(size=10,
                                  line=dict(width=2,
                                            color='DarkSlateGrey')),
                      selector=dict(mode='markers'))
    # Add dropdown
    fig.update_xaxes(showspikes=True,spikecolor="black", spikesnap="cursor", spikemode="across")
    fig.update_yaxes(showspikes=True,spikecolor="black", spikesnap="cursor", spikemode="across")
    fig.update_layout(
    #     xaxis=dict(rangeslider=dict(visible=True)),
        updatemenus=[go.layout.Updatemenu(
                active=0,
                buttons=list([
                    dict(
                        args=[{"marker.color": [embedding_df["%alpha"]]}],
                        label="helix percentage",
                        method="restyle"
                    ),
                    dict(
                        args=[{"marker.color": [embedding_df["%beta"]]}],
                        label="beta percentage",
                        method="restyle"
                    ),
                    dict(
                        args=[{"marker.color": [embedding_df["%coil"]]}],
                        label="coil percentage",
                        method="restyle"
                    ),
                    dict(
                        args=[{"marker.color": [embedding_df["RMSD"]]}],
                        label="RMSD",
                        method="restyle"
                    )
                ]),
                direction="down",
                pad={"l": -30, "t": 1},
                showactive=True,
                x=0.1,
                xanchor="left",
                y=1.1,
                yanchor="top"
            ),
        ]
    )
    fig.update_traces(hovertemplate="PDB: %{customdata[0]}<br> \u03B1: %{customdata[1]}; \u03B2: %{customdata[2]}; C: %{customdata[3]}<br> RMSD: %{customdata[4]:.2f} ") #
    fig.show("notebook")
    # px.data.iris()

# Function to download part of a file in small chunks
def download_chunk(file_url, start, end, file_name, pbar, chunk_size=1024*1024):
    headers = {'Range': f'bytes={start}-{end}'}
    response = requests.get(file_url, headers=headers, stream=True)
    
    with open(file_name, 'r+b') as file:
        file.seek(start)
        
        for chunk in response.iter_content(chunk_size=chunk_size):
            if chunk:  # Filter out keep-alive new chunks
                file.write(chunk)
                pbar.update(len(chunk))

# Function to download a file using multiple threads, each handling a part of the file
def download_large_file_multithreaded(url, destination, num_threads=4, chunk_size=1024*1024):
    # Get the total file size
    response = requests.head(url)
    total_size = int(response.headers['content-length'])

    # Create an empty file with the appropriate size
    with open(destination, 'wb') as f:
        f.truncate(total_size)

    # Define the size for each thread to download
    part_size = total_size // num_threads

    # Set up progress bar
    pbar = tqdm(total=total_size, unit='B', position=0, leave=True, unit_scale=True, desc=destination)

    # Create a thread pool for parallel downloads
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = []
        for i in range(num_threads):
            start = i * part_size
            end = (i + 1) * part_size - 1 if i < num_threads - 1 else total_size - 1
            futures.append(executor.submit(download_chunk, url, start, end, destination, pbar, chunk_size))

        # Wait for all threads to finish
        for future in futures:
            future.result()

    pbar.close()

import zipfile
import os
import threading

# Function to extract a single file and update progress bar
def extract_file(zip_ref, file_info, destination, progress_bar):
    zip_ref.extract(file_info, destination)
    progress_bar.update(1)  # Update progress bar after extracting each file

# Multi-threaded decompression function with progress bar
def extract_zip_multithreaded(zip_path, destination, file_name=None, num_threads=4):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        file_list = zip_ref.infolist()
        if file_name != None:
          filelist = []
          for file in file_list:
            for filename in file_name:
              if filename in file.filename:
                print(file.filename)
                filelist.append(file)
          file_list = filelist
        
        # Create the destination folder if it doesn't exist
        os.makedirs(destination, exist_ok=True)

        # Initialize tqdm progress bar
        with tqdm(total=len(file_list), position=0, leave=True, desc="Decompressing", unit="file") as progress_bar:
            # Create threads to extract files in parallel
            threads = []
            for i in range(num_threads):
                # Split the file list for each thread
                part_files = file_list[i::num_threads]
                for file_info in part_files:
                    thread = threading.Thread(target=extract_file, args=(zip_ref, file_info, destination, progress_bar))
                    threads.append(thread)
                    thread.start()

            # Wait for all threads to finish
            for thread in threads:
                thread.join()


def download_zenodo(record_id,download_file=None,pdb=None):

    zenodo_url = f'https://zenodo.org/api/records/{record_id}'

    # Get record metadata
    response = requests.get(zenodo_url)
    response.raise_for_status()
    metadata = response.json()

    # Get all file URLs
    files = metadata['files']

    # Create a folder to store downloaded files
    # if os.path.exists('zenodo_downloads'):
    #     shutil.rmtree('zenodo_downloads')
    os.makedirs('zenodo_downloads', exist_ok=True)
    # Download and decompress each file in the record
    for file_info in files:
        file_url = file_info['links']['self']
        file_name = file_info['key']
        if file_name in download_file:
            file_path = os.path.join('zenodo_downloads', file_name)
            print(f'Downloading {file_name} with memory-efficient multi-threading...')
            download_large_file_multithreaded(file_url, file_path, num_threads=4, chunk_size=1024*1024)

            # Check if the file is a ZIP file and decompress it
            if zipfile.is_zipfile(file_path) and "trajs" not in file_path and file_path != '0_0_iter_0_pdb.zip':
                print(f'Decompressing {file_name}...')
                extract_zip_multithreaded(file_path, 'zenodo_downloads', file_name=pdb, num_threads=4)

                # with zipfile.ZipFile(file_path, 'r') as zip_ref:
                #     zip_ref.extractall('zenodo_downloads')
                os.remove(file_path)  # Optionally remove the zip file after extraction
            # elif "trajs" in file_path or file_path == '0_0_iter_0_pdb.zip':
            #     extract_zip_multithreaded(file_path, 'zenodo_downloads', file_name=pdb, num_threads=4)


print('Download files useful for both trajectory and embedding visualization.')
record_ids = {'6':'13766276','3':'13766281','4':'13777196','2':'13772757','18_and_embeddings':'13826566','57':'13788039'}
fig_1_6_pdbs = ['3GB1','1MI0','1HZ5','1KH0','1UBQ','2HDA']

download_file = ['rmsds_plddts_embeddings.zip','model_1_2_gap_0_6_per_residue_plddts.zip','0_0_iter_0_pdb.zip']
record_id = record_ids['18_and_embeddings']
download_zenodo(record_id=record_id,download_file=download_file)

print('Download and decompression complete.')


In [None]:
#@title visualize embeddings {"vertical-output":true,"form-width":"50%","display-mode":"form"}
visualize_embeddings = True #@param {type:"boolean"}
if visualize_embeddings:
    download_file = ['all_ss_gap0.npy', 'all_ss_gap6.npy', 'all_pdbs_gap0.npy', 'all_seq_length_gap0.npy', 'rmsds_plddts_embeddings.zip']
    record_id = record_ids['18_and_embeddings']
    download_zenodo(record_id=record_id,download_file=download_file)

all_ss = np.load("./zenodo_downloads/all_ss_gap0.npy")
all_rmsd = np.load("./zenodo_downloads/all_rmsd_model1_gap0.npy")
all_rmsd2 = np.load("./zenodo_downloads/all_rmsd_model2_gap0.npy")
all_seq_length = np.load("./zenodo_downloads/all_seq_length_gap0.npy")
all_rmsd = np.array(all_rmsd)
all_rmsd2 = np.array(all_rmsd2)
all_tmfile_pd=np.load("./zenodo_downloads/all_pdbs_gap0.npy")
all_EH = np.array([[(np.array([*i])=='E').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()),
                    (np.array([*i])=='H').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()),
                    (np.array([*i])=='E').sum()/len(i),
                    (np.array([*i])=='H').sum()/len(i),
                    (np.array([*(i.strip('C'))])=='C').sum()/len(i)] for i in all_ss])
all_H = [(np.array([*i])=='H').sum()/((np.array([*i])=='E').sum()+(np.array([*i])=='H').sum()) for i in all_ss]
import plotly.io as pio
pio.renderers.default = 'notebook'
import plotly.offline as pyo
pyo.init_notebook_mode(connected=True)
seq_length_threshold = 100
long_seq_index = np.where(all_seq_length>seq_length_threshold)[0]
all_X_embedded = np.load('./result/TSE_embedding_gap0.npy')
X_embedded = all_X_embedded[9]
plot_embedding(X_embedded[long_seq_index],all_EH[long_seq_index],all_rmsd2[long_seq_index],np.array(all_tmfile_pd)[long_seq_index],selection='RMSD',selection_min=0.,selection_max=3)


In [None]:
#@title visualize individual pdb {"vertical-output":true,"form-width":"50%","display-mode":"form"}
pdb_id = '3gb1'.upper() #@param {type:"str"}

if pdb_id is not None:
    if pdb_id in fig_1_6_pdbs:
        download_file = f"{pdb_id.lower()}.zip"
        record_id = record_ids['18_and_embeddings']
        download_zenodo(record_id,download_file=download_file)
    elif pdb_id[0] in ['6','3','4','2']:
        record_id = record_ids[pdb_id[0]]
        download_file = [f"model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip"]
        download_zenodo(record_id,download_file=download_file,pdb=pdb_id)
        extract_zip_multithreaded(f'zenodo_downloads/model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/0_0_iter_0_pdb.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/model_1_2_gap_0_6_per_residue_plddts.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
    elif pdb_id[0] in ['1','8']:
        record_id = record_ids['18_and_embeddings']
        download_file = [f"model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip"]
        download_zenodo(record_id,download_file=download_file,pdb=pdb_id)
        extract_zip_multithreaded(f'zenodo_downloads/model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/0_0_iter_0_pdb.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/model_1_2_gap_0_6_per_residue_plddts.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
    elif pdb_id[0] in ['5','7']:
        record_id = record_ids['57']
        download_file = [f"model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip"]
        download_zenodo(record_id,download_file=download_file,pdb=pdb_id)
        extract_zip_multithreaded(f'zenodo_downloads/model_1_2_gap_0_6_trajs_{pdb_id[0]}.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/0_0_iter_0_pdb.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
        extract_zip_multithreaded('zenodo_downloads/model_1_2_gap_0_6_per_residue_plddts.zip', 'zenodo_downloads', file_name=pdb_id, num_threads=4)
    else:
        print(f"{pdb_id} is not available.")

!pip install mdtraj
import glob
import mdtraj as md

traj = glob.glob(f"zenodo_downloads/*/{pdb_id}/*xtc")[0]
top = glob.glob(f"zenodo_downloads/*/{pdb_id}/{pdb_id}*pdb")[0]
traj = md.load(traj,top=top)
xyz = traj.xyz
seq = sequence_to_one_hot(traj.top.to_fasta()[0])
HTML(make_animation(seq, xyz=xyz, pae=None, dpi=80, interval=300))
        