# Imports

In [1]:
# This code is not functional without individual level data!

# %load_ext rpy2.ipython
########################################################
## Base Imports:

# Sys Imports:
import time
import os

# Standard Imports:
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Scientific computing
from scipy import stats
from scipy import linalg

#########################################################
## Experiment Specific Imports

# Logistics Imports:
import psutil
from sys import getsizeof
import inspect
from importlib import reload 
import copy
import IPython as ip
import pickle
import sys
import submitit
# from mjwt.tools import Timer, Tree, beep, fullvars #dont forget about mro 
# from mjwt.tools import sizegb, implot, redo_all_above, Struct, ResultsStorage
# from mjwt.tools import jobinfo, corr
from IPython.display import display, HTML, Audio, Javascript
from tqdm.notebook import tqdm
import glob, re
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor


# ML Imports:
from sklearn.metrics import r2_score, roc_auc_score
from sklearn.model_selection import ParameterGrid
from scipy.stats import pearsonr, spearmanr


# Genomics Imports:
from pysnptools.snpreader import Bed, Pheno, SnpHdf5, SnpData
from pysnptools.pstreader import PstData, PstHdf5, PstReader
from pysnptools.kernelreader import KernelHdf5, KernelData
from pysnptools.standardizer import UnitTrained
import pysnptools.util as pstutil



########################################################
## Configuration & Initialisation

# Display Configuration:
from IPython.display import set_matplotlib_formats
plt.rcParams['figure.figsize'] = [10, 5]
pd.set_option('display.max_colwidth', None) # No pd trunkation (radical)
display(HTML("<style>.container { width:75% !important; }</style>"))

# Initializations:
timer = Timer(); toc = timer.toc; tic = timer.tic; tic('')
job_dt_lst   = [] if not 'job_dt_lst'   in locals() else job_dt_lst
model_dt_lst = [] if not 'model_dt_lst' in locals() else model_dt_lst
notebook = False  if '__file__' in locals() else True
# os.environ["OMP_NUM_THREADS"] = str(int(os.environ['SLURM_JOB_CPUS_PER_NODE']) - 1)
log=np.log10
# get_ipython().run_line_magic('load_ext', 'line_profiler')
# import line_profiler
# %load_ext line_profiler
# sys.version_info


In [None]:
# This code is not functional without individual level data!

# LinkageData

In [9]:


"""
LinkageData
durr tst
"""

import scipy as sp
import numpy as np
from scipy import linalg
from sys import getsizeof

random = []
from collections import OrderedDict
from pysnptools.standardizer import Unit
import pysnptools
import warnings
from collections import deque, defaultdict
import pandas as pd
import pysnptools.util as pstutil
from pysnptools.standardizer import UnitTrained


class SqrtNinv(Unit):
    def __init__(self):
        super(SqrtNinv, self).__init__()


