## AF3Plot
Plotting MSA, pLDDT, and PAE output from the new [AlphaFold3 server](alphafoldserver.com/) 

### 1. Importing dependencies

In [None]:
import json
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

### 2. Sorting through output folder files
The folder should contain something like this:

|&#128193; | | | |
|---------------------|--------------------|----------------------|---------------------|
|fold_[...]_full_data_0.json         |fold_[...]_full_data_1.json        |fold_[...]_full_data_2.json          |fold_[...]_full_data_3.json  |
|fold_[...]_full_data_4.json  |fold_[...]_job_request.json  |fold_[...]_model_0.cif          |fold_[...]_model_1.cif|
|fold_[...]_model_2.cif                 |fold_[...]_model_3.cif       |fold_[...]_model_4.cif    |fold_[...]_summary_confidences_0.json  |
|fold_[...]_summary_confidences_1.json   |fold_[...]_summary_confidences_2.json  |fold_[...]_summary_confidences_3.json |fold_[...]_summary_confidences_4.json|
|terms_of_use.md         |

In [None]:
class ARG:
    def __init__(self, repo):
        self.input_dir = repo
        self.output_dir = repo
        self.name = name 

repo = [input('Please copy the path to your output folder (e.g. /User/Documents/fold_2024__05_12_protein1 ): ')] # This is a list of all output repositories
name = input('Please enter your AlphaFold3 job name (e.g. 2024-05-22_protein1): ')
name = name.replace('-','_')
print(name, f'({repo})')

### 3. Loading JSON files into Jupyter

In [None]:
# Option 1: json.load
for r in repo:
    args = ARG(r)
    summary_confidences=[]
    full_data=[]
    for i in range(0,5,1):
        with open(os.path.join(r, "fold_"+name+"_summary_confidences_"+str(i)+".json")) as scFile:
            sc = json.load(scFile)
            summary_confidences.append(sc)
        with open(os.path.join(r, "fold_"+name+"_full_data_"+str(i)+".json")) as fdFile:
            fd = json.load(fdFile)
            full_data.append(fd)

In [None]:
summary_confidences[0].keys()

In [None]:
full_data[0].keys()

### Structure of JSON files:

#### full_data

dict_keys(['atom_chain_ids', 'atom_plddts', 'contact_probs', 'pae', 'token_chain_ids', 'token_res_ids'])

