In [None]:
import pandas as pd
import numpy as np
import argparse
import glob
# import wandb
import json
import joblib
import os

from sklearn.model_selection import KFold, train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder, StandardScaler

CALC_PATHS = '/home/nfs/jludwiczak/af2_cc/af2_multimer/calc'

In [84]:
def get_x(id_: int, rank: int, model: str = "af2", 
          use_pairwise: bool = True):
    
    single_repr_fns = sorted(glob.glob(f"{CALC_PATHS}/{id_}/*_single_repr_rank_00*"))
    pair_repr_fns = sorted(glob.glob(f"{CALC_PATHS}{id_}/*_pair_repr_rank_00*"))
    
    mat = np.load(single_repr_fns[rank]).mean(axis=0)
    if use_pairwise:
        mat = np.hstack((mat, np.load(pair_repr_fns[rank]).mean(axis=0).mean(axis=0)))
    return mat

def get_af2_emb(id_: int, model_id: int, use_pairwise: bool):
    
    single_repr_fns = sorted(glob.glob(f"{CALC_PATHS}/{id_}/*_single_repr_rank_*_model_{model_id+1}_*"))
    pair_repr_fns = sorted(glob.glob(f"{CALC_PATHS}/{id_}/*_pair_repr_rank_*_model_{model_id+1}_*"))
    

    mat = np.load(single_repr_fns[0]).mean(axis=0)
    
    if use_pairwise:
        mat = np.hstack((mat, np.load(pair_repr_fns[0]).mean(axis=0).mean(axis=0)))
    
    return mat

def get_af2_emb_upd(colabfold_output_dir, use_pairwise: bool):
    """Order of the models is inconsistent with Janeks code, so we need to do some sorting"""

    representations = sorted(glob.glob(f"{colabfold_output_dir}/*_repr_rank_*"))
    # order representations by 'rank_
    # representations = sorted(representations, key=lambda x: int(x.split('_rank_')[1].split('_')[0]))
    
    single_repr_fns = sorted([x for  x in representations if "single" in x])
    pair_repr_fns = sorted([x for  x in representations if "pair" in x])
    
    print(representations)

    mat = np.load(single_repr_fns[0]).mean(axis=0)
    
    if use_pairwise:
        mat = np.hstack((mat, np.load(pair_repr_fns[0]).mean(axis=0).mean(axis=0)))
    
    return mat

get_af2_emb_upd('./sample/10/',True);
get_af2_emb(10,0,True);

['./sample/10/3w8v_pair_repr_rank_001_alphafold2_multimer_v3_model_4_seed_000.npy', './sample/10/3w8v_pair_repr_rank_002_alphafold2_multimer_v3_model_3_seed_000.npy', './sample/10/3w8v_pair_repr_rank_003_alphafold2_multimer_v3_model_5_seed_000.npy', './sample/10/3w8v_pair_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy', './sample/10/3w8v_pair_repr_rank_005_alphafold2_multimer_v3_model_1_seed_000.npy', './sample/10/3w8v_single_repr_rank_001_alphafold2_multimer_v3_model_4_seed_000.npy', './sample/10/3w8v_single_repr_rank_002_alphafold2_multimer_v3_model_3_seed_000.npy', './sample/10/3w8v_single_repr_rank_003_alphafold2_multimer_v3_model_5_seed_000.npy', './sample/10/3w8v_single_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy', './sample/10/3w8v_single_repr_rank_005_alphafold2_multimer_v3_model_1_seed_000.npy']
['/home/nfs/jludwiczak/af2_cc/af2_multimer/calc/10/3w8v_pair_repr_rank_005_alphafold2_multimer_v3_model_1_seed_000.npy']


In [None]:
def train(c=10, balanced=0, dual=1, ensemble_size=1, use_pairwise=True, use_scaler=True):
    
    # Load dataset
    df = pd.read_csv("src/data/set4_homooligomers.csv", sep="\t")
    df = df.drop_duplicates(subset="full_sequence", keep="first")
    
    le = LabelEncoder()
    df['y'] = le.fit_transform(df['chains'])\

    results = np.zeros((ensemble_size, 5, len(df), 3))
    model = {}

    for j in range(0, ensemble_size):
        for i in range(0, 5): # 5 since we have 5 AF2 models

            X = np.asarray([get_af2_emb(id_, model_id=i, use_pairwise=use_pairwise) for id_ in df.index])
            y = df['y'].values

            cv = KFold(n_splits=5, shuffle=True)

            for k, (tr_idx, te_idx) in enumerate(cv.split(X, y)):

                X_tr, X_te = X[tr_idx], X[te_idx]
                y_tr, y_te = y[tr_idx], y[te_idx]

                if use_scaler == True:
                    sc = StandardScaler()
                    X_tr = sc.fit_transform(X_tr)
                    X_te = sc.transform(X_te)
                    model[f"scaler_{j}_{i}_{k}"] = sc
                clf = LogisticRegression(C=c, max_iter=1000, solver='liblinear',
                                         dual = False if dual == 0 else True, 
                                         class_weight = 'balanced' if balanced == 1 else None) 
                clf.fit(X_tr, y_tr)
                results[j, i, te_idx, :] = clf.predict_proba(X_te)
                model[f"clf_{j}_{i}_{k}"] = clf


    y_pred_bin = results.mean(axis=0).mean(axis=0).argmax(axis=1)
    joblib.dump(clf, 'weights.p')

    results_ = {}
    results_["accuracy"] = accuracy_score(y, y_pred_bin)
    results_["f1"] = f1_score(y, y_pred_bin, average='macro')

    df["y_pred"] = y_pred_bin
    df["prob_dimer"] = results.mean(axis=0).mean(axis=0)[:, 0]
    df["prob_trimer"] = results.mean(axis=0).mean(axis=0)[:, 1]
    df["prob_tetramer"] = results.mean(axis=0).mean(axis=0)[:, 2]
    df.to_csv('results.csv')

    return results_, model, df

