### E Karvelis | 5/24/2023
### Purpose
Test the performance of trained models

In [2]:
# import modules
from transformer_1 import *

import sys
sys.path.append('/data/karvelis03/dl_kcat/scripts/')
from prep_data import *

import matplotlib.pyplot as plt
import seaborn as sns

import time

from scipy.stats import spearmanr

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
"""
Helper functions
"""

'\nHelper functions\n'

In [4]:
class ModelTest(PathDataset):
    # Class that organizes a model, dataset, and 
    # set of held-out variants (within dataset)
    # for execution of various model performance
    # analyses
    
    def __init__(self, model_file, data_file, meta_file=None, output_file=None, scaler=None):
        
        super().__init__(data_file, meta_file, path_set_size=1)
        self.scaler = scaler
        self.model_file = model_file
        self.config_file = glob('/'.join(model_file.split('/')[0:-1]) + '/*conf*txt')[0]
        self.train_vars, self.test_vars = ModelTest.get_val_variants(model_file, output_file=output_file)
        
        # Create train indexes corresponding to variants in the train set
        train_idx = np.nonzero(np.in1d(np.array(self.meta.variant),np.unique(self.train_vars)))[0]

        # Load the model
        #self.model = TransformerModel(self.data.shape[-1], self.data.shape[-2], d_model=128) ### read d_model from config file
        self.model = self.make_model()
        self.model.load_state_dict(torch.load(self.model_file))

        # Fit the scaler to the training data
        if self.scaler == None:
            self.scaler = NormalScaler()
            self.scaler.fit(self.data[train_idx,:,:])
            self.data_scaler = DataScaler(self.scaler)
        else:
            self.data_scaler = DataScaler(self.scaler)
            
    def make_model(self):
        # Creates the PyTorch model object into which pre-trained
        # weights are loaded
        d_model = 256
        n_head = 4
        d_tran_ffn = 1024
        dropout_tran_encoder = 0.2
        n_tran_layers = 2
        d_mlp_head = 128
        dropout_mlp_head = 0.2
        with open(self.config_file, 'r') as f:
            settings = f.read()
            exec(settings)
        model = TransformerModel(input_size=self.data.shape[-1],
                                 input_length=self.data.shape[-2],
                                d_model = 128,
                                n_head = 2,
                                d_tran_ffn = 256,
                                dropout_tran_encoder = 0.2,
                                n_tran_layers = 1,
                                d_mlp_head = 64,
                                dropout_mlp_head = 0.2).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        return model
    
    def score_var(self, var):
        
        var_idx = np.nonzero(np.in1d(np.array(self.obs.variant), var))[0]
        var_dataset = PathTorchDataset(self, elligible_idxs=var_idx, transform=self.data_scaler)

        loader = DataLoader(var_dataset, batch_size=32, shuffle=True)

        self.model.eval()
        total_obs = 0
        with torch.no_grad():
            
            results = []
            for batch_idx, batch in enumerate(loader):

                n_obs = batch['paths'].size(0)
                output = self.model(batch['paths']).cpu().numpy().reshape((-1))
                total_obs += n_obs
                # self.results[var][path_set_size].append(output)
                results.append(output)
    
        return results
        
    def plot_test_var_pred(self, path_set_sizes=[10,100,1000], figname=False):
        # Plots the distribution of predicted log(kcat) values 
        # for each held-out (i.e., validation) variant as 
        # function of the path_set_size, which is the number of
        # paths included in each observation of the variant
        
        results = {}
        for var in self.test_vars:
            d = {}
            for size in path_set_sizes:
                d[size] = []
            results[var] = d
            
        for path_set_size in path_set_sizes:
            
            print (f'Testing path_set_size: {path_set_size}')
            self.path_set_size = path_set_size
            
            # Populate the obs attribute
            self.make_observations()
            
            # grab indexes of the held-out variants
            for i,var in enumerate(np.unique(self.test_vars)):
                res = self.score_var(var)
                results[var][path_set_size] = res
                print (f'Completed {i+1}/{self.test_vars.shape[0]} variants...')
            print ()
            
            # Empty the obs attribute to reset it
            self.obs = None     
            
        for var in results:
            for size in results[var]:
                results[var][size] = np.concatenate(results[var][size])
                
        # Plot
        ncols = 3
        nrows = int(np.ceil(len(results)/ncols))
        fig, axes = plt.subplots(nrows, ncols, figsize=(9*ncols,6*nrows))
        for i,var in enumerate(results):

            var_kcat = np.log10(self.meta.kcat[np.where(np.array(self.meta.variant) == var)[0][0]])

            path_set_sizes, y_pred = [],[]
            for path_set_size in results[var]:

                data = results[var][path_set_size]
                y_pred += list(data)
                path_set_sizes += [path_set_size]*data.shape[0]


            df = pd.DataFrame({'path_set_size': path_set_sizes, 'Pred. log(kcat)': y_pred})

            # Create violin plot with seaborn
            axes.flatten()[i] = sns.violinplot(data=df, x='path_set_size', y='Pred. log(kcat)', ax=axes.flatten()[i])
            axes.flatten()[i].axhline(var_kcat, ls='--', c='gray', label=r'TIS log($k_{cat}$)')
            axes.flatten()[i].set_xlabel('')
            axes.flatten()[i].set_ylabel('')
            axes.flatten()[i].set_title(var, fontsize=22)
            axes.flatten()[i].legend(fontsize=20)
            axes.flatten()[i].tick_params(axis='both', labelsize=22)
            fig.text(0.5, 0.07, 'Paths per prediction', ha='center', fontsize=32)
            fig.text(0.05, 0.5, r'Predicted log($k_{cat}$)', va='center', rotation='vertical', fontsize=32)

        # fig.tight_layout()
        if figname:
            fig.savefig(figname, dpi=300)
        
        return results
    
    def plot_train_var_pred(self, path_set_sizes=[10,100,1000], figname=False):
        # Plots the distribution of predicted log(kcat) values 
        # for each variant included during training as a 
        # function of the path_set_size, which is the number of
        # paths included in each observation of the variant
        
        results = {}
        for var in self.train_vars:
            d = {}
            for size in path_set_sizes:
                d[size] = []
            results[var] = d
            
        for path_set_size in path_set_sizes:
            
            print (f'Testing path_set_size: {path_set_size}')
            self.path_set_size = path_set_size
            
            # Populate the obs attribute
            self.make_observations()
            
            # grab indexes of the held-out variants
            for i,var in enumerate(np.unique(self.train_vars)):
                res = self.score_var(var)
                results[var][path_set_size] = res
                print (f'Completed {i+1}/{self.train_vars.shape[0]} variants...')
            print ()
            
            # Empty the obs attribute to reset it
            self.obs = None     
            
        for var in results:
            for size in results[var]:
                results[var][size] = np.concatenate(results[var][size])
                
        # Plot
        ncols = 3
        nrows = int(np.ceil(len(results)/ncols))
        fig, axes = plt.subplots(nrows, ncols, figsize=(9*ncols,6*nrows))
        for i,var in enumerate(results):

            var_kcat = np.log10(self.meta.kcat[np.where(np.array(self.meta.variant) == var)[0][0]])

            path_set_sizes, y_pred = [],[]
            for path_set_size in results[var]:

                data = results[var][path_set_size]
                y_pred += list(data)
                path_set_sizes += [path_set_size]*data.shape[0]


            df = pd.DataFrame({'path_set_size': path_set_sizes, 'Pred. log(kcat)': y_pred})

            # Create violin plot with seaborn
            axes.flatten()[i] = sns.violinplot(data=df, x='path_set_size', y='Pred. log(kcat)', ax=axes.flatten()[i])
            axes.flatten()[i].axhline(var_kcat, ls='--', c='gray', label=r'TIS log($k_{cat}$)')
            axes.flatten()[i].set_xlabel('')
            axes.flatten()[i].set_ylabel('')
            axes.flatten()[i].set_title(var, fontsize=22)
            axes.flatten()[i].legend(fontsize=20)
            axes.flatten()[i].tick_params(axis='both', labelsize=22)
            fig.text(0.5, 0.07, 'Paths per prediction', ha='center', fontsize=32)
            fig.text(0.05, 0.5, r'Predicted log($k_{cat}$)', va='center', rotation='vertical', fontsize=32)

        # fig.tight_layout()
        if figname:
            fig.savefig(figname, dpi=300)
        
        return results
    
    def spearman_test(self, verbose=False):
        # Calculate Spearman rank correlation between predictions and 
        # TIS measured kcat values for variants in the test set
        
        # From https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html:
        # Returns:
        # res :
        # SignificanceResult
        # An object containing attributes:
        #
        #    statistic : 
        #    float or ndarray (2-D square)
        #    Spearman correlation matrix or correlation coefficient (if only 2 
        #    variables are given as parameters). Correlation matrix is square 
        #   with length equal to total number of variables (columns or rows) 
        #    in a and b combined.
        #
        #    pvalue :
        #   float
        #    The p-value for a hypothesis test whose null hypothesis is that 
        #    two samples have no ordinal correlation. See alternative above 
        #    for alternative hypotheses. pvalue has the same shape as statistic.
        #
        #    See https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html
        

        # Populate the obs attribute
        self.make_observations()

        variants, preds, targets = [], [], []
        # grab indexes of the held-out variants
        for i,var in enumerate(np.unique(self.test_vars)):
            res = self.score_var(var)
            res = np.concatenate(res)

            variants.append(var)
            preds.append(np.mean(res))
            targets.append(np.log10(self.obs.kcat[np.where(self.obs.variant==var)[0][0]]))

            print (f'Completed {i+1}/{np.unique(self.test_vars).shape[0]} variants...')

        if verbose:
            for i,var in enumerate(variants):
                print (f'{var}: {preds[i]:.2f}   |   {targets[i]:.2f}')

        res = spearmanr(preds, targets)#, alternative='greater')

        return res
    
    def spearman_train(self, verbose=False):
        # Calculate Spearman rank correlation between predictions and 
        # TIS measured kcat values for variants in the train set
        
        # From https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html:
        # Returns:
        # res :
        # SignificanceResult
        # An object containing attributes:
        #
        #    statistic : 
        #    float or ndarray (2-D square)
        #    Spearman correlation matrix or correlation coefficient (if only 2 
        #    variables are given as parameters). Correlation matrix is square 
        #   with length equal to total number of variables (columns or rows) 
        #    in a and b combined.
        #
        #    pvalue :
        #   float
        #    The p-value for a hypothesis test whose null hypothesis is that 
        #    two samples have no ordinal correlation. See alternative above 
        #    for alternative hypotheses. pvalue has the same shape as statistic.
        #
        #    See https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html
        

        # Populate the obs attribute
        self.make_observations()

        variants, preds, targets = [], [], []
        # grab indexes of the training variants
        for i,var in enumerate(np.unique(self.train_vars)):
            res = self.score_var(var)
            res = np.concatenate(res)

            variants.append(var)
            preds.append(np.mean(res))
            targets.append(np.log10(self.obs.kcat[np.where(self.obs.variant==var)[0][0]]))

            print (f'Completed {i+1}/{np.unique(self.train_vars).shape[0]} variants...')

        if verbose:
            for i,var in enumerate(variants):
                print (f'{var}: {preds[i]:.2f}   |   {targets[i]:.2f}')

        res = spearmanr(preds, targets)#, alternative='greater')

        return res
        
    @staticmethod
    def get_dir(file):
        # Returns direction under which file is stored
        file_dir = '/'.join(file.split('/')[0:-1]) + '/'
        return file_dir

    @staticmethod
    def get_val_variants(model_file, output_file=None):
        # Returns list of variants that were held out
        # during training of the model saved to model_file
        # INPUT:
        # model_file -- full path to the file to which the 
        #               model was saved
        # output_file -- the text file to which output was 
        #                written by the script that trained
        #                and saved the model

        if output_file == None:
            output_file = ModelTest.get_dir(model_file) + 'transformer_1_output.txt'

        if ModelTest.get_dir(model_file) != ModelTest.get_dir(output_file):
            raise ValueError('model_file and output_file directories do not match.' +\
                             'Are you sure you have the right ones?')

        # get the CV fold
        fold = model_file.split('cvfold')[-1].split('.pt')[0]

        # read the output file
        with open(output_file, 'r') as f:
            data = f.read()

        # grab the variants used for training
        train_vars = data.split(f"epoch 1 | CV fold {fold}")[0].split('Training variants:')[-1]

        # grab the variants used for validation
        val_vars = data.split(f"Best loss achieved, saving model state to best_model_cvfold{fold}.pt")[0]
        val_vars = val_vars.split('Validation variants:')[-1]
        
        train_vars = train_vars.replace("' '",',').replace("'",'').replace('\n ',',').split('[')[-1].split(']')[0].split(',')
        val_vars = val_vars.replace("' '",',').replace("'",'').replace('\n ',',').split('[')[-1].split(']')[0].split(',')

        return np.array(train_vars), np.array(val_vars)
        