class BaseLinkageData():

    def __init__(self, *, sst_df, regdef_df, n_samples_sst=None,
                 srd=None, bim_df=None, sda_standardizer=Unit,
                 prd=None, fam_df=None, pda_standardizer=Unit,
                 lrd=None, lda_standardizer=None,
                 grd=None, gda_standardizer=False,
                 singurd=None,
                 distal_linkage='shiftblocks', shift=0, cm=None, do_setzero=True,
                 allow_onthefly_linkage_gen=False,
                 clear_decomp_local_linkage=False, 
                 clear_xda=True,
                 clear_linkage=False,
                 compute_sumstats=False,
                 calc_allelefreq=False,
                 always_unit=True,
                 _region_filter_fun=None,
                 gb_size_limit=10., dtype='float32', verbose=False):

        # bim and fam df have to be supplied because pysnptools halvely
        # implemented these portions of the genetic data into their object
        # meaning that srd cannot be relied uppon

        # New rule: blx have to be created from the inside
        # Perhaps later it can be made into a special load instead of a compute

        # Initial checks:
        assert type(sst_df) is pd.DataFrame
        if cm is not None: assert cm > 0
        
        self.regdef_df = regdef_df  # checks?
        # Checks that are needed: region ranges count upward, check if regid is present, else make it.
        # assert chrom contains no None's and nans

        self.srd = srd  # SNP reader
        self.prd = prd
        self.grd = grd
        self.singurd = singurd # Singular value ReaDer.
        self.lrd = lrd
        if lrd is not None: raise NotImplementedError('lrd not possible atm.')
        # This should go different in the future:
        # https://github.com/scikit-learn/scikit-learn/blob/7e1e6d09bcc2eaeba98f7e737aac2ac782f0e5f1/sklearn/base.py#L31
        # clone objects instead of messing with classes.. skl -> clone(estimator)
        self.sda_Standardizer = sda_standardizer  # Pysnptools standardizer
        self.pda_Standardizer = pda_standardizer
        self.lda_Standardizer = lda_standardizer
        if grd is not None:
            assert gda_standardizer or (gda_standardizer is None)
        self.gda_Standardizer = gda_standardizer
        self.dtype = dtype  # dtype to work in
        self.verbose = verbose

        # Linkage Algo Settings:
        self.distal_linkage = distal_linkage  # Refering to the approach to determine left and right linkage
        self.shift = shift  # The integer left&right blockshift that is to be used
        self.cm    = cm
        self.do_setzero = do_setzero
        #if not allow_onthefly_linkage_gen:
            #raise NotImplementedError('blocks against on the fly comp have not been coded. contact dev.')
        self.allow_onthefly_linkage_gen = allow_onthefly_linkage_gen 
        #self.clear_decomp_local_linkage = clear_decomp_local_linkage
        self.clear_xda = clear_xda  # whether to clear the gda, sda, etc. when linkage is determined
        self.clear_linkage = clear_linkage
        self.always_unit = always_unit

        # Summstats:
        assert type(compute_sumstats) is bool
        self.compute_sumstats = compute_sumstats  # Determines whether sumstats should be computed.

        # DataFrames & Other non pst-reader objects:
        self.sst_df = sst_df  # Should become a bim compliant sst dataframe.
        self.bim_df = bim_df  # assert chrom contains no nons & more
        self.fam_df = fam_df  # Not really used atm.
        self.reg_dt = OrderedDict()
        self.regid2i_dt = OrderedDict()
        self.cur_total_size_in_gb = 0.0
        self.gb_size_limit = gb_size_limit
        self.xda_q = deque()
        [self.xda_q.append((-1,'')) for _ in range(5)]  # put 5x -1 in queue
        self.reloaded_xda_cnt = 0
        self.stansda_dt = OrderedDict()

        # Checks:
        if self.srd is not None:
            self._check_xrd()
        else:
            raise NotImplementedError('need srd for now')

        # Determination of # of samples in this data:
        if self.compute_sumstats:
            assert self.prd is not None
            self.n_samples_sst = self.prd.shape[0]
        else:
            assert n_samples_sst is not None
            self.n_samples_sst = n_samples_sst
        assert type(self.n_samples_sst) is int
        self.calc_allelefreq = calc_allelefreq

        # Init the regions:
        self.init_regions()

        # Filtering of regions functionality:
        if _region_filter_fun is None:
            self._region_filter_fun = self._default_region_filter_fun
        else:
            self._region_filter_fun = _region_filter_fun

    def _check_xrd(self):

        if self.srd is not None:
            assert pysnptools.snpreader.SnpReader in self.srd.__class__.__mro__

        if self.prd is not None:
            n_start = len(self.prd.iid)
            self.srd, self.prd = pstutil.intersect_apply([self.srd, self.prd])
            if len(self.prd.iid) != n_start:
                warnings.warn(f'Number of samples do not match up after internal intersection, samples were lost: {n_start - len(self.prd.iid)}, start = {n_start}, after_intersection = {len(self.prd.iid)}')

        if self.grd is not None:
            # Check alignment for now, auto alignment needs work cause iid stuffs:
            if self.srd is not None:
                if not np.all(self.grd.sid == self.srd.sid):
                    raise Exception('snps of grd and srd not matching up, align first,'
                                    ' auto align will be implemented later')
            else:
                raise NotImplementedError('Not sure what to do with grd if no srd is present. not implemented.')
        
    def init(self): 
        return self

    ###########################
    # Regions Administration:

    def init_regions(self):
        do_beta_moving = ('beta_mrg' in self.sst_df.columns)
        if not do_beta_moving:
            warnings.warn('No \'beta\' column detected in sst_df! This means that no summary stats were detected.')
        cur_chrom = None
        i = 0; n_snps_cumsum = 0
        sst_df_lst = []
        for reg_cnt, (_, row) in enumerate(self.regdef_df.iterrows()):
            # Move region into specialized dictionary
            regid = row['regid'];
            chrom = row['chrom']
            start = row['start'];
            stop  = row['stop']

            # Map Variants to region
            ind = self.sst_df.chrom == chrom
            ind = (self.sst_df['pos'] >= start) & ind
            ind = (self.sst_df['pos'] < stop) & ind
            sid = self.sst_df['snp'][ind].values
            indices = self.srd.sid_to_index(sid)  # if sid not strickly present this will give an error!
            n_snps_reg = len(indices)
            if n_snps_reg == 0:
                continue
            else:
                geno_dt = dict(regid=regid,
                               chrom=chrom,
                               start=start,
                               stop=stop,
                               start_j=n_snps_cumsum)
                n_snps_cumsum += n_snps_reg
                geno_dt['stop_j'] = n_snps_cumsum
                geno_dt['n_snps_reg'] = n_snps_reg
                sst_df = self.sst_df[ind].copy(); sst_df['i'] = i
                geno_dt['sst_df'] = sst_df
                assert geno_dt['start_j'] == sst_df.index[0]; sst_df_lst.append(sst_df)
                assert geno_dt['stop_j']  == sst_df.index[-1] + 1
                if do_beta_moving:
                    geno_dt['beta_mrg'] = geno_dt['sst_df']['beta_mrg'].copy().values[:, np.newaxis]
                    assert len(geno_dt['beta_mrg'].shape) == 2
                self.regid2i_dt[regid] = i
                if self.srd is not None:
                    geno_dt['srd'] = self.srd[:, indices]
                    geno_dt['stansda'] = self.sda_Standardizer() if self.sda_Standardizer is not None else None
                else:
                    raise NotImplementedError()
                if self.grd is not None:
                    geno_dt['grd'] = self.grd[:, indices]
                    geno_dt['stangda'] = self.gda_Standardizer() if self.gda_Standardizer is not None else None
                # Count up if things are actually stored in reg_dt
                self.reg_dt[i] = geno_dt
                i += 1
        self.n_snps_total = n_snps_cumsum
        sst_df = pd.concat(sst_df_lst, axis=0)
        self.sst_df = sst_df

    def get_i_regions_lst(self, filter_like_lst=None):
        if filter_like_lst is not None:
            raise NotImplementedError()
        return list(self.reg_dt.keys())

    def _default_region_filter_fun(self, *, geno_dt):
        filter_key = geno_dt['chrom']
        return filter_key

    def set_region_filter_fun(self, fun):
        raise NotImplementedError()

    def gen_filter_dt(self):
        filter_dt = defaultdict(list)
        for i, geno_dt in self.reg_dt.items():
            filter_key = self._region_filter_fun(geno_dt=geno_dt)
            if filter_key is None:
                continue
            else:
                filter_dt[filter_key].append(i)
        self.filter_dt = filter_dt
        return filter_dt

    def _load_all_snpdata(self):
        # load all regions
        for i, geno_dt in self.reg_dt.items():
            sda = geno_dt['srd'].read(dtype=self.dtype)
            stansda = sda.train_standardizer(apply_in_place=True,
                                             standardizer=geno_dt['stansda'])
            geno_dt['sda'] = sda
            geno_dt['stansda'] = stansda

    #####################################
    # General Linkage Retrieval Methods:

    ###########################
    ## Compute: ###############

    # Local Linkage Stuff: ####
    
    def compute_linkage_sameregion(self, *, i):
        return self.compute_linkage_shiftregion(i=i, shift=0)

    def regions_compatible(self, *, i, j):
        try:
            if self.reg_dt[i]['chrom'] == self.reg_dt[j]['chrom']:
                res = True
            else:
                res = False
        except Exception as e:
            if (not (i in self.reg_dt.keys())) or (not (j in self.reg_dt.keys())):
                res = False
            else:
                raise e
        return res

    def compute_linkage_shiftregion(self, *, i, shift):
        j = i + shift
        if self.regions_compatible(i=i, j=j):
            self_sda = self.get_sda(i=i)
            dist_sda = self.get_sda(i=j)
            n = len(self_sda.iid)
            S_shift = self_sda.val.T.dot(dist_sda.val) / n
            return S_shift
        else:
            self_sda = self.get_sda(i=i)
            return np.zeros((self_sda.val.shape[1], 0))
        
    def compute_linkage_cmfromregion(self, *, i, cm):            
        geno_dt = self.reg_dt[i]; lst = []
        if cm < 0: # Doing left:
            stop_j   = geno_dt['start_j']
            cm_left  = geno_dt['sst_df'].loc[stop_j]['cm'] 
            slc_df = self.sst_df.loc[:stop_j-1]
            slc_df = slc_df[slc_df.chrom==geno_dt['chrom']]
            slc_df = slc_df[slc_df.cm > (cm_left + cm)]
            start_i = slc_df['i'].min()
            start_i = -7 if np.isnan(start_i) else start_i
            for cur_i in range(start_i, i):
                lst.append(self.compute_linkage_shiftregion(i=i, shift=cur_i-i))
                if start_i == -7: break
            L = np.concatenate(lst, axis=1)[:,-slc_df.shape[0]:] # concat & clip
            if self.do_setzero:
                cms_reg    = geno_dt['sst_df']['cm'].values
                cms_distal = slc_df['cm'].values
                cms_L      =  cms_distal[np.newaxis,:] - cms_reg[:,np.newaxis]
                setzero_L  = cms_L < cm
                L[setzero_L] = 0
                assert L.shape == setzero_L.shape
            return L
        else:
            start_j   = geno_dt['stop_j']
            cm_right  = geno_dt['sst_df'].loc[start_j-1]['cm']
            slc_df = self.sst_df.loc[start_j:]
            slc_df = slc_df[slc_df.chrom==geno_dt['chrom']]
            slc_df = slc_df[slc_df.cm < (cm_right + cm)]
            stop_i = slc_df['i'].max()
            stop_i = i+2 if np.isnan(stop_i) else stop_i + 1
            for cur_i in range(i+1, stop_i):
                lst.append(self.compute_linkage_shiftregion(i=i, shift=cur_i-i))
            R = np.concatenate(lst, axis=1)[:,:slc_df.shape[0]] # concat & clip
            if self.do_setzero:
                cms_reg    = geno_dt['sst_df']['cm'].values
                cms_distal = slc_df['cm'].values
                cms_R     =  cms_distal[np.newaxis,:] - cms_reg[:,np.newaxis]
                setzero_R = cms_R > cm
                R[setzero_R] = 0
                assert R.shape == setzero_R.shape
            return R
        
    # Glocal Linkage Stuff: ####
    
    def compute_linkage_decompshiftregion(self, *, i, shift):
        j = i + shift
        if self.regions_compatible(i=i, j=j):
            self_gda = self.get_gda(i=i)
            dist_gda = self.get_gda(i=j)
            # n = len(self_gda.iid)
            # There is not division by n here, important difference!!
            S_glocalshift = self_gda.val.T.dot(dist_gda.val)
            return S_glocalshift
        else:
            self_gda = self.get_gda(i=i)
            return np.zeros((self_gda.val.shape[1], 0))

    def compute_linkage_globalregion(self, *, i):
        gda = self.get_gda(i=i)
        Z = gda.val
        if gda.sqrtninvscaling:
            n_q = Z.shape[0]
            Z = Z / np.sqrt(n_q)
        return Z
    
    # NON Linkage Stuff: #######
    
    def compute_sumstats_region(self, *, i):
        geno_dt = self.reg_dt[i]
        sda = self.get_sda(i=i)
        X = sda.val
        y = self.get_pda().val
        n = len(y)
        c_reg = X.T.dot(y) / n
        return c_reg
    
    def compute_allelefreq_region(self, *, i):
        # Speed might be improved by using dot prod here, instead of sums
        # np.unique was way slower (5x)
        geno_dt = self.reg_dt[i]
        n, p_blk = sda.val.shape
        sst_df = geno_dt['sst_df'].copy()
        cnt0   = np.sum(sda.val==0, axis=0)
        cnt1   = np.sum(sda.val==1, axis=0)
        cnt2   = np.sum(sda.val==2, axis=0)
        cntnan = np.sum(np.isnan(sda.val), axis=0)
        assert np.allclose(cnt0 + cnt1 + cnt2 + cntnan, n)
        sst_df['altcnt=0']   = cnt0
        sst_df['altcnt=1']   = cnt1
        sst_df['altcnt=2']   = cnt2
        sst_df['altcnt=nan'] = cntnan
        sst_df['altfreq']    = (cnt1 + cnt2)/(n - cntnan)
        sst_df['missfreq']   = 1 - cntnan/n
        return sst_df
    
    def compute_ldscores_region(self, *, i):
        sst_df = self.reg_dt[i]['sst_df'].copy()
        L = self.get_left_linkage_region(i=i)
        D = self.get_auto_linkage_region(i=i)
        R = self.get_right_linkage_region(i=i)
        for k, j in enumerate(sst_df.index):
            slds = np.sum(L[k]**2) + np.sum(D[k]**2) + np.sum(R[k]**2)
            sst_df.loc[j, 'lds'] = np.sqrt(slds)
        return sst_df
        

    ############################
    ## Retrieve: ###############
    
    # Local Linkage: ############
    
    def retrieve_linkage_allregions_auto(self, compute_sumstats=False):
        for i, geno_dt in self.reg_dt.items():
            geno_dt['D'] = self.compute_linkage_sameregion(i=i)
            if compute_sumstats:
                geno_dt['beta_mrg'] = self.compute_sumstats_region(i=i)
    
    def retrieve_linkage_allregions_shiftwindow(self):
        for i, geno_dt in self.reg_dt.items():
            print(f'Processing region #{i} on chr{geno_dt["chrom"]}', end='\r') if self.verbose else None
            self.retrieve_linkage_region_shiftwindow(i=i)
        if self.clear_xda:
            self.clear_all_xda()

    def retrieve_linkage_selectedregions_shiftwindow(self, *, filter_i_lst):
        filter_i_set = set(filter_i_lst)
        for i in self.reg_dt.keys():
            geno_dt = self.reg_dt[i]
            if i in filter_i_set:
                print(f'Processing region #{i} on chr{geno_dt["chrom"]}', end='\r') if self.verbose else None
                self.retrieve_linkage_region_shiftwindow(i=i)
            elif self.compute_sumstats:  # Padding of skipped regions
                geno_dt['beta_mrg'] = np.zeros((geno_dt['n_snps_reg'], 1), dtype=self.dtype)
        if self.clear_xda:
            self.clear_all_xda()

    def retrieve_linkage_region_shiftwindow(self, *, i):
        shift = self.shift; cm = self.cm
        compute_sumstats = self.compute_sumstats
        geno_dt = self.reg_dt[i]
        if 'L' in geno_dt.keys():
            if 'D' in geno_dt.keys():
                if 'R' in geno_dt.keys():
                    return None  # everything is done now.
                
        if (shift > 0):
            L_lst = []
            R_lst = []
            for cur_shift in range(1, shift + 1):
                L_lst.append(self.compute_linkage_shiftregion(i=i, shift=-cur_shift))
                R_lst.append(self.compute_linkage_shiftregion(i=i, shift=cur_shift))

            # Store Linkage in geno_dt
            geno_dt['L'] = np.concatenate(L_lst[::-1], axis=1)  # L stands for left
            geno_dt['D'] = self.compute_linkage_sameregion(i=i)  # Linkage within region, D is convention from LDpred 1
            geno_dt['R'] = np.concatenate(R_lst, axis=1)  # R stands for right

            # Indices needed for slicing and dicing matched variables (e.g. beta weights):
            geno_dt['start_j_L'] = geno_dt['start_j'] - geno_dt['L'].shape[1]
            geno_dt['stop_j_L'] = geno_dt['start_j']
            geno_dt['start_j_R'] = geno_dt['stop_j']
            geno_dt['stop_j_R'] = geno_dt['stop_j'] + geno_dt['R'].shape[1]
            
        elif (shift==0) and (cm is None):  # Only same region has to be done.
            geno_dt['D'] = self.compute_linkage_sameregion(i=i)
        elif (shift==0) and cm > 0:
            geno_dt['L'] = self.compute_linkage_cmfromregion(i=i, cm=-cm)
            geno_dt['D'] = self.compute_linkage_sameregion(i=i)
            geno_dt['R'] = self.compute_linkage_cmfromregion(i=i, cm=cm)
            
            # Indices needed for slicing and dicing matched variables (e.g. beta weights):
            geno_dt['start_j_L'] = geno_dt['start_j'] - geno_dt['L'].shape[1]
            geno_dt['stop_j_L'] = geno_dt['start_j']
            geno_dt['start_j_R'] = geno_dt['stop_j']
            geno_dt['stop_j_R'] = geno_dt['stop_j'] + geno_dt['R'].shape[1]
            
        if compute_sumstats:
            self.retrieve_sumstats_region(i=i)

    # Global Linkage: #######
    
    def retrieve_linkage_allregions_global(self):
        for i, geno_dt in self.reg_dt.items():
            print(f'Processing region #{i} on chr{geno_dt["chrom"]}', end='\r') if self.verbose else None
            self.retrieve_linkage_region_global(i=i)

    def retrieve_linkage_selectedregions_global(self, *, filter_i_lst):
        raise NotImplementedError('Contact dev.')

    def retrieve_linkage_region_global(self, *, i):
        geno_dt = self.reg_dt[i]
        if 'Z' in geno_dt.keys():
            return None  # All is done already.
        geno_dt['Z'] = self.compute_linkage_globalregion(i=i)
        
    def retrieve_linkage_allregions_all(self):
        self.retrieve_linkage_allregions_shiftwindow()
        if self.grd is not None:
            self.retrieve_linkage_allregions_global()

        
    # SumStat: ##############

    def retrieve_sumstats_allregions(self):
        for i, geno_dt in self.reg_dt.items():
            self.retrieve_sumstats_region(i=i)
            
    def retrieve_sumstats_region(self, *, i):
        geno_dt = self.reg_dt[i] 
        sst_df  = geno_dt['sst_df']
        if 'beta_mrg' in geno_dt.keys():
            return None # Sumstat present to ne need to compute anything.
        geno_dt['beta_mrg'] = self.compute_sumstats_region(i=i)
        if not 'beta_mrg' in sst_df.columns:
            geno_dt['sst_df']['beta_mrg'] = geno_dt['beta_mrg']
            
    def retrieve_betamrg_region(self, *, i):
        geno_dt = self.reg_dt[i] 
        sst_df  = geno_dt['sst_df']
        if 'beta_mrg' in geno_dt.keys():
            return None # Sumstat present to ne need to compute anything.
        geno_dt['beta_mrg'] = self.compute_sumstats_region(i=i)
        if not 'beta_mrg' in sst_df.columns:
            geno_dt['sst_df']['beta_mrg'] = geno_dt['beta_mrg']
            
    def retrieve_ldscores_allregions(self):
        for i, geno_dt in self.reg_dt.items():
            self.retrieve_ldscores_region(i=i)
            
    def retrieve_ldscores_region(self, *, i):
        geno_dt = self.reg_dt[i]
        sst_df = geno_dt['sst_df']
        if not 'lds' in sst_df.columns:
            newsst_df = self.compute_ldscores_region(i=i)
            geno_dt['sst_df'] = newsst_df
        if self.clear_linkage:
            self.clear_linkage_region(i=i)
            
            
    # Clearance Function: #####

    def clear_all_xda(self):
        while len(self.xda_q) != 0:
            i_2_rm, key = self.xda_q.popleft()
            if i_2_rm == -1:
                continue  # Continue to next iter if encountering a padding -1
            rmgeno_dt = self.reg_dt[i_2_rm]
            self.cur_total_size_in_gb -= getsizeof(rmgeno_dt[key].val) / 1024 ** 3
            rmgeno_dt.pop(key)
        [self.xda_q.append((-1,'')) for _ in range(5)]  # put 5x -1 in queue
        
    def clear_linkage_region(self, *, i):
        geno_dt = self.reg_dt[i]
        filter_lst = ['L','D','R',
                      'L_decomp','D_decomp','R_decomp',
                      'L_glocal','D_glocal','R_glocal',
                      'Z']
        key_lst = list(geno_dt.keys())
        for key in key_lst:
            if key in filter_lst:
                geno_dt.pop(key)
        

    ############################
    ## Get: ####################

    def get_auto_linkage_region(self, *, i):
        return self.get_specificied_linkage_region(i=i, shiftletter='D')

    def get_left_linkage_region(self, *, i):
        return self.get_specificied_linkage_region(i=i, shiftletter='L')

    def get_right_linkage_region(self, *, i):
        return self.get_specificied_linkage_region(i=i, shiftletter='R')

    def get_specificied_linkage_region(self, *, i, shiftletter):
        try:
            return self.reg_dt[i][shiftletter]
        except KeyError as e:
            if self.allow_onthefly_linkage_gen:
                if '_glocal' in shiftletter:
                    self.retrieve_linkage_region_glocalshiftwindow(i=i)
                elif shiftletter in 'LDR':
                    self.retrieve_linkage_region_shiftwindow(i=i)
                elif shiftletter == 'Z':
                    self.retrieve_linkage_region_global(i=i)
                else:
                    raise Exception(f'shiftletter={shiftletter}, on-the-fly loading not possible ')
                try:
                    return self.reg_dt[i][shiftletter]
                except Exception as e:
                    print('Fail eventough trying to do on the spot recovery.')
                    raise e
            else:
                raise NotImplementedError('on-the-fly compute blocked, enable if desired')
            ################################################################## NOOOW NEEDED FOR PPB, SWAPPING NEEDED
            # Here respective retrieval should be activated
            #return self.reg_dt[i][shiftletter]

    def get_auto_range_region(self, *, i):
        return self.reg_dt[i]['start_j'], self.reg_dt[i]['stop_j']

    def get_left_range_region(self, *, i):
        return self.reg_dt[i]['start_j_L'], self.reg_dt[i]['stop_j_L']

    def get_right_range_region(self, *, i):
        return self.reg_dt[i]['start_j_R'], self.reg_dt[i]['stop_j_R']

    # Linkage Utility functions: sda, gda, gda
    def get_sda(self, *, i):
        geno_dt = self.reg_dt[i]
        if 'sda' in geno_dt.keys():
            return geno_dt['sda']
        else:
            if 'srd' in geno_dt.keys():
                sda = geno_dt['srd'].read(dtype=self.dtype)
                sda, stansda = sda.standardize(standardizer=geno_dt['stansda'], return_trained=True)
                geno_dt['sda'] = sda
                geno_dt['stansda'] = stansda
                
                if 'loaded_sda' in geno_dt.keys():
                    self.reloaded_xda_cnt += 1
                    if self.reloaded_xda_cnt in [5, 20, 100, 400]:
                        warnings.warn(
                            f'Reloaded sda for the {self.reloaded_xda_cnt}\'th time. This causes memory swapping,'
                            ' that might make the computation of linkage quite slow.'
                            'Probably because memory limits and/or linkage size.')
                # Size determination and accounting:
                geno_dt['loaded_sda']=True
                self.cur_total_size_in_gb += getsizeof(sda.val) / 1024 ** 3
                self.xda_q.append((i,'sda'))  # put respective i in queue.
                while self.cur_total_size_in_gb > self.gb_size_limit:  # Keep removing till size is ok
                    i_2_rm, key = self.xda_q.popleft()
                    if i_2_rm == -1:
                        continue  # Continue to next iter if encountering a padding -1
                    rmgeno_dt = self.reg_dt[i_2_rm]
                    self.cur_total_size_in_gb -= getsizeof(rmgeno_dt[key].val) / 1024 ** 3
                    rmgeno_dt.pop(key)
                    if len(self.xda_q) <= 4:
                        raise Exception('The memory footprint of current settings is too high, '
                                        'reduce blocksize and/or correction windows or increase memory limits.')
                return sda
            else:
                raise Exception(f'No srd or sda found in region i={i}, this is not supposed to happen.')

    def get_pda(self):
        if not hasattr(self, 'pda'):
            pda = self.prd.read(dtype=self.dtype)
            pda, self.stanpda = pda.standardize(return_trained=True,
                            standardizer=self.pda_Standardizer())
            self.pda = pda
        return self.pda

    def get_beta_marginal_full(self):
        beta_mrg_lst = []
        for i, geno_dt in self.reg_dt.items():
            beta_mrg_lst.append(geno_dt['beta_mrg'])
        beta_mrg_full = np.concatenate(beta_mrg_lst)
        return beta_mrg_full

    def get_beta_marginal_region(self, *, i):
        return self.reg_dt[i]['beta_mrg']

    def get_cur_sumstats_dataframe(self):
        sst_df_lst = []
        for i, geno_dt in self.reg_dt.items():
            #if self.verbose: print(f'computing sumstats for region {i}', end='\r')
            #self.retrieve_sumstats_region(i=i)
            sst_df = geno_dt['sst_df']
            sst_df_lst.append(sst_df)
        sst_df = pd.concat(sst_df_lst, axis=0)
        return sst_df

    # Standardisation Functionality:
    def get_combined_unit_stansda(self):
        if hasattr(self, 'stansda'):
            if type(self.stansda) is UnitTrained:
                return self.stansda

        standardizer_list = []
        for i, geno_dt in self.reg_dt.items():
            if 'stansda' in geno_dt.keys():
                if type(geno_dt['stansda']) is UnitTrained:
                    standardizer_list.append(geno_dt['stansda'])

        test = np.all([type(stan) is UnitTrained for stan in standardizer_list])
        assert test
        sid = np.concatenate([stan.sid for stan in standardizer_list])

        test = np.unique(sid).shape[0] == sid.shape[0]
        assert test

        stats = np.concatenate([stan.stats for stan in standardizer_list], dtype=self.dtype)
        combined_unit_standardizer = UnitTrained(sid, stats)
        self.stansda = combined_unit_standardizer
        return combined_unit_standardizer

    def get_combined_stansda(self):
        if self.always_unit:
            return self.get_combined_unit_stansda()
        else:
            raise NotImplementedError('contact dev')


