In [1]:
from sklearn.model_selection import KFold, cross_val_predict
from itertools import combinations
import pandas as pd
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.preprocessing import OneHotEncoder
import time
import pickle
import warnings
warnings.filterwarnings('ignore')



In [2]:
#load data
with open('../Features/Fingerprints/fingerprints_df_grouped.pickle', 'rb') as f:
    fingerprints = pickle.load(f)

with open('../Data/processed/clean_df_grouped.pkl', 'rb') as f:
    df = pickle.load(f)
    
df.reset_index(drop=True, inplace=True)

In [3]:
fingerprints.keys()

dict_keys(['map4', 'erg', 'atompair-count', 'ecfp', 'layered', 'topological', 'rdkit', 'binary profile of physicochemical property', 'one-hot-encoded-sequence', 'peptide_descriptors', 'ankh_base_embedding', 'mean_embeddings_large', 'maccs', 'avalon', 'fcfp', 'atompair', 'pattern', 'secfp', 'estate', 'avalon-count', 'rdkit-count', 'ecfp-count', 'fcfp-count', 'topological-count'])

In [4]:
fingerprints_filtered = {
    'erg': fingerprints['erg'],
    'atompair-count': fingerprints['atompair-count'],
    'ankh_base_embedding': fingerprints['ankh_base_embedding'],
}



In [13]:
variant_r2_scores = {}
variant_mean_absolute_errors = {}
subsets_r2_scores = {}
subsets_mean_absolute_errors = {}

trained_subsets = set()
variants = ['GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV']

for variant in variants:
    variant_r2_scores[variant] = {}
    variant_mean_absolute_errors[variant] = {}

In [18]:
enc = OneHotEncoder()
one_hot_features = enc.fit_transform(df[['Assay', 'REGION_NOTES']]).toarray()

def train_and_predict_model(X, y, model_type):
    start_time_model = time.time()
    if model_type == 'ankh_base_embedding':
        print("Training Ridge model for ankh_base_embedding")
        model = Ridge()
    else:
        print(f"Training HistGradientBoostingRegressor model for {model_type}")
        model = HistGradientBoostingRegressor()
    predictions = cross_val_predict(model, X, y, cv=kf, n_jobs=-1)
    end_time_model = time.time()
    time_taken_model = end_time_model - start_time_model
    print(f"Finished training model for {model_type}. Time taken: {time_taken_model} seconds")
    return predictions


kf = KFold(n_splits=5, shuffle=True, random_state=123)

for r in range(1, len(variants) + 1):
    for subset in combinations(variants, r):
        if subset in trained_subsets:
            print(f"Skipping subset: {subset} as it has already been trained")
            continue
        else:
            print(f"Processing subset: {subset}")
        start_time_subset = time.time()

        subset_df = df[df['Variants_of'].isin(subset)]
        
        ensemble_predictions = []

        for fingerprint_name, fingerprint in fingerprints_filtered.items():
            X = pd.DataFrame(fingerprint).loc[subset_df.index]
            X = pd.concat([X, pd.DataFrame(one_hot_features).loc[subset_df.index]], axis=1)
            y = subset_df['pIC50']

            ensemble_predictions.append(train_and_predict_model(X, y, fingerprint_name))

        ensemble_predictions = np.mean(ensemble_predictions, axis=0)

        for variant in subset:
            variant_df = subset_df[subset_df['Variants_of'] == variant]
            y_true = variant_df['pIC50']
            y_pred = pd.DataFrame(ensemble_predictions, index=subset_df.index).loc[variant_df.index]
            variant_r2_scores[variant][subset] = r2_score(y_true, y_pred)


            variant_mean_absolute_errors[variant][subset] = mean_absolute_error(y_true, y_pred)


        subsets_r2_scores[subset] = r2_score(y, ensemble_predictions)

        subsets_mean_absolute_errors[subset] = mean_absolute_error(y, ensemble_predictions)

        end_time_subset = time.time()
        time_taken_subset = end_time_subset - start_time_subset
        trained_subsets.add(subset)
        print(f"Finished processing subset: {subset}. Time taken: {time_taken_subset} seconds")



