In [None]:
#default_exp high_level

In [None]:
#export
from combinatorial_gwas.simulation import SimulatedPheno, SNPInfoUnit
from combinatorial_gwas.genotype import load_genetic_file
from combinatorial_gwas.data_catalog import get_catalog, get_config, get_parameters
from combinatorial_gwas.phenotypes import get_phenotype, get_GWAS_snps_for_trait, upsample_pheno

from typing import List, Union, Literal
import numpy as np
import logging
from tqdm.auto import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from dataclasses import dataclass

  return pd.read_csv(fs_file, **self._load_args)


In [None]:
#export

# def get_tune_train_validation_test_data(datasource):
#     x = datasource.get_data(slice(0, max_samples))
#     y = datasource.pheno_df_ordered.to_numpy()[:max_samples]
#     train_x, validation_x, train_y, validation_y = sklearn.model_selection.train_test_split(x, y, test_size=validation_split, stratify=y)
#     train_x, test_x, train_y, test_y = sklearn.model_selection.train_test_split(train_x, train_y, test_size=test_split, stratify=train_y)
#     train_x, tune_x, train_y, tune_y = sklearn.model_selection.train_test_split(train_x, train_y, test_size=tune_split, stratify=train_y)
#     tune_x, optimize_x, tune_y, optimize_y = sklearn.model_selection.train_test_split(tune_x, tune_y, test_size=optimize_split, stratify=tune_y)
#     return tune_x, optimize_x, tune_y, optimize_y, train_x, train_y, validation_x, validation_y, test_x, test_y

def get_info_dict(data, balance_pheno, name="test"):
    return_dict = {f"{name}_data": data.shape[0], 
                            f"{name}_data_n_controls": data[balance_pheno].value_counts().to_dict()[0],
                            f"{name}_data_n_cases": data[balance_pheno].value_counts().to_dict()[1],
                           f"{name}_data_frac_controls": data[balance_pheno].value_counts(normalize=True).to_dict()[0],
                           f"{name}_data_frac_cases": data[balance_pheno].value_counts(normalize=True).to_dict()[1]
                           }
    return return_dict

In [None]:
#export
catalog = get_catalog()
simulation_I83_queries_pheno_dict = catalog.load('simulation_I83_queries_pheno_dict')


In [None]:
simulation_I83_queries_pheno_dict

