In [1]:
import os
import sys
import glob
import numpy as np
import pandas as pd
import math
import sys
import random
import pickle
import csv

import dask.dataframe as dd
from dask.distributed import Client

import torch
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule
from pyro.infer import Predictive

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

sys.path.insert(0, '/home/djl34/lab_pd/kl/git/KL/raklette')
import raklette_updated
from run_raklette import run_raklette
from run_raklette import TSVDataset

##############################################################################################################

KL_data_dir = "/home/djl34/lab_pd/kl/data"
scratch_dir = "/n/scratch3/users/d/djl34"

base_set = ["A", "C", "T", "G"]
chrom_set = [str(x) for x in range(1, 23)]
# chrom_set = ["22"]

  from .autonotebook import tqdm as notebook_tqdm


## make smaller sample file

In [52]:
header = "window_1k/1_split_500_0"

variants = os.path.join(scratch_dir, "kl_input/{header}.tsv")
variants = variants.replace("{header}", header)
input_filename = variants

In [53]:
df = pd.read_csv(variants, sep = "\t")

In [68]:
df_window = df[df["window_1k"] == 570]

In [70]:
df_window["gene"] = 0

In [71]:
df_window.to_csv("1_split_500_window1k_570.tsv", sep = "\t", index = None)

## load files

In [95]:
variants = "1_split_500_window1k_570.tsv"
input_filename = variants

# length_file = os.path.join(scratch_dir, "kl_input/{header}_length.txt")
# length_file = length_file.replace("{header}", header)
neutral_sfs = KL_data_dir + "/whole_genome/neutral/5_bins/all.tsv"
n_covs = 0

df = pd.read_csv(variants, sep = "\t")
nb_samples = len(df)
n_genes = 1

print("number of samples: " + str(nb_samples))
print("number of genes: " + str(n_genes))

   
with open(variants) as f:
    first_line = f.readline()
header = first_line.split("\t")
header = [x.strip() for x in header]

chunksize = 1000000

print("number of chunks " + str(nb_samples/chunksize))

dataset = TSVDataset(input_filename, chunksize=chunksize, nb_samples = nb_samples, header_all = header, features = header)
loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False)

num_epochs = 2000

number of samples: 2955
number of genes: 1
number of chunks 0.002955


In [96]:
header

['mu_index', 'Freq_bin', 'gene', 'window_1k']

In [97]:
#now with checkpoints
neutral_sfs_filename = neutral_sfs
lr = 0.01
gamma = 0.5
cov_sigma_prior = torch.tensor(0.1, dtype=torch.float32)
fit_prior = True
gene_col = 2
mu_col = 0
bin_col = 1

print("running raklette", flush=True)

# read neutral sfs
sfs = pd.read_csv(neutral_sfs_filename, sep = "\t")
bin_columns = []
for i in range(5):
    bin_columns.append(str(float(i)))
neutral_sfs = torch.tensor(sfs[bin_columns].values)
mu_ref = torch.tensor(sfs["mu"].values)
n_bins = len(neutral_sfs[1]) - 1
print("number of bins: " + str(n_bins), flush = True)

running raklette
number of bins: 4


In [98]:
KL = raklette_updated.raklette(neutral_sfs, n_bins, mu_ref, n_covs, n_genes, cov_sigma_prior = cov_sigma_prior, fit_prior = fit_prior)

if fit_prior == False:
    print("fitting with predefined prior for genes", flush = True)

model = KL.model
guide = pyro.infer.autoguide.AutoNormal(model)

#run inference
pyro.clear_param_store()

num_steps = num_epochs * len(loader)
lrd = gamma ** (1 / num_steps)

# run SVI
optimizer = pyro.optim.ClippedAdam({"lr":lr, 'lrd': lrd})
#     optimizer = pyro.optim.Adam({"lr":lr})
elbo = pyro.infer.Trace_ELBO(num_particles=1, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, optimizer, elbo)

In [99]:
losses = []