From [AlphaFold Server FAQ page](https://alphafoldserver.com/welcome#how-do-i-interpret-all-the-outputs-in-the-downloaded-json-files):
- **atom_chain_ids**: A [num_atoms] array indicating the chain ids corresponding to each atom in the prediction.
- **atom_plddts**: A [num_atoms] array, element i indicates the predicted local distance difference test (pLDDT) for atom i in the prediction.
- **contact_probs**: A [num_tokens, num_tokens] array. Element (i, j) indicates the predicted probability that token i and token j are in contact (8Å between the representative atom for each token), see paper for details.
- **pae**: A [num_tokens, num_tokens] array. Element (i, j) indicates the predicted error in the position of token j, when the prediction is aligned to the ground truth using the frame of token i.
- **token_chain_ids**: A [num_tokens] array indicating the chain ids corresponding to each token in the prediction.
- **token_res_ids**: _Position (residue) numbers ; format 1-(total number of residues)_ (nothing on their FAQ page)

#### summary_confidences

From [AlphaFold Server FAQ page](https://alphafoldserver.com/welcome#how-do-i-interpret-all-the-outputs-in-the-downloaded-json-files):
- **ptm**: A scalar in the range 0-1 indicating the predicted TM-score for the full structure.
- **iptm**: A scalar in the range 0-1 indicating predicted interface TM-score (confidence in the predicted interfaces) for - all interfaces in the structure.
- **fraction_disordered**: A scalar in the range 0-1 that indicates what fraction of the prediction structure is disordered, as measured by accessible surface area, see our [paper](https://www.nature.com/articles/s41586-024-07487-w) for details.
- **has_clash**: A boolean indicating if the structure has a significant number of clashing atoms (more than 50% of a chain, or a chain with more than 100 clashing atoms).
- **ranking_score**: A scalar in the range [-100, 1.5] that can be used for ranking predictions, it incorporates ptm, iptm, fraction_disordered and has_clash into a single number with the following equation: 0.8 x ipTM + 0.2 x pTM + 0.5 x disorder − 100 x has_clash.
- **chain_pair_pae_min**: A [num_chains, num_chains] array. Element (i, j) of the array contains the lowest PAE value across rows restricted to chain i and columns restricted to chain j. This has been found to correlate with whether two chains interact or not, and in some cases can be used to distinguish binders from non-binders.
- **chain_pair_iptm**: A [num_chains, num_chains] array. Off-diagonal element (i, j) of the array contains the ipTM restricted to tokens from chains i and j. Diagonal element (i, i) contains the pTM restricted to chain i. Can be used for ranking a specific interface between two chains, when you know that they interact, e.g. for antibody-antigen interactions
- **chain_ptm**: A [num_chains] array. Element i contains the pTM restricted to chain i. Can be used for ranking individual chains when the structure of that chain is most of interest, rather than the cross-chain interactions it is involved with.
- **chain_iptm**: A [num_chains] array that gives the average confidence (interface pTM) in the interface between each chain and all other chains. Can be used for ranking a specific chain, when you care about where the chain binds to the rest of the complex and you do not know which other chains you expect it to interact with. This is often the case with ligands.

In [None]:
num_models = len(full_data)
model_scores = pd.DataFrame(columns=['Model', 'iptm', 'ptm', 'Weighted Confidence Score'])
for i in range(0,num_models,1):
    #AF2_C_w = (0.8*summary_confidences[i]['iptm']) + (0.2*summary_confidences[i]['ptm'])
    C_w = (0.8*summary_confidences[i]['iptm']) + (0.2*summary_confidences[i]['ptm']) + (0.5*summary_confidences[i]['fraction_disordered']) - (100*summary_confidences[i]['has_clash'])
    row = pd.Series({'Model': i, 'iptm': summary_confidences[i]['iptm'], 'ptm': summary_confidences[i]['ptm'], 'Weighted Confidence Score': C_w})
    model_scores.loc[i] = row
    
    print(f'Model {i} C_w = {C_w}')
    #print(summary_confidences[i]['ranking_score'])
    #print(summary_confidences[i]['has_clash'])
    #print(summary_confidences[i]['fraction_disordered'])
    if summary_confidences[i]['has_clash'] > 0:
        print('Model', i, 'has a clash')
model_scores
    

In [None]:
len(full_data[0]['contact_probs'])

### Plotting the _per atom_ pLDDT

In [None]:
plot_param = 'atom_plddts'
# Plotting top-ranked prediction pLDDT scores
max_rank = input('How many of the top-ranked predictions would you like to plot? ')
if max_rank == '':
    rank_range = np.arange(0,5,1)
    NUM_COLORS = 5
else:
    rank_range = np.arange(0,int(max_rank),1)
    NUM_COLORS = int(max_rank)
#cm = plt.get_cmap('gist_rainbow') #color option 1
#cm = sns.color_palette('husl', n_colors = NUM_COLORS) #color option 2
cm = plt.cm.jet(np.linspace(0,1,NUM_COLORS)) #color option 3

fig, ax = plt.subplots(figsize=(10,6))
# Set fontface
matplotlib.rcParams['font.sans-serif'] = "Arial"
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams["mathtext.default"] = "regular"

for i in rank_range:
    ax.plot(full_data[i][plot_param], label='Rank '+str(i), color=cm[i], linewidth=.75)

ax.set_title(name, fontsize=14)
ax.set_ylabel('Predicted LDDT', fontsize=12)
ax.set_xlabel('Atoms', fontsize=12)
ax.legend(bbox_to_anchor=(1.01, 1.05))
#ax.xaxis.set_major_locator(plt.MaxNLocator(df_xy.shape[0]/))
ax.set_ylim(top=100)
#xrange = input('What range of residues do you want to plot (e.g. "0-100" – leave blank for all): ')
#if xrange == '':
xmin, xmax = 0, len(full_data[i][plot_param])
#else:
    #xrng = xrange.split('-')
    #xmin, xmax = int(xrng[0]), int(xrng[1])

#horizLine = input('Plot a vertical divider? (y/n)')
#if horizLine == 'y':
    #ax.axvline(int(input('Which position? ')), color='gray', linestyle='--')

ax.set_xlim(xmin,xmax)
ax.set_ylim(0,100)
ax.set_xticks(np.round(np.linspace(xmin, xmax, 11), 2));
#1103.5
##outpath = csv_path.replace(filename,'')
#plt.savefig(f"{outpath}plddt-replot.pdf")   #<-- Remove #'s on this line to automatically save the figure as .pdf

### Plotting the Predicted Aligned Error (PAE)

In [None]:
################# PLOT THE PREDICTED ALIGNED ERROR################
plt.figure(figsize=(3 * num_models, 2), dpi=100)


for i in range(0,num_models,1):
    plt.subplot(1, num_models, i + 1)
    plt.title(f"Model {i}: "+r"$C_{w} =$ "+str(model_scores.loc[i]['Weighted Confidence Score']))
    #if max(full_data[i]['pae'])>=30:
    #    vmax = max(full_data[i]['pae']) 
    #else:
    #    vmax=30
    plt.imshow(full_data[i]['pae'], label=f"Model_{i}", cmap="bwr", vmin=0, vmax=30)
    plt.xlabel("Scored Residue", fontsize=10)
    #plt.set_tickparams(fontsize=8)
    plt.ylabel("Aligned Residue", fontsize=10)
    cbar = plt.colorbar(orientation="horizontal",fraction=0.046, pad=.3)
    cbar.ax.set_xlabel('Expected Position Error / Å', fontsize=8)
    cbar.ax.set_xticklabels([0,10,20,30],fontsize=8);
    #plt.tight_layout()