class LinkageData(BaseLinkageData):
    pass

class DelayedLinkageData():
    def __init__(self, *, generating_fun, args=None, kwargs=None, test=False):
        self.generating_fun = generating_fun
        self.args = args if args is not None else list()
        self.kwargs = kwargs if kwargs is not None else dict()
        assert test is False  # Later we can make a pre testing thing.
        self.test = test

    def init(self):
        linkdata = self.generating_fun(*self.args, **self.kwargs)
        assert type(linkdata) is LinkageData
        return linkdata





# Experimental Setup

In [None]:

# Utility fun:
def nonenan_casting_and_copy_fun(dt):
    dt = copy.deepcopy(dt)
    new_dt = dict()
    for key, item in dt.items():
        if (item == 'none') or np.isnan(item):
            item = None
            
        new_dt[key] = item
    return new_dt

def make_pgs_df(dn='../results/betas/'):

    # Load all PGS's:
    tic()
    sbayesr_df = pd.read_csv(dn + 'final/sbayesr.csv').add_prefix('sbayesr_'); toc(1)
    prscs_df  = pd.read_csv(dn + 'final/prscs.csv').add_prefix('prscs_'); toc(2)
    ldpred2_df  = pd.read_csv(dn + 'ldpred2/ldpred2-selectpred.csv').add_prefix('ldpred2_'); toc(3)
    lassosum_df = pd.read_csv(dn + 'lassosum/lassosum-pseudobetas.csv').add_prefix('lassosum_');
    
    # Combine into one big pgs dataframe:
    pgs_df = pd.concat([sbayesr_df, prscs_df, ldpred2_df, lassosum_df], axis=1); toc(4)
    
    return pgs_df


