In [39]:
#!pip install git+https://github.com/evolutionaryscale/esm
#!pip install py3Dmol
#!pip install umap-learn

In [3]:
import py3Dmol
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
import pickle
import os
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from huggingface_hub import login
from esm.sdk import client
import requests
from bs4 import BeautifulSoup
from Bio.PDB import PDBParser, Superimposer, PDBIO, Structure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import wandb
import tempfile
import plotly.graph_objects as go
from pathlib import Path
from urllib.parse import urljoin
import plotly.express as px
from scipy.spatial.distance import cdist
import scipy.stats as stats
from esm.sdk.api import (
    ESM3InferenceClient,
    ESMProtein,
    GenerationConfig,
    SamplingConfig
)
import umap
from sklearn.manifold import TSNE
from Bio.Align import substitution_matrices
# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login()
model: ESM3InferenceClient = ESM3.from_pretrained("esm3-open").to("cuda") # or "cpu"
url = "http://prodata.swmed.edu/ecod/af2_pdb/domain/"


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

data/entry_list_safety_29026.list:   0%|          | 0.00/1.60M [00:00<?, ?B/s]

data/1utn.pdb:   0%|          | 0.00/569k [00:00<?, ?B/s]

data/ParentChildTreeFile.txt:   0%|          | 0.00/595k [00:00<?, ?B/s]

data/esm3_entry.list:   0%|          | 0.00/1.93M [00:00<?, ?B/s]

hyperplanes_8bit_58641.npz:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

hyperplanes_8bit_68103.npz:   0%|          | 0.00/34.9M [00:00<?, ?B/s]

data/interpro2keywords.csv:   0%|          | 0.00/7.32M [00:00<?, ?B/s]

(…)ata/interpro_29026_to_keywords_58641.csv:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

data/keywords.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

(…)ord_vocabulary_safety_filtered_58641.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]

keyword_idf_safety_filtered_58641.npy:   0%|          | 0.00/469k [00:00<?, ?B/s]

data/tag_dict_4.json:   0%|          | 0.00/691k [00:00<?, ?B/s]

data/tag_dict_4_safety_filtered.json:   0%|          | 0.00/569k [00:00<?, ?B/s]

tfidf_safety_filtered_58641.pkl:   0%|          | 0.00/2.02M [00:00<?, ?B/s]

esm3_function_decoder_v0.pth:   0%|          | 0.00/1.30G [00:00<?, ?B/s]

(…)0_residue_annotations_gt_1k_proteins.csv:   0%|          | 0.00/109k [00:00<?, ?B/s]

esm3_structure_decoder_v0.pth:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

esm3_sm_open_v1.pth:   0%|          | 0.00/2.80G [00:00<?, ?B/s]

