In [2]:
%load_ext autoreload
%autoreload 2

import os 
import sys

ROOT_PATH = os.path.dirname(os.getcwd())
sys.path.append(ROOT_PATH)
try:
    sys.path.remove('/projects/p30802/Karina/protease_stability/')
except:
    pass

import numpy as np
import matplotlib.pyplot as plt

from src_.evals.run_model import get_params
from src_.evals.data_processing import get_and_process_data, get_folded_unfolded_data_splits
from src_.utils.general import sample_arrays
from src_.utils.plotting import plot_losses, plot_losses_unfolded_kT_kC, plot_scatter_predictions
from src_.models.wrapper import ProtNet
from src_.config import Config
from src_.evals.stability_score import plot_stability_score_correlation

np.random.seed(42)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


_____

### Custom Config

In [3]:
DATA_PATH = os.path.join(ROOT_PATH, "data/210728_scrambles_for_unstructure_model.csv")
DATA2_PATH = os.path.join(ROOT_PATH, "data/210728_dmsv2_alldata.csv")

In [4]:
MODEL_TYPE = "convnet_1d"

_____

### Prepare data

In [None]:
X_unfolded, kT_unfolded, kC_unfolded = get_and_process_data(DATA_PATH, fit_to_range="remove")
X_folded, kT_folded, kC_folded = get_and_process_data(DATA2_PATH, fit_to_range="remove")

In [None]:
unfolded_data, folded_data = get_folded_unfolded_data_splits(X_unfolded, kT_unfolded, kC_unfolded,
                                                            X_folded, kT_folded, kC_folded)

_____

### Train the model

In [None]:
params, epochs = get_params(MODEL_TYPE)

params["num_char"] = Config.get("n_char")
params["seq_length"] = Config.get("seq_length")

In [None]:
model = ProtNet(model_type = MODEL_TYPE, **params)

In [None]:
model.train(
    X_unfolded=unfolded_data["X_train"],
    kT_unfolded=unfolded_data["kT_train"],
    kC_unfolded=unfolded_data["kC_train"],
    X_folded=folded_data["X_train"],
    kT_folded=folded_data["kT_train"],
    kC_folded=folded_data["kC_train"],
    epochs=epochs,
)

_____

_____

### Evaluate model on the test data

In [None]:
mse_kT, mse_kC, mse_stability = \
    model.evaluate(
        X_unfolded=unfolded_data["X_test"],
        kT_unfolded=unfolded_data["kT_test"],
        kC_unfolded=unfolded_data["kC_test"],
        X_folded=folded_data["X_test"],
        kT_folded=folded_data["kT_test"],
        kC_folded=folded_data["kC_test"]
    )

_____

In [None]:
plot_scatter_predictions(model, 
                         unfolded_data["X_test"],
                         unfolded_data["kT_test"],
                         unfolded_data["kC_test"],
                         sample=2000)

_____

### Plot losses

In [None]:
plot_losses(model)

In [None]:
plot_losses_unfolded_kT_kC(model)

### Plot stability scores

In [None]:
X_test_folded_samples = sample_arrays([folded_data["X_test"]], n_samples=1000)[0]

kT_pred, kC_pred = model.predict(X_test_folded_samples)

title=model.model.name.replace("_", " ")
plt.scatter(kT_pred, kC_pred, alpha=0.3)
plt.title(f"USM-Predicted Range \n {title}")
plt.show()

In [None]:
np.random.seed(42)
plot_stability_score_correlation(model.model,
                                 folded_data["X_test"],
                                 folded_data["kT_test"],
                                 folded_data["kC_test"],
                                 sample = 1000)

#### Local stability scores

In [None]:
# data_folded = pd.read_csv(DATA2_PATH)
# data_folded["pdb_code"] =  [name.split(".")[0] for name in data_folded.name]
# grouped_indices = data_folded.groupby(by="pdb_code").indices

# X, kT, kC = get_and_process_data(DATA2_PATH, return_as_df=True)


# save = False
# n_to_plot = 20
# save_dir = save_path = os.path.join(ROOT_PATH, f"results/stability_scores/{MODEL_TYPE}/mutations/")


# if not os.path.exists(save_dir):
#     os.makedirs(save_dir)

# for i, (group_name, indices) in enumerate(grouped_indices.items()):
#     if len(indices) > 10:
#         X_, kT_, kC_ = X.loc[indices], kT.loc[indices], kC.loc[indices]
#         kT_, kC_ = np.array(kT_), np.array(kC_)
        
#         save_path = os.path.join(save_dir, f"{group_name}.png")

#         if save:   
#             plot_stability_score_correlation(model.model, X_, kT_, kC_, title=group_name, save_path = save_path)
#         else:
#             plot_stability_score_correlation(model.model, X_, kT_, kC_, title=group_name)

#     if i == n_to_plot:
#         break

_____

_____