for epoch in range(num_epochs):
    print("epoch: " + str(epoch), flush = True)

    # Take a gradient step for each mini-batch in the dataset
    for batch_idx, data in enumerate(loader):
        gene_ids = data[:,:,gene_col].reshape(-1)
        gene_ids = gene_ids.type(torch.LongTensor)

        mu_vals = data[:,:,mu_col].reshape(-1)
        mu_vals = mu_vals.type(torch.LongTensor)

        freq_bins = data[:,:,bin_col].reshape(-1)

        if n_covs == 0:
            loss = svi.step(mu_vals, gene_ids, None, freq_bins)
        else:
            # covariate_vals = data[:,:,3:].reshape((data.shape[0]*data.shape[1],) + data.shape[3:])
            covariate_vals = data[:,:,3:].reshape(-1).unsqueeze(0)
            covariate_vals = torch.transpose(covariate_vals, 0, 1)
            covariate_vals = covariate_vals.type(torch.LongTensor)

            loss = svi.step(mu_vals, gene_ids, covariate_vals, freq_bins)

        losses.append(loss/data.shape[1])

        if batch_idx % 10 == 0:
#             print(batch_idx, flush=True)
            print(loss/data.shape[1], flush=True)

epoch: 0
3.2555118842586515
epoch: 1
3.2123412729391707
epoch: 2
3.2286714337279285
epoch: 3
3.2329276557698337
epoch: 4
3.1592467811371465
epoch: 5
3.1506986625372524
epoch: 6
3.1478089764455968
epoch: 7
3.168924670712242
epoch: 8
3.202349952921474
epoch: 9
3.065362420946221
epoch: 10
3.122384371169541
epoch: 11
3.173739982593516
epoch: 12
3.226503353336815
epoch: 13
3.1125713697611355
epoch: 14
3.1011351076216145
epoch: 15
2.987710798059239
epoch: 16
3.187373786628495
epoch: 17
3.1619593808280446
epoch: 18
3.193905216524531
epoch: 19
3.155765998075853
epoch: 20
3.0326476399870623
epoch: 21
3.0873199523971926
epoch: 22
3.251467690946849
epoch: 23
2.914809922512603
epoch: 24
3.034536113551521
epoch: 25
3.0577061223895026
epoch: 26
2.955598045460873
epoch: 27
3.0289994606235835
epoch: 28
3.0374413176524646
epoch: 29
3.004231208830502
epoch: 30
2.959547062998199
epoch: 31
3.043530720585352
epoch: 32
3.000355542989675
epoch: 33
3.184201356478114
epoch: 34
3.029953600047348
epoch: 35
2.993

1.2025929471401426
epoch: 281
1.033315289437062
epoch: 282
1.1317129692344337
epoch: 283
0.9636887525618367
epoch: 284
0.9777744792776345
epoch: 285
0.9531208546852162
epoch: 286
0.9824733970760166
epoch: 287
0.9511052122187406
epoch: 288
0.8522220250884829
epoch: 289
0.8764508293293586
epoch: 290
1.0229946840055577
epoch: 291
0.949909550899155
epoch: 292
1.020280902683496
epoch: 293
0.9279309862596677
epoch: 294
0.9184336345574637
epoch: 295
1.0235062512091635
epoch: 296
0.9372543600378263
epoch: 297
0.7851319342044475
epoch: 298
0.8510032304255505
epoch: 299
0.8947038608544955
epoch: 300
0.9424133959028872
epoch: 301
0.9100368705181159
epoch: 302
0.8567070001075513
epoch: 303
0.8518226364667698
epoch: 304
0.9055853172439378
epoch: 305
0.9374890425799707
epoch: 306
0.852906101132901
epoch: 307
1.0419397936392407
epoch: 308
0.9371621866372657
epoch: 309
0.7902213429343737
epoch: 310
0.8957358632853625
epoch: 311
0.853496675404122
epoch: 312
0.9010907072409234
epoch: 313
0.8855875304116