esm3_structure_encoder_v0.pth:   0%|          | 0.00/62.3M [00:00<?, ?B/s]

  state_dict = torch.load(


In [4]:
df = pd.read_csv('OMBB_data.csv')
df.head()

Unnamed: 0,id,strands,seq,seq_len
0,e1af6A1,18,VDFHGYARSGIGWTGSGGEQQCFQTTGAQSKYRLGNECETYAELKL...,421
1,e1kmoA2,22,IPQDFGIEAGVEGQLSPTSSQNNPKETHNLMVGGTADNGFGTALLY...,523
2,e1p4tA1,8,EGASGFYVQADAAHAKASSSLGSAKGFSPRISAGYRINDLRFAVDY...,155
3,e1prnA1,16,EISLNGYGRFGLQYVEDRGVGLEDTIISSRLRINIVGTTETDQGVT...,289
4,e1qd5A1,12,AVRGSIIANMLQEHDNPFTLYPYDTNYLIYTQTSDLNKEAIASYDW...,257


In [5]:
def download_pdb_file(id, base_url, output_filename):
    try:
        # Request the page
        response = requests.get(base_url + id)
        response.raise_for_status()  # Raise HTTPError for bad responses
        soup = BeautifulSoup(response.text, 'html.parser')

        # Find the "Coordinates" link under "Download files"
        link = soup.find('a', string="Coordinates")
        if link:
            # Handle relative URL by combining with the base URL
            href = urljoin(base_url, link['href'])

            # Download the file
            coord_response = requests.get(href)
            coord_response.raise_for_status()  # Raise HTTPError for bad responses
            with open(output_filename, 'wb') as file:
                file.write(coord_response.content)
            print(f'Coordinates file downloaded successfully as {output_filename}')
        else:
            print(f'No Coordinates link found for {id}')
    except Exception as e:
        print(f'Error: {e}')



In [6]:
def getPdbId(id, url):
    try:
        response = requests.get(url + id)
        soup = BeautifulSoup(response.text, 'html.parser')
        pdb_id = None
        link = soup.find('a', title="Link to PDB")
        if link:
            href = link['href']
            pdb_id = href.split("structureId=")[-1]
        if pdb_id is None:
            print(f'No PDB ID found for {id}')
        return pdb_id
    except Exception as e:
        print(f'Error: {e}')
        return None

In [7]:
def get_chain_and_range(id, url):
  try:
    # Request the page
    response = requests.get(url + id)
    response.raise_for_status()  # Raise HTTPError for bad responses

    # Parse the HTML
    soup = BeautifulSoup(response.text, "html.parser")

    # Find the <td> tag next to the <th> with the text "Range:"
    range_td = soup.find("th", string="Range:").find_next_sibling("td")
    if range_td:
      # Extract the range text
      range_text = range_td.text.strip()

      # Parse the ranges into a structured format
      ranges = []
      for segment in range_text.split(","):
          chain, residues = segment.split(":")
          start, end = map(int, residues.split("-"))
          ranges.append({"chain": chain, "start": start, "end": end})
      return ranges
  except Exception as e:
    print(f'Error: {e}')
    return None

In [8]:

# error in files e2wjrA1, e3bryA1, e3qq2A1, e3sy7A2, e3szvA1, e4afkA1, e4c00A4, e4cu4A2, e4fqeA1, e4frxA1, e4fspA1, e4q35A2, e4rdrA2, e4rjwA1, e5dl5A1, e5fokA1, e5fp1A1, e5fq8B2, e5fr8A2, e5fvnA1,
#e5ldvA1, e5m9bA1, e5mdoA1, e5o65A1, e5t3rD1, e6e4vA1, e6ehbA1, e6ehdA1, e6fokA1, e6gieA1, e6i96A1, e6r2qB1, e6sljA1, e6ucuA1, e6v81A2

def load_protein_chains(cache_path):
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as file:
            protein_chains = pickle.load(file)
        print("Loaded list")
    else:
        protein_chains = []
        for idx, row in tqdm(df.iterrows(), total=len(df), desc='Fetching ProteinChains'):
            id = row['id']
            path = f"pdb_files/{id}.pdb"
            if not os.path.exists(path):
                download_pdb_file(id, url, path)
            try:
                out_membraine_chain = ProteinChain.from_pdb(path)
            except ValueError as e:
                print(f"ValueError while processing {id} at {path}: {e}")
                # Attempt to fetch the PDB from an alternative source if ValueError occurs
                pdb_id = getPdbId(id, url)

                # Get chain and ranges from ECOD
                range_data = get_chain_and_range(id, url)

                # Load pdb from RCSB and filter it using chain and range data
                filtered_residues = []
                residues = []
                for range in range_data:
                  out_membraine_chain = ProteinChain.from_rcsb(pdb_id, range['chain'])
                  start, end = range['start'], range['end']
                  for residue in out_membraine_chain.residue_index:
                      if start <= residue <= end:
                          filtered_residues.append(residue)
                      residues.append(residue)
                  filtered_residues_idx = [residues.index(value) for value in filtered_residues if value in residues]

                out_membraine_chain = out_membraine_chain[filtered_residues_idx]

            except Exception as e:
                print(f"MAX ERROR!!!!. Error while processing {id} at {path}: {e}")
            protein_chains.append(out_membraine_chain)

        with open(cache_path, 'wb') as file:
            pickle.dump(protein_chains, file)
        print("List saved successfully!")

    return protein_chains

In [9]:
def log_py3Dmol_to_wandb(view, pdb_id):
    temp_html_path = Path("temp_protein_view.html")

    # Write HTML content using a file handle
    with temp_html_path.open('w', encoding='utf-8') as f:
        view.write_html(f, fullpage=True)

    # Read the HTML content
    with temp_html_path.open('r', encoding='utf-8') as f:
        html_content = f.read()

    # Log to W&B as HTML
    wandb.log({
        pdb_id: wandb.Html(html_content)
    })

    # Clean up temporary file
    temp_html_path.unlink()

In [10]:
def view_protein_chain(protein_chain,id):
  view = py3Dmol.view(width=500, height=500)

  # py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
  pdb_str = protein_chain.to_pdb_string()
  # Load the PDB string into the `py3Dmol` view object
  view.addModel(pdb_str, "pdb")
  # Set the style of the protein chain
  view.setStyle({"cartoon": {"color": "spectrum"}})
  # Zoom in on the protein chain
  view.zoomTo()
  # Display the protein chain
  #view.show()
  log_py3Dmol_to_wandb(view, id)


In [11]:
def create_masked_protein_chain(protein_chain, mask_percent=0.1):
  mask_pos = int(len(protein_chain.sequence) * (1-mask_percent))
  mask_amount = len(protein_chain.sequence) - mask_pos
  # Create a mask for the sequence
  #print(f"Masking {mask_percent * 100}% ({mask_amount}) chars at the end of the sequence.")
  sequence_prompt = protein_chain.sequence[:mask_pos] + ''.join(['_'] * mask_amount)
  #print("Sequence prompt:", sequence_prompt)
  return sequence_prompt, mask_pos

In [12]:
def view_masked_protein_chain(inds, pdb_str, id):
  view = py3Dmol.view(width=500, height=500)
  view.addModel(pdb_str, format='pdb')
  view.setStyle({"cartoon": {"color": "lightgrey"}})
  view.addStyle({"resi": (inds + 1).tolist()}, {"cartoon": {"color": "cyan"}})
  view.zoomTo()
  #view.show()
  log_py3Dmol_to_wandb(view, id)

In [13]:
def calc_seq_similarity(sequence1, sequence2):
  blosum62 = substitution_matrices.load("BLOSUM62")

  if len(sequence1) != len(sequence2):
        raise ValueError("Sequences must have the same length!")

  blosum62 = substitution_matrices.load("BLOSUM62")

  # Calculate similarity
  similarity_score = sum(blosum62[(res_seq1, res_seq2)] for res_seq1, res_seq2 in zip(sequence1, sequence2))

  # Normalize score
  normalized_score = similarity_score / len(sequence2) if sequence2 else 0

  return similarity_score, normalized_score

In [14]:
def calc_seq_identity(sequence1, sequence2):
  if len(sequence1) != len(sequence2):
    raise ValueError("Sequences must have the same length")
  else:
    return sum(1 for a, b in zip(sequence1, sequence2) if a == b) / len(sequence1)

In [15]:
def predict(sequence_prompt, prior_sequence, sequence_generation_config, structure_prediction_config):

  protein = ESMProtein(sequence=sequence_prompt)
  sequence_generation = model.generate(protein, sequence_generation_config)
  #print("Sequence Prompt:\n\t", protein.sequence)
  #print("Generated sequence:\n\t", sequence_generation.sequence)

  # gets embeddings for the sequence generation
  protein_tensor_gen = model.encode(sequence_generation)
  output_gen = model.forward_and_sample(
      protein_tensor_gen, SamplingConfig(return_per_residue_embeddings=True)
  )
  embeddings_sequence_generation = output_gen.per_residue_embedding

  # generate structure
  structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
  structure_prediction = model.generate(
      structure_prediction_prompt, structure_prediction_config
  )
  # get structure embeddings
  protein_tensor_struct = model.encode(structure_prediction)
  output_struct = model.forward_and_sample(
      protein_tensor_struct, SamplingConfig(return_per_residue_embeddings=True)
  )
  embeddings_structure_prediction = output_struct.per_residue_embedding


  # generate structure for PRIOR
  structure_prior = ESMProtein(sequence=prior_sequence)
  # get structure embeddings for PRIOR
  protein_tensor_struct_prior = model.encode(structure_prior)
  output_prior = model.forward_and_sample(
      protein_tensor_struct_prior, SamplingConfig(return_per_residue_embeddings=True)
  )
  embeddings_structure_prior = output_prior.per_residue_embedding

  #get the average of all embeddings
  embeddings_sequence_generation = torch.mean(embeddings_sequence_generation, dim=0)
  embeddings_structure_prediction = torch.mean(embeddings_structure_prediction, dim=0)
  embeddings_structure_prior = torch.mean(embeddings_structure_prior, dim=0)

  del structure_prediction_prompt
  torch.cuda.empty_cache()
  return structure_prediction, embeddings_sequence_generation, embeddings_structure_prediction,embeddings_structure_prior

In [16]:
def view_aligned_structures(pdb1, pdb2, gt_inds, pred_inds, id):
  # Display the aligned structures using py3Dmol
  view = py3Dmol.view(width=1000, height=500)
  view.addModel(pdb1, "pdb")
  view.setStyle({'model': 0}, {"cartoon": {"color": "lightgrey"}})
  view.addStyle({"resi": (gt_inds + 1).tolist()}, {"cartoon": {"color": "cyan"}})
  view.addModel(pdb2, "pdb")
  view.setStyle({'model': 1}, {"cartoon": {"color": "lightgreen"}})
  view.addStyle({"resi": (pred_inds + 1).tolist()}, {"cartoon": {"color": "cyan"}})
  view.zoomTo()
  #view.show()
  log_py3Dmol_to_wandb(view, id)

In [17]:
def view_side_by_side_structures(pdb1, pdb2, gt_inds, pred_inds, id):
    # Display the side-by-side structures using py3Dmol
    view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
    view.addModel(pdb1, "pdb", viewer=(0, 0))
    view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
    view.addStyle({"resi": (gt_inds + 1).tolist()}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
    view.addModel(pdb2, "pdb", viewer=(0, 1))
    view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
    view.addStyle({"resi": (pred_inds + 1).tolist()}, {"cartoon": {"color": "cyan"}}, viewer=(0, 1))
    view.zoomTo()
    #view.show()
    log_py3Dmol_to_wandb(view, id)


In [18]:
def getGenerationConfigs(sequence_prompt):
    sequence_generation_config = GenerationConfig(
        track="sequence",
        num_steps=sequence_prompt.count("_")
        // 4,
        temperature=0.5,
        top_p = 1,
        schedule='cosine'
    )
    structure_prediction_config = GenerationConfig(
        track="structure",
        num_steps=len(sequence_prompt) // 10,
        temperature=0.7,
        top_p = 1,
        schedule='cosine'
    )
    return sequence_generation_config, structure_prediction_config

In [19]:
def is_closed_protein(protein_chain, threshold=8):
    """
    Calculate the Euclidean distance between the first and last residue's CA atoms return if its closed
    """

    # Get coordinates of first and last residue CA atoms
    first_residue = protein_chain.atom37_positions[0,1,:]  # first CA ATOMS
    last_residue = protein_chain.atom37_positions[-1,1,:] # last CA ATOM

    # Euclidean distance
    distance = np.linalg.norm(first_residue - last_residue)

    is_closed = distance < threshold

    if is_closed:
        return 'Closed'
    return 'Open'

In [20]:
def find_interesting_cases(cosine_distances, rmsd_results, labels, gt_barrel,
                           pred_barrel, seq_identities, seq_similarities,
                           norm_seq_similarities, cosine_threshold=0.1, rmsd_threshold=8):
    interesting_cases = []
    for i in range(len(cosine_distances)):
        for j in range(len(cosine_distances[i])):
            if cosine_distances[i][j] < cosine_threshold and rmsd_results[j] > rmsd_threshold:
                interesting_cases.append([
                    labels[j],
                    cosine_distances[i][j],
                    rmsd_results[j],
                    gt_barrel[j],
                    pred_barrel[j],
                    seq_identities[j],
                    seq_similarities[j],
                    norm_seq_similarities[j]
                ])
    # Sort by the ratio of RMSD to cosine distance
    interesting_cases.sort(key=lambda x: x[2] / x[1], reverse=True)
    wandb.log({
            "interesting_cases": wandb.Table(
                data=interesting_cases,
                columns=["Protein ID", "Cosine Distance", "RMSD", "gt closed/open",
                         "pred closed/open", "Sequence Identity", "Sequence Similarity",
                         "Norm. Sequence Similarity"]
            )
    })

In [21]:
def analyze_barrel_predictions(labels,open_or_closed_pred, rmsd_results, open_or_closed_gt, rmsd_threshold=8):
    mismatched_data = []

    for label,pred, gt, rmsd in zip(labels,open_or_closed_pred, open_or_closed_gt, rmsd_results):
        if pred != gt and rmsd > rmsd_threshold:
            mismatched_data.append([label,pred, gt, rmsd])

    wandb.log({
        "high_rmsd_mismatches": wandb.Table(
            data=mismatched_data,
            columns=["ID","Predicted Type", "Ground Truth", "RMSD"]
        )
    })

In [37]:
def log_summary_statistics(rmsd_results, cosine_distance, correlations):
    mean = np.mean(rmsd_results)
    std = np.std(rmsd_results)
    min_val = np.min(rmsd_results)
    max_val = np.max(rmsd_results)
    median = np.median(rmsd_results)
    min_consine = np.min(cosine_distance)
    max_consine = np.max(cosine_distance)
    mean_consine = np.mean(cosine_distance)
    median_consine = np.median(cosine_distance)


    wandb.run.summary.update({
        "mean_rmsd": mean,
        "std_rmsd": std,
        "min_rmsd": min_val,
        "max_rmsd": max_val,
        "median_rmsd": median,
        "mean_cosine_distance": mean_consine,
        "min_cosine_distance": min_consine,
        "max_cosine_distance": max_consine,
        "median_cosine_distance": median_consine,
        "pearson_correlation": correlations["Pearson Correlation"],
        "spearman_correlation": correlations["Spearman Correlation"],
    })

def log_histogram(rmsd_results):
    fig = go.Figure(data=[go.Histogram(x=rmsd_results, nbinsx=30)])
    with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as tmp_file:
        html_path = tmp_file.name
        fig.write_html(html_path)
        wandb.log({"Histogram of rmsd values": wandb.Html(html_path)})

def log_scatter_plot(x, y, ids, x_title, y_title, plot_title, log_name):
    fig = go.Figure(data=[go.Scatter(
        x=x,
        y=y,
        mode='markers',
        marker=dict(size=10, color=y, colorscale='Viridis', colorbar=dict(title='RMSD')),
        text=ids
    )])

    fig.update_layout(
        title=plot_title,
        xaxis_title=x_title,
        yaxis_title=y_title
    )

    with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as tmp_file:
        html_path = tmp_file.name
        fig.write_html(html_path)
        wandb.log({log_name: wandb.Html(html_path)})

def log_umap_plot(embeddings, labels, ids, plot_title, log_name, useBuckets = True):
    if useBuckets:
        # Create 5 buckets using percentiles and get the bins
        buckets, bin_edges = pd.qcut(labels, q=5, retbins=True, duplicates="drop")

        # Create bucket labels from the bin edges
        bucket_labels = [f'{bin_edges[i]:.2f}-{bin_edges[i+1]:.2f}'
                        for i in range(len(bin_edges)-1)]

        # Map each value to its bucket label
        label_mapping = dict(zip(sorted(set(buckets)), bucket_labels))
        labels = [label_mapping[val] for val in buckets]
    traces = []
    for unique_id in sorted(set(labels)):
        # Filter points for this specific bucket/ID
        mask = [i for i, id_val in enumerate(labels) if id_val == unique_id]

        trace = go.Scatter(
            x=embeddings[mask, 0],
            y=embeddings[mask, 1],
            mode='markers',
            name=f'{unique_id}',
            marker=dict(size=6),
            text=[f'{ids[i]}, {labels[i]}' for i in mask],
            hoverinfo='text'
        )
        traces.append(trace)

    fig = go.Figure(data=traces)

    fig.update_layout(
        title=plot_title,
        xaxis_title='TSNE Component 1',
        yaxis_title='TSNE Component 2',
        legend_title_text='Value Ranges',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01,
            font=dict(size=8),
            itemsizing='constant',
            traceorder='grouped',
            itemwidth=30
        ),
        legend_itemwidth=30
    )

    with tempfile.NamedTemporaryFile(suffix='.html', delete=False) as tmp_file:
        html_path = tmp_file.name
        fig.write_html(html_path)
        wandb.log({log_name: wandb.Html(html_path)})

def log_heatmaps(df):
    df_sorted_by_length = df.sort_values(by="seq_len", ascending=False)
    x_axis = df_sorted_by_length['seq_len']

    # Create heatmaps
    def create_heatmap(z, title):
      return go.Figure(
          data=go.Heatmap(
              z=[z],
              x=x_axis,
              colorscale="Viridis",
              colorbar_title="Score",
          )
      ).update_layout(
          title=title,
          xaxis_title="Sequence Length",
          yaxis=dict(showticklabels=False),  # Move showticklabels here
          template="plotly",
      )

    # Identity Heatmap
    heatmap_identity = create_heatmap(df_sorted_by_length["seq_identity"], "Identity Heatmap")
    wandb.log({"Identity Heatmap": wandb.Html(heatmap_identity.to_html())})

    # Similarity Heatmap
    heatmap_similarity = create_heatmap(df_sorted_by_length["seq_similarity"], "Similarity Heatmap")
    wandb.log({"Similarity Heatmap": wandb.Html(heatmap_similarity.to_html())})

    # Normalized Similarity Heatmap
    heatmap_normalized_similarity = create_heatmap(
        df_sorted_by_length["norm_seq_similarity"], "Normalized Similarity Heatmap"
    )
    wandb.log({"Normalized Similarity Heatmap": wandb.Html(heatmap_normalized_similarity.to_html())})

def log_corrolation_plot(df, correlations):
  pearson_corr = correlations["Pearson Correlation"]
  spearman_corr = correlations["Spearman Correlation"]

  fig = px.scatter(df, x='seq_len', y='rmsd',
                  trendline="ols",  # Adds regression line
                  title=f'Correlation between Sequence Length and RMSD\nPearson: {pearson_corr:.2f}, Spearman: {spearman_corr:.2f}',
                  labels={"sequence_length": "Sequence Length", "rmsd": "RMSD"})

  wandb.log({"scatter_plot": wandb.Html(fig.to_html())})

def wand_logs(df, sequence_embeddings, structure_embeddings, all_priors, labels, open_or_closed_gt, open_or_closed_pred, correlations):
    rmsd_results = df['rmsd'].to_list()
    ids = df['id'].to_list()
    seq_len = df['seq_len'].to_list()
    strands = df['strands'].to_list()
    seq_identities = df['seq_identity'].to_list()
    seq_similarities = df['seq_similarity'].to_list()
    norm_seq_similarities = df['norm_seq_similarity'].to_list()

    seq_len_per_embedding = []
    strand_per_embedding = []
    rmsd_per_embedding = []
    for id in labels:
        seq_length = df[df['id'] == id]['seq_len'].values[0].astype(int)
        strand_value = df[df['id'] == id]['strands'].values[0].astype(int)
        rmsd_value = df[df['id'] == id]['rmsd'].values[0].astype(float)
        seq_len_per_embedding.append(seq_length)
        strand_per_embedding.append(strand_value)
        rmsd_per_embedding.append(rmsd_value)


    protein_metrics = sorted(list(zip(ids, rmsd_results, seq_len, strands,
                                      open_or_closed_gt, open_or_closed_pred,
                                      seq_identities, seq_similarities,
                                      norm_seq_similarities)), key= lambda x: x[2]) # sort by seq length
    wandb.log({
        "rmsd_results": wandb.Table(
            data=protein_metrics,
            columns=["ID", "RMSD", "Sequence Length", "Strands", "Open/Closed GT",
                     "Open/Closed Pred", "Sequence Identity", "Sequence Similarity",
                     "Norm. Sequence Similarity"]
        )
    })

    log_histogram(rmsd_results)

    log_corrolation_plot(df, correlations)

    log_scatter_plot(seq_len, rmsd_results, ids, "Sequence Length", "RMSD Value", "RMSD Values by sequence length", "RMSD Scatter plot")
    log_scatter_plot(strands, rmsd_results, ids, "Number of Strands", "RMSD Value", "RMSD Values by number of strands", "RMSD Scatter plot")

    log_heatmaps(df)

    sequence_reducer = TSNE(n_components=2, perplexity=5)
    sequence_emb = sequence_reducer.fit_transform(sequence_embeddings)

    structure_reducer = TSNE(n_components=2, perplexity=5)
    structure_emb = structure_reducer.fit_transform(structure_embeddings)

    log_umap_plot(sequence_emb, seq_len_per_embedding, labels, "Protein Sequence Embedding Colored By Sequence Len", "Embeddings Sequence")
    log_umap_plot(sequence_emb, strand_per_embedding, labels, "Protein Sequence Embeddings Colored By Strands", "Embeddings Sequence")

    log_umap_plot(structure_emb, seq_len_per_embedding,labels,"Protein Structure Embeddings Colored By Sequence Len", "Embeddings Structure")
    log_umap_plot(structure_emb, strand_per_embedding,labels,"Protein Structure Embeddings Colored By Strands", "Embeddings Structure")

    log_umap_plot(sequence_emb, rmsd_per_embedding, labels,"Protein Sequence Embeddings Colored By RMSD", "Embeddings Sequence")
    log_umap_plot(structure_emb, rmsd_per_embedding, labels,"Protein Structure Embeddings Colored By RMSD", "Embeddings Structure")

    combined_embeddings = np.concatenate((structure_embeddings, all_priors))
    prior_labels = ["Original"] * len(all_priors)
    struct_labels = ["Generated"] * len(structure_embeddings)
    all_labels = struct_labels + prior_labels
    structure_reducer = TSNE(n_components=2, perplexity=5)
    combined_embs = structure_reducer.fit_transform(combined_embeddings)

    id_labels = labels + labels

    log_umap_plot(combined_embs, all_labels, id_labels, "Protein Structure Original vs Generated Embeddings", "Embeddings Space", useBuckets=False)

    cosine_distances = cdist(all_priors, structure_embeddings, metric='cosine')

    find_interesting_cases(cosine_distances, rmsd_results, labels, open_or_closed_gt,
                           open_or_closed_pred, seq_identities, seq_similarities,
                           norm_seq_similarities)

    analyze_barrel_predictions(labels, open_or_closed_pred, rmsd_results, open_or_closed_gt)

    log_summary_statistics(rmsd_results, cosine_distances, correlations)


In [23]:
def calc_resi_pos(protein_chain, mask_percantage):
  # Get actual residue values
  residues = protein_chain.residue_index

  total_residues = len(residues)
  mask_pos = int(total_residues * (1-mask_percantage))

  return residues[0], residues[mask_pos], residues[total_residues-1]

In [24]:
def calculate_correlation(chain_lengths, rmsds):
    if len(chain_lengths) != len(rmsds):
        raise ValueError("Chain lengths and RMSDs must have the same number of elements!")

    # Pearson correlation
    pearson_corr, pearson_p = stats.pearsonr(chain_lengths, rmsds)

    # Spearman correlation
    spearman_corr, spearman_p = stats.spearmanr(chain_lengths, rmsds)

    return {
        "Pearson Correlation": pearson_corr,
        "Pearson p-value": pearson_p,
        "Spearman Correlation": spearman_corr,
        "Spearman p-value": spearman_p,
    }


In [38]:
number_of_runs = 1
for i in range(number_of_runs):
    print("Run number: ", i)
    mask_percantage = 0.1
    wandb.init(project="DFold", config={"mask_percantage": mask_percantage})
    cache_path = 'protein_chains.pkl'
    protein_chains = load_protein_chains(cache_path)
    rmsd_results = []
    ids = df['id'].to_list()
    all_sequence_embeddings = []
    all_structure_embeddings = []
    all_priors = []
    labels = []
    open_or_closed_gt = []
    open_or_closed_pred = []
    seq_identities = []
    seq_similarities = []
    norm_seq_similarities = []
    for protein_chain,id in zip(protein_chains, ids):
        print("id", id)
        #print("Protein sequence length: {}".format(len(protein_chain)))
        #print("Sequence: {}".format(protein_chain.sequence))

        # View loaded protein chain
        #print("Loaded protein:")
        view_protein_chain(protein_chain, id)
        masked_protein, mask_pos = create_masked_protein_chain(protein_chain, mask_percent=mask_percantage)

        gt_resi_start, gt_resi_mask_start, gt_resi_end = calc_resi_pos(protein_chain, mask_percantage)
        gt_inds = np.arange(gt_resi_start-1, gt_resi_mask_start-1)
        gt_masked_inds = np.arange(gt_resi_mask_start-1, gt_resi_end-1)

        # View masked protein
        #print("Masked protein:")
        pdb_str = protein_chain.to_pdb_string()
        view_masked_protein_chain(gt_inds, pdb_str, id)

        # get generation configs
        sequence_generation_config, structure_prediction_config = getGenerationConfigs(masked_protein)

        # add to wandb config
        wandb.config.update({
            "sequence_num_steps": sequence_generation_config.num_steps,
            "sequence_temperature": sequence_generation_config.temperature,
            "sequence_top_p": sequence_generation_config.top_p,
            "sequence_schedule": sequence_generation_config.schedule,
            "structure_num_steps": structure_prediction_config.num_steps,
            "structure_temperature": structure_prediction_config.temperature,
            "structure_top_p": structure_prediction_config.top_p,
            "structure_schedule": structure_prediction_config.schedule
        })

        # get the prior
        prior_sequence = protein_chain.sequence

        # Use ESM3 to predict protein structure of the masked protein
        structure_prediction, sequence_embeddings, structure_embeddings, prior_embeddings = predict(masked_protein, prior_sequence, sequence_generation_config, structure_prediction_config)

        labels.extend([id])
        # Convert the structure prediction to a ProteinChain object
        structure_prediction_chain = structure_prediction.to_protein_chain()

        pred_resi_start, pred_resi_mask_start, pred_resi_end = calc_resi_pos(structure_prediction_chain, mask_percantage)
        pred_inds = np.arange(pred_resi_start-1, pred_resi_mask_start-1)
        pred_masked_inds = np.arange(pred_resi_mask_start-1, pred_resi_end-1)

        # Align the generated structure with the original structure using the non-masked sequence
        aligned_chain = structure_prediction_chain.align(
            protein_chain, mobile_inds=pred_inds, target_inds=pred_inds)

        # View aligned structures
        pdb1 = protein_chain.to_pdb_string()
        pdb2 = aligned_chain.to_pdb_string()
        view_aligned_structures(pdb1, pdb2, gt_inds, pred_inds, id)

        # Calculate RMSD on the masked part
        view_side_by_side_structures(pdb1, pdb2, gt_inds, pred_inds, id)
        crmsd_masked = aligned_chain.rmsd(protein_chain, mobile_inds=pred_masked_inds,
                                        target_inds=pred_masked_inds, only_compute_backbone_rmsd=True)
        # print("RMSD: ", crmsd_masked)

        # Calculate sequence identity
        sequence_identity = calc_seq_identity(protein_chain.sequence[pred_masked_inds[0]:],
                                              aligned_chain.sequence[pred_masked_inds[0]:])
        # print("Sequence identity: ", sequence_identity)

        # Calculate sequence similarity (using BLOSUM62)
        seq_similarity, norm_seq_similarity = calc_seq_similarity(protein_chain.sequence[pred_masked_inds[0]:],
                                                                  aligned_chain.sequence[pred_masked_inds[0]:])
        # print("Sequence similarity: ", seq_similarity)
        # print("Normalized sequence similarity: ", norm_seq_similarity)

        open_or_closed_gt.append(is_closed_protein(protein_chain))
        open_or_closed_pred.append(is_closed_protein(structure_prediction_chain))
        rmsd_results.append(crmsd_masked)
        all_sequence_embeddings.append(sequence_embeddings)
        all_structure_embeddings.append(structure_embeddings)
        all_priors.append(prior_embeddings)
        seq_identities.append(sequence_identity)
        seq_similarities.append(seq_similarity)
        norm_seq_similarities.append(norm_seq_similarity)
        del aligned_chain, structure_prediction_chain, protein_chain

    df['rmsd'] = rmsd_results
    df['seq_identity'] = seq_identities
    df['seq_similarity'] = seq_similarities
    df['norm_seq_similarity'] = norm_seq_similarities

    # Calculate corrolation between sequence length and RSMD
    correlations = calculate_correlation(df['seq_len'], rmsd_results)

    all_sequence_embeddings = torch.stack(all_sequence_embeddings).detach().cpu().numpy()
    all_structure_embeddings = torch.stack(all_structure_embeddings).detach().cpu().numpy()
    all_priors = torch.stack(all_priors).detach().cpu().numpy()

    wand_logs(df, all_sequence_embeddings, all_structure_embeddings, all_priors, labels, open_or_closed_gt, open_or_closed_pred, correlations)
    wandb.finish()

    #df.to_csv('OMBB_data_crmsd.csv', index=False)




Run number:  0
Loaded list
id e1af6A1


100%|██████████| 10/10 [00:00<00:00, 12.81it/s]
100%|██████████| 42/42 [00:03<00:00, 12.83it/s]

`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



id e1kmoA2


100%|██████████| 13/13 [00:01<00:00, 12.82it/s]
100%|██████████| 52/52 [00:04<00:00, 12.77it/s]


id e1p4tA1


100%|██████████| 4/4 [00:00<00:00, 12.14it/s]
100%|██████████| 15/15 [00:01<00:00, 12.59it/s]


id e1prnA1


100%|██████████| 7/7 [00:00<00:00, 12.34it/s]
100%|██████████| 28/28 [00:02<00:00, 12.25it/s]


id e1qd5A1


100%|██████████| 6/6 [00:00<00:00, 12.56it/s]
100%|██████████| 25/25 [00:01<00:00, 12.76it/s]


id e1qj8A1


100%|██████████| 3/3 [00:00<00:00, 12.49it/s]
100%|██████████| 14/14 [00:01<00:00, 12.67it/s]


id e1qjpA1


100%|██████████| 3/3 [00:00<00:00, 11.89it/s]
100%|██████████| 13/13 [00:01<00:00, 12.66it/s]


id e1t16A1


100%|██████████| 10/10 [00:00<00:00, 12.91it/s]
100%|██████████| 42/42 [00:03<00:00, 13.00it/s]


id e1tlyA1


100%|██████████| 6/6 [00:00<00:00, 12.66it/s]
100%|██████████| 25/25 [00:01<00:00, 12.79it/s]


id e1xkwA1


100%|██████████| 13/13 [00:00<00:00, 13.06it/s]
100%|██████████| 52/52 [00:04<00:00, 12.78it/s]


id e2ervA1


100%|██████████| 3/3 [00:00<00:00, 12.38it/s]
100%|██████████| 15/15 [00:01<00:00, 12.71it/s]


id e2f1vA1


100%|██████████| 4/4 [00:00<00:00, 12.55it/s]
100%|██████████| 18/18 [00:01<00:00, 12.71it/s]


id e2fgrA1


100%|██████████| 8/8 [00:00<00:00, 12.64it/s]
100%|██████████| 33/33 [00:02<00:00, 12.63it/s]


id e2gskA1


100%|██████████| 11/11 [00:00<00:00, 12.87it/s]
100%|██████████| 46/46 [00:03<00:00, 12.77it/s]


id e2hdiA1


100%|██████████| 11/11 [00:00<00:00, 12.92it/s]
100%|██████████| 46/46 [00:03<00:00, 12.86it/s]


id e2iahA3


100%|██████████| 13/13 [00:01<00:00, 12.92it/s]
100%|██████████| 54/54 [00:04<00:00, 12.76it/s]


id e2lhfA1


100%|██████████| 4/4 [00:00<00:00, 12.66it/s]
100%|██████████| 17/17 [00:01<00:00, 12.83it/s]


id e2mafA1


100%|██████████| 6/6 [00:00<00:00, 12.72it/s]
100%|██████████| 23/23 [00:01<00:00, 12.85it/s]


id e2porA1


100%|██████████| 7/7 [00:00<00:00, 12.70it/s]
100%|██████████| 30/30 [00:02<00:00, 12.73it/s]


id e2vdfA1


100%|██████████| 5/5 [00:00<00:00, 12.69it/s]
100%|██████████| 22/22 [00:01<00:00, 12.75it/s]


id e2wjrA1


100%|██████████| 5/5 [00:00<00:00, 12.19it/s]
100%|██████████| 20/20 [00:01<00:00, 12.76it/s]


id e2x55A1


100%|██████████| 7/7 [00:00<00:00, 12.79it/s]
100%|██████████| 27/27 [00:02<00:00, 12.87it/s]


id e2ynkA1


100%|██████████| 11/11 [00:00<00:00, 13.05it/s]
100%|██████████| 44/44 [00:03<00:00, 12.94it/s]


id e3aehA1


100%|██████████| 7/7 [00:00<00:00, 12.51it/s]
100%|██████████| 27/27 [00:02<00:00, 12.59it/s]


id e3bryA1


100%|██████████| 9/9 [00:00<00:00, 12.81it/s]
100%|██████████| 38/38 [00:02<00:00, 12.87it/s]


id e3bs0A1


100%|██████████| 10/10 [00:00<00:00, 12.88it/s]
100%|██████████| 41/41 [00:03<00:00, 12.78it/s]


id e3cslA1


100%|██████████| 16/16 [00:01<00:00, 12.60it/s]
100%|██████████| 63/63 [00:04<00:00, 12.71it/s]


id e3dzmA1


100%|██████████| 5/5 [00:00<00:00, 12.53it/s]
100%|██████████| 20/20 [00:01<00:00, 12.67it/s]


id e3efmA1


100%|██████████| 11/11 [00:00<00:00, 13.02it/s]
100%|██████████| 43/43 [00:03<00:00, 12.78it/s]


id e3fhhA2


100%|██████████| 12/12 [00:00<00:00, 12.91it/s]
100%|██████████| 49/49 [00:03<00:00, 12.81it/s]


id e3fidA1


100%|██████████| 7/7 [00:00<00:00, 12.62it/s]
100%|██████████| 29/29 [00:02<00:00, 12.66it/s]


id e3fipA1


100%|██████████| 9/9 [00:00<00:00, 12.95it/s]
100%|██████████| 36/36 [00:02<00:00, 12.61it/s]


id e3gp6A1


100%|██████████| 4/4 [00:00<00:00, 12.83it/s]
100%|██████████| 15/15 [00:01<00:00, 12.77it/s]


id e3kvnA1


100%|██████████| 8/8 [00:00<00:00, 12.76it/s]
100%|██████████| 31/31 [00:02<00:00, 12.79it/s]


id e3ohnA1


100%|██████████| 10/10 [00:00<00:00, 12.92it/s]
100%|██████████| 39/39 [00:03<00:00, 12.84it/s]


id e3qlbA1


100%|██████████| 13/13 [00:01<00:00, 12.80it/s]
100%|██████████| 53/53 [00:04<00:00, 12.88it/s]


id e3qq2A1


100%|██████████| 6/6 [00:00<00:00, 12.52it/s]
100%|██████████| 24/24 [00:01<00:00, 12.45it/s]


id e3qraA1


100%|██████████| 4/4 [00:00<00:00, 12.68it/s]
100%|██████████| 15/15 [00:01<00:00, 12.82it/s]


id e3sy7A2


100%|██████████| 9/9 [00:00<00:00, 12.75it/s]
100%|██████████| 38/38 [00:03<00:00, 12.56it/s]


id e3sy9A2


100%|██████████| 9/9 [00:00<00:00, 12.77it/s]
100%|██████████| 37/37 [00:02<00:00, 12.84it/s]


id e3sybA2


100%|██████████| 10/10 [00:00<00:00, 12.91it/s]
100%|██████████| 40/40 [00:03<00:00, 12.85it/s]


id e3sysA1


100%|██████████| 9/9 [00:00<00:00, 12.73it/s]
100%|██████████| 37/37 [00:02<00:00, 12.78it/s]


id e3szvA1


100%|██████████| 8/8 [00:00<00:00, 12.51it/s]
100%|██████████| 32/32 [00:02<00:00, 12.57it/s]


id e3t0sA1


100%|██████████| 9/9 [00:00<00:00, 12.74it/s]
100%|██████████| 36/36 [00:02<00:00, 12.69it/s]


id e3v8xA1


100%|██████████| 18/18 [00:01<00:00, 12.36it/s]
100%|██████████| 72/72 [00:05<00:00, 12.73it/s]


id e4afkA1


100%|██████████| 10/10 [00:00<00:00, 12.62it/s]
100%|██████████| 39/39 [00:03<00:00, 12.61it/s]


id e4aipA1


100%|██████████| 13/13 [00:01<00:00, 12.53it/s]
100%|██████████| 54/54 [00:04<00:00, 12.59it/s]


id e4c00A4


100%|██████████| 8/8 [00:00<00:00, 12.44it/s]
100%|██████████| 31/31 [00:02<00:00, 12.43it/s]


id e4cu4A2


100%|██████████| 13/13 [00:00<00:00, 13.05it/s]
100%|██████████| 54/54 [00:04<00:00, 12.43it/s]


id e4d5bA1


100%|██████████| 8/8 [00:00<00:00, 12.56it/s]
100%|██████████| 31/31 [00:02<00:00, 12.52it/s]


id e4e1sA1


100%|██████████| 6/6 [00:00<00:00, 12.20it/s]
100%|██████████| 24/24 [00:01<00:00, 12.45it/s]


id e4e1tA1


100%|██████████| 6/6 [00:00<00:00, 12.55it/s]
100%|██████████| 24/24 [00:01<00:00, 12.33it/s]


id e4epaA1


100%|██████████| 12/12 [00:00<00:00, 12.37it/s]
100%|██████████| 50/50 [00:03<00:00, 12.65it/s]


id e4fqeA1


100%|██████████| 4/4 [00:00<00:00, 12.52it/s]
100%|██████████| 16/16 [00:01<00:00, 12.74it/s]


id e4frxA1


100%|██████████| 10/10 [00:00<00:00, 13.00it/s]
100%|██████████| 39/39 [00:03<00:00, 12.66it/s]


id e4fsoA1


100%|██████████| 9/9 [00:00<00:00, 12.70it/s]
100%|██████████| 35/35 [00:02<00:00, 12.66it/s]


id e4fspA1


100%|██████████| 8/8 [00:00<00:00, 12.40it/s]
100%|██████████| 32/32 [00:02<00:00, 12.45it/s]


id e4fuvA1


100%|██████████| 5/5 [00:00<00:00, 12.43it/s]
100%|██████████| 21/21 [00:01<00:00, 12.64it/s]


id e4geyA1


100%|██████████| 10/10 [00:00<00:00, 12.69it/s]
100%|██████████| 42/42 [00:03<00:00, 12.72it/s]


id e4k3bA6


100%|██████████| 9/9 [00:00<00:00, 12.67it/s]
100%|██████████| 37/37 [00:02<00:00, 12.70it/s]


id e4k3cA3


100%|██████████| 9/9 [00:00<00:00, 12.66it/s]
100%|██████████| 37/37 [00:02<00:00, 12.60it/s]


id e4meeA1


100%|██████████| 7/7 [00:00<00:00, 11.92it/s]
100%|██████████| 30/30 [00:02<00:00, 12.62it/s]


id e4n75A1


100%|██████████| 9/9 [00:00<00:00, 12.96it/s]
100%|██████████| 37/37 [00:02<00:00, 12.82it/s]


id e4q35A2


100%|██████████| 14/14 [00:01<00:00, 12.94it/s]
100%|██████████| 55/55 [00:04<00:00, 12.56it/s]


id e4qkyA1


100%|██████████| 7/7 [00:00<00:00, 12.41it/s]
100%|██████████| 28/28 [00:02<00:00, 12.61it/s]


id e4rdrA2


100%|██████████| 14/14 [00:01<00:00, 12.91it/s]
100%|██████████| 57/57 [00:04<00:00, 12.71it/s]


id e4rjwA1


100%|██████████| 10/10 [00:00<00:00, 12.69it/s]
100%|██████████| 39/39 [00:03<00:00, 12.59it/s]


id e4rl8A1


100%|██████████| 6/6 [00:00<00:00, 12.43it/s]
100%|██████████| 26/26 [00:02<00:00, 12.65it/s]


id e4rlcA1


100%|██████████| 3/3 [00:00<00:00, 12.57it/s]
100%|██████████| 13/13 [00:01<00:00, 12.84it/s]


id e4y25A1


100%|██████████| 7/7 [00:00<00:00, 12.54it/s]
100%|██████████| 29/29 [00:02<00:00, 12.68it/s]


id e4zgvA1


100%|██████████| 16/16 [00:01<00:00, 12.44it/s]
100%|██████████| 66/66 [00:05<00:00, 12.49it/s]


id e5dl5A1


100%|██████████| 10/10 [00:00<00:00, 12.75it/s]
100%|██████████| 40/40 [00:03<00:00, 12.71it/s]


id e5dl6A1


100%|██████████| 10/10 [00:00<00:00, 12.40it/s]
100%|██████████| 39/39 [00:03<00:00, 12.67it/s]


id e5dl7A1


100%|██████████| 10/10 [00:00<00:00, 12.71it/s]
100%|██████████| 40/40 [00:03<00:00, 12.68it/s]


id e5dl8A1


100%|██████████| 9/9 [00:00<00:00, 12.66it/s]
100%|██████████| 37/37 [00:02<00:00, 12.53it/s]


id e5fokA1


100%|██████████| 13/13 [00:01<00:00, 12.60it/s]
100%|██████████| 53/53 [00:04<00:00, 12.55it/s]


id e5fp1A1


100%|██████████| 14/14 [00:01<00:00, 12.79it/s]
100%|██████████| 57/57 [00:04<00:00, 12.49it/s]


id e5fq8B2


100%|██████████| 20/20 [00:01<00:00, 12.77it/s]
100%|██████████| 81/81 [00:06<00:00, 12.47it/s]


id e5fr8A2


100%|██████████| 14/14 [00:01<00:00, 12.76it/s]
100%|██████████| 56/56 [00:04<00:00, 12.29it/s]


id e5fvnA1


100%|██████████| 8/8 [00:00<00:00, 12.57it/s]
100%|██████████| 34/34 [00:02<00:00, 12.58it/s]


id e5ivaA1


100%|██████████| 14/14 [00:01<00:00, 12.86it/s]
100%|██████████| 57/57 [00:04<00:00, 12.49it/s]


id e5ldvA1


100%|██████████| 10/10 [00:00<00:00, 12.65it/s]
100%|██████████| 40/40 [00:03<00:00, 12.64it/s]


id e5m9bA1


100%|██████████| 14/14 [00:01<00:00, 12.75it/s]
100%|██████████| 56/56 [00:04<00:00, 12.51it/s]


id e5mdoA1


100%|██████████| 8/8 [00:00<00:00, 12.50it/s]
100%|██████████| 33/33 [00:02<00:00, 12.63it/s]


id e5o65A1


100%|██████████| 5/5 [00:00<00:00, 12.42it/s]
100%|██████████| 20/20 [00:01<00:00, 12.51it/s]


id e5o8oA1


100%|██████████| 8/8 [00:00<00:00, 12.44it/s]
100%|██████████| 33/33 [00:02<00:00, 12.44it/s]


id e5t3rD1


100%|██████████| 19/19 [00:01<00:00, 12.50it/s]
100%|██████████| 78/78 [00:06<00:00, 12.47it/s]


id e6bpmA1


100%|██████████| 14/14 [00:01<00:00, 12.78it/s]
100%|██████████| 58/58 [00:04<00:00, 12.49it/s]


id e6e4vA1


100%|██████████| 13/13 [00:01<00:00, 12.68it/s]
100%|██████████| 54/54 [00:04<00:00, 12.47it/s]


id e6ehbA1


100%|██████████| 8/8 [00:00<00:00, 12.54it/s]
100%|██████████| 31/31 [00:02<00:00, 12.63it/s]


id e6ehdA1


100%|██████████| 8/8 [00:00<00:00, 12.26it/s]
100%|██████████| 32/32 [00:02<00:00, 12.54it/s]


id e6eusA1


100%|██████████| 8/8 [00:00<00:00, 12.63it/s]
100%|██████████| 34/34 [00:02<00:00, 12.69it/s]


id e6fokA1


100%|██████████| 12/12 [00:00<00:00, 12.70it/s]
100%|██████████| 47/47 [00:03<00:00, 12.69it/s]


id e6gieA1


100%|██████████| 4/4 [00:00<00:00, 12.39it/s]
100%|██████████| 17/17 [00:01<00:00, 12.48it/s]


id e6h3iF1


100%|██████████| 8/8 [00:00<00:00, 12.52it/s]
100%|██████████| 32/32 [00:02<00:00, 12.49it/s]


id e6h7fA1


100%|██████████| 13/13 [00:01<00:00, 12.89it/s]
100%|██████████| 53/53 [00:04<00:00, 12.76it/s]


id e6i96A1


100%|██████████| 13/13 [00:01<00:00, 12.41it/s]
100%|██████████| 53/53 [00:04<00:00, 12.65it/s]


id e6qwrA1


100%|██████████| 5/5 [00:00<00:00, 12.44it/s]
100%|██████████| 21/21 [00:01<00:00, 12.70it/s]


id e6r2qB1


100%|██████████| 16/16 [00:01<00:00, 12.74it/s]
100%|██████████| 64/64 [00:05<00:00, 12.60it/s]


id e6sljA1


100%|██████████| 19/19 [00:01<00:00, 12.71it/s]
100%|██████████| 75/75 [00:05<00:00, 12.58it/s]


id e6ucuA1


100%|██████████| 7/7 [00:00<00:00, 12.38it/s]
100%|██████████| 30/30 [00:02<00:00, 12.71it/s]


id e6v81A2


100%|██████████| 14/14 [00:01<00:00, 12.96it/s]
100%|██████████| 55/55 [00:04<00:00, 12.56it/s]


0,1
max_cosine_distance,0.11308
max_rmsd,12.22372
mean_cosine_distance,0.04387
mean_rmsd,2.76368
median_cosine_distance,0.0415
median_rmsd,1.81325
min_cosine_distance,0.01398
min_rmsd,0.25524
pearson_correlation,0.38064
spearman_correlation,0.41502