{('6:134911816_G_A',
  '6:26118570_T_C'): {' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 2.0))': SimulatedPheno(snps=[SNPInfoUnit(negation='not', snp_id='6:134911816_G_A', geno=0.0), SNPInfoUnit(negation='not', snp_id='6:26118570_T_C', geno=2.0)], op=['and'], query=' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 2.0))', pheno_col=5542886    0
  5137974    1
  3758348    1
  1391800    1
  3165331    1
            ..
  5512806    1
  5548469    1
  2956972    1
  5229561    1
  3665101    0
  Length: 487409, dtype: int64, case_count=404360, control_count=83049), ' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 0.0))': SimulatedPheno(snps=[SNPInfoUnit(negation='not', snp_id='6:134911816_G_A', geno=0.0), SNPInfoUnit(negation='not', snp_id='6:26118570_T_C', geno=0.0)], op=['and'], query=' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 0.0))', pheno_col=5542886    0
  5137974    0
  3758348    1
  1391800    1
  3165331    

In [None]:
#export

class snp_filter:
    SORT_PVALUE = 'pval'
    SORT_BETA = 'beta'
    def __init__(self, phenotype:str, sort:str, threshold=1e-5):
        self.phenotype = phenotype
        self.sort = sort
        self.threshold = threshold

@dataclass
class DataSplitParams:
    validation_split: float
    test_split: float
    tune_split: float
    validation_tune_split: float

default_data_split_params = DataSplitParams(validation_split=0.2, test_split=0.1, tune_split=0.01, validation_tune_split = 0.2)

class chromosome_datasource:
    def __init__(self, chromosomes:List[int] =list(range(1, 23)), snp_filters:List[snp_filter] =[snp_filter('I84', snp_filter.SORT_PVALUE)], samples: np.array =None, max_samples:int = 100_000, random_state:int =42, balance_pheno:str="I84", data_split_params:DataSplitParams = default_data_split_params):
        self.genome_files = {chromosome: load_genetic_file(chromosome) for chromosome in tqdm(chromosomes, "Loading genotype file(s)")}
        self.chromosomes = chromosomes
        self.phenotypes = [snp.phenotype for snp in snp_filters]
        self.data_split_params = data_split_params
        self.balance_pheno = balance_pheno
        self.max_samples = max_samples
        self.random_state = random_state
        
        if samples is None:
            self.samples = self.genome_files[chromosomes[0]].samples
        
        logging.warning("creating phenotype matrix")
        
        self.pheno_df_ordered = get_phenotype(self. phenotypes, samples = self.samples, max_samples = None)
        self.data_dict, self.data_split_info_df = self.split_data(self.pheno_df_ordered, self.balance_pheno, self.data_split_params, self.random_state, train_upsampled_max_samples = self.max_samples ) 
        logging.warning("finished creating phenotype matrix")
        
        self.all_snps_dict = {}
        self.snp_filters = snp_filters
        
        logging.warning(f"Loading SNPs list for {len(chromosomes)} chromosomes: {chromosomes=}")
        self.snp_full_info, self.snp_trait_info = self.load_snps_for_traits_all_chroms()
        logging.warning(f"Found {self.snp_full_info.shape[0]} SNPs associated with traits {self. phenotypes} for {chromosomes=}")
    
    
    def split_data(self, pheno_df, balance_pheno, data_split_params_obj, random_state, train_upsampled_max_samples):
        train_original, test = train_test_split( pheno_df, test_size=data_split_params_obj.test_split, stratify = pheno_df[balance_pheno], random_state = random_state)
        train_remaining, validation = train_test_split(train_original, test_size=data_split_params_obj.validation_split, stratify = train_original[balance_pheno], random_state = random_state)
        train_before_upsample, tune = train_test_split(train_remaining, test_size=data_split_params_obj.tune_split, stratify = train_remaining[balance_pheno], random_state = random_state)
        train_tune, validation_tune = train_test_split(tune, test_size=data_split_params_obj.validation_tune_split, stratify = tune[balance_pheno], random_state = random_state)
        if train_upsampled_max_samples == None:
            train_final = train_before_upsample
        else:
            train_tune = upsample_pheno(pheno_df = train_tune, balance_pheno = balance_pheno, max_samples = int(train_upsampled_max_samples * ((1 - data_split_params_obj.validation_tune_split)/4)), random_state = random_state)
            train_final = upsample_pheno(pheno_df = train_before_upsample, balance_pheno = balance_pheno, max_samples = train_upsampled_max_samples, random_state = random_state)

        data_dict = {"train_original": train_original,
                    "test": test,
                    "train_before_upsample": train_before_upsample,
                    "validation": validation,
                    "tune": tune,
                    "validation_tune": validation_tune,
                    "train_tune": train_tune,
                    "train": train_final}
        
        info_dict = {}
        for name, data in data_dict.items():
            info_dict = {**info_dict, **get_info_dict(data,balance_pheno = balance_pheno, name=name )}
        info_df = pd.DataFrame.from_dict(info_dict, orient="index", columns = [balance_pheno])
        
        return data_dict, info_df 
        
    def load_snps_for_traits_all_chroms(self):
        all_SNPs_df = []
        pbar = tqdm(self.snp_filters)
        for snp in pbar:
            pbar.set_description(f"Searching GWAS result for trait: {snp.phenotype}")
            trait_snps_df = get_GWAS_snps_for_trait(snp.phenotype, chromosome=self.chromosomes, id_only=False, sort_val_cols_list=snp.sort, ascending_bool_list=[snp.sort == snp_filter.SORT_PVALUE]).query(f'{snp.sort} {"<" if snp.sort == snp_filter.SORT_PVALUE else ">"} {snp.threshold}')#.sort_values('position')['full_id'].values
            trait_snps_df[["chr", "position"]] = trait_snps_df[["chr", "position"]].astype(int)
            trait_snps_df["trait"] =  snp.phenotype
            trait_snps_df["sort_cols"] = snp.sort
            trait_snps_df["threshold"] = snp.threshold
            all_SNPs_df.append(trait_snps_df)

        snp_full_info = pd.concat(all_SNPs_df).sort_values(["chr", "position"])
        snp_trait_info = snp_full_info.groupby("full_id")["trait"].agg(["unique", "nunique"])
        return snp_full_info, snp_trait_info
    
    def get_geno_matrix_specific_chrom(self, sample_id_subset, chrom):
        chrom_specific_variant_ids = self.snp_full_info.drop_duplicates("full_id").query(f"chr == {chrom}").sort_values("position")["full_id"].values
        logging.warning(f"Loading {len(chrom_specific_variant_ids)} SNPs for chromosome {chrom}, {sample_id_subset.shape[0]} people")
        genos = self.genome_files[chrom].get_geno_each_sample(prob_to_geno_func = "max", sample_ids=sample_id_subset, variant_ids=chrom_specific_variant_ids)
        return genos
    
    def get_sample_id_in_split(self, sample_slice, split:Literal["train", "validation", "train_tune", "validation_tune", "test"]):
        split_df = self.data_dict.get(split)
        split_df_subset = split_df.iloc[sample_slice, :]
        sample_id_subset = split_df_subset.index.values
        return split_df_subset, sample_id_subset
        
    def get_X(self, sample_id_subset, pbar):
        all_genos_all_chrom = []
        for chrom in pbar:
            genos = self.get_geno_matrix_specific_chrom(sample_id_subset, chrom)
            all_genos_all_chrom.append(genos)
        X = np.hstack(all_genos_all_chrom)
        return X
    
    def get_data(self, sample_slice: slice, split:Literal["train", "validation", "train_tune", "validation_tune", "test"]):
        import multiprocessing as mp
        
        split_df_subset, sample_id_subset = self.get_sample_id_in_split(sample_slice, split)
        pbar = tqdm(self.snp_full_info.chr.unique(), f"Loading genotype data from {split=}")
        X = self.get_X(sample_id_subset, pbar)
        y = split_df_subset.values.reshape(-1)
        return X, y
    
    
    def get_simulated_data(self, sample_slice: slice, split:Literal["train", "validation", "train_tune", "validation_tune", "test"], snp_pair: List):
        split_df_subset, sample_id_subset = self.get_sample_id_in_split(sample_slice, split)
        pbar = tqdm(self.snp_full_info.chr.unique(), f"Loading genotype data from {split=}")
        X = self.get_X(sample_id_subset, pbar)
        try:
            y = simulation_I83_queries_pheno_dict[(snp_pair[0], snp_pair[1])]
        except KeyError:
            y = simulation_I83_queries_pheno_dict[(snp_pair[1], snp_pair[0])]
        y = {k: v.pheno_col.loc[sample_id_subset] for k, v in y.items()}
        
        return X, y

### How to use the datasource

First step is to create the datasource, which will load
1. the memory-mapped genotype files (takes little memory)
2. the phenotypes, with balanced case/control and with number of samples = `max_samples`

In [None]:
test_datasource = chromosome_datasource(chromosomes = list(range(1, 23)), snp_filters= [snp_filter('I83', snp_filter.SORT_PVALUE, threshold= 1e-6)], max_samples=100_000, balance_pheno="I83")

Loading genotype file(s):   0%|          | 0/22 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]



In [None]:
test_datasource.get_simulated_data(slice(1, 3), split = "train", snp_pair = ["6:134911816_G_A", "6:26118570_T_C"])

Loading genotype data from split='train':   0%|          | 0/22 [00:00<?, ?it/s]



reading -- time=0:00:00.16, thread 1 of 48, part 4 of 4




reading -- time=0:00:00.37, thread 1 of 48, part 8 of 8




reading -- time=0:00:00.05, thread 1 of 48, part 2 of 2




reading -- time=0:00:00.05, thread 1 of 48, part 2 of 2




reading -- time=0:00:00.20, thread 1 of 48, part 6 of 6




reading -- time=0:00:00.49, thread 1 of 48, part 10 of 10




reading -- time=0:00:00.08, thread 1 of 48, part 3 of 3




reading -- time=0:00:00.00, thread 1 of 41, part 1 of 1




reading -- time=0:00:00.45, thread 1 of 48, part 7 of 7




reading -- time=0:00:00.00, thread 1 of 26, part 1 of 1




reading -- time=0:00:00.00, thread 1 of 47, part 1 of 1




reading -- time=0:00:00.12, thread 1 of 48, part 2 of 2




reading -- time=0:00:00.00, thread 1 of 33, part 1 of 1




reading -- time=0:00:00.00, thread 1 of 19, part 1 of 1




reading -- time=0:00:00.00, thread 1 of 24, part 1 of 1




reading -- time=0:00:00.44, thread 1 of 48, part 7 of 7




reading -- time=0:00:00.08, thread 1 of 48, part 3 of 3




reading -- time=0:00:00.00, thread 1 of 12, part 1 of 1




reading -- time=0:00:00.02, thread 1 of 48, part 2 of 2




reading -- time=0:00:00.04, thread 1 of 48, part 2 of 2




reading -- time=0:00:00.00, thread 1 of 11, part 1 of 1




reading -- time=0:00:00.00, thread 1 of 9, part 1 of 1


(array([[[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         ...,
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]],
 
        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         ...,
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]]]),
 {' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 2.0))': 5776670    1
  4662161    1
  dtype: int64,
  ' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 0.0))': 5776670    0
  4662161    0
  dtype: int64,
  ' (not (`6:134911816_G_A` == 0.0)) and (not (`6:26118570_T_C` == 1.0))': 5776670    1
  4662161    1
  dtype: int64,
  ' (not (`6:134911816_G_A` == 0.0)) and ( (`6:26118570_T_C` == 2.0))': 5776670    0
  4662161    0
  dtype: int64,
  ' (not (`6:134911816_G_A` == 0.0)) and ( (`6:26118570_T_C` == 0.0))': 5776670    1
  4662161    1
  dtype: int64,
  ' (not (`6:134911816_G_A` == 0.0)) and ( (`6:261185

In [None]:
test_datasource.snp_full_info

Unnamed: 0,position_rank,variant,minor_allele,minor_AF,expected_case_minor_AC,low_confidence_variant,n_complete_samples,AC,ytx,beta,se,tstat,pval,chr,position,major_allele,full_id,trait,sort_cols,threshold
53977,53977,6:9977741:G:A,A,2.616590e-06,0.066341,True,361194,1.890200,1.031370,0.950167,0.182670,5.20154,1.977490e-07,6,9977741,G,6:9977741_G_A,I84,pval,0.000001
142540,142540,6:26234880:C:T,T,1.541730e-06,0.039089,True,361194,1.113730,1.015690,0.981198,0.183849,5.33697,9.457020e-08,6,26234880,C,6:26234880_C_T,I84,pval,0.000001
147709,147709,6:27277595:C:T,T,4.342890e-08,0.001101,True,361194,0.031372,0.031372,30.830200,5.863970,5.25756,1.460590e-07,6,27277595,C,6:27277595_C_T,I84,pval,0.000001
155910,155910,6:28891396:C:T,T,1.400580e-06,0.035510,True,361194,1.011760,1.000000,0.970334,0.183962,5.27463,1.330990e-07,6,28891396,C,6:28891396_C_T,I84,pval,0.000001
176844,176844,6:30885567:C:T,T,1.509160e-06,0.038263,True,361194,1.090200,1.003920,0.967270,0.183858,5.26097,1.433750e-07,6,30885567,C,6:30885567_C_T,I84,pval,0.000001
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1191044,1191044,16:74685860:C:G,G,1.411440e-06,0.035786,True,361194,1.019610,1.000000,0.961236,0.183960,5.22524,1.740290e-07,16,74685860,C,16:74685860_C_G,I84,pval,0.000001
1193638,1193638,16:75147947:G:A,A,4.885760e-08,0.001239,True,361194,0.035294,0.011765,76.703800,15.637600,4.90508,9.343140e-07,16,75147947,G,16:75147947_G_A,I84,pval,0.000001
1272924,1272924,16:85133830:A:T,T,3.246310e-06,0.082307,True,361194,2.345100,1.078430,0.934747,0.178569,5.23465,1.653860e-07,16,85133830,A,16:85133830_A_T,I84,pval,0.000001
1297111,1297111,16:88501594:G:A,A,1.031440e-05,0.261511,True,361194,7.450980,1.752940,0.730517,0.128540,5.68317,1.323210e-08,16,88501594,G,16:88501594_G_A,I84,pval,0.000001


In [None]:
assert set(test_datasource.data_dict["test"].index) & set(test_datasource.data_dict["train"].index) == set()

In [None]:
test_datasource_not_upsampled.get_data(sample_slice = slice(0, None), split = "train_tune")[1]

Loading genotype data from split='train_tune':   0%|          | 0/1 [00:00<?, ?it/s]



reading -- time=0:00:00.00, thread 1 of 10, part 1 of 1


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

Since the data is sampled with replacement, the phenotype dataframe will have duplicated samples

In [None]:
test_datasource.pheno_df_ordered.index.nunique() #how many unique samples in 100k samples?

59487

In [None]:
test_datasource.pheno_df_ordered.apply(pd.Series.value_counts)

Unnamed: 0,I84,R07
0,49939,90303
1,50061,9697


The `datasource` object stores the SNPs on `chromosome 21` and `chromosome 22` that are associated with both of the traits with low pvalues

In [None]:
test_datasource.snp_full_info[["full_id", "chr", "position", "trait", "pval"]]

Unnamed: 0,full_id,chr,position,trait,pval
1248,1:981475_G_A,1,981475,I84,4.775580e-06
1346,1:990297_A_G,1,990297,I84,3.585510e-09
1410,1:999867_T_C,1,999867,I84,1.639050e-06
1419,1:1000861_T_C,1,1000861,I84,1.597670e-06
1457,1:1007834_C_T,1,1007834,I84,1.909400e-06
...,...,...,...,...,...
13328645,22:45731077_C_T,22,45731077,I84,2.764850e-07
13328950,22:45792291_G_C,22,45792291,I84,3.162620e-07
13334324,22:46792651_G_A,22,46792651,I84,2.084790e-07
13336473,22:47087541_G_T,22,47087541,I84,1.143940e-07


We now get the genotype data matrix by reading the genotype files. The `get_data` function takes a `slice` object, and returns all genotypes at the 80 SNPs above for the samples in that slice. Here we load `80 snps for the first 1000 samples` (row 1 - 1000 of `test_datasource.pheno_df_ordered`)

In [None]:
data = test_datasource.get_data(slice(0,1000))
data.shape

reading -- time=0:00:00.00, thread 1 of 11, part 1 of 1
reading -- time=0:00:00.03, thread 1 of 48, part 2 of 2


(1000, 80, 4)