# Imports

In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import yaml
from tqdm import tqdm
import avoidome.uniprot as uniprot
from importlib import reload
from avoidome.target import TargetStructureData
from asapdiscovery.data.openeye import load_openeye_cif, oechem
import plotly.express as px

# Load the data

In [5]:
data_dir = Path('../data')
fig_dir = Path('../figures')
uniprot_dir = data_dir / 'uniprot_downloads'
af_dir = data_dir / 'alphafold_downloads'
schema_cache = data_dir / 'schema_cache'

In [6]:
with open(data_dir / 'admet_names_curated.yml') as f:
    adme_names = yaml.safe_load(f)

In [7]:
tsd_paths = list(schema_cache.glob('*.yml'))

In [9]:
tsds = []
for tsd_path in tqdm(tsd_paths):
    with open(tsd_path) as f:
        tsds.append(TargetStructureData.from_dict(yaml.safe_load(f)))

100%|██████████| 57/57 [00:01<00:00, 35.54it/s]


# Calculate AF Confidence

In [10]:
def get_af_confidence(tsd):
    mol = load_openeye_cif(af_dir / f'{tsd.target_name}.cif')
    residues = {oechem.OEAtomGetResidue(atom) for atom in mol.GetAtoms() if oechem.OEHasResidue(atom)}
    bfactors = np.array([res.GetBFactor() for res in residues]) 
    return bfactors

In [11]:
af_confidences = {tsd.target_name: get_af_confidence(tsd) for tsd in tqdm(tsds)}

100%|██████████| 57/57 [00:04<00:00, 13.29it/s]


In [12]:
lengths = np.array([len(confidences) for confidences in af_confidences.values()]) 

In [13]:
lengths.max()

2221

In [14]:
def convert_to_relative_v2(af_confidences, length=10):
    quantum = 1/len(af_confidences)
    resolution = np.array([i*quantum for i in range(len(af_confidences)-1)] + [1])
    
    result = []
    # iterate through n desired length
    for i in range(length):
        current = i/length
        # find the category that the current value falls into
        
        for j in range(len(resolution)):
            category = af_confidences[j]
            percent = resolution[j]
            # print(current, percent, category)
            if current <= percent:
                result.append(category)
                break
    return result

# Make Dataframe

In [15]:
uniprot_to_grant_name = {}
for ref in adme_names:
    uniprot_to_grant_name[ref['uniprot']] = ref['grant_name']

In [16]:
uniprot_to_category = {}
for ref in adme_names:
    uniprot_to_category[ref['uniprot']] = ref['admet_category']

In [17]:
dfs = []
for tsd in tsds:
    confidences = af_confidences[tsd.target_name]
    df = pd.DataFrame({
        'Target': tsd.target_name,
        'AF Confidence': convert_to_relative_v2(confidences, 1000),
        'Uniprot': tsd.uniprot_id,
        'Protein Name': uniprot_to_grant_name.get(tsd.uniprot_id, tsd.uniprot_id),
        'Category': uniprot_to_category.get(tsd.uniprot_id, 'Unknown') 
    })
    dfs.append(df)
df = pd.concat(dfs)

# Calculate correct colorscale

## make AF colorscale mapper

In [18]:
def discrete_colorscale(bvals, colors):
    """
    <https://chart-studio.plotly.com/~empet/15229/heatmap-with-a-discrete-colorscale/#/>
    bvals - list of values bounding intervals/ranges of interest
    colors - list of rgb or hex colorcodes for values in [bvals[k], bvals[k+1]],0<=k < len(bvals)-1
    returns the plotly  discrete colorscale
    """
    if len(bvals) != len(colors)+1:
        raise ValueError('len(boundary values) should be equal to  len(colors)+1')
    bvals = sorted(bvals)     
    nvals = [(v-bvals[0])/(bvals[-1]-bvals[0]) for v in bvals]  #normalized values
    
    dcolorscale = [] #discrete colorscale
    for k in range(len(colors)):
        dcolorscale.extend([[nvals[k], colors[k]], [nvals[k+1], colors[k]]])
    return dcolorscale    

