In [56]:
import pandas as pd
import json
import os
import pymol
from pymol import cmd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import scipy as sp

# from statannotations.Annotator import Annotator
# from alphafetcher import AlphaFetcher

In [57]:
def annotate(fn, fp, label, uniprot, pdb, reg1, reg2):
    
    cmd.delete('all')
    cmd.load(fn)

    for obj in ['native','pred']:
        # select region1
        cmd.select(f'{obj}_1', f'{obj} and resi {reg1}')
        # select reg2
        cmd.select(f'{obj}_2', f'{obj} and resi {reg2}')
    
    # align n2, p2
    cmd.align('native_2', 'pred_2')

    # Color various parts
    cmd.color('grey40', 'native')
    cmd.color('grey90', 'native_1')

    # Color of pred changes based on what it was predicted with
    if label == 'full_depth':
        cmd.color('forest', 'pred')
    elif 'U' in label:
        cmd.set_color('uc', '[0.3843, 0.6392, 0.1804]')
        cmd.color('uc', 'pred')
    else:
        cmd.set_color('lc','[0.6902, 0.8627, 0.4902]')
        cmd.color('lc', 'pred')

    cmd.color('cyan', 'pred_1')

    # Save the session
    cmd.save(os.path.join(fp, f'{uniprot}_{pdb}_{label}.pse'))

def annotate_uc(fn, fp, label, uniprot, pdb, reg1, reg2):
    
    cmd.delete('all')
    cmd.load(fn)

    for obj in ['native','pred']:
        # select region1
        cmd.select(f'{obj}_1', f'{obj} and resi {reg1}')
        # select reg2
        cmd.select(f'{obj}_2', f'{obj} and resi {reg2}')
    
    # align n2, p2
    cmd.align('native_2', 'pred_2')

    # Make uniform color

    # Color various parts
    cmd.color('grey40', 'native')
    cmd.color('red', 'native_2')
    cmd.color('grey90', 'native_1')

    cmd.color('forest', 'pred')
    cmd.color('cyan', 'pred_1')

    # Save the session
    cmd.save(os.path.join(fp, f'{uniprot}_{pdb}_{label}.pse'))

def annotate_all(df, fp1, fp2, fp3):

    for i, row in df.iterrows():

        uniprot = row['UniProt']
        pdb = row['PDB']
        cluster = row['Cluster']
        reg1 = row['region_1'].replace(',', '+')
        reg2 = row['region_2'].replace(',', '+')
        # conf = row['conformation_type']

        cl_fn = os.path.join(fp2, uniprot, f'{pdb}_{uniprot}_{cluster}.pdb')

        annotate(cl_fn, os.path.join(fp3, uniprot), cluster, uniprot, pdb, reg1, reg2)

        if i == 1: # Only have to color the full-depth once
            fd_fn = os.path.join(fp1, f'{pdb}_{uniprot}.pdb')

            annotate(fd_fn, os.path.join(fp3, uniprot), 'full_depth', uniprot, pdb, reg1, reg2)


        # if conf == 'base':

        #     # Full depth filename
        #     fd_fn = os.path.join(fp1, f'{pdb}_{uniprot}.pdb')

        #     annotate(fd_fn, fp3, conf, uniprot, pdb, reg1, reg2)

        # elif conf == 'alternate':

        #     cl_fn = os.path.join(fp2, uniprot, f'{pdb}_{uniprot}_{cluster}.pdb')

        #     annotate(cl_fn, fp3, conf, uniprot, pdb, reg1, reg2)

In [62]:
# df1 = pd.read_csv('../two_state_best.csv')
# df = pd.read_csv('./project_pipeline/data/ai_pdb_cluster_compared.tsv', sep='\t').rename(columns={'uniprot': 'UniProt', 'pdb': 'PDB', 'cluster': 'Cluster'}) \
#                 [['UniProt', 'PDB', 'Cluster', 'region_1', 'region_2', 'complex_rmsd', '2_comp']]
# pae = pd.read_csv('./project_pipeline/data/ai_cluster_pae.tsv', sep='\t').rename(columns={'uniprot': 'UniProt', 'pdb': 'PDB', 'cluster': 'Cluster'})
# fp1 = './project_pipeline/data/output/complexes/'
# fp2 = './project_pipeline/data/output/cf_pdb_complexes/'
# fp3 = '../paper/'

In [63]:
# # Make list of clusters
# clusters = ['U10-005', 'U100-007', 'U10-008', '201', '065']
# pdbs = ['4coo', '4pcu']

# df = df[(df['Cluster'].isin(clusters)) & (df['PDB'].isin(pdbs))].reset_index(drop=True)

# # best_both = pd.concat([best_uc, best_lc])
# annotate_all(df, fp1, fp2, fp3)
# df

Unnamed: 0,UniProt,PDB,Cluster,region_1,region_2,complex_rmsd,2_comp
0,P35520,4coo,201,414-551,40-413,19.826,27.842
1,P35520,4coo,065,414-551,40-413,21.616,42.729
2,P35520,4coo,U10-008,414-551,40-413,19.712,28.141
3,P35520,4coo,U10-005,414-551,40-413,17.057,43.017
4,P35520,4coo,U100-007,414-551,40-413,19.705,42.371
5,P35520,4pcu,201,414-551,40-413,27.887,57.363
6,P35520,4pcu,065,414-551,40-413,14.529,13.847
7,P35520,4pcu,U10-008,414-551,40-413,24.002,57.283
8,P35520,4pcu,U10-005,414-551,40-413,12.48,14.721
9,P35520,4pcu,U100-007,414-551,40-413,6.095,13.052


