# Calculate amino acid conservation score.

In [None]:
from Bio import AlignIO
from Bio import SeqIO
from Bio.Align.AlignInfo import PSSM
from Bio.Align.AlignInfo import SummaryInfo
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
target_dict = {
    'blaOXA-114s': 'U3N8W9',#antibiotic inactivation ABW87257
 'rpoB': 'NP_273190.1',# antibiotic target alteration AF:B4RQW2
 'macB': 'A0A011P660', # antibiotic efflux
 'tetW': 'ABN80187'# antibiotic target protection
}
gene_name = 'tetW'
gene_id = target_dict[gene_name]

if not os.path.exists(gene_name):
    os.makedirs(path + gene_name)

In [None]:
fasta_file = gene_name + "/clustalw.fasta"

align = AlignIO.read(fasta_file, "fasta")
summary_align = SummaryInfo(align)

for record in align:
    if record.id == gene_name:
        sequence = record.seq
        break

freq = {}
for aa in set(list(sequence)):
    if aa != '-':
        freq[aa] = 1/20

    

In [None]:
pssm = summary_align.pos_specific_score_matrix(axis_seq = sequence, chars_to_ignore = ['-'])
df_pssm = pd.DataFrame(index=pssm[0].keys(),columns=list(range(len(summary_align.dumb_consensus()))))
n = 0
for p in pssm:
    df_pssm[n] = p.values()
    n += 1

from scipy.stats import entropy
qk = [1/20] * 20
info = []
for n in df_pssm.columns:
    info.append(entropy(df_pssm[n], qk=qk, base = 2))

amino_index = []
for i,a in enumerate(list(sequence)):
    if a != '-':
        amino_index.append(i)
info_content = []
for a_i in amino_index:
    info_content.append(info[a_i])

In [None]:
path = '/'
fold_table = pd.read_csv('../Sample_data/fold_table.csv', index_col=0)
fold = fold_table[fold_table['ID']==gene_id]['fold'].tolist()[0]
attention = pd.read_csv('attention/fold_'+str(fold)+'_attention.csv', index_col = 0)[gene_id].dropna().iloc[1:-1]

In [None]:
conserv_score = pd.DataFrame({'Attention': attention,'Conservation score':info_content})
threshold = conserv_score['Attention'].quantile(q=[0.33,0.66]).tolist()
threshold

In [None]:
def annotation(c):
    if c <= threshold[0]:
        return 'Low'
    elif threshold[0] <= c < threshold[1]:
        return 'Medium'
    elif threshold[1] <= c:
        return 'High'
conserv_score['Groups'] = list(map(annotation,conserv_score['Attention']))

In [None]:
conserv_score['Groups'].unique()

In [None]:
plt.figure(figsize=(10, 7.5))


categories = ['Low', 'Medium', 'High',]
palette = sns.color_palette('Set2', n_colors=len(categories))
line_label = ['33 percentile', '66 percentile']
line_style = ['dashed','dashdot']

sns.scatterplot(x='Attention', y='Conservation score', hue='Groups', data=conserv_score, palette = palette)
for i,t in enumerate(threshold):
    plt.axvline(x = t, color='red', label = line_label[i], linestyle = line_style[i])
plt.xlabel('Attention', fontsize = 30)
plt.ylabel('Conservation score', fontsize = 30)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.legend(fontsize = 20, markerscale = 3)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# A few helper functions:
from statannotations.Annotator import Annotator
from statannotations.stats.utils import check_alpha


import numpy as np
from scipy.stats import mannwhitneyu

new_order = ['Low','High','Medium']
new_palette = [palette[categories.index(category)] for category in new_order]

# Putting the parameters in a dictionary avoids code duplication
# since we use the same for `sns.boxplot` and `Annotator` calls
plotting_parameters = {
    'data':   conserv_score,
    'x':       'Groups',
    'y':       'Conservation score',
    'order':   categories,
    #'color': color
    'palette': new_palette,
}

pairs = [('Low', 'Medium'),
         ('Medium', 'High'),
         ('Low', 'High')]


plt.rcParams["font.size"] = 20
fig = plt.figure(figsize=(10, 7.5))
ax = fig.add_subplot(1, 1, 1)

with sns.plotting_context('notebook', font_scale=1.4):

    # Plot with seaborn
    sns.violinplot(ax = ax, **plotting_parameters)

    # Add annotations
    annotator = Annotator(ax, pairs, **plotting_parameters)
    annotator.configure(test='Mann-Whitney', comparisons_correction="bonferroni")
    _, corrected_results = annotator.apply_and_annotate()
    
ax.set_xlabel("")