## map the colorscale

In [19]:
bvals = [0, 50, 70, 90, 100]
#pulled directly from AF website
colors = ['#f47c48', '#feda0c', '#66caf3', '#355daa']
colorscale = discrete_colorscale(bvals, colors)
bvals = np.array(bvals)
tickvals = [np.mean(bvals[k:k+2]) for k in range(len(bvals)-1)] #position with respect to bvals where ticktext is displayed
ticktext = [f'<{bvals[1]}'] + [f'{bvals[k]}-{bvals[k+1]}' for k in range(1, len(bvals)-2)]+[f'>{bvals[-2]}']
text_names = ['Very low', 'Low', 'High','Very high', ]
ticktext = [f"{text} ({value})" for text, value in zip(text_names, ticktext)]

# clean the data

In [88]:
df.sort_values(['Category', 'Protein Name'], inplace=True, ascending=False)

In [89]:
categories = df['Category'].unique()

In [90]:
protein_names = df['Protein Name'].unique()

In [91]:
heatmap = [df[df['Protein Name'] == protein]['AF Confidence'].values for protein in protein_names]

## Get correct names

In [92]:
import plotly.graph_objects as go
import numpy as np
my_heatmap = go.Heatmap(z=heatmap, 
                        colorscale=colorscale,
                        y = protein_names,
                        x = np.arange(0,1,1/1000),
                        colorbar=dict(title='Confidence (pLDDT)', ticktext=ticktext, tickvals=tickvals, len=0.25, lenmode='fraction', yanchor='top', y=1),
                        ygap=5,
                        )
fig = go.Figure(data=[my_heatmap])
fig.update_layout(template="seaborn",  
                  height=1600, 
                  width=1600, 
                  title='AlphaFold Confidence',
                  xaxis=dict(title='Fractional Sequence Position'),)
fig.show()

In [93]:
fig.write_image(fig_dir / 'af_confidence_heatmap.png')
fig.write_image(fig_dir / 'af_confidence_heatmap.svg')

# Put it all together

In [208]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## make Resolution Traces

In [209]:
resolution_df = pd.concat([pd.DataFrame({'Uniprot ID': tsd.uniprot_id,
                                         'Category': uniprot_to_category.get(tsd.uniprot_id, 'Unknown'),
                   'Average Sequence Coverage': tsd.average_coverage,
                              "Sequence Coverage": [exp_struc.sequence_coverage / tsd.sequence_length for exp_struc in tsd.experimental_structures],
                   'Number of Experimental Structures': [tsd.n_experimental_structures for exp_struc in tsd.experimental_structures],
                  "Sequence Length": [tsd.sequence_length for exp_struc in tsd.experimental_structures],
                          "Resolution (A)": [exp_struc.resolution for exp_struc in tsd.experimental_structures],
                                    "Method": [exp_struc.method for exp_struc in tsd.experimental_structures],
                                         'Protein Name': uniprot_to_grant_name.get(tsd.uniprot_id, tsd.uniprot_id),
                   }) for tsd in tsds]) 


Mean of empty slice.


invalid value encountered in double_scalars



In [438]:
len(resolution_df['Protein Name'].unique())

49

In [439]:
# Create Other Traces
jitter_amount = 0.1
protein_name_mapping = {name: i for i, name in enumerate(protein_names)}
symbols = {'X-ray': 'diamond-tall', 'EM': 'circle', 'AlphaFold': 'circle-open'}
# colors = {'metabolism_redox': '#d07c09', 'metabolism': '#ffb418', 'absorption': '#11efb7', 'drug_transporter': '#9553ff', 'toxicity': '#5e2bcb'}
colors = {'metabolism': '#ffb418', 'distribution': '#11efb7', 'transporters': '#d07c09', 'toxicity': '#5e2bcb'}
category_list = sorted(list(colors.keys()))

