## Run KL analysis for intergenic region, footprinting region, and DHS sites

In [1]:
import sys
import pandas as pd
import torch
import pyro
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

sys.path.insert(0, '/home/djl34/lab_pd/kl/git/KL')

import raklette_updated

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
####################################################
# Create Dataset
####################################################
class TSVDataset(Dataset):
    def __init__(self, path, chunksize, nb_samples, header_all, features):
        self.path = path
        self.chunksize = chunksize
        self.len = nb_samples // self.chunksize
        self.header = header_all
        self.features = features
        
    def __getitem__(self, index):
        x = next(
            pd.read_csv(
                self.path,
                sep = "\t",
                skiprows=index * self.chunksize + 1,  #+1, since we skip the header
                chunksize=self.chunksize,
                names=self.header))
        
        x = x[self.features]
        x["e_module"] = x["e_module"] + 1
        x = torch.from_numpy(x.values)
        return x

    def __len__(self):
        return self.len

## load file

In [5]:
KL_data_dir = "/home/djl34/lab_pd/kl/data/"
scratch_dir = "/n/scratch3/users/d/djl34/"

sfs = pd.read_csv(KL_data_dir + "neutral_SFS_5bins.tsv", sep = "\t")

bin_columns = []

for i in range(5):
    bin_columns.append(str(i) + "_bin")

neutral_sfs = torch.tensor(sfs[bin_columns].values)

mu_ref = torch.tensor(sfs["mu"].values)

df = pd.read_csv(scratch_dir + "kl_input/enhancer_module/22.tsv", sep = "\t")

nb_samples = len(df)
header = df.columns

dataset = TSVDataset(scratch_dir + "kl_input/enhancer_module/22.tsv", chunksize=100000, nb_samples = nb_samples, header_all = header, features = header)

loader = DataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)


## run raklette

In [None]:
n_genes = len(df["e_module"].unique()) + 1
n_covs = 0

# n_covs = 1
# n_genes = 3
n_bins = len(neutral_sfs[1]) - 1

del df

#define model and guide
KL = raklette_updated.raklette(neutral_sfs, n_bins, mu_ref, n_covs, n_genes)
model = KL.model
guide = pyro.infer.autoguide.AutoNormal(model)

#run inference
pyro.clear_param_store()
# run SVI
adam = pyro.optim.Adam({"lr":0.005})
elbo = pyro.infer.Trace_ELBO(num_particles=1, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, adam, elbo)
losses = []

num_epochs = 1

print("lets start running")

for epoch in range(num_epochs):
    # Take a gradient step for each mini-batch in the dataset
    for batch_idx, data in enumerate(loader):
        print(batch_idx)
        
        gene_ids = data[:,:,2].reshape(-1)
        gene_ids = gene_ids.type(torch.LongTensor)
        
        mu_vals = data[:,:,0].reshape(-1)
        mu_vals = mu_vals.type(torch.LongTensor)
        
        loss = svi.step(mu_vals, gene_ids, None, data[:,:,1].reshape(-1))
#         if y is not None:
#             y = y.type_as(x)
#         loss = svi.step(x, y)
        losses.append(loss)
        print(loss)

    # Tell the scheduler we've done one epoch.
    scheduler.step()

    print("[Epoch %02d]  Loss: %.5f" % (epoch, np.mean(losses)))

print("Finished training!")

lets start running
0
3774122.0995813925
1
3688026.0496023977
2
3795006.187812071
3
3774722.4471480334
4
3721177.7994393357
5
3653188.4336485583
6
3717426.325801814


In [7]:
losses

[3774122.0995813925,
 3688026.0496023977,
 3795006.187812071,
 3774722.4471480334,
 3721177.7994393357,
 3653188.4336485583,
 3717426.325801814,
 3627309.021757453,
 3723481.503836682,
 3643854.6385830026,
 3720541.5758641926,
 3783693.8939951444,
 3772715.5778691536,
 3734375.7529152613,
 3615489.5409375024,
 3740072.3755401047,
 3728837.6883581844,
 3757568.5838964074,
 3677971.3174640844,
 3649052.3038019505,
 3748807.7162428307,
 3780207.8677345477,
 3768676.1648491253,
 3769426.6493610013,
 3780196.3918031743,
 3639361.445603936,
 3695648.295733693,
 3749811.3858994436,
 3809752.218985416,
 3698004.475128364,
 3762648.734965913,
 3774777.41416014,
 3754826.0175651945,
 3608521.195523982,
 3728258.7287544124,
 3700756.6786036496,
 3583568.810217082,
 3654187.4872947107,
 3728170.5515762344,
 3681315.8185349638,
 3670170.1582644708,
 3645550.3723968444,
 3586410.191857693,
 3526315.8630386377,
 3703319.347211928,
 3732739.8720382797,
 3732698.751738348,
 3655957.0159287658,
 3637585