In [None]:
import sys
sys.path.append('../')
sys.path.append('../../')

import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model
from scipy.stats import pearsonr
from utils import get_data, set_data_size
import seaborn as sns
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.callbacks import Callback
import numpy as np
from utils import one_hot_enc
from hyper_params import HyperParams
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from models import gen_model
from models import TransformerBlock
from Bio import SeqIO
from tensorflow.keras.callbacks import EarlyStopping

In [None]:
def load_data(data_path):

    [x_train, y_train, _], [x_test, y_test, _], [x_val, y_val, _], = get_data(data_path, min_read=2000)
    [x_test] = set_data_size(80, [x_test])

    return x_test, y_test

In [None]:
def train(kernel_size):
    [x_train, y_train, _], [x_test, y_test, _], [x_val, y_val, _] = get_data('./data/human')
    [x_train, x_test, x_val] = set_data_size(80, [x_train, x_test, x_val])
    es_callback = EarlyStopping(monitor='val_loss', patience=3)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        f'./model_multi_kernel/{kernel_size}.h5',
        save_best_only=True,
        mode='max',
        monitor='val_pearson'
    )
    model = gen_model(kernel_size)
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        0.001,
        decay_steps=20,
        decay_rate=0.99
    )
    model.compile(loss=MeanSquaredError(), optimizer=tf.keras.optimizers.Adam(lr_schedule),
                   metrics=[tf.keras.metrics.MeanAbsoluteError(), tf.keras.metrics.MeanSquaredError()])
    model.fit(x_train, y_train, batch_size=128, epochs=20, validation_data=(x_val, y_val), shuffle=True)

    model = load_model('./models/' + str(kernel_size) + '.h5', custom_objects={'TransformerBlock': TransformerBlock})
    y_hat = model(x_val).numpy()
    y_hat = y_hat.reshape(len(y_hat))
    y_val = y_val.reshape(len(y_val))
    # evaluate model performance
    pr_corr = pearsonr(y_hat, y_val)[0]
    print(f"val corr = {pr_corr}")
    return pr_corr, model

In [None]:
def evaluate_model(pre_path):
    x_val, y_val = load_data('../data/human')
    y_list = []
    for i in range(15):
        model_path = pre_path + '/model_' + str(i) + '.h5'
        # get model
        model = load_model(model_path, custom_objects={"pearson": pearson, 'TransformerBlock': TransformerBlock})
        if i == 0:
            # make a prediction on the val set
            y_hat = model(x_val).numpy()
            y_hat = y_hat.reshape(len(y_hat))
            y_val = y_val.reshape(len(y_val))
            # evaluate model performance
            pr_corr = pearsonr(y_hat, y_val)[0]
            print(f"item:{i}, val corr = {pr_corr}")
            y_list.append(pr_corr)
        else:
            y_hat = model(x_val).numpy().reshape(len(y_hat)) + y_hat
            y_val = y_val.reshape(len(y_val))
            # evaluate model performance
            pr_corr = pearsonr(y_hat / (i + 1), y_val)[0]
            print(f"item:{i}, val corr = {pr_corr}")
            y_list.append(pr_corr)

    print(y_list)