# Make Phenos:
def mapperfun(arg):
    if 'sim_' in arg:
        exitstr = '.'+'.'.join((arg.split('.')[-2:]))
        return 'SIM' + re.search("_i=\\s*(.*?)\\s*_.", arg).group(1) + exitstr
    else:
        return arg
mapper_dt = dict()
def make_pheno_df(dn = '../lnk/data/ukbb/imp/pheno/', fold='test'):
    
    # Load Pheno data into dataframe:
    ddf = dd.read_table(dn+f'*.{fold}.pheno', header=None, include_path_column=True)
    df = ddf.compute()

    # Process:
    df['path'] = df['path'].str.split('/').str[-1].str.replace('.pheno','')
    df = df.rename(columns={0:'fid', 1:'iid', 2:'pheno'})
    df['iid'] = df['iid'].astype(str)
    df = df.set_index(['fid','iid']); ori_index = df.index
    df = pd.concat({key: sub_df['pheno'] for key, sub_df in df.groupby('path')}, axis=1, join='inner')
    assert np.all(df.index == ori_index[:len(df.index)])
    mapper_df = pd.DataFrame(df.columns.to_series().apply(mapperfun), index=df.columns, columns=['map'])
    mapper_dt[fold] = mapper_df
    df.columns = mapper_df['map']
    pheno_df = df
    
    return pheno_df