epoch: 553
0.44310900199539077
epoch: 554
0.46263129738365466
epoch: 555
0.44290359923104833
epoch: 556
0.43700051783239474
epoch: 557
0.43589531397163894
epoch: 558
0.4384613678291758
epoch: 559
0.4348149587465535
epoch: 560
0.43826379554235734
epoch: 561
0.4364652645844776
epoch: 562
0.4427743228346603
epoch: 563
0.4448961819424345
epoch: 564
0.438452152176851
epoch: 565
0.4462881764916711
epoch: 566
0.4519315465360307
epoch: 567
0.4341459049727486
epoch: 568
0.4362685002708762
epoch: 569
0.44768601911693745
epoch: 570
0.4474814720328936
epoch: 571
0.4361478453229506
epoch: 572
0.432586212611414
epoch: 573
0.4469814381880779
epoch: 574
0.4404316746892137
epoch: 575
0.437228576958665
epoch: 576
0.43813289215446344
epoch: 577
0.4324696918653923
epoch: 578
0.4468732479013767
epoch: 579
0.43271557218955176
epoch: 580
0.43897918024677013
epoch: 581
0.4345070673898424
epoch: 582
0.4348878957845132
epoch: 583
0.4354641470480778
epoch: 584
0.4408789345063706
epoch: 585
0.44353816150092823
ep

0.4188694508894243
epoch: 824
0.41876999922869285
epoch: 825
0.4188785835588193
epoch: 826
0.41918812784421017
epoch: 827
0.418833205184017
epoch: 828
0.41920417148831246
epoch: 829
0.42033922861195017
epoch: 830
0.4190329334266733
epoch: 831
0.418316913750078
epoch: 832
0.41960251503384216
epoch: 833
0.4193020804396341
epoch: 834
0.41914749890068803
epoch: 835
0.4189805268407969
epoch: 836
0.4204477184637526
epoch: 837
0.4184297441697093
epoch: 838
0.41874257043991514
epoch: 839
0.42018225999823733
epoch: 840
0.41867564293979415
epoch: 841
0.42060512567881386
epoch: 842
0.41925046070858335
epoch: 843
0.4184041493123092
epoch: 844
0.4190591154913298
epoch: 845
0.4191538897133323
epoch: 846
0.4190761168846719
epoch: 847
0.4190302611524246
epoch: 848
0.4187073216337934
epoch: 849
0.41871800896513334
epoch: 850
0.41844144505231606
epoch: 851
0.41850195647286736
epoch: 852
0.41872624642330203
epoch: 853
0.41907109417524063
epoch: 854
0.4190396813600124
epoch: 855
0.4187419946859304
epoch: 

epoch: 1090
0.4170524998104803
epoch: 1091
0.4172107396629482
epoch: 1092
0.4172122549993595
epoch: 1093
0.4169917013235174
epoch: 1094
0.4169493293180435
epoch: 1095
0.4168918228977398
epoch: 1096
0.4168824487038088
epoch: 1097
0.41690711428518856
epoch: 1098
0.41729616247699824
epoch: 1099
0.41670460630840084
epoch: 1100
0.417936942731705
epoch: 1101
0.41719042905142795
epoch: 1102
0.4173684711639586
epoch: 1103
0.41667142837579857
epoch: 1104
0.4180829981459318
epoch: 1105
0.4168448110382387
epoch: 1106
0.4177084980890186
epoch: 1107
0.4177993725037598
epoch: 1108
0.417923266417722
epoch: 1109
0.4177826520756218
epoch: 1110
0.41719432678631907
epoch: 1111
0.4175818243344909
epoch: 1112
0.41735896492006486
epoch: 1113
0.41721988830578377
epoch: 1114
0.41722337343863874
epoch: 1115
0.41770810774922484
epoch: 1116
0.41780995317648784
epoch: 1117
0.41708538270490386
epoch: 1118
0.41759329353902147
epoch: 1119
0.4167472692511583
epoch: 1120
0.4174195937024854
epoch: 1121
0.41781573018574

