## Run KL analysis 

In [1]:
import os
import sys
import glob
import numpy as np
import pandas as pd
import math
import sys
import random
import pickle

import dask.dataframe as dd
from dask.distributed import Client

import torch
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

sys.path.insert(0, '/home/djl34/lab_pd/kl/git/KL')
import raklette_updated

sys.path.insert(0, '/home/djl34/lab_pd/kl/git/KL/kl_gene')
from raklette_gene import run_raklette


KL_data_dir = "/home/djl34/lab_pd/kl/data"
scratch_dir = "/n/scratch3/users/d/djl34"

base_set = ["A", "C", "T", "G"]
chrom_set = [str(x) for x in range(1, 23)]

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TSVDataset(Dataset):
    def __init__(self, path, chunksize, nb_samples, header_all, features):
        self.path = path
        self.chunksize = chunksize
        self.len = nb_samples // self.chunksize
        self.header = header_all
        self.features = features

    def __getitem__(self, index):
        x = next(
            pd.read_csv(
                self.path,
                sep = "\t",
                skiprows=index * self.chunksize + 1,  #+1, since we skip the header
                chunksize=self.chunksize,
                names=self.header))

        x = x[self.features]
        x = torch.from_numpy(x.values)
        return x

    def __len__(self):
        return self.len

## load file

In [3]:
input_filename = os.path.join(scratch_dir, "kl_input/enhancer_module/22.tsv")
output_filename = os.path.join(KL_data_dir, "raklette_output/kl_gene/enhancer_module/chr22_noncoding.tsv")
neutral_sfs_filename = neutral_sfs = KL_data_dir + "/neutral_SFS_5bins.tsv"

# input_filename = os.path.join(scratch_dir, "kl_input/enhancer_module/gene.tsv")
# output_filename = os.path.join(KL_data_dir, "raklette_output/kl_gene/enhancer_module/whole_genome_noncoding.tsv")


with Client() as client:
    df = dd.read_csv(input_filename, sep = "\t")
    nb_samples = len(df)
    header = df.columns
    n_genes = len(df["gene"].unique()) + 1
    del df
    print("done reading")

chunksize = 100000
dataset = TSVDataset(input_filename, chunksize=chunksize, nb_samples = nb_samples, header_all = header, features = header)

print("number of chunks " + str(nb_samples/chunksize))

loader = DataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)

n_covs = 0
num_epochs = 1

done reading
number of chunks 772.01997


In [4]:
print("running raklette")

# read neutral sfs
sfs = pd.read_csv(neutral_sfs_filename, sep = "\t")
bin_columns = []
for i in range(5):
    bin_columns.append(str(i) + "_bin")
neutral_sfs = torch.tensor(sfs[bin_columns].values)
mu_ref = torch.tensor(sfs["mu"].values)

n_bins = len(neutral_sfs[1]) - 1

#define model and guide
KL = raklette_updated.raklette(neutral_sfs, n_bins, mu_ref, n_covs, n_genes)
model = KL.model
guide = pyro.infer.autoguide.AutoNormal(model)

#run inference
pyro.clear_param_store()
# run SVI
adam = pyro.optim.Adam({"lr":0.005})
elbo = pyro.infer.Trace_ELBO(num_particles=1, vectorize_particles=True)
svi = pyro.infer.SVI(model, guide, adam, elbo)
losses = []


running raklette


In [5]:
for epoch in range(num_epochs):
    # Take a gradient step for each mini-batch in the dataset
    for batch_idx, data in enumerate(loader):
        gene_ids = data[:,:,2].reshape(-1)
        gene_ids = gene_ids.type(torch.LongTensor)

        mu_vals = data[:,:,0].reshape(-1)
        mu_vals = mu_vals.type(torch.LongTensor)

        loss = svi.step(mu_vals, gene_ids, None, data[:,:,1].reshape(-1))
        losses.append(loss)
        

        if batch_idx % 10 == 0:
            print(batch_idx)
            print(loss)
            break


print("Finished training!")

0
3770883.531110159
Finished training!


In [22]:
loss

3783954.6413213485

In [23]:
print(loss)

3783954.6413213485


In [24]:
pyro.param()

TypeError: param() missing 1 required positional argument: 'name'