In [229]:
import glob
import json
import ast
import matplotlib
import pandas as pd
from Bio import SeqIO

In [271]:
def get_adaptation_results(virus, subtype, gene):
    """
    Read in the codon-specific rates of adaptation as inferred by bhatt_nextstrain-epitopes
    """
    if subtype:
        path_to_results = f'../3_3/{virus}_{subtype}_{gene}*.json'
    
    else:
        path_to_results = f'../3_3/{virus}_{gene}*.json'
    
    #keep track of rate of adaptation at each residue
    codons = []
    adaptation_rates = []
    
    for f in glob.glob(path_to_results):
        if subtype:
            nt_pos_start = ast.literal_eval(f.split('_')[4])[0]
        else:
            nt_pos_start = ast.literal_eval(f.split('_')[3])[0]
        residue = int(nt_pos_start/3)
        if nt_pos_start%3!=0:
            print(nt_pos_start)
            print('codon counting error')
        with open(f) as json_handle:
            json_dict = json.load(json_handle)
            adaptation_rate = float(json_dict['rate_of_adaptation'])*10**3
            #make negative adaptation rates be 0 for now
            if adaptation_rate < 0.0:
                adaptation_rate = 0.0
        codons.append(residue)
        adaptation_rates.append(adaptation_rate)

        
    return codons, adaptation_rates

In [265]:
def make_color_scale(virus, subtype, gene, determine_max='from_oc43'):
    """
    Make a heatmap color code for the rates of adaptation
    """
    
    codons, adaptation_rates = get_adaptation_results(virus, subtype, gene)
    if determine_max == 'from_oc43':
        #max rate seen in any gene so far was in OC43 spike. Use this as max so all viruses have same scale
        max_rate = 111.27944203064257
    elif determine_max == 'from_self':
        max_rate = max(adaptation_rates)
        
            
    #normalize all adaptation rates so they're between 0 and 1 for the color-coding
    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_rate)
    normalized_rates = [norm(x) for x in adaptation_rates]
    
    #make cmap
    cmap = matplotlib.cm.get_cmap('Reds')
    rgba_colors = [cmap(x) for x in normalized_rates] 
    hex_colors = [matplotlib.colors.to_hex(x)[1:] for x in rgba_colors]
    
    codon_adaptation_rate_colors = zip(codons, adaptation_rates, hex_colors)
    
    #get heatmap scale bar info
    heatmap_numbers = list(range(0,round(max_rate),5))
    heatmap_numbers_normalized = [norm(x) for x in heatmap_numbers]
    heatmap_colors = [cmap(x) for x in heatmap_numbers_normalized] 
    heatmap_hex_colors = [f'0x{matplotlib.colors.to_hex(x)[1:]}' for x in heatmap_colors] 


    return codon_adaptation_rate_colors, heatmap_numbers, heatmap_hex_colors