# The setup of xp:
def experimental_setup(return_variable=None, trait='Asthma', adj_type='m16', geno_fn='UKBB_imp_HM3', fold='test',
                       regdef_key='1blk_shift=0',shift=0, cm=None, random_state=42, gb_size_limit=10, dtype='float32'
                       ):
    
    # Defining Paramters:
    cfg = Struct()
    cfg_dt = dict()
    cfg_dt.update(dict(

        test_fn    = f'pheno/{trait}.{adj_type}.{fold}.pheno',
        geno_fn    = geno_fn,
        base_fn    = '../lnk/data/ukbb/imp/',
        regdef_dn  = '../lnk/data/regdef/',
        n_pca        = None,
        random_state = random_state,
        dtype        = dtype
        
    ))
    cfg.update(cfg_dt)
    
    # Computing full paths:
    plink_fn      = cfg.base_fn + cfg.geno_fn
    pheno_tst_fn  = cfg.base_fn + cfg.test_fn
    regdef_fn     = cfg.regdef_dn + f'regions_{regdef_key}.regdef.tsv'

    # Load SNPs filter list & region definition:
    regdef_df = pd.read_csv(regdef_fn, delimiter='\t') 
    
    # Load Genotype data (mostly 'implicit')
    bim_df, fam_df = load_bimfam(plink_fn, fil_arr=None)
    df = pd.read_csv('../data/ukbb/mapper.csv') # a bunch of ugly stuff to get cM's in
    assert np.all(df[['chromosome','marker.ID']].values == bim_df[['chrom','snp']].values)
    mapper_df = df
    bim_df['cm'] = mapper_df['cm']
    tot_srd  = Bed(plink_fn, count_A1=True)
    
    # Load Pheno & SNP Reader (=implicit loading):
    tst_prd  = Pheno(pheno_tst_fn) # Test
    tst_srd, tst_prd = pstutil.intersect_apply([tot_srd, tst_prd])
    
    # Load test linkdata, perfectly prepped for PPBMeasureComputer
    tst_linkdata = LinkageData(regdef_df=regdef_df, sst_df=bim_df.copy(), 
                               bim_df=bim_df, srd=tst_srd, prd=tst_prd, 
                               grd=None, gda_standardizer=None, gb_size_limit=gb_size_limit,
                               compute_sumstats=True, verbose=True, shift=shift, cm=cm, 
                               allow_onthefly_linkage_gen=True, dtype=cfg.dtype);
    
    if return_variable is None: # Return everything
        return locals()
    else:
        return locals()[return_variable]

    