epoch: 1351
0.41688810586670777
epoch: 1352
0.4168866316530003
epoch: 1353
0.4167002150603465
epoch: 1354
0.4167466671802185
epoch: 1355
0.4170845398615676
epoch: 1356
0.4167106177550551
epoch: 1357
0.4168765887013097
epoch: 1358
0.41671971931418306
epoch: 1359
0.41615198711680645
epoch: 1360
0.4165630321529427
epoch: 1361
0.41641114178959804
epoch: 1362
0.4165635709366191
epoch: 1363
0.4166745294025066
epoch: 1364
0.4167656667650877
epoch: 1365
0.41661991054823855
epoch: 1366
0.4174327478607585
epoch: 1367
0.41679338927996545
epoch: 1368
0.41723242928566073
epoch: 1369
0.4162380382449906
epoch: 1370
0.41695838929444634
epoch: 1371
0.4172777054984819
epoch: 1372
0.41663960821894935
epoch: 1373
0.41668155456022904
epoch: 1374
0.41755279315461263
epoch: 1375
0.41673893694090897
epoch: 1376
0.41664329151728424
epoch: 1377
0.4163132541063525
epoch: 1378
0.4163752441587081
epoch: 1379
0.4165607667925174
epoch: 1380
0.4165694006115508
epoch: 1381
0.41668168960129953
epoch: 1382
0.41655395643

0.41678991500464346
epoch: 1613
0.4168803907245204
epoch: 1614
0.41607318780709734
epoch: 1615
0.4166238717503024
epoch: 1616
0.4171014742485948
epoch: 1617
0.4167664608195298
epoch: 1618
0.41630177535139185
epoch: 1619
0.4167031995021986
epoch: 1620
0.4162197889207503
epoch: 1621
0.4162258499356151
epoch: 1622
0.4167349118904932
epoch: 1623
0.41640074338576977
epoch: 1624
0.41661322901780345
epoch: 1625
0.4170446832842646
epoch: 1626
0.4165986001769834
epoch: 1627
0.4164413621588593
epoch: 1628
0.4165126344206656
epoch: 1629
0.4166006652209548
epoch: 1630
0.4165002615237305
epoch: 1631
0.4164044896839971
epoch: 1632
0.4167210252662881
epoch: 1633
0.41648530494352837
epoch: 1634
0.41690371960032685
epoch: 1635
0.4167019585248469
epoch: 1636
0.41691424066111205
epoch: 1637
0.41676226744438993
epoch: 1638
0.416572028111801
epoch: 1639
0.4161310664849351
epoch: 1640
0.4165135968027003
epoch: 1641
0.4162752285829917
epoch: 1642
0.41665001475018704
epoch: 1643
0.4165194010019352
epoch: 1644

0.4164137093932631
epoch: 1874
0.41650818956524543
epoch: 1875
0.4165553406487928
epoch: 1876
0.4168971476046808
epoch: 1877
0.416773222749553
epoch: 1878
0.4165080641320768
epoch: 1879
0.4164260772467834
epoch: 1880
0.4164306731524676
epoch: 1881
0.416480576637115
epoch: 1882
0.4164738101821903
epoch: 1883
0.41677827259237393
epoch: 1884
0.4163302762540213
epoch: 1885
0.41652635793220655
epoch: 1886
0.41633126486383615
epoch: 1887
0.4165378917572271
epoch: 1888
0.4158814152179147
epoch: 1889
0.4162907897007652
epoch: 1890
0.41638046384521094
epoch: 1891
0.416369754541697
epoch: 1892
0.4161787904509775
epoch: 1893
0.41642854768618287
epoch: 1894
0.41634922393779944
epoch: 1895
0.4164554987753867
epoch: 1896
0.41635295768362873
epoch: 1897
0.4164005509996391
epoch: 1898
0.41671965974578334
epoch: 1899
0.41629207944178287
epoch: 1900
0.4165663789031788
epoch: 1901
0.4165385523856917
epoch: 1902
0.4161271574687405
epoch: 1903
0.41635263996931154
epoch: 1904
0.4166618094620508
epoch: 1905


## save model

In [179]:
svi.params

AttributeError: 'SVI' object has no attribute 'params'

In [176]:
import pickle

In [180]:
output_dict = {}
output_dict['model']=model
output_dict['guide']=guide
# output_dict['params']=svi_result.params

In [181]:
with open('1_split_500_window1k_570.model', 'wb') as handle:
    pickle.dump(output_dict, handle)

In [175]:
pyro.get_param_store().save("1_split_500_window1k_570.params")