coverage_traces = []
resolution_traces = []
for category in category_list:
    for method in resolution_df['Method'].unique():
        df_subset = resolution_df[(resolution_df['Method'] == method) & (resolution_df['Category'] == category)]
        
        # jitter y values to avoid overlap
        jitter = np.random.uniform(low=-jitter_amount, high=jitter_amount, size=len(df_subset))
        y_values = [protein_name_mapping[name] + jitter[i] for i, name in enumerate(df_subset['Protein Name'])]
        
        marker_dict = dict(size=10, color=colors[category], symbol=symbols[method], line_width=2 if method=="EM" else 1, line_color='black')
        
        # make resolution trace for each legend group
                
        # for category
        trace = go.Scatter(
                x=df_subset['Resolution (A)'],
                y=y_values,
                mode='markers',
                name=category,
                marker=marker_dict,
                showlegend=True if method == 'X-ray' else False,
                opacity=1,
                legendgroup=1,
                legendgrouptitle=dict(text="ADMET Category (Color)"),
            
            
            )
        resolution_traces.append(trace)
        
        # for method
        trace = go.Scatter(
                x=df_subset['Resolution (A)'],
                y=y_values,
                mode='markers',
                name=method,
                marker=marker_dict,
                showlegend=True if category == category_list[0] else False,
                opacity=1,
                legendgroup=2,
                legendgrouptitle=dict(text="Method (Symbol)"),
            
            )
        resolution_traces.append(trace)

## Make Coverage Traces

# make colorscale for "Fraction good"

In [440]:
bvals = [0, 80, 95, 100]
colors = ['#ed5d4c', '#eadc00', '#78c616']
colorscale2 = discrete_colorscale(bvals, colors)
bvals = np.array(bvals)
tickvals2 = [np.mean(bvals[k:k+2]) for k in range(len(bvals)-1)] #position with respect to bvals where ticktext is displayed
ticktext2 = [f'<{bvals[1]}'] + [f'{bvals[k]}-{bvals[k+1]}' for k in range(1, len(bvals)-2)]+[f'>{bvals[-2]}']

In [441]:
coverage_traces = []
for protein in protein_names:
    df_subset = resolution_df[resolution_df['Protein Name'] == protein]
    if len(df_subset) == 0:
        coverage = 0
    else:
        coverage = df_subset['Sequence Coverage'].max()
    for colorrange in colorscale2:
        if coverage <= colorrange[0]:
            text_color = colorrange[1]
            break
    coverage_traces.append(dict(y=protein_name_mapping[protein], 
                                x=0, 
                                text=f"{coverage:.2%}", 
                                font=dict(color=text_color, 
                                          size=12),
                                xref="x2", 
                                yref="y", 
                                align="center",
                                showarrow=False,
                                ax=0,
                                ay=0),
                           )
# my_table = go.Table(header=dict(values=["Average Sequence Coverage"],
#                                fill_color='paleturquoise',
#                                align='left'),
#                     cells=dict(values=[cov],
#                                fill_color='lavender',
#                                align='left'))

## Make AF Confidence trace

In [442]:
confidence_traces = []
for protein in protein_names:
    confidence_dict_name = df[df['Protein Name'] == protein]['Target'][0]
    # af_confidence = af_confidences[confidence_dict_name].mean()
    # get percent of residues with very high confidence
    af_confidence = (af_confidences[confidence_dict_name] > 90).sum() / len(af_confidences[confidence_dict_name])
    for colorrange in colorscale2:
        if af_confidence <= colorrange[0]:
            text_color = colorrange[1]
            break
    confidence_traces.append(dict(y=protein_name_mapping[protein], 
                                x=0, 
                                text=f"{af_confidence:.2%}", 
                                font=dict(color=text_color, 
                                          size=12),
                                xref="x4", 
                                yref="y", 
                                align="center",
                                showarrow=False,
                                ax=0,
                                ay=0),
                           )