In [243]:
def make_coordinates_consistent(virus, subtype, gene, codon_adaptation_rate_colors):
    #make coordinates consistent between 
    #will need to do this somewhat manually for each virus based on the pdb file 
    reference_file_paths =  {'h3n2':{None:{'ha':f'../../../seasonal-flu_60y/config/reference_h3n2_{gene}.gb'}}, 
                             'h1n1pdm':{None:{'ha':f'../../../seasonal-flu_60y/config/reference_h1n1pdm_{gene}.gb'}}, 
                             'oc43':{'a':{'spike':f'../../../seasonal-cov/oc43/separate_lineages/config/oc43_{gene}_reference.gb'}}, 
                             '229e':{None:{'spike':f'../../../seasonal-cov/229e/config/229e_{gene}_reference.gb'}}}
    
    reference_file = reference_file_paths[virus][subtype][gene]
    
    #some pdb files have subunits listed as separate chains, some have the listed as one
    chainA_subunit = {'h3n2':'ha1', 'h1n1pdm':'ha1','oc43':'spike', '229e':'spike'}
    chainB_subunit = {'h3n2':'ha2', 'h1n1pdm':'ha2'}
    
    for seq_record in SeqIO.parse(reference_file, "genbank"):
        for feature in seq_record.features:
            if feature.type == 'CDS':
                if 'gene' in feature.qualifiers.keys():
                    if feature.qualifiers['gene'][0].lower() == chainA_subunit[virus]:
                        #if sigpep comes before chain A, will need to renumber so chainA residues 
                        #start with codon 1, rather than 15 (or whatever it would be after sigpep)
                        startA = int(feature.location.start/3)
                        chainA_len = len(feature.location.extract(seq_record.seq).translate())
                        #one-based numbering
                        endA = int(startA+chainA_len-1)
                    if virus in chainB_subunit.keys():
                        if feature.qualifiers['gene'][0].lower() == chainB_subunit[virus]:
                            chainB_len = len(feature.location.extract(seq_record.seq).translate())
                            startB = int(endA+1)
                            if startB != feature.location.start/3:
                                print('check coordinates')
                            endB = int(startB+chainB_len-1)
    
    adjusted_codon_adaptation_rate_colors = []
    
    for x in codon_adaptation_rate_colors:
        codon = int(x[0])
        if codon in range(startA, endA+1):
            chain = 'A'
            #add 1 for 1-based instead of 0. Noticed all coordinates were off by one
            adjusted_codon = codon-startA+1
        elif codon in range(startB, endB+1):
            chain = 'B'
            adjusted_codon = codon-startB+1
        adjusted_codon_adaptation_rate_colors.append((chain, adjusted_codon, x[0], x[1], x[2]))
    return adjusted_codon_adaptation_rate_colors
        

In [225]:
def get_timespan(virus, subtype, gene):
    """
    Find the time span covered by analysis
    """
    if subtype:
        path_to_results = f'../3_3/{virus}_{subtype}_{gene}*.json'
    
    else:
        path_to_results = f'../3_3/{virus}_{gene}*.json'
        
    for f in glob.glob(path_to_results):
        with open(f) as json_handle:
            json_dict = json.load(json_handle)
            #time windows are 3 years, so start date is 1.5 years before first window_midpoint
            #and end date is 1.5 years after last 
            start_date = json_dict['window_midpoint'][0] -1.5
            end_date = json_dict['window_midpoint'][-1] +1.5
            break
    timespan = end_date - start_date
    
    return timespan

In [239]:
#based on the pdb files, chains that are identical to chains A or B, in the trimer
trimerization_chains = {'h3n2':{'A':['C', 'E'], 'B':['D', 'F']}, 
                        'h1n1pdm':{'A':['C', 'E'], 'B':['D', 'F']}, 
                        'oc43':{'A':['B','C']}, 
                        '229e':{'A':['B','C']}}

In [240]:
#chains to show surface on (other will be cartoon)
surface_chains = {'h3n2': 'A+B+C+D', 'h1n1pdm': 'A+B+C+D', 
                  'oc43': 'A+B', '229e': 'A+B'}