## get predictions

In [5]:
# file_dir = "/home/djl34/lab_pd/kl/data/raklette_output/gene_and_covariates/window_1k/"
# model_filename = file_dir + "1_split_500_0_lr_0.01_gamma_0.1_epoch_3000.save"

# pyro.get_param_store().load(model_filename)

In [161]:
df_neutral = pd.read_csv("/home/djl34/lab_pd/kl/data/whole_genome/neutral/5_bins/all.tsv", sep = "\t")

In [101]:
predictive = Predictive(model, guide=guide, num_samples = 800)

In [114]:
df.groupby("mu_index").size()

mu_index
0      32
1     183
2     304
3     257
4     199
5     223
6     220
7     165
8     145
9      96
10     62
11     86
12     77
13     80
14     69
15     55
16     64
17     86
18     75
19     65
20     38
21     40
22     25
23     22
24     25
25     15
26     31
27     15
28     23
29     15
30     17
31      5
32    101
33     29
34      7
35      4
dtype: int64

In [162]:
mu_bin = 2

df_mu = df[df["mu_index"] == mu_bin]

In [163]:
x = df_mu[header]
x = x.astype(float)
x = torch.from_numpy(x.values)

In [164]:
gene_ids = x[:,gene_col].reshape(-1)
gene_ids = gene_ids.type(torch.LongTensor)

mu_vals = x[:,mu_col].reshape(-1)
mu_vals = mu_vals.type(torch.LongTensor)

freq_bins = x[:,bin_col].reshape(-1)

In [165]:
samples = predictive(mu_vals, gene_ids, None)

In [166]:
samples.keys()

dict_keys(['beta_sel', 'obs'])

In [167]:
print(samples["beta_sel"].shape)
print(samples["obs"].shape)

torch.Size([800, 1, 4])
torch.Size([800, 304])


In [168]:
samples["beta_sel"]

tensor([[[-4.6882,  2.3927, -0.3579,  0.6352]],

        [[-4.7175,  2.3018, -0.5015,  0.5884]],

        [[-4.8170,  2.3385, -0.6223,  0.4911]],

        ...,

        [[-4.6873,  2.4799, -0.4353,  0.8995]],

        [[-4.6930,  2.4925, -0.5846,  0.7904]],

        [[-4.6282,  2.4631, -0.2165,  0.6480]]])

In [169]:
samples["beta_sel"].mean(axis = 1).mean(axis = 0)

tensor([-4.6893,  2.4687, -0.3954,  0.5626])

## plot posterior predictive distributions

In [170]:
df_mu_sfs = pd.DataFrame((df_mu.groupby("Freq_bin").size())/len(df_mu)).reset_index()

for i in range(5):
    if len(df_mu_sfs[df_mu_sfs["Freq_bin"] == i]) == 0:
        df_mu_sfs.loc[i] = [i, 0.0]
    
df_mu_sfs.rename({0: "observed"}, axis = 1, inplace = True)

df_mu_sfs["neutral"] = pd.Series(df_neutral.iloc[mu_bin, 1:6]).reset_index(drop = True)


In [171]:
df_mu_sfs

Unnamed: 0,Freq_bin,observed,neutral
0,0.0,0.921053,0.971404
1,1.0,0.069079,0.024632
2,2.0,0.003289,0.002315
3,3.0,0.006579,0.001247
4,4.0,0.0,0.000401


In [172]:
df_post_prediction = pd.DataFrame(samples["obs"].reshape(-1))
# df_post_prediction.rename({0: "posterior predictions"}, axis = 1, inplace = True)

In [173]:
df_mu_sfs["post pred"] = list(pd.Series(df_post_prediction.groupby(0).size()/len(df_post_prediction)))

In [174]:
df_mu_sfs

Unnamed: 0,Freq_bin,observed,neutral,post pred
0,0.0,0.921053,0.971404,0.880498
1,1.0,0.069079,0.024632,0.106316
2,2.0,0.003289,0.002315,0.009762
3,3.0,0.006579,0.001247,0.003174
4,4.0,0.0,0.000401,0.000251
