In [1]:
from utils import LOCODataLoader
import numpy as np

In [2]:
loader = LOCODataLoader()
loader.load_files()

In [3]:
np.unique(loader.plate_id)

array([0., 1., 2., 3.])

In [16]:
class LOCOPlotter:
    def __init__(
        self,
        loader
    ):
        self.cmpd_names = loader.cmpd_names
        self.compound_id = loader.compound_id
        self.labels = loader.labels
        self.plate_id = loader.plate_id
        self.feat_vecs = loader.feat_vecs
        self.label_to_name = loader.label_to_name
        self.moa_dict_w_dose = loader.moa_dict_w_dose
        self.classes = loader.classes
        self.moa_to_num = loader.moa_to_num
        
    def index(
        self,
        input_maps,
        input_choices
    ):
        "Returns boolean list to index array"
        idx_list_ = []
        for maps, choices in zip(input_maps, input_choices):
            idx_list_.append(np.logical_or.reduce([np.array(maps) == c for c in choices]))

        return np.logical_and.reduce(idx_list_)
        
    def knn_clf(
        self,
        n_neighbors=150,
        metric='cosine'
    ):
        from sklearn.neighbors import KNeighborsClassifier
        from sklearn.preprocessing import StandardScaler
        from collections import Counter
        from sklearn import metrics

        cmpd_names_ = [c for c in self.cmpd_names[:-2] if c not in ['Colistin', 'PolymyxinB']]
        acc_dict = dict()
        for p_id in [0,1,2,3]:
            for conc in ['0.125xIC50', '0.25xIC50', '0.5xIC50', '1xIC50']:
                for cmpd_id, cmpd_name in zip([self.cmpd_names.index(c) for c in cmpd_names_], cmpd_names_):
                    idx_list_drop = self.index([self.compound_id, self.labels, self.plate_id], [[self.cmpd_names.index(cmpd_name)], [self.classes.index(f'{cmpd_name}_{conc}')], [p_id]])
                    idx_list_train = self.index([self.compound_id, self.labels, self.plate_id], [[cmpd_id], [self.classes.index(f'{c}_{conc}') for c in self.cmpd_names if c not in [cmpd_name]], [p_id]])
        
                    feat_vecs_drop = self.feat_vecs[idx_list_drop]
                    
                    feat_vecs_train = self.feat_vecs[idx_list_train]
        
                    scaler = StandardScaler(with_std=False)
                    feat_vecs_train = scaler.fit_transform(feat_vecs_train)
                    feat_vecs_drop = scaler.transform(feat_vecs_drop)
                    
                    labels_drop = [self.label_to_name[l] for l in self.labels[idx_list_drop]]
                    moa_labels_drop = [self.moa_dict_w_dose[l][0] for l in labels_drop]
                    
                    labels_train = [self.label_to_name[l] for l in self.labels[idx_list_train]]
                    moa_labels_train = [self.moa_dict_w_dose[l][0] for l in labels_train]
        
                    # clf = KNeighborsClassifier(n_neighbors=150, metric='cosine', weights='distance').fit(feat_vecs_train, moa_labels_train)
                    clf = KNeighborsClassifier(n_neighbors=150, metric='cosine', weights='distance').fit(feat_vecs_train, moa_labels_train)
        
                    moa_labels_drop_hat = clf.predict(feat_vecs_drop)
                    
                    cntr_labels = Counter(moa_labels_drop)
                    mc_labels = cntr_labels.most_common(1)[0][0]
                    
                    cntr_preds = Counter(moa_labels_drop_hat)
                    mc_preds = cntr_preds.most_common(1)[0][0]            
                    
                    moa_hat_as_num = [self.moa_to_num[m] for m in moa_labels_drop_hat]
                    moa_gt_as_num = [self.moa_to_num[m] for m in moa_labels_drop]
        
                    qk = [self.moa_to_num[mc_preds]] * len(moa_labels_drop_hat)
                                
                    acc_dict[p_id, conc, cmpd_name] = metrics.accuracy_score([mc_labels], [mc_preds])
        return acc_dict

    def plot_knn_acuracy(
        self
    ):
        acc_dict = self.knn_clf()
        
        pbp1 = ['Cefsulodin', 'PenicillinG', 'Sulbactam']
        pbp2 = ['Avibactam', 'Mecillinam', 'Meropenem', 'Relebactam', 'Clavulanate']
        pbp3 = ['Aztreonam', 'Ceftriaxone', 'Cefepime']
        gyr = ['Ciprofloxacin', 'Levofloxacin', 'Norfloxacin']
        rib = ['Doxycycline', 'Kanamycin', 'Chloramphenicol', 'Clarithromycin']
        mem = ['Colistin', 'PolymyxinB']
        
        data_knn_150 = []
        for conc in ['0.125xIC50', '0.25xIC50', '0.5xIC50', '1xIC50']:
            data_ = []
            for cmpds in [pbp1, pbp2, pbp3, gyr, rib]:
                data_.append([np.mean([acc_dict[p_id, conc, cmpd][0] for cmpd in cmpds]) for p_id in [0,1,2,3]])
            
            data_knn_150.append(data_)
        data_knn_150 = np.array(data_knn_150)

        # This is strong out-of-training data prediction 
        # A weaker form of this would be to re-train the last layer to do linear combination from feature vectors alone using all but the dropped compound for prediction 
        plt.figure(figsize=(7,5))
        for m, i in zip(['Cell wall (PBP 1)', 'Cell wall (PBP 2)', 'Cell wall (PBP 3)', 'Gyrase', 'Ribosome'], range(5)):
            mean_vals = np.array([np.mean(data_knn_150[x,i,:]) for x in range(4)])
            std_vals = np.array([np.std(data_knn_150[x,i,:]) for x in range(4)])
            plt.plot(mean_vals, 'o--', linewidth=2, label=m)
            plt.fill_between([i for i in range(4)], mean_vals-std_vals, mean_vals+std_vals, alpha=0.1, edgecolor='grey')
        
        plt.xticks([0,1,2,3], ['0.125', '0.25', '0.5', '1'], fontsize=16)
        plt.yticks(fontsize=16)
        plt.xlabel('Antibiotic concentration (x IC50)', fontsize=16)
        plt.ylabel('LOCO accuracy', fontsize=16)
        plt.title('k-NN LOCO accuracy', fontsize=18)
        plt.legend(frameon=True, fontsize=12)
        # plt.savefig('kNN_150_loco_accuracy_by_dose.svg')

