In [1]:
import os
import torch

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import jax
import numpyro
import jax.numpy as jnp
import arviz as az
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import random_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv(f'{os.getcwd()}/data/mean_tpm.csv', index_col=0)

In [3]:
def normalize_group(group):
    scaler = StandardScaler()
    group['normalized_tpm'] = scaler.fit_transform(group[['mean_tpm']])
    return group

df = df.groupby('species').apply(normalize_group)
df['mean_tpm'] = df['normalized_tpm']
df = df.drop(columns=['normalized_tpm'])

  df = df.groupby('species').apply(normalize_group)


In [4]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,species,upstream200,stress_condition,mean_tpm
species,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Achromobacter xylosoxidans SOLR10,0,Achromobacter xylosoxidans SOLR10,AAAAAAGGCGGGCAGGATGAAGAGCGAACGGCCGCGTCACGGCAGT...,achx_as,-0.168032
Achromobacter xylosoxidans SOLR10,1,Achromobacter xylosoxidans SOLR10,AAAAAAGGCGGGCAGGATGAAGAGCGAACGGCCGCGTCACGGCAGT...,achx_bs,-0.168032
Achromobacter xylosoxidans SOLR10,2,Achromobacter xylosoxidans SOLR10,AAAAAAGGCGGGCAGGATGAAGAGCGAACGGCCGCGTCACGGCAGT...,achx_ctrl,-0.168032
Achromobacter xylosoxidans SOLR10,3,Achromobacter xylosoxidans SOLR10,AAAAAAGGCGGGCAGGATGAAGAGCGAACGGCCGCGTCACGGCAGT...,achx_li,-0.168032
Achromobacter xylosoxidans SOLR10,4,Achromobacter xylosoxidans SOLR10,AAAAAAGGCGGGCAGGATGAAGAGCGAACGGCCGCGTCACGGCAGT...,achx_mig,-0.168032
...,...,...,...,...,...
Vibrio cholerae O1 biovar El Tor str. N16961,1124704,Vibrio cholerae O1 biovar El Tor str. N16961,TTTTTTGACCGCTAATTAAGTGTTACTATACCTCGCTTGTCAGCCA...,vibrio_oss,-0.027653
Vibrio cholerae O1 biovar El Tor str. N16961,1124705,Vibrio cholerae O1 biovar El Tor str. N16961,TTTTTTGACCGCTAATTAAGTGTTACTATACCTCGCTTGTCAGCCA...,vibrio_oxs,-0.123126
Vibrio cholerae O1 biovar El Tor str. N16961,1124706,Vibrio cholerae O1 biovar El Tor str. N16961,TTTTTTGACCGCTAATTAAGTGTTACTATACCTCGCTTGTCAGCCA...,vibrio_sp,-0.088967
Vibrio cholerae O1 biovar El Tor str. N16961,1124707,Vibrio cholerae O1 biovar El Tor str. N16961,TTTTTTGACCGCTAATTAAGTGTTACTATACCTCGCTTGTCAGCCA...,vibrio_tm,0.006088


In [5]:
enc_species = OneHotEncoder()
encoded_species = jnp.array(enc_species.fit_transform(df[['species']]).toarray())

enc_stress = OneHotEncoder()
encoded_stress_condition = jnp.array(enc_stress.fit_transform(df[['stress_condition']]).toarray())

nucleotide_mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

# Convert sequences to numerical indices
def sequence_to_indices(sequence, max_length):
    indices = [nucleotide_mapping.get(nucleotide, -1) for nucleotide in sequence]
    indices += [-1] * (max_length - len(indices))  # Padding with -1 for shorter sequences
    return jnp.array(indices)

# Function to one-hot encode numerical indices using JAX
@jax.jit
def one_hot_encode_indices(indices, num_classes=4):
    one_hot_encoded = jax.nn.one_hot(indices, num_classes)
    one_hot_encoded = jnp.where(indices[:, None] == -1, 0, one_hot_encoded)  # Mask padding with zeros
    return one_hot_encoded

max_length = df['upstream200'].apply(len).max()

# Convert all sequences to numerical indices
sequences_indices = jnp.array([sequence_to_indices(seq, max_length) for seq in df['upstream200'].values])

# One-hot encode all sequences
sequences_encoded = jax.vmap(one_hot_encode_indices)(sequences_indices)

# Combine all features
X = jnp.hstack((sequences_encoded.reshape((sequences_encoded.shape[0], -1)), encoded_stress_condition, encoded_species))
y = jnp.array(df['mean_tpm'].values)

In [6]:
def hierarchical_model(X, y=None):
    # Hyperpriors for the hierarchical model
    mu_alpha = numpyro.sample('mu_alpha', dist.Normal(0, 10))
    sigma_alpha = numpyro.sample('sigma_alpha', dist.HalfNormal(10))
    
    mu_beta = numpyro.sample('mu_beta', dist.Normal(0, 10))
    sigma_beta = numpyro.sample('sigma_beta', dist.HalfNormal(10))
    
    # Priors for the coefficients
    alpha = numpyro.sample('alpha', dist.Normal(mu_alpha, sigma_alpha).expand([X.shape[1]]))
    beta = numpyro.sample('beta', dist.Normal(mu_beta, sigma_beta))
    
    # Linear model
    mu = jnp.dot(X, alpha) + beta
    
    # Likelihood
    sigma = numpyro.sample('sigma', dist.HalfNormal(1))
    numpyro.sample('obs', dist.LogNormal(mu, sigma), obs=y)

# Set up the MCMC
nuts_kernel = NUTS(hierarchical_model)
mcmc = MCMC(nuts_kernel, num_warmup=10, num_samples=10)
mcmc.run(jax.random.PRNGKey(0), X=X, y=y)

# Extract the samples
posterior_samples = mcmc.get_samples()

In [None]:
# Convert samples to an ArviZ InferenceData object
idata = az.from_numpyro(mcmc)

# Plot diagnostics
az.plot_trace(idata)
az.plot_posterior(idata)

# Predicting on new data
def predict(X_new, samples):
    alpha = samples['alpha']
    beta = samples['beta']
    mu_pred = jnp.dot(X_new, alpha.T) + beta
    return mu_pred.mean(axis=1)

# Example of new data
X_new = X[:5]  # Replace with your new data
predictions = predict(X_new, posterior_samples)