Skipping subset: ('GpTx-1',) as it has already been trained
Skipping subset: ('Protoxin II',) as it has already been trained
Skipping subset: ('JzTx-V',) as it has already been trained
Processing subset: ('Huwentoxin-IV',)
Training HistGradientBoostingRegressor model for erg
Finished training model for erg. Time taken: 3.074171543121338 seconds
Training HistGradientBoostingRegressor model for atompair-count
Finished training model for atompair-count. Time taken: 9.219215631484985 seconds
Training Ridge model for ankh_base_embedding
Finished training model for ankh_base_embedding. Time taken: 1.2085340023040771 seconds
Finished processing subset: ('Huwentoxin-IV',). Time taken: 26.29340672492981 seconds
Processing subset: ('GpTx-1', 'Protoxin II')
Training HistGradientBoostingRegressor model for erg
Finished training model for erg. Time taken: 2.6185436248779297 seconds
Training HistGradientBoostingRegressor model for atompair-count
Finished training model for atompair-count. Time taken

In [21]:
for variant in variant_r2_scores.keys():
    results = variant_r2_scores[variant]
    #sort results by values
    results = {k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)}
    for result in results:
        print(variant, result)
        print(round(results[result],2))
    print("______________________")

GpTx-1 ('GpTx-1', 'JzTx-V', 'Huwentoxin-IV')
0.83
GpTx-1 ('GpTx-1', 'JzTx-V')
0.83
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.82
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.82
GpTx-1 ('GpTx-1',)
0.82
GpTx-1 ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.82
GpTx-1 ('GpTx-1', 'Protoxin II')
0.81
GpTx-1 ('GpTx-1', 'Huwentoxin-IV')
0.8
______________________
Protoxin II ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.75
Protoxin II ('Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.74
Protoxin II ('Protoxin II',)
0.73
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.72
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.72
Protoxin II ('GpTx-1', 'Protoxin II')
0.72
Protoxin II ('Protoxin II', 'JzTx-V')
0.71
Protoxin II ('Protoxin II', 'Huwentoxin-IV')
0.71
______________________
JzTx-V ('GpTx-1', 'JzTx-V', 'Huwentoxin-IV')
0.72
JzTx-V ('GpTx-1', 'JzTx-V')
0.7
JzTx-V ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.69
JzTx-V ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')

In [12]:
subsets_r2_scores

{('GpTx-1',): 0.8180553562848694,
 ('Protoxin II',): 0.7254140899963385,
 ('JzTx-V',): 0.5627314085386583,
 ('Huwentoxin-IV',): 0.2330262313852589,
 ('GpTx-1', 'Protoxin II'): 0.8449050548910697,
 ('GpTx-1', 'JzTx-V'): 0.8097635405164636,
 ('GpTx-1', 'Huwentoxin-IV'): 0.7494324106433555,
 ('Protoxin II', 'JzTx-V'): 0.695452046388268,
 ('Protoxin II', 'Huwentoxin-IV'): 0.6559082868407982,
 ('JzTx-V', 'Huwentoxin-IV'): 0.5050765766043527,
 ('GpTx-1', 'Protoxin II', 'JzTx-V'): 0.831052080727239,
 ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV'): 0.8067006365581478,
 ('GpTx-1', 'JzTx-V', 'Huwentoxin-IV'): 0.7676833231212137,
 ('Protoxin II', 'JzTx-V', 'Huwentoxin-IV'): 0.6687983611727613,
 ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV'): 0.7914781925163633}

In [14]:
for variant in variant_mean_absolute_errors.keys():
    results = variant_mean_absolute_errors[variant]
    for result in results:
        print(variant, result)
        print(results[result])
    print("______________________")


GpTx-1 ('GpTx-1',)
0.29838764564518855
GpTx-1 ('GpTx-1', 'Protoxin II')
0.29998697027212073
GpTx-1 ('GpTx-1', 'JzTx-V')
0.29897996495297985
GpTx-1 ('GpTx-1', 'Huwentoxin-IV')
0.3031337587187484
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.3015420878800774
GpTx-1 ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.3006505755306909
GpTx-1 ('GpTx-1', 'JzTx-V', 'Huwentoxin-IV')
0.29144927086972955
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.29903079671474203
______________________
Protoxin II ('Protoxin II',)
0.25124995795121885
Protoxin II ('GpTx-1', 'Protoxin II')
0.24946041681009
Protoxin II ('Protoxin II', 'JzTx-V')
0.24442884928508374
Protoxin II ('Protoxin II', 'Huwentoxin-IV')
0.24790006802672418
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.2436259201771511
Protoxin II ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.2377635480658832
Protoxin II ('Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.2390817240110308
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')


In [15]:
for variant in variant_r2_scores.keys():
    results = variant_r2_scores[variant]
    for result in results:
        print(variant, result)
        print(results[result])
    print("______________________")

GpTx-1 ('GpTx-1',)
0.8230169101684812
GpTx-1 ('GpTx-1', 'Protoxin II')
0.8278864133766886
GpTx-1 ('GpTx-1', 'JzTx-V')
0.8245664664215095
GpTx-1 ('GpTx-1', 'Huwentoxin-IV')
0.8168712525665042
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.8226161884215969
GpTx-1 ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.8243998371631742
GpTx-1 ('GpTx-1', 'JzTx-V', 'Huwentoxin-IV')
0.8344522568463153
GpTx-1 ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.8286413285173693
______________________
Protoxin II ('Protoxin II',)
0.7217835564385777
Protoxin II ('GpTx-1', 'Protoxin II')
0.7161981880583146
Protoxin II ('Protoxin II', 'JzTx-V')
0.7236927978341127
Protoxin II ('Protoxin II', 'Huwentoxin-IV')
0.7271157687089864
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V')
0.7203591573873451
Protoxin II ('GpTx-1', 'Protoxin II', 'Huwentoxin-IV')
0.7338149913388627
Protoxin II ('Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.7439365829974885
Protoxin II ('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV')
0.7159

In [22]:
for subset in subsets_r2_scores:
    print(subset, subsets_r2_scores[subset])


('GpTx-1',) 0.8180553562848694
('Protoxin II',) 0.7254140899963385
('JzTx-V',) 0.5627314085386583
('Huwentoxin-IV',) 0.2330262313852589
('GpTx-1', 'Protoxin II') 0.8449050548910697
('GpTx-1', 'JzTx-V') 0.8097635405164636
('GpTx-1', 'Huwentoxin-IV') 0.7494324106433555
('Protoxin II', 'JzTx-V') 0.695452046388268
('Protoxin II', 'Huwentoxin-IV') 0.6559082868407982
('JzTx-V', 'Huwentoxin-IV') 0.5050765766043527
('GpTx-1', 'Protoxin II', 'JzTx-V') 0.831052080727239
('GpTx-1', 'Protoxin II', 'Huwentoxin-IV') 0.8067006365581478
('GpTx-1', 'JzTx-V', 'Huwentoxin-IV') 0.7676833231212137
('Protoxin II', 'JzTx-V', 'Huwentoxin-IV') 0.6687983611727613
('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV') 0.7914781925163633


In [17]:
for subset in subsets_r2_scores:
    print(subset, subsets_mean_absolute_errors[subset])

('GpTx-1',) 0.29838764564518855
('Protoxin II',) 0.25124995795121885
('JzTx-V',) 0.445408687534699
('Huwentoxin-IV',) 0.2912038164224761
('GpTx-1', 'Protoxin II') 0.2803029283526695
('GpTx-1', 'JzTx-V') 0.3161909287017322
('GpTx-1', 'Huwentoxin-IV') 0.30098056971488724
('Protoxin II', 'JzTx-V') 0.29639774998117435
('Protoxin II', 'Huwentoxin-IV') 0.2698048554407912
('JzTx-V', 'Huwentoxin-IV') 0.3394351918620379
('GpTx-1', 'Protoxin II', 'JzTx-V') 0.29538464697635236
('GpTx-1', 'Protoxin II', 'Huwentoxin-IV') 0.2812753245307472
('GpTx-1', 'JzTx-V', 'Huwentoxin-IV') 0.306222749320957
('Protoxin II', 'JzTx-V', 'Huwentoxin-IV') 0.2939706486249919
('GpTx-1', 'Protoxin II', 'JzTx-V', 'Huwentoxin-IV') 0.2941069910000965