# model_file = '/data/karvelis03/dl_kcat/transformer_1s/job0/best_model_cvfold1.pt'
# data_file = '/data/karvelis03/dl_kcat/data/total/tptrue_gsfalse_o-0dot4_0dot8_s1_2_3_4_5_r1_2_t-110_0_sub500_numNone.470000-111-70memnpy'

# test = ModelTest(model_file, data_file)#, scaler=scaler)

# start_time = time.time()
# results = test.plot_test_var_pred()
# print (f'\n\nRuntime: {time.time() - start_time}')
# print (test.train_vars)
# print (test.test_vars)
# print (test.data.shape)
# print (len(test.meta.variant))

# Spearman rank correlation

In [13]:
data_file = '/data/karvelis03/dl_kcat/data/total/tptrue_gsfalse_o-0dot4_0dot8_s1_2_3_4_5_r1_2_t-35_75_sub500_numNone.550000-111-70memnpy'
meta_file = '/data/karvelis03/dl_kcat/data/total/tptrue_gsfalse_o-0dot4_0dot8_s1_2_3_4_5_r1_2_t-35_75_sub500_numNone.550000-111-70metadata'

In [14]:
spearman_rs = []
for model_file in glob('./best_model*'):
    
    test = ModelTest(model_file, data_file, meta_file=meta_file)
    
    test.path_set_size = 10
    res = test.spearman_test()
    spearman_rs.append(res.correlation)