## Make heatmap

In [443]:
my_heatmap = go.Heatmap(z=heatmap, 
                        colorscale=colorscale,
                        y = list(protein_name_mapping.values()),
                        x = np.arange(0,1,1/1000),
                        colorbar=dict(title='AlphaFold <br>Confidence (pLDDT)', ticktext=ticktext, tickvals=tickvals, len=0.3, lenmode='fraction', yanchor='top', y=0.5),
                        ygap=5,
                        )

## Make Plot!

In [444]:
#type angstrom character
angstrom = u"\u212B"


In [451]:
# Create subplot figure with 3 columns and 1 row
fig = make_subplots(rows=1, cols=4, shared_yaxes=True, column_widths=[0.4, 0.05, 0.5, 0.05], horizontal_spacing=0.01, 
                    subplot_titles=["Structural Resolution (\u212B) <br> (<2\u212B Structures Highlighted)", "% Experimental Sequence <br> Coverage (Max) ", "AlphaFold Confidence (pLDDT) Per Residue", "% of Sequence with <br> Very High Confidence"],
                    # specs=[[{"type": "table"}, {"type": "scatter"}, {"type": "heatmap"}]]
                    )
# add coverage to figure

for trace in coverage_traces:
    fig.add_annotation(trace)
    
# Add resolution traces to figure
fig.add_trace(go.Scatter(x=[0,0,2,2], y=[-0.5,len(protein_names), len(protein_names), -0.5], fill='toself', fillcolor='#66caf3', line_color='white', showlegend=False, opacity=0.5, ), row=1, col=1, )
for trace in resolution_traces:
    fig.add_trace(trace, row=1, col=1)

# Add heatmap to figure
fig.add_trace(my_heatmap, row=1, col=3)

# Add confidence to figure
for trace in confidence_traces:
    fig.add_annotation(trace)

# Add Fake Colorbar
fig.add_trace(go.Scatter(x=[None], 
                         y=[None], 
                         mode='markers', 
                         marker=dict(colorscale=colorscale2, 
                                     colorbar=dict(title='High Quality <br>Sequence Coverage (%)', 
                                                   ticktext=ticktext2, 
                                                   tickvals=tickvals2, 
                                                   len=0.2, 
                                                   lenmode='fraction', 
                                                   yanchor='top', y=0.8), 
                                    cmin=0, 
                                    cmax=100), 
                         showlegend=False), 
              row=1, col=4)
    
fig.update_layout(template="simple_white",  
                  yaxis=dict(showticklabels=True, showline=True, range=(-0.5, len(protein_names)), showgrid=False, ticklen=0, tickvals=list(protein_name_mapping.values()), ticktext=protein_names),
                  yaxis2=dict(showticklabels=False,  showline=True, ticklen=0),
                    yaxis3=dict(showticklabels=False, showline=True, showgrid=False, ticklen=0),
                    yaxis4=dict(showticklabels=False, showline=False, showgrid=False, ticklen=0),
                  xaxis2=dict(showticklabels=False, range=(-0.1, 0.1), showline=False, showgrid=False, ticklen=0),
                  xaxis=dict(title='Resolution (\u212B)', showgrid=True, range=(0,5), dtick=0.5), # note that this cuts 1 xtal structure off the graph for the serotonin transporter
                  xaxis3=dict(title='Fractional Sequence Position', range=(0,1), showgrid=True, dtick=0.1),
                  xaxis4=dict(showticklabels=False, range=(-0.1, 0.1), showline=False, showgrid=False, ticklen=0),
                  height=1600, 
                  width=1600, 
                  )
fig.show()

In [452]:
fig.write_image(fig_dir / 'final_figure.png')
fig.write_image(fig_dir / 'final_figure.svg')