In [109]:
df1 = pd.read_csv('./project_pipeline/data/ai_cluster_interface.tsv', sep='\t').fillna(0)
df2 = pd.read_csv('./project_pipeline/data/alphafold_interface.tsv', sep='\t')
df3 = pd.read_csv('./project_pipeline/data/ai_cluster_pae.tsv', sep='\t')
df4 = pd.read_csv('./project_pipeline/data/disorder.tsv', sep='\t')
df1.head()

Unnamed: 0,uniprot,cluster,region_1,region_2,cf_filename,interacting_residue_pairs,interface_residues,number_interface_residues,region_1 search,region_2 search
0,P04637,004,364-393,102-292,P04637_004_unrelaxed_rank_001_alphafold2_multi...,0,0,0.0,"[364, 365, 366, 367, 368, 369, 370, 371, 372, ...","[102, 103, 104, 105, 106, 107, 108, 109, 110, ..."
1,P04637,000,364-393,102-292,P04637_000_unrelaxed_rank_001_alphafold2_multi...,0,0,0.0,"[364, 365, 366, 367, 368, 369, 370, 371, 372, ...","[102, 103, 104, 105, 106, 107, 108, 109, 110, ..."
2,P04637,015,364-393,102-292,P04637_015_unrelaxed_rank_001_alphafold2_multi...,0,0,0.0,"[364, 365, 366, 367, 368, 369, 370, 371, 372, ...","[102, 103, 104, 105, 106, 107, 108, 109, 110, ..."
3,P04637,009,364-393,102-292,P04637_009_unrelaxed_rank_001_alphafold2_multi...,0,0,0.0,"[364, 365, 366, 367, 368, 369, 370, 371, 372, ...","[102, 103, 104, 105, 106, 107, 108, 109, 110, ..."
4,P04637,U100-003,364-393,102-292,P04637_U100-003_unrelaxed_rank_001_alphafold2_...,0,0,0.0,"[364, 365, 366, 367, 368, 369, 370, 371, 372, ...","[102, 103, 104, 105, 106, 107, 108, 109, 110, ..."


In [110]:
# Determine the best cluster by PAE
best_pae = df3.sort_values('mean_pae_1_2').groupby('uniprot').first().reset_index()

# Get the interfaces
best_pae_ints = pd.merge(best_pae, df1, on=['uniprot', 'cluster'], how='inner')

# Combine interfaces and paes for af2
af2 = pd.merge(df2, df4, on='uniprot', how='inner')

af2 = af2[['uniprot', 'number_interface_residues', 'mean_pae_1_2']]

best_pae_ints = best_pae_ints[['uniprot', 'cluster', 'number_interface_residues', 'mean_pae_1_2']]

ints_compared = pd.merge(af2, best_pae_ints, on=['uniprot'], how='inner').fillna(0)

In [116]:
ints_compared['pae_change'] = ints_compared['mean_pae_1_2_x'] - ints_compared['mean_pae_1_2_y']
ints_compared['n_res_change'] = ints_compared['number_interface_residues_x'] - ints_compared['number_interface_residues_y']

ints_compared = ints_compared.astype({'pae_change': float, 'n_res_change': float})

for index, row in ints_compared.iterrows():

        if row['pae_change'] < 0 and row['n_res_change'] > 0:
            ints_compared.at[index, 'better_pae_closer_res'] = 'Better & Closer'
        elif row['pae_change'] < 0 and row['n_res_change'] < 0:
            ints_compared.at[index, 'better_pae_closer_res'] = 'Better but Further'
        elif row['pae_change'] > 0 and row['n_res_change'] > 0:
            ints_compared.at[index, 'better_pae_closer_res'] = 'Worse and Closer'
        elif row['pae_change'] > 0 and row['n_res_change'] < 0:
            ints_compared.at[index, 'better_pae_closer_res'] = 'Worse and Further'
        else:
            ints_compared.at[index, 'better_pae_closer_res'] = 'Null'



In [118]:
ints_compared['better_pae_closer_res'].value_counts()

Worse and Further     14
Better & Closer       11
Worse and Closer       5
Null                   4
Better but Further     4
Name: better_pae_closer_res, dtype: int64

In [119]:
fifteen_prots = ['P07038', 'Q8NQJ3', 'P60240', 'P28482', 'P62826',
       'P22681', 'P21333', 'P12931', 'Q9Y6K1', 'P26358', 'P29350', 'P35520',
       'P27577', 'O08967', 'P00579']

fif_ints = ints_compared[ints_compared['uniprot'].isin(fifteen_prots)]

fif_ints['better_pae_closer_res'].value_counts()

Worse and Further     8
Better & Closer       3
Better but Further    2
Worse and Closer      1
Null                  1
Name: better_pae_closer_res, dtype: int64