spearman_rs = np.array(spearman_rs)
print (spearman_rs)
print (np.mean(spearman_rs))

Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
Completed 1/5 variants...
Completed 2/5 variants...
Completed 3/5 variants...
Completed 4/5 variants...
Completed 5/5 variants...
Completed 1/5 variants...
Completed 2/5 variants...
Completed 3/5 variants...
Completed 4/

# CV Fold 1

In [5]:
model_file = '/data/karvelis03/dl_kcat/transformer_1s/denseweight/job12-1/stoch_labels/test/best_model_cvfold1.pt'
test = ModelTest(model_file, data_file, meta_file=meta_file)

In [6]:
_ = test.plot_train_var_pred(figname='best_model_cvfold1_train.png')

Testing path_set_size: 10
Completed 1/49 variants...
Completed 2/49 variants...
Completed 3/49 variants...
Completed 4/49 variants...
Completed 5/49 variants...
Completed 6/49 variants...
Completed 7/49 variants...
Completed 8/49 variants...
Completed 9/49 variants...
Completed 10/49 variants...
Completed 11/49 variants...
Completed 12/49 variants...
Completed 13/49 variants...
Completed 14/49 variants...
Completed 15/49 variants...
Completed 16/49 variants...
Completed 17/49 variants...
Completed 18/49 variants...
Completed 19/49 variants...
Completed 20/49 variants...
Completed 21/49 variants...
Completed 22/49 variants...
Completed 23/49 variants...
Completed 24/49 variants...
Completed 25/49 variants...
Completed 26/49 variants...
Completed 27/49 variants...
Completed 28/49 variants...
Completed 29/49 variants...
Completed 30/49 variants...
Completed 31/49 variants...
Completed 32/49 variants...
Completed 33/49 variants...
Completed 34/49 variants...
Completed 35/49 variants...
Com

