In [1]:
from jax_unirep import get_reps



In [2]:
import pandas as pd
from os import path
import numpy as np

In [40]:
from sklearn.model_selection import RandomizedSearchCV
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import LeaveOneGroupOut
from sklearn import metrics
from sklearn.utils.fixes import loguniform 

In [3]:
DATA_DIR = "../../data"

In [22]:
chen_train = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_train_data_w_clusters.csv"), index_col=0)
chen_train.head()

Unnamed: 0,Antibody_ID,heavy,light,Y,cluster,cluster_merged
2073,6aod,EVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLE...,DIVMTKSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKR...,0,313,3
1517,4yny,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,EFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1,347,3
2025,5xcv,EVQLVESGGGLVQPGRSLKLSCAASGFTFSNYGMAWVRQTPTKGLE...,QFVLTQPNSVSTNLGSTVKLSCKRSTGNIGSNYVNWYQQHEGRSPT...,1,347,3
2070,6and,EVQLVESGGGLVQPGGSLRLSCAASGYEFSRSWMNWVRQAPGKGLE...,DIQMTQSPSSLSASVGDRVTITCRSSQSIVHSVGNTFLEWYQQKPG...,1,458,4
666,2xqy,QVQLQQPGAELVKPGASVKMSCKASGYSFTSYWMNWVKQRPGRGLE...,DIVLTQSPASLALSLGQRATISCRASKSVSTSGYSYMYWYQQKPGQ...,0,465,4


In [7]:
chen_valid = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_valid_data.csv"), index_col=0)
chen_test = pd.read_csv(path.join(DATA_DIR, "chen/deduplicated/chen_test_data.csv"), index_col=0)
chen_test = pd.concat([chen_valid, chen_test])
chen_test.head()

Unnamed: 0,Antibody_ID,heavy,light,Y
2169,6ct7,EVQLVESGGGLVEPGGSLRLSCAVSGFDFEKAWMSWVRQAPGQGLQ...,SYELTQPPSVSVSPGQTARITCSGEALPMQFAHWYQQRPGKAPVIV...,0
1342,4nzu,AVSLVESGGGTVEPGSTLRLSCAASGFTFGSYAFHWVRQAPGDGLE...,DIEMTQSPSSLSASTGDKVTITCQASQDIAKFLDWYQQRPGKTPKL...,0
1728,5i8c,QEVLVQSGAEVKKPGASVKVSCRAFGYTFTGNALHWVRQAPGQGLE...,DIQLTQSPSFLSASVGDKVTITCRASQGVRNELAWYQQKPGKAPNL...,1
1729,5i8e,QEVLVQSGAEVKKPGASVKVSCRAFGYTFTGNALHWVRQAPGQGLE...,IQLTQSPSFLSASVGDKVTITCRASQGVRNELAWYQQKPGKAPNLL...,0
2114,6bb4,QVQLQQSDAELVKPGASVKISCKASGYTFTDRTIHWVKQRPEQGLE...,DVQMIQSPSSLSASLGDIVTMTCQASQDTSINLNWFQQKPGKAPKL...,0


## Only heavy sequences

In [31]:
train_h_seqs = list(chen_train["heavy"])
y_train = list(chen_train["Y"])

test_h_seqs = list(chen_test["heavy"])
y_test = list(chen_test["Y"])

In [24]:
# following https://elarkk.github.io/jax-unirep/getting-started/#basic-usage
X_train, h_final, c_final = get_reps(train_h_seqs)

X_test, _, _ = get_reps(test_h_seqs)

In [25]:
X_train.shape

(1338, 1900)

In [26]:
groups = chen_train["cluster_merged"]

In [20]:
def svm(n):
    svc = SVC(max_iter=8000, probability=True, class_weight='balanced')
    parameters = {'C': loguniform(0.001, 100), 'kernel':["linear", "rbf"], 'gamma': loguniform(1e-3, 1e0)}
    return svc, parameters, "SVM"

In [34]:
classifier, params, model_label = svm(1338)
splitter = LeaveOneGroupOut()
split = splitter.split(X_train, y_train, groups=groups)
grid = RandomizedSearchCV(classifier, params, verbose=1, scoring="f1", cv=split)
grid.fit(X_train, y_train)
estimator = grid.best_estimator_
best_params = grid.best_params_
y_pred = estimator.predict(X_test)

Fitting 10 folds for each of 10 candidates, totalling 100 fits


In [35]:
print(metrics.f1_score(y_test, y_pred))

0.49681528662420377


In [37]:
print(metrics.accuracy_score(y_test, y_pred))
print(metrics.matthews_corrcoef(y_test, y_pred))

0.6694560669456067
0.3510547884992857


In [38]:
def multilayer_perceptron(n):
    mlp = MLPClassifier(random_state=42, max_iter=int(1000))
    parameters = {'hidden_layer_sizes': [(100,), (50,), (100, 100)], "activation": ["relu", "logistic"]}
    return mlp, parameters, "multilayer_perceptron"

In [41]:
classifier, params, model_label = multilayer_perceptron(1338)
splitter = LeaveOneGroupOut()
split = splitter.split(X_train, y_train, groups=groups)
grid = RandomizedSearchCV(classifier, params, verbose=1, scoring="f1", cv=split)
grid.fit(X_train, y_train)
estimator = grid.best_estimator_
best_params = grid.best_params_
y_pred = estimator.predict(X_test)



Fitting 10 folds for each of 6 candidates, totalling 60 fits


In [45]:
print(metrics.f1_score(y_test, y_pred))
print(metrics.accuracy_score(y_test, y_pred))
print(metrics.matthews_corrcoef(y_test, y_pred))

0.36619718309859156
0.8117154811715481
0.3043448794404797


In [49]:
train_seqs = list(chen_train["heavy"] + chen_train["light"])
y_train = list(chen_train["Y"])

test_seqs = list(chen_test["heavy"] + chen_test["light"])
y_test = list(chen_test["Y"])

In [50]:
X_train, h_final, c_final = get_reps(train_seqs)

X_test, _, _ = get_reps(test_seqs)

In [47]:
classifier, params, model_label = svm(1338)
splitter = LeaveOneGroupOut()
split = splitter.split(X_train, y_train, groups=groups)
grid = RandomizedSearchCV(classifier, params, verbose=1, scoring="f1", cv=split)
grid.fit(X_train, y_train)
estimator = grid.best_estimator_
best_params = grid.best_params_
y_pred = estimator.predict(X_test)

Fitting 10 folds for each of 10 candidates, totalling 100 fits


In [53]:
print(metrics.f1_score(y_test, y_pred))
print(metrics.accuracy_score(y_test, y_pred))
print(metrics.matthews_corrcoef(y_test, y_pred))

0.5492957746478873
0.7322175732217573
0.42372004540697317


In [54]:
classifier, params, model_label = multilayer_perceptron(1338)
splitter = LeaveOneGroupOut()
split = splitter.split(X_train, y_train, groups=groups)
grid = RandomizedSearchCV(classifier, params, verbose=1, scoring="f1", cv=split)
grid.fit(X_train, y_train)
estimator = grid.best_estimator_
best_params = grid.best_params_
y_pred = estimator.predict(X_test)



Fitting 10 folds for each of 6 candidates, totalling 60 fits


In [55]:
print(metrics.f1_score(y_test, y_pred))
print(metrics.accuracy_score(y_test, y_pred))
print(metrics.matthews_corrcoef(y_test, y_pred))

0.3870967741935484
0.8410041841004184
0.42657837358124684
