In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_validate
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

In [None]:
df_result = pd.read_csv('DB_scanned.csv')

In [None]:
def sort_by_greek_alphabet(x):
    greek_order = ['Alpha', 'Beta', 'Gamma', 'Delta', 'Epsilon', 'Eta', 'Iota', 'Kappa', 'Lambda', 'N/A', 'Omicron', 'Zeta', 'Mu', 'other']
    return sorted(x, key=greek_order.index)

def assign_amino_acid_parameter(df: list):
    # Amino acid property [amino acid volume, amino acid hydrophilicity]. Both of these are relative values.
    amino_acid_propeties = {
        'A':[-2.90, -1.03], 'R':[2.41, 1.31], 'N':[-0.68, 0.79],
        'D':[-0.92, 1.23], 'C':[-1.89, 0.15], 'Q':[0.36, 1.09],
        'E':[0.16, 1.28], 'G':[-4.04, 0.01], 'H':[0.83, 1.15],
        'I':[0.51, -1.32], 'L':[0.52, -1.40], 'K':[0.92, 1.23],
        'M':[0.92, -1.42], 'F':[2.22, -1.47], 'P':[-1.25, -0.64],
        'S':[-2.36, 0.38], 'T':[-1.19, 0.28], 'W':[4.28, -0.18],
        'Y':[2.75, -0.18], 'V':[-0.65, -1.27], '-':[0, 0]
    }
    
    # Assign amino acid properties from amino acid sequences.
    pairwised_aa = [list(item) for item in tqdm(df['pairwised_sequence'].tolist())]
    amino_acid_property = [[amino_acid_propeties[i] for i in sublist] for sublist in tqdm(pairwised_aa)]
    
    pairwised_volume = [[item[0] for item in sublist] for sublist in tqdm(amino_acid_property)] 
    pairwised_hydrophilicity = [[item[1] for item in sublist] for sublist in tqdm(amino_acid_property)]
    amino_acid_volume_hydrophilicity = [sum([volume, hydro], []) for volume, hydro in tqdm(zip(pairwised_volume, pairwised_hydrophilicity))]
    return amino_acid_volume_hydrophilicity

In [None]:
train_data = assign_amino_acid_parameter(df_result)

In [None]:
def cross_validate_random_forest(data_X, data_Y, cv = 5):
    np.set_printoptions(precision = 5)
    rf = RandomForestClassifier(
        n_estimators = 100,
        max_leaf_nodes = 16,
        random_state = 42,
        class_weight='balanced'
        )
    
    skf = StratifiedKFold(n_splits = cv, shuffle = True, random_state = 0)
    scores = cross_validate(rf, data_X, data_Y, cv = skf, scoring = 'accuracy', return_estimator = True)
    return scores

def return_class_flag(labels, name):
    if labels == name:
        return 1
    else:
        return 0

def extract_important_amino_acid_by_random_forest(train_data, labels_list, species_list):
    train_label = [return_class_flag(i, species_list) for i in labels_list]
    output = cross_validate_random_forest(train_data, train_label)
    feature_amino_acid_list = [i.feature_importances_ for i in output['estimator']]
    return feature_amino_acid_list

In [None]:
def output_important_features_box_plot(feature_importance_list, species_name, save_folder: str):
    
    os.makedirs(os.path.join('Results', f'{save_folder}'), exist_ok = True)
    
    amino_acid_name_list = []
    original_seq = \
    'MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIF'\
    'GTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINL'\
    'VRDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNF'\
    'RVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYK'\
    'LPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKS'\
    'TNLVKNKCVNFNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQDVNCTEVPVAIHADQLTPTWRVY'\
    'STGSNVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYI'\
    'CGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARD'\
    'LICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQ'\
    'NAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRASANLAATKMSECVLGQSKRVDFCGKGYHLMSFPQSAPHGV'\
    'VFLHVTYVPAQEKNFTTAPAICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNTFVSGNCDVVIGIVNNTVYDPLQPELDSFKEELDKYFKNHTSPDVDL'\
    'GDISGINASVVNIQKEIDRLNEVAKNLNESLIDLQELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIMLCCMTSCCSCLKGCCSCGSCCKFDEDDSEPVLKGVKLHYT'
    
    for param in ['V','H']:
        for i, name in enumerate(original_seq):
            amino_acid_name_list.append(name + str(i + 1) + '_' + param)
    
    df_res = pd.DataFrame(columns = amino_acid_name_list)
    
    # create the table of importance calculated by RF for each residue
    for importance in feature_importance_list:
        s = pd.Series(importance, index = amino_acid_name_list)
        df_res.append(s, ignore_index = True)
        
    s = pd.DataFrame(feature_importance_list, columns = amino_acid_name_list)
    s = s.append(s.mean(), ignore_index = True)
    s = s.rename(index = {len(feature_importance_list): 'mean'})
    
    # calculate the average of importance in 5 times RF for each residue
    s = s.sort_values(by = 'mean', axis = 1)
    name_list=[]
    
    # extract top20
    for name in s.columns[-20:]:
        name_list.append(name)
        
    fig = plt.figure(figsize = (12, 8))
    boxplot = s.iloc[:5].boxplot(column = name_list, rot = 90, fontsize = 14)
    boxplot.plot()
    plt.xticks(fontsize = 18)
    plt.yticks(fontsize = 18)
    plt.ylabel('Importance', fontsize = 20, fontweight = 'bold')
    plt.title(species_name, fontsize = 20)
    plt.ylim(bottom = 0)
    fig.savefig(os.path.join('Results', f'{save_folder}', f'RF_feature_importance_{species_name}.png'), dpi = 300, bbox_inches = 'tight')
    plt.show()

In [None]:
# Fig. 6

major_dict = {
    DB_label: f'{index}, The major species is {species_name}' for DB_label, index, species_name 
    in zip(df_result['aa_properties_dbscan'], df_result['cluster_no'], df_result['major_species_in_cluster'])
    }

for DB_major_species in major_dict:
    tmp = extract_important_amino_acid_by_random_forest(train_data, df_result['aa_properties_dbscan'], DB_major_species)
    output_important_features_box_plot(tmp, str(major_dict[DB_major_species]), 'Fig_4')