nrows = None
big_dt = dict()
for fn in tqdm(glob.glob('../results/betas/*/*.csv')):
    df = pd.read_csv(fn, nrows=nrows)
#     df = pd.read_csv(fn)
    for col in df.columns:
        big_dt[col] = df[col].values.astype('float32')
        

df = pd.DataFrame(big_dt.keys())
bnfo_df = df[0].str.split('_').apply(pd.Series)
bnfo_df.columns = ['method','mtype','params', 'pheno']
bnfo_df = bnfo_df[~bnfo_df.params.str.contains('sp=yes')]
bnfo_df = bnfo_df.sort_values(['pheno','method','mtype'])
bnfo_df.head(4)

# Optional:
with open('big_dt.pkl', 'wb') as f:
    pickle.dump(big_dt, f, protocol=pickle.HIGHEST_PROTOCOL)
with open('big_dt.pkl', 'rb') as handle:
    b = pickle.load(handle)


key_lst = bnfo_df.apply(lambda x: '_'.join(x), axis=1).to_list()
prepgs_df=pd.DataFrame()
for key in tqdm(key_lst):
    prepgs_df[key] = big_dt[key]


# ## Save Betas in Pst with wide matrix
xp_dt = experimental_setup(geno_fn='UKBB_imp_HM3.valid', shift=1, fold='val', gb_size_limit=100.)
modelstr_arr = prepgs_df.columns.values.astype(str)
sid = xp_dt['tst_linkdata'].srd.sid
pstda = PstData(row=modelstr_arr, col=sid, val=prepgs_df.values.T)
PstHdf5.write('../results/betas/final/all-betas.pst.h5', pstda, col_major=True)






# Evaluate PGS with PPB Approaches

## Metrics Computers

In [None]:

import numpy as np
from scipy.stats import pearsonr
from pysnptools.standardizer import UnitTrained

locals_dt = dict()    


class PrivacyPreservingMetricsComputer():
    
    def __init__(self, *, linkdata, brd, s, Bm, dtype='float32', cov_method='local', 
                 clear_linkage=True, verbose=True):
        
        self.linkdata   = linkdata
        self.brd        = brd
        assert (np.isnan(s).sum()+np.isinf(s).sum()) == 0
        self.s          = s
        self.Bm         = Bm 
        self.dtype      = dtype
        self.cov_method = cov_method
        self.clear_linkage = clear_linkage
        self.verbose    = verbose
        
        self._do_global = False
        self._do_local  = False
        if cov_method == 'local': # Local Residualized Marginals
            self._do_local  = True
        elif cov_method == 'global': # Global Residualized Marginals
            self._do_global = True
        elif cov_method == 'glocal': # Global Local Residualized Marginals
            self._do_global = True
            self._do_local  = True
        else:
            raise Exception(f'Option not recognized: \'{cov_method}\' ')
            
    def evaluate(self, debug=False):
        
        # Load and init variables:
        linkdata = self.linkdata
        #linkdata = self.linkdata.init() # init, in case required.
        brd = self.brd; s = self.s; Bm = self.Bm
        bCb = 0.; BmBt = 0.
        info_dt = dict()

        # Cycle through the blocks:
        for i, geno_dt in tqdm(linkdata.reg_dt.items()):
            if self.verbose: print(f'PPB: Processing region {i}', end='\r')

            # Ready the LD:
            L = linkdata.get_left_linkage_region(i=i)
            D = linkdata.get_auto_linkage_region(i=i)
            R = linkdata.get_right_linkage_region(i=i)
            lr = linkdata.get_left_range_region(i=i)
            ar = linkdata.get_auto_range_region(i=i)
            rr = linkdata.get_right_range_region(i=i)

            # Ready The Weights:
            B_L = brd[:,lr[0]:lr[1]].read().val.astype(self.dtype).T
            B_D = brd[:,ar[0]:ar[1]].read(dtype=self.dtype).val.T
            B_R = brd[:,rr[0]:rr[1]].read(dtype=self.dtype).val.T
            B_L = s[lr[0]:lr[1]]*B_L
            B_D = s[ar[0]:ar[1]]*B_D
            B_R = s[rr[0]:rr[1]]*B_R

            # Do the computation:
            CB = L.dot(B_L) + D.dot(B_D) + R.dot(B_R)
            bCb += (B_D*CB).sum(axis=0)
            BmBt += (B_D.T.dot(Bm.iloc[ar[0]:ar[1],:])).T
            info_dt[i] = dict(shapeL=L.shape, shapeD=D.shape, shapeR=R.shape, 
                              lr=lr, ar=ar, rr=rr)

            # Pruning to minimize memory overhead:
            if (i > 0) and self.clear_linkage:
                linkdata.clear_linkage_region(i=i-1)
            if (i > 38) & debug:
                break
                return locals()

        # Complete resutls:
        linkdata.clear_all_xda()
        cols = brd.row.astype(str).flatten()
        bCb  = pd.DataFrame(bCb[np.newaxis,:], index=['bCb'], columns=cols)
        BmBt = pd.DataFrame(BmBt, index=Bm.columns, columns=cols)
        ppbr2_df = (BmBt**2)/bCb.loc['bCb']
        res_dt = dict(ppbr2_df=ppbr2_df, bCb=bCb, BmBt=BmBt, info_dt=info_dt, s=s)

        return res_dt
    