In [273]:
def write_pml_file(pml_filename, pdb_accession, virus, subtype, gene, determine_max, trimerize=True):
    """
    Write .pml file to color every residue in the given pdb structure according to the inferred rate of adaptation
    """
    
    codon_adaptation_rate_colors, heatmap_numbers, heatmap_hex_colors = make_color_scale(virus, subtype, gene, determine_max)
    
    
    adjusted_codon_adaptation_rate_colors = make_coordinates_consistent(virus, subtype, gene, codon_adaptation_rate_colors)
    
    #"hide sticks",
    fetch_pdb = f"fetch {pdb_accession}"
    text_lines = [fetch_pdb, "bg_color white", "color 0xD3D3D3", "show surface", "hide sticks",
                  "remove solvent", "set seq_view, 1",  
                  f"ramp_new rate,  {pdb_accession}, {heatmap_numbers},  color={heatmap_hex_colors}"]
    
    timespan = get_timespan(virus, subtype, gene)
    #write summary file of all adaptive residues
    summary_of_adaptive_residues = []
    
    for residue in adjusted_codon_adaptation_rate_colors:
        text_lines.append(f"select chain {residue[0]} and resi {residue[1]}")
        text_lines.append(f"color 0x{residue[4]}, sele")
        #also color the corresponding residues on other chains of the trimer
        if trimerize ==True:
            other_chains = trimerization_chains[virus][residue[0]]
            for oc in other_chains:
                text_lines.append(f"select chain {oc} and resi {residue[1]}")
                text_lines.append(f"color 0x{residue[4]}, sele")
                
        if residue[3] > 0.0:
            adaptive_subs_per_year = residue[3]
            adaptive_subs_at_residue = residue[3]*10**-3*timespan
            summary_of_adaptive_residues.append({'virus': virus, 'gene': residue[0], 
                                                 'codon': residue[1], 
                                                 'adaptive_subs_per_year': adaptive_subs_per_year, 
                                                 'adaptive_subs_at_residue':adaptive_subs_at_residue})
            
    df = pd.DataFrame(summary_of_adaptive_residues)
    df = df.sort_values(by=['gene', 'codon'])
    if subtype:
        df.to_csv(f'adaptive_codons/{virus}_{subtype}_{gene}_adaptive_codons.csv', index=False)
    else:
        df.to_csv(f'adaptive_codons/{virus}_{gene}_adaptive_codons.csv', index=False)
    
                
        
    with open(pml_filename, 'w') as f:
        for line in text_lines:
            f.write(line)
            f.write('\n')
        
    

In [274]:
write_pml_file('h3n2_adaptation_colormap_oc43max.pml', '4fnk', 'h3n2', None, 'ha', 'from_oc43')

In [275]:
write_pml_file('h3n2_adaptation_colormap_selfmax.pml', '4fnk', 'h3n2', None, 'ha', 'from_self')

In [276]:
write_pml_file('oc43A_adaptation_colormap_selfmax.pml', '6ohw', 'oc43', 'a', 'spike', 'from_self')

In [277]:
write_pml_file('229e_adaptation_colormap_selfmax.pml', '6u7h', '229e', None, 'spike', 'from_self')

In [278]:
write_pml_file('229e_adaptation_colormap_oc43max.pml', '6u7h', '229e', None, 'spike', 'from_oc43')

In [272]:
write_pml_file('h1n1_adaptation_colormap_oc43max.pml', '4m4y', 'h1n1pdm', None, 'ha', 'from_oc43')

1637
codon counting error
317
codon counting error
1364
codon counting error
311
codon counting error
200
codon counting error
164
codon counting error
1106
codon counting error
1691
codon counting error
419
codon counting error
533
codon counting error
716
codon counting error
404
codon counting error
1487
codon counting error
989
codon counting error
671
codon counting error
1373
codon counting error
1514
codon counting error
536
codon counting error
359
codon counting error
713
codon counting error
1046
codon counting error
392
codon counting error
1184
codon counting error
794
codon counting error
344
codon counting error
131
codon counting error
1667
codon counting error
1430
codon counting error
1097
codon counting error
1193
codon counting error
1565
codon counting error
629
codon counting error
1148
codon counting error
1454
codon counting error
1229
codon counting error
1376
codon counting error
299
codon counting error
1025
codon counting error
662
codon counting error
1178
c

In [None]:
run /Users/katekistler/nextstrain/adaptive-evolution/adaptive_loci_results/pymol_adaptation_on_structure/h3n2_adaptation_colormap_selfmax.pml

run /Users/katekistler/nextstrain/adaptive-evolution/adaptive_loci_results/pymol_adaptation_on_structure/229e_adaptation_colormap_selfmax.pml