In [17]:
plotter = LOCOPlotter(loader)

In [18]:
plotter.knn_clf()

{(0, '0.125xIC50', 'Cefsulodin'): 0.0,
 (0, '0.125xIC50', 'PenicillinG'): 0.0,
 (0, '0.125xIC50', 'Avibactam'): 1.0,
 (0, '0.125xIC50', 'Mecillinam'): 1.0,
 (0, '0.125xIC50', 'Meropenem'): 1.0,
 (0, '0.125xIC50', 'Aztreonam'): 0.0,
 (0, '0.125xIC50', 'Ceftriaxone'): 0.0,
 (0, '0.125xIC50', 'Cefepime'): 0.0,
 (0, '0.125xIC50', 'Clavulanate'): 1.0,
 (0, '0.125xIC50', 'Relebactam'): 1.0,
 (0, '0.125xIC50', 'Sulbactam'): 0.0,
 (0, '0.125xIC50', 'Ciprofloxacin'): 0.0,
 (0, '0.125xIC50', 'Levofloxacin'): 0.0,
 (0, '0.125xIC50', 'Norfloxacin'): 0.0,
 (0, '0.125xIC50', 'Doxycycline'): 0.0,
 (0, '0.125xIC50', 'Kanamycin'): 0.0,
 (0, '0.125xIC50', 'Chloramphenicol'): 1.0,
 (0, '0.125xIC50', 'Clarithromycin'): 1.0,
 (0, '0.25xIC50', 'Cefsulodin'): 1.0,
 (0, '0.25xIC50', 'PenicillinG'): 0.0,
 (0, '0.25xIC50', 'Avibactam'): 1.0,
 (0, '0.25xIC50', 'Mecillinam'): 1.0,
 (0, '0.25xIC50', 'Meropenem'): 1.0,
 (0, '0.25xIC50', 'Aztreonam'): 0.0,
 (0, '0.25xIC50', 'Ceftriaxone'): 0.0,
 (0, '0.25xIC50', 'Ce