# locals_dt = dict()
class MultiPGSModel():
    
    def __init__(self, *, brd, unscaled=True, verbose=False, dtype='float32', allow_nan=False):
        self.brd   = brd
        if hasattr(brd, 'val'):
            assert np.sum(np.isnan(brd.val)) == 0 
        self.unscaled = unscaled
        self.verbose = verbose
        self.dtype  = dtype
        self.allow_nan = allow_nan
        
    def predict(self, *, srd, prd=None, n_inchunk=1000, stansda=None):
            
        # Load that PGS (& optionaly phenos)
        brd = self.brd
        Yhat = np.zeros((srd.shape[0], brd.shape[0]), dtype=self.dtype)        
        assert np.all(brd.col.astype(str) == srd.sid)
        if prd: 
            assert np.all(srd.iid == prd.iid)
            pda   = prd.read(dtype=self.dtype).standardize()
            Ytru  = pda.val
            Bm    = np.zeros((srd.shape[1], Ytru.shape[1])) + np.nan
        
        # Loop through Genome:
        stansda_lst = []; start=0
        for start in tqdm(range(0, srd.shape[1], n_inchunk)):
            stop = min(start+n_inchunk, srd.shape[1])
            sda, stansda = srd[:,start:stop].read(dtype=self.dtype).standardize(return_trained=True)
            X = sda.val
            s = stansda.stats[:,1][:,np.newaxis]; s[np.isinf(s)] = 1
            B = s*brd[:,start:stop].read(dtype=self.dtype).val.T
            Yhat += X@B
            if prd: Bm[start:stop] = X.T@Ytru
            stansda_lst.append(stansda)
            
        if prd:    
            Bm = Bm/Ytru.shape[0]
            if not self.allow_nan: assert np.isnan(Bm).sum() == 0
            if not self.allow_nan: assert np.isnan(Bm).sum() == 0
            Bm = pd.DataFrame(Bm, index=srd.sid, columns=prd.col)
            Ytru = pd.DataFrame(Ytru, # Make Ytru a proper dataframe
                index=pd.MultiIndex.from_arrays(prd.iid.T, names=('fid','iid')),
                columns=prd.col)
        else:
            Ytru=None; Bm=None
            
        # Combine Standardizers:
        sid     = np.concatenate([stan.sid   for stan in stansda_lst])
        assert  np.unique(sid).shape[0] == sid.shape[0]
        stats   = np.concatenate([stan.stats for stan in stansda_lst])
        stansda = UnitTrained(sid, stats)   
        s = stansda.stats[:,1][:,np.newaxis]; s[np.isinf(s)] = 1
        
        # Create Yhat dataframe:
        Yhat  = pd.DataFrame(
            data    = Yhat, 
            index   = pd.MultiIndex.from_arrays(srd.iid.T, names=('fid','iid')),
            columns = self.brd.row.astype(str)
        ); assert Yhat.isna().sum().sum() == 0
        
        res_dt = dict(Yhat=Yhat, Bm=Bm, brd=brd, Ytru=Ytru, stansda=stansda, s=s)
        
        return locals()
    
    def run(self):
        pass
    
    def fit(self):
        pass

## Job Processing:

In [None]:


brd = PstHdf5('../results/betas/final/all-betas.pst.h5')
ind = ['auto' in str(elem) for elem in brd.row]
bda = brd[ind,:].read()
pgs_df = pd.DataFrame(bda.val.T, index=bda.col.astype(str), columns=bda.row.astype(str))
pheno_df = make_pheno_df(fold='test')



# [chaindt(minidt for minidt in dt.values()) for dt in params_lst]
def nans2nones(in_df):
    isna = in_df.isna().values
    vals = in_df.values; vals[isna] = None
    return pd.DataFrame(vals, index=in_df.index, columns=in_df.columns)


ld_lst = [
    dict(cm=0.01, shift=0), #0
    dict(cm=2.0,  shift=0), #3
    dict(cm=10.0, shift=0), #6
]

fold_lst = [
    dict(fold='val', geno_fn='UKBB_imp_HM3.val'),
    dict(fold='test', geno_fn='UKBB_imp_HM3')
]

dtype_lst = [
    dict(dtype='float32'),
    dict(dtype='float64')
]

dt = dict(
    ld    = ld_lst,
    fold  = fold_lst,
    dtype = dtype_lst
)

params_lst = list(ParameterGrid(dt))
df = pd.DataFrame(params_lst)
param_df = pd.concat([df[cols].apply(lambda x: pd.Series(x, dtype='object')) for cols in df.columns], axis=1)
param_df = nans2nones(param_df)
# param_df.T.to_dict()
param_df


# # Define Internal Executor:
debug = False; print('debug =', debug)
folder = "../lnk/menno/log_test/%j"
executor = submitit.AutoExecutor(folder=folder)
executor.update_parameters(slurm_mem='119G', cpus_per_task=15, slurm_time='11:54:00',
                           slurm_additional_parameters={'account': 'NCRR'})

if debug:
    make_brd_df = lambda : brd 
else:
    def make_brd_df(fn='../results/betas/final/all-betas.pst.h5'):
        brd = PstHdf5(fn)
        return brd
def debug_make_pheno_df(fold='test'):
    return pheno_df
wrapped_make_pheno_df = debug_make_pheno_df if debug else make_pheno_df

job_dt = dict()
for i, cfg_dt in param_df.T.to_dict().items(): 

    def ppb_fun():
        xp_dt = experimental_setup(geno_fn=cfg_dt['geno_fn'], fold=cfg_dt['fold'], shift=cfg_dt['shift'], gb_size_limit=19., cm=cfg_dt['cm'], dtype=cfg_dt['dtype']) 
        pheno_df = wrapped_make_pheno_df(fold=cfg_dt['fold'])
        brd = make_brd_df()

        tst_prd = Pheno(dict(iid=pheno_df.index.to_frame()[['fid','iid']].values.astype(str),
                   vals=pheno_df.values, header=list(pheno_df.columns)))
        cur_srd, cur_prd = pstutil.intersect_apply([xp_dt['tst_srd'], tst_prd])

        pgm = MultiPGSModel(brd=brd, verbose=True, dtype=cfg_dt['dtype'])
        pred_dt = pgm.predict(srd=cur_srd, prd=cur_prd)

        linkdata = xp_dt['tst_linkdata']
        mc = PrivacyPreservingMetricsComputer(linkdata=linkdata, brd=brd, s=pred_dt['s'], Bm=pred_dt['Bm'], cov_method='local', dtype=cfg_dt['dtype'], clear_linkage=False)
        mcres_dt = mc.evaluate(debug=debug)
        
        C = corr(pred_dt['Ytru'].astype(cfg_dt['dtype']), pred_dt['Yhat'].astype(cfg_dt['dtype']))
        
        res_dt = dict(brd=pred_dt['brd'], C=C,
                      Bm=pred_dt['Bm'], stansda=pred_dt['stansda'], 
                      mcres_dt=mcres_dt, linkdata=linkdata)
        
        return res_dt
                           
    job_dt[i] = executor.submit(ppb_fun)
    print('submitted: ', i, end='\r')