OutOfMemoryError: CUDA out of memory. Tried to allocate 174.00 MiB (GPU 0; 11.92 GiB total capacity; 269.54 MiB already allocated; 14.06 MiB free; 272.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
_ = test.plot_test_var_pred(figname='best_model_cvfold1_train.png')


In [7]:
test.path_set_size = 10
res = test.spearman_test()
print (res)

Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
SpearmanrResult(correlation=0.7714285714285715, pvalue=0.07239650145772594)


# CV Fold 2

In [8]:
model_file = '/data/karvelis03/dl_kcat/transformer_1s/denseweight/job12-1/stoch_labels/test/best_model_cvfold2.pt'
test = ModelTest(model_file, data_file, meta_file=meta_file)

In [None]:
_ = test.plot_train_var_pred(figname='best_model_cvfold2_train.png')

In [None]:
_ = test.plot_test_var_pred(figname='best_model_cvfold2_train.png')

In [9]:
test.path_set_size = 10
res = test.spearman_test()
print (res)

Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
SpearmanrResult(correlation=0.6, pvalue=0.20799999999999982)


# CV Fold 3

In [10]:
model_file = '/data/karvelis03/dl_kcat/transformer_1s/denseweight/job12-1/stoch_labels/test/best_model_cvfold3.pt'
test = ModelTest(model_file, data_file, meta_file=meta_file)

In [None]:
_ = test.plot_train_var_pred(figname='best_model_cvfold3_train.png')

In [None]:
_ = test.plot_test_var_pred(figname='best_model_cvfold3_train.png')

In [11]:
test.path_set_size = 10
res = test.spearman_test()
print (res)

Completed 1/6 variants...
Completed 2/6 variants...
Completed 3/6 variants...
Completed 4/6 variants...
Completed 5/6 variants...
Completed 6/6 variants...
SpearmanrResult(correlation=0.3142857142857143, pvalue=0.5440932944606414)
