In [17]:
import numpy as np
from read_script import read_dataset

# converts the dict to an array
# averages protein embedding across sequence
def converDictToArrAndAverageEmbedding(data_dict, labels, avg=True):

    if avg:
        emb_size = list(data_dict.values())[0].shape[1]
    else:
        emb_size = list(data_dict.values())[0].shape[0]
        
    X = np.zeros((len(data_dict), emb_size))
    y = np.zeros(len(data_dict))
    
    i = 0
    for key in data_dict:
        if avg:
            X[i] = np.mean(data_dict[key], axis=0)
        else:
            X[i] = data_dict[key]
        y[i] = labels[key]
        i += 1
        
    print("X shape: ", X.shape)    
    print("y shape: ", y.shape)    
    
    return X, y


from timeit import default_timer as timer
from datetime import timedelta
from sklearn.linear_model import LinearRegression
from scipy import stats
from sklearn.neural_network import MLPRegressor

def linear_regression_base_eval(model, task, avg=True):

    print(f"linear_regression_base_eval for model: {model}, task: {task}")
    print("===============================================")
    
    print("\nTRAIN")
    X_train_dict = read_dataset(model, task, 'train')
    y_train_dict = read_dataset('label', task, 'train')
    print("reshape data...")
    X_train, y_train = converDictToArrAndAverageEmbedding(X_train_dict, y_train_dict, avg)

    print("\nfitting reg...")
    start = timer()
    reg =  MLPRegressor(random_state=1, max_iter=500).fit(X_train, y_train)
    end = timer()
    print("fit time: ", timedelta(seconds=end-start))

    print("\nVALID")
    X_valid_dict = read_dataset(model, task, 'valid')
    y_valid_dict = read_dataset('label', task, 'valid')
    print("reshape data...")
    X_valid, y_valid = converDictToArrAndAverageEmbedding(X_valid_dict, y_valid_dict, avg)
    print("score valid...")
    scores_valid = reg.predict(X_valid)
    rho_valid = stats.spearmanr(y_valid, scores_valid)
    print("spearman rho valid: ", rho_valid)
    
    
    print("\nTEST")
    X_test_dict = read_dataset(model, task, 'test')
    y_test_dict = read_dataset('label', task, 'test')
    print("reshape data...")
    X_test, y_test = converDictToArrAndAverageEmbedding(X_test_dict, y_test_dict, avg)
    print("score test...")
    scores_test = reg.predict(X_test)
    rho_test = stats.spearmanr(y_test, scores_test)
    print("spearman rho test: ", rho_test)




In [18]:
linear_regression_base_eval("elmo", "stability")

linear_regression_base_eval for model: elmo, task: stability

TRAIN
reshape data...
X shape:  (50473, 1024)
y shape:  (50473,)

fitting reg...
fit time:  0:10:21.266920

VALID
reshape data...
X shape:  (2512, 1024)
y shape:  (2512,)
score valid...
spearman rho valid:  SpearmanrResult(correlation=0.6156721719480129, pvalue=4.9721998458262346e-262)

TEST
reshape data...
X shape:  (12851, 1024)
y shape:  (12851,)
score test...
spearman rho test:  SpearmanrResult(correlation=0.439839218865642, pvalue=0.0)


In [19]:
linear_regression_base_eval("unirep", "stability")

linear_regression_base_eval for model: unirep, task: stability

TRAIN
reshape data...
X shape:  (50473, 1900)
y shape:  (50473,)

fitting reg...
fit time:  0:47:15.588245

VALID
reshape data...
X shape:  (2512, 1900)
y shape:  (2512,)
score valid...
spearman rho valid:  SpearmanrResult(correlation=0.585409959447197, pvalue=5.292124675003061e-231)

TEST
reshape data...
X shape:  (12851, 1900)
y shape:  (12851,)
score test...
spearman rho test:  SpearmanrResult(correlation=0.29860900087638365, pvalue=6.099808752913723e-263)


In [20]:
linear_regression_base_eval("compact", "stability", avg=False)

linear_regression_base_eval for model: compact, task: stability

TRAIN
reshape data...
X shape:  (50473, 64)
y shape:  (50473,)

fitting reg...
fit time:  0:00:42.082265

VALID
reshape data...
X shape:  (2512, 64)
y shape:  (2512,)
score valid...
spearman rho valid:  SpearmanrResult(correlation=0.4503402875654906, pvalue=1.0243134823687443e-125)

TEST
reshape data...
X shape:  (12851, 64)
y shape:  (12851,)
score test...
spearman rho test:  SpearmanrResult(correlation=0.08562387954793516, pvalue=2.3956862733105413e-22)