In [None]:
train(use_pairwise=True)

In [92]:
def predict(colabfold_output_dir: str, use_pairwise: bool = True):
    # load model from sklearn saved weights
    model = joblib.load('weights.p')
    # load colabfold output and convert to numpy array
    # single_repr_fns  = [f for f in sorted(glob.glob(f'{colabfold_output_dir}/*_single_repr_rank_*_model_*')) if f.endswith('.npy')]
    # pair_repr_fns  = [f for f in sorted(glob.glob(f'{colabfold_output_dir}/*_pair_repr_rank_*_model_*')) if f.endswith('.npy')]

    # for i in range(0, 5):
    X = np.asarray(get_af2_emb_upd(colabfold_output_dir, use_pairwise=use_pairwise)).reshape(1, -1)

    print(model.predict(X))

    # X = np.asarray([get_af2_emb(id_, model_id=i, use_pairwise=use_pairwise) for id_ in df.index])
    # for i in range(0, 5): # 5 since we have 5 AF2 models

    #     X = np.asarray([get_af2_emb(id_, model_id=i, use_pairwise=use_pairwise) for id_ in df.index])




predict('./sample/7/')

['./sample/7/7bji_pair_repr_rank_001_alphafold2_multimer_v3_model_5_seed_000.npy', './sample/7/7bji_pair_repr_rank_002_alphafold2_multimer_v3_model_4_seed_000.npy', './sample/7/7bji_pair_repr_rank_003_alphafold2_multimer_v3_model_3_seed_000.npy', './sample/7/7bji_pair_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy', './sample/7/7bji_pair_repr_rank_005_alphafold2_multimer_v3_model_1_seed_000.npy', './sample/7/7bji_single_repr_rank_001_alphafold2_multimer_v3_model_5_seed_000.npy', './sample/7/7bji_single_repr_rank_002_alphafold2_multimer_v3_model_4_seed_000.npy', './sample/7/7bji_single_repr_rank_003_alphafold2_multimer_v3_model_3_seed_000.npy', './sample/7/7bji_single_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy', './sample/7/7bji_single_repr_rank_005_alphafold2_multimer_v3_model_1_seed_000.npy']
[0]


In [155]:
from src.predict import predict_oligo_state

path_to_colabfold_train_data = glob.glob('/home/nfs/jludwiczak/af2_cc/af2_multimer/calc/*')
pdbs = [x.split('/')[-1][:4] for x in glob.glob('/home/nfs/jludwiczak/af2_cc/af2_multimer/calc/*/*unrelaxed_rank_001*.pdb')]
df = pd.read_csv("src/results/results.csv", sep=",")[['pdb','y_pred']]
df
for pdb, _id in zip(pdbs, path_to_colabfold_train_data):
    print(pdb, os.listdir(_id)[2][:4])
    result = predict_oligo_state(_id, use_pairwise=True)
    print(df.loc[df['pdb'] == pdb]['y_pred'].values[0], result)
    # print(result, df.iloc[int(_id.split('/')[-1])].values[0])
    # # assert result == df.iloc[int(_id.split('/')[-1])].values[0], f"Predicted oligomer state: {result} is not equal to {df.iloc[int(_id.split('/')[-1])].values[0]}"
    # print("\n")

# pdb




6otn 6otn
Predicted oligomer state: Dimer
0 0
1kyc 245_
Predicted oligomer state: Trimer
1 1
2qdq 2qdq
Predicted oligomer state: Dimer
0 0
5aps 5aps
Predicted oligomer state: Dimer
0 0
3qh9 3qh9
Predicted oligomer state: Dimer
0 0
4w80 4w80
Predicted oligomer state: Dimer
2 0
5wll 5wll
Predicted oligomer state: Dimer
2 0
1uii conf
Predicted oligomer state: Dimer
0 0
1t6f 1t6f
Predicted oligomer state: Dimer
0 0
7bji 7bji
Predicted oligomer state: Dimer
0 0
4pxj 4pxj
Predicted oligomer state: Dimer
0 0
2xg7 2xg7
Predicted oligomer state: Dimer
0 0
5oiy 5oiy
Predicted oligomer state: Dimer
2 0
2o1j 2o1j
Predicted oligomer state: Dimer
2 0
4j4a 4j4a
Predicted oligomer state: Trimer
2 1
5kb1 5kb1
Predicted oligomer state: Dimer
1 0
1w5l 1w5l
Predicted oligomer state: Dimer
2 0
1w5k 1w5k
Predicted oligomer state: Dimer
2 0
5hhe 5hhe
Predicted oligomer state: Dimer
0 0
5y2e 5y2e
Predicted oligomer state: Dimer
2 0
6v5i 6v5i
Predicted oligomer state: Trimer
1 1
2z5i 2z5i
Predicted oligomer st