In [1]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import anndata
import sc_utils
import scvi

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)


In [2]:
import h5py
import hdf5plugin

In [3]:
f = h5py.File("../data/test_multi_inputs.h5")

In [4]:
names = f["test_multi_inputs"]["axis0"].asstr()[:]

In [5]:
names = pd.Series(names)

In [6]:
chromosomes = pd.Series(names.str.split(":").str[0].unique())

In [7]:
chromosomes = chromosomes[chromosomes.str.startswith("chr")]

In [9]:
from typing import Iterator, Mapping, Tuple, NamedTuple, Sequence, Union

from absl import app
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import matplotlib.pyplot as plt
import math

In [10]:
batch_size = 128 #"Size of the batch to train on.")
learning_rate = 0.001 # "Learning rate for the optimizer.")
training_steps = 1000 #, "Number of training steps to run.")
eval_frequency = 10 #, "How often to evaluate the model.")
random_seed = 42 #, "Random seed.")
l1_coef = 0.5

In [11]:
PRNGKey = jnp.ndarray
Batch = Mapping[str, np.ndarray]
HiddenSize = Union[int, tuple]

MNIST_IMAGE_SHAPE: Sequence[int] = (28, 28, 1)

In [12]:
class Encoder(hk.Module):
    """Encoder model."""

    def __init__(self, hidden_size: HiddenSize = 512, latent_size: int = 10):
        super().__init__()
        if not isinstance(hidden_size, tuple):
            hidden_size = (hidden_size,)
        self._hidden_size = hidden_size
        self._latent_size = latent_size

    def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        x = hk.Flatten()(x)
        activation_sum = 0
        for layer_size in self._hidden_size:
            x = hk.Linear(layer_size)(x)
            activation_sum += jnp.mean(jnp.abs(x))
            x = jax.nn.leaky_relu(x)

        return hk.Linear(self._latent_size)(x), activation_sum


class Decoder(hk.Module):
    """Decoder model."""

    def __init__(
        self,
        hidden_size: HiddenSize = 512,
        output_shape: Sequence[int] = MNIST_IMAGE_SHAPE,
    ):
        super().__init__()
        if not isinstance(hidden_size, tuple):
            hidden_size = (hidden_size,)
        self._hidden_size = hidden_size
        self._output_shape = output_shape

    def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
        activation_sum = 0
        for layer_size in reversed(self._hidden_size):
            z = hk.Linear(layer_size)(z)
            activation_sum += jnp.mean(jnp.abs(z))
            z = jax.nn.leaky_relu(z)

        return hk.Linear(np.prod(self._output_shape))(z), activation_sum


class AEOutput(NamedTuple):
    data: jnp.ndarray
    latent: jnp.ndarray
    act_sum: float


class AutoEncoder(hk.Module):
    """Main VAE model class, uses Encoder & Decoder under the hood."""

    def __init__(
        self,
        hidden_size: HiddenSize = 512,
        latent_size: int = 10,
        output_shape: Sequence[int] = MNIST_IMAGE_SHAPE,
    ):
        super().__init__()
        self._hidden_size = hidden_size
        self._latent_size = latent_size
        self._output_shape = output_shape

    def __call__(self, x: jnp.ndarray) -> AEOutput:
        x = x.astype(jnp.float32)
        z, act_sum_en = Encoder(self._hidden_size, self._latent_size)(x)
        data, act_sum_de = Decoder(self._hidden_size, self._output_shape)(z)
        act_sum_total = act_sum_en + act_sum_de + jnp.mean(jnp.abs(z))
        return AEOutput(data, z, act_sum_total)

In [13]:
@jax.jit
def loss_fn(params: hk.Params, batch) -> jnp.ndarray:
    output: AEOutput = model.apply(params, batch)
    return jnp.mean(jnp.square(batch - output.data)) + l1_coef * output.act_sum

In [14]:
@jax.jit
def update(
    params: hk.Params,
    opt_state: optax.OptState,
    batch: Batch,
) -> Tuple[hk.Params, optax.OptState]:
    """Single SGD update step."""
    grads = jax.grad(loss_fn)(params, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

In [None]:
for c in chromosomes:
    c_names = names.str.startswith(c + ":")
    c_data = f["test_multi_inputs"]["block0_values"][:, c_names.values]
    c_shape = c_names.sum()
    print(f"Applying for chromosome {c}: {c_shape}")
    model = hk.transform(
        lambda x: AutoEncoder(
            hidden_size=(c_shape, c_shape // 16),
            latent_size=64, # TODO: try different sizes, including 1
            output_shape=(c_shape,)
        )(x)
    )  # pylint: disable=unnecessary-lambda
    model = hk.without_apply_rng(model)
    optimizer = optax.adam(learning_rate)
    rng = jax.random.PRNGKey(1066)
    params = jnp.load(f"10_models/{c}.model.npy", allow_pickle=True).ravel()[0]
    output: AEOutput = model.apply(params, c_data)
    jnp.save(f"11_latent/{c}.latent", output.latent)

Applying for chromosome chr10: 11009
Applying for chromosome chr11: 10889
Applying for chromosome chr12: 11354
Applying for chromosome chr13: 6168
Applying for chromosome chr14: 7299
Applying for chromosome chr15: 7541
Applying for chromosome chr16: 6996
Applying for chromosome chr17: 9674
Applying for chromosome chr18: 4969
Applying for chromosome chr19: 7391
Applying for chromosome chr1: 21706
Applying for chromosome chr20: 6223
Applying for chromosome chr21: 2751
Applying for chromosome chr22: 4506
Applying for chromosome chr2: 19071
Applying for chromosome chr3: 16018
Applying for chromosome chr4: 11120
Applying for chromosome chr5: 12306
Applying for chromosome chr6: 13826
Applying for chromosome chr7: 11976
Applying for chromosome chr8: 10361