print('done')
eval_job_dt = job_dt




# Process results:

In [None]:

tic(); iprev=-1
dtype_Y = 'float64'
res_dt = dict()
for i, job in eval_job_dt.items():
    try:
        if job.done():
            res_dt[i] = job.result()
            print(f'Job {i} success.  --> ', end='')
        else:
            hkergkjhegr
    except:
        print(f'Job {i} failed.', end='')
        continue
        
    if (iprev != -1):
        for key in ['Yhat','Ytru', 'Bm']:
            try:
                isclose = np.allclose(res_dt[iprev][key].values, 
                                      res_dt[i][key].values)
            except:
                isclose = False
            if isclose:
                res_dt[i][key] = res_dt[iprev][key]
                if key == 'Yhat':
                    res_dt[i]['C'] = res_dt[iprev]['C']
                print(key, end=', ')
    
    if not ('C' in res_dt[i].keys()):
        print('computing C', end='')
        res_dt[i]['C'] = corr(res_dt[i]['Ytru'].astype(dtype_Y), res_dt[i]['Yhat'].astype(dtype_Y))
    iprev = i; print()
toc()



if not ('pheno_df_dt' in locals().keys()):
    pheno_df_dt = dict()
    pheno_df_dt['val'] = make_pheno_df(fold='val')
    pheno_df_dt['test'] = make_pheno_df(fold='test')
quickload_dt = dict(pheno_df_dt=pheno_df_dt, res_dt=res_dt)


# Optional:
if 'quickload_dt' in locals().keys():
    with open('quickload_dt.pkl', 'wb') as f:
        pickle.dump(quickload_dt, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('quickload_dt.pkl', 'rb') as handle:
        quickload_dt = pickle.load(handle)
locals().update(quickload_dt)
get_ipython().system('echo $TEMP')



# Prevalence Stuff:
def compute_corfac(K):
    zscore = stats.norm.ppf(K)
    z2 = stats.norm.pdf(zscore)**2
    corfac = K*(1-K)/z2
    return corfac

def compute_nagelkerkefac(K):
    zscore = stats.norm.ppf(K)
    z2 = stats.norm.pdf(zscore)**2
    corfac = (K**(2*K))*((1-K)**(2*(1-K)))
    corfac = 1/(1-corfac)
    return corfac

def excl(arg,string, axis=1):
    return arg.filter(regex=f'^((?!{string}).)*$', axis=axis)


# Prevalence:
pheno_df = pheno_df_dt['test']
prev_df = pheno_df.filter(like='base', axis=1).mean(axis=0).to_frame(name='preval')
prev_df['corfac'] = prev_df['preval'].apply(compute_corfac)
prev_df['nagelkerke'] = prev_df['preval'].apply(compute_nagelkerkefac)
prev_df['flip'] = 1.0/prev_df['preval']
prev_df.index = prev_df.index.str.split('.').str[0]
pheno_lst = list(prev_df.index.unique())
# arr = pheno_df.columns.str.split('.', expand=True).to_frame().values

bigres_dt = dict(); adjppbr2_dt={}; adjvanr2_dt={}; mega_dt = {}
gb = sizegb(np.random.randn(1000,1000).astype('float32'))
# for i_res, item in res_dt.items(): 
for i_res, job in job_dt.items(): 
    
    try:
        res = res_dt[i_res] if i_res in res_dt.keys() else job.result()
        print('res', i_res, end=', ')
    except:
        print('continue, i_res: ', i_res, end=', ')
        continue
    
    # Start stuffs:
    fun = job_dt[i_res].submission().function
    cfg_dt = fun.__globals__['cfg_dt']
    locals().update(res)
    info_df = pd.DataFrame(mcres_dt['info_dt']).T
    gbs=(info_df.iloc[:,:3].applymap(lambda x: x[0]*x[1]).sum().sum()/1e6)*gb
    cfg_dt.update(gbs=gbs)
    ppbr2_df = mcres_dt['ppbr2_df']
    adjppbr2_df = (ppbr2_df.T*prev_df['corfac'].loc[ppbr2_df.index.str.split('.').str[0]].values)
    adjvanr2_df = ((C**2).T*prev_df['corfac'].loc[C.index.str.split('.').str[0]].values)
    
    # Store stuff:
    keys = ['adjvanr2_df', 'adjppbr2_df', 'cfg_dt']
    loc = locals()s
    bigres_dt[i_res] = { key : loc[key] for key in keys}
    cm = cfg_dt['cm']
    fold = cfg_dt['fold']
    if cm is None:
        continue
    adjppbr2_dt[(fold, f'cm={cm}')] = adjppbr2_df
    adjvanr2_dt[(fold, f'cm={cm}')] = adjvanr2_df
    if not (cm is None):
        mega_dt[(fold, cm, 'van')] = adjvanr2_df
        mega_dt[(fold, cm, 'ppb')] = adjppbr2_df
    



In [None]:

tic()
df = pd.concat(mega_dt, axis=1)
df.columns.names = ['fold','cm','flav','pheno']
df.index = df.index.str.split('_', expand=True)
df.index.names = ['method','app','params','phenoprs']
df = df.stack(['cm','flav'])
df.columns = df.columns.get_level_values('pheno').str.split('.', expand=True)
df.columns.names = ['pheno','adj','fold']
df = df.stack('pheno')
# idx_df = df.index.to_frame() # Other option
df = df.stack().stack()
df.name = 'adjr2'
df = df.reset_index()
df.pheno = df.pheno.str.replace('S.+\D(?=[0-9]$)', 'SIM0', regex=True)
df.phenoprs = df.phenoprs.str.replace('S.+\D(?=[0-9]$)', 'SIM0', regex=True)
display(df.head())
mega_df = df
toc()


adj = 'm16'
fullres_df = mega_df[mega_df.phenoprs == mega_df.pheno]
fullres_df.loc[:,'cm'] = fullres_df.cm * 2.0 
fullres_df = fullres_df.set_index(fullres_df.columns[:-1].to_list())
fullres_df = fullres_df.droplevel('phenoprs')
fullres_df = fullres_df.unstack(['adj','flav','fold'])['adjr2']
fullres_df = fullres_df.sort_index(axis=1)
fullres_df = fullres_df[adj]
fullres_df



In [None]:
# fullres_df contains all the 'normal' and PPB R^2 results.