In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

mechanism_list = ['antibiotic inactivation',
 'antibiotic target alteration',
 'antibiotic efflux',
 'antibiotic target replacement',
 'antibiotic target protection',
 'others']

# Fig.2

In [None]:
result_acc = pd.DataFrame()
result_pre = pd.DataFrame()
result_rec = pd.DataFrame()
result_f1 = pd.DataFrame()

for method in ['Proposed', 'LM-ARG', 'BLAST']:
    result_method = pd.read_csv('Prediction results/Raw_data/LHD_all_'+method+'.csv', index_col=0)
    result_acc[method] = result_method['Accuracy']
    result_pre[method] = result_method['Precision']
    result_rec[method] = result_method['Recall']
    result_f1[method] = result_method['F1 Score']

In [None]:
result_acc['threshold'] = result_method['threshold']
result_acc['fold'] = result_method['fold']
result_pre['threshold'] = result_method['threshold']
result_pre['fold'] = result_method['fold']
result_rec['threshold'] = result_method['threshold']
result_rec['fold'] = result_method['fold']
result_f1['threshold'] = result_method['threshold']
result_f1['fold'] = result_method['fold']

In [None]:
result_for_heatmap_acc = result_acc[result_acc['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]
result_for_heatmap_pre = result_pre[result_pre['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]
result_for_heatmap_rec = result_rec[result_rec['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]
result_for_heatmap_f1 = result_f1[result_f1['fold'] == 'Average'].set_index('threshold').loc[:,['Proposed', 'LM-ARG', 'BLAST']]

In [None]:
category_order = ['F1 Score', 'Recall', 'Precision', 'Accuracy'][::-1]
ylabel_heatmap =  ['threshold','','','']

fig_heat,ax_heat = plt.subplots(1, len(category_order), figsize=(18, 10), sharey=True)
plt.rcParams["font.size"] = 15
cmap = plt.get_cmap("Set2")
sns.set(style="whitegrid")  
sns.heatmap(result_for_heatmap_acc, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[0],annot_kws={"size": 20},cbar=False, cmap='pink_r')
sns.heatmap(result_for_heatmap_pre, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[1],annot_kws={"size": 20},cbar=False, cmap='pink_r')
sns.heatmap(result_for_heatmap_rec, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[2],annot_kws={"size": 20},cbar=False, cmap='pink_r')
sns.heatmap(result_for_heatmap_f1, vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[3],annot_kws={"size": 20}, cbar_kws={ "location":"right"}, cmap='pink_r')

for i in range(len(category_order)):
    indicate = category_order[i]
    
    ax_heat[i].set_xlabel(indicate, fontsize=20)
    ax_heat[i].set_ylabel(ylabel_heatmap[i], fontsize = 20)
    ax_heat[i].tick_params(axis='x', labelrotation=45, labelsize=20)
    ax_heat[i].tick_params(axis='y', labelsize=20)
     
fig_heat.tight_layout()

# Fig.S1

In [None]:
fold_table = pd.read_csv('../Sample_data/fold_table.csv',index_col=0)
mechanism_count = pd.DataFrame(fold_table['mechanism'].value_counts()).rename(columns = {'mechanism':'Resistance mechanism'})

hmdargdb = pd.read_csv('../Sample_data/input.csv',index_col=0) ## Change the path to your data.
length_list = []
for m in mechanism_list:
    length_list.append([hmdargdb['mechanism']==m]['Length'].tolist())

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(15, 15))

sns.barplot(mechanism_count['Resistance mechanism'],mechanism_count.index, palette = 'Set2', ax = ax[0])
for i in range(len(mechanism_count.index)):
    ax[0].text(mechanism_count.iloc[i,0], i, mechanism_count.iloc[i,0],fontsize = 20)


colors = [cmap(i) for i in range(len(set(mechanism_list)))]
ax[1].hist(length_list, histtype='barstacked',  label=mechanism_list, color = colors)
ax[1].legend(title="Resistance mechanism")
plt.xlabel('Sequence length')
plt.ylabel('# of sequences')

# Fig S2-4

In [None]:
category = 'mechanism' # mechanism or threshold
dataset = 'hmdargdb' # LHD_0.4 or hmdargdb

In [None]:
fold_table = pd.read_csv('../Sample_data/fold_table.csv', index_col = 0)

mechanism_count = fold_table['mechanism'].value_counts()

c_list = [0.4,0.6,0.7,0.8,0.9]

mechanism_count_dict = {}
for m in mechanism_list[:-1]:
    mechanism_count_dict[m] = m+'\n('+str(mechanism_count.loc[m])+')'

In [None]:
def make_dataset_for_heatmap(method,dataset,category):
    result_per_mechanism = pd.read_csv('Prediction results/'+category + '/'+dataset+'_' + method +'.csv',index_col=0)
    if category == 'mechanism':
        result_for_heatmap = result_per_mechanism[result_per_mechanism['fold'] == 'Average'].iloc[:-1,1:].set_index(category).reindex(index=mechanism_list)[:-1].rename(index = mechanism_count_dict)
    else:
        result_for_heatmap = result_per_mechanism[result_per_mechanism['fold'] == 'Average'].iloc[:,1:].set_index(category)
    return result_for_heatmap

In [None]:
result = pd.read_csv('Prediction results/Raw_data/'+dataset+'_' + method +'.csv', index_col = 0).reset_index(drop = 'True')
df = pd.DataFrame()

category_order = ['F1 Score', 'Recall', 'Precision', 'Accuracy'][::-1]

if category == 'mechanism':
    ylabel_heatmap = ['Resistance mechanism(# of sequence)','','','','']
else:
    ylabel_heatmap = [category,'','','']

if dataset == 'hmdargdb':
    method_list = ['Proposed', 'LM-ARG', 'HMD-ARG', 'BLAST', 'CARD-RGI']
else:
    method_list = ['Proposed', 'LM-ARG', 'BLAST']

fig, axis = plt.subplots(1, len(method_list), figsize=(15, 10), sharey=True)
fig_heat,ax_heat = plt.subplots(1, len(method_list), figsize=(18, 10), sharey=True)
plt.rcParams["font.size"] = 15
cmap = plt.get_cmap("Set2")
sns.set(style="whitegrid")  

for i in range(len(method_list)):
    methods = method_list[i]
    
    sns.stripplot(x="metrics", y="Value", data=result[result['method'] == methods], jitter=True, alpha=0.7, ax=axis[i], order=category_order, size=10,color='black')
    
    category_means_proposed = result[result['method'] == methods].groupby('metrics')[['Value']].mean().reindex(category_order)
    category_means_proposed.plot(kind='bar', alpha=1, ax=axis[i], color=cmap(i))#color_list[i])
    df[methods] = category_means_proposed['Value']
    
    axis[i].set_xlabel(methods, fontsize=20)
    axis[i].set_ylabel('',fontsize=20)
    axis[i].set_ylim(0.6, 1)
    axis[i].tick_params(axis='x', labelrotation=45, labelsize=20)
    axis[i].tick_params(axis='y', labelsize=20)
    axis[i].legend('')
    
    sns.heatmap(make_dataset_for_heatmap(methods,dataset,category), vmax=1, vmin=0.70, annot=True, fmt='.3f', ax = ax_heat[i],annot_kws={"size": 13}, cbar_kws={ "location":"top"}, cmap='pink_r')
    ax_heat[i].set_xlabel(methods, fontsize=20)
    ax_heat[i].set_ylabel(ylabel_heatmap[i], fontsize = 20)
    ax_heat[i].tick_params(axis='x', labelrotation=45, labelsize=20)
    ax_heat[i].tick_params(axis='y', labelsize=13)
    
df
fig.tight_layout()   
fig_heat.tight_layout()
plt.show()
