In [1]:
import os
import pandas as pd
import numpy as np
import plotnine as gg
import esm

# Read inference results

In [2]:
def read_samples(results_dir):
    all_csvs = []
    print(f'Reading samples from {results_dir}')
    for sample_length in os.listdir(results_dir):
        if '.' in sample_length:
            continue
        length_dir = os.path.join(results_dir, sample_length)
        length = int(sample_length.split('_')[1])
        for i,sample_name in enumerate(os.listdir(length_dir)):
            if '.' in sample_name:
                continue
            csv_path = os.path.join(length_dir, sample_name, 'self_consistency', 'sc_results.csv')
            if os.path.exists(csv_path):
                design_csv = pd.read_csv(csv_path, index_col=0)
                design_csv['length'] = length
                design_csv['sample_id'] = i
                all_csvs.append(design_csv)
    results_df = pd.concat(all_csvs)
    return results_df


def sc_filter(raw_df, metric):
    # Pick best self-consistency sample
    if metric == 'tm_score':
        df = raw_df.sort_values('tm_score', ascending=False)
        df['designable'] = df.tm_score.map(lambda x: x > 0.5)
    elif metric == 'rmsd':
        df = raw_df.sort_values('rmsd', ascending=True)
        df['designable'] = df.rmsd.map(lambda x: x < 2.0)
    else:
        raise ValueError(f'Unknown metric {metric}')
    df = df.groupby(['length', 'sample_id']).first().reset_index()
    percent_designable = df['designable'].mean()
    print(f'Percent designable: {percent_designable}')
    return df

In [3]:
# results_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_noise_scale_10_ts_100'
# results_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_noise_scale_10'
# results_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_noise_scale_10_ts_10'
# results_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_noise_scale_10_ts_10_ode'
results_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_ts_10_ode'
samples_df = read_samples(results_dir)
samples_df = samples_df[samples_df.sample_id < 8] # Ensure we only consider 8 sequences per backbone.

scrmsd_results = sc_filter(samples_df, 'rmsd')
sctm_results = sc_filter(samples_df, 'tm_score')

Reading samples from /data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/inference_outputs/scope_ts_10_ode
Percent designable: 0.1709090909090909
Percent designable: 0.6327272727272727


# Re calculate pLDDT

In [None]:
folding_model = esm.pretrained.esmfold_v1().eval()
folding_model = folding_model.to('cuda:4')

In [15]:
designable_sctm_results = sctm_results[sctm_results.designable]

In [20]:
all_sequences = designable_sctm_results.sequence.tolist()
all_plddt = []
for seq in all_sequences:
    with torch.no_grad():
        output = folding_model.infer(seq)
    all_plddt.append(output['mean_plddt'])
all_plddt = torch.stack(all_plddt).squeeze().cpu().numpy()
designable_sctm_results['plddt'] = all_plddt

'AAARERAERLARIRALWEEARAENPNATLREIGERAGISPETVSRGIREVEEEERRAAGI'

In [41]:
confident_sctm_results = designable_sctm_results[designable_sctm_results.plddt > 70]
confident_designable = len(confident_sctm_results) / len(designable_sctm_results)

In [50]:
import shutil

In [53]:
confident_paths = confident_sctm_results.sample_path.map(
    lambda x: '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/' + '/'.join(x.split('/')[:-3]) + '/sample_1.pdb').tolist()

In [60]:
write_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/cluster_results/scope'
pdb_text_path = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/cluster_results/scope/pdbs.txt'
with open(pdb_text_path, 'w') as f:
    for i,path in enumerate(confident_paths):
        write_path = write_dir + f'/sample_{i}.pdb'
        shutil.copy(path, write_path)
        f.write(write_path+'\n')

In [65]:
167 / 404

0.41336633663366334

In [61]:
cluster_results_path = '/data/rsg/chemistry/jyim/projects/flow-matching/se3_diffusion/cluster_results/scope/maxcluster_results.txt'
with open(cluster_results_path) as f:
    lines = f.readlines()

In [None]:
cluster_lines = []
for x in lines:
    if x

In [None]:
[x for x in lines if 'INFO  : Item     Cluster'