# ChromStateTransferVAE

An attempt at getting to a variational encoder to annotate chromatin state in samples of interest with small number of signals by using
state annotation from state annotations from reference epigenomes

## The problem statement

The data for the VAE will consist of:

- $m$ is a signal matrix of size $L$ genomic positions by $N$ number of observed signals
- $r$ is a reference epigenome state annotation indicator matrix of size $L$, $R$ (number of references) $S$ number of states.

The goal is to generate a state annotation (from set $S$) for signal matrix $m$

## Generating simulated data

To generate data we will use the following generative model. Simulation parameters are:

- $\alpha$ prior over reference
- $r$ reference state indicator matrix
- $p$ state to signal parameter matrix

with those, pseudo code for the generative model for a specific genomic location is given by

```
1. generate probability distribution over references 
   pi (shape: (num_references,)) ~ Dirichlet(alpha)
2. collapse pi to a probability distribution over states 
   collapsed_pi (shape: (num_states,)) = r * pi (i don't think these dimensions match)
3. generate sample state z (shape (1,)) ~ Categorical(collapsed_pi)
4. generate signal vector m (shape: (num_signals,)) ~ Bernoulli(p[z,:])
```

### Prior over reference

$\alpha$ a vector of size $R$ used as prior parameter of a Categorical distribution of references, may be interpreted with how similar the sample of
interest is to each reference. Here are some useful cases:



In [1]:
import torch

# case 1: sample of interest is essentially identical to one of the references
def generate_single_reference_alpha(num_references, w=10):
    alpha = tensor.ones((num_references))
    alpha[0] = w
    return alpha

# case 2: sample of interest is equally similar to a small number of the references
def generate_batch_reference_alpha(num_references, batch_size, w=10):
    alpha = tensor.ones((num_references))
    alpha[0:batch_size] = w
    return alpha

# note case 1 is a special case of case 2, so only a single function needed

### Reference state indicator matrix

This encodes the state annotation along the genome for each of the references

In [17]:
import math
import torch.nn.functional as F

# helper function to create an indicator matrix from an assignment matrix
def generate_indicator(assignments, num_states):
    return F.one_hot(assignments.long(), num_classes=num_states)

# case 1: there is a single reference that matters
def generate_single_reference_r(num_positions, num_references, num_states, state_sequence_length=3):
    assignments = torch.zeros(num_positions, num_references)
    # the important refernces has a unique state sequence
    state_sequence = torch.arange(state_sequence_length)
    num_times = math.ceil(num_positions / len(state_sequence))
    state_sequence = state_sequence.repeat(num_times)
    assignments[:,0] = state_sequence[:num_positions]

    random_assignments = torch.multinomial(
            torch.ones(num_states), 
            num_samples=num_positions*(num_references - 1),
            replacement=True)\
        .reshape((num_positions,-1))
        
    assignments[:,1:] = random_assignments
    return generate_indicator(assignments, num_states)

In [19]:
generate_single_reference_r(20, 10, 8)[:3,:,:]

tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1]],

        [[0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 1, 

In [84]:
# case 2: there is a batch of similar references, but, their state sequences
# are identical
def generate_batch_identical_references_r(
        num_positions, 
        num_references, 
        num_states, 
        batch_size, 
        state_sequence_length=3):
    assignments = torch.zeros(num_positions, num_references)
    # the important references has a unique state sequence
    num_times = math.ceil(num_positions / state_sequence_length)
    state_sequence = torch.arange(state_sequence_length)\
        .repeat_interleave(batch_size)\
        .reshape(-1,batch_size)\
        .repeat((num_times,1))[:num_positions,:]
    assignments[:,0:batch_size] = state_sequence
    random_assignments = torch.multinomial(
            torch.ones(num_states), 
            num_samples=num_positions*(num_references - batch_size),
            replacement=True)\
        .reshape((num_positions,-1))
        
    assignments[:,batch_size:] = random_assignments
    return generate_indicator(assignments, num_states)

In [83]:
generate_batch_identical_references_r(20,10,8,3)[:3,:,:]

tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0]],

        [[0, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1]],

        [[0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [0, 0, 

In [93]:
# case 3: there is a batch of references the sample is similar to
# but their state sequences are different (this is where it's a mixture)
def generate_batch_different_references_r(
        num_positions,
        num_references,
        num_states,
        batch_size,
        state_sequence_length=3):
    assignments = torch.zeros(num_positions, num_references)
    num_times = math.ceil(num_positions / state_sequence_length)
    state_sequence = torch.arange(state_sequence_length)\
        .repeat_interleave(batch_size)\
        .reshape(-1, batch_size)\
        .add(torch.arange(batch_size).repeat(state_sequence_length,1))\
        .repeat((num_times,1))[:num_positions,:]
    assignments[:,0:batch_size] = state_sequence
    random_assignments = torch.multinomial(
        torch.ones(num_states),
        num_samples=num_positions * (num_references - batch_size),
        replacement=True)\
    .reshape((num_positions,-1))
    
    assignments[:,batch_size:] = random_assignments
    return generate_indicator(assignments, num_states)
        

In [94]:
generate_batch_different_references_r(20,10,8,3)[:3,:,:]

tensor([[[1, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0]],

        [[0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 1, 0, 0]],

        [[0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 1, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 1],
         [1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 

### state to signal parameter matrix

This is a matrix of shape $S$\times$N$ (number of states by number of signals)

In [140]:
# helper function to turn a state number to a signal pattern
def state_to_signal_pattern(state, num_signals):
    return torch.tensor([int(x) for x in list('{0:03b}'.format(state+1).zfill(num_signals))])

# case 1: the "important states" are perfectly specified by the signal pattern
# signal for other states is random (0.5)
def generate_specific_signal_p(
        num_states,
        num_signals,
        num_important_states,
        weight=10):
    assert num_important_states < 2**num_signals
    p = 0.5 * torch.ones((num_states,num_signals))
    for i in range(num_important_states):
        p[i,:] = state_to_signal_pattern(i, num_signals)
    p = weight * (2*p - 1)
    return torch.sigmoid(p)

In [141]:
generate_specific_signal_p(12,3,4)

tensor([[4.5398e-05, 4.5398e-05, 9.9995e-01],
        [4.5398e-05, 9.9995e-01, 4.5398e-05],
        [4.5398e-05, 9.9995e-01, 9.9995e-01],
        [9.9995e-01, 4.5398e-05, 4.5398e-05],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01]])

In [145]:
# case 2: the "similar states" are perfectly specified by the signal
# pattern, but one other "non-similar state" shares the same
# signal pattern
def generate_similar_signal_p(
        num_states,
        num_signals,
        num_important_states,
        weight=10):
    assert 2 * num_important_states < 2**num_signals
    p = 0.5 * torch.ones((num_states, num_signals))
    for i in range(num_important_states):
        p[i,:] = state_to_signal_pattern(i, num_signals)
        p[i+num_important_states,:] = p[i,:]
    p = weight * (2*p - 1)
    return torch.sigmoid(p)

In [147]:
generate_similar_signal_p(12,3,3)

tensor([[4.5398e-05, 4.5398e-05, 9.9995e-01],
        [4.5398e-05, 9.9995e-01, 4.5398e-05],
        [4.5398e-05, 9.9995e-01, 9.9995e-01],
        [4.5398e-05, 4.5398e-05, 9.9995e-01],
        [4.5398e-05, 9.9995e-01, 4.5398e-05],
        [4.5398e-05, 9.9995e-01, 9.9995e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01],
        [5.0000e-01, 5.0000e-01, 5.0000e-01]])

Ok, let's refactor some of these and create classes to encapsulate all of these behaviors

In [154]:
class AlphaGenerator:
    def __init__(self, num_references, w=10):
        self.num_references = num_references
        self.w = w
        
    def init_alpha(self):
        return torch.ones((self.num_references))
    
    def set_alpha(self, alpha):
        alpha[0] = self.w
        return alpha
    
    def generate(self):
        return self.update_alpha(self.init_alpha())
    
    def update_alpha(self, alpha):
        return alpha
    
class SingleAlphaGenerator(AlphaGenerator):
    def __init__(self, num_references, w=10):
        super().__init__(num_references, w)
        
class BatchAlphaGenerator(AlphaGenerator):
    def __init__(self, num_references, batch_size, w=10):
        super().__init__(num_references, w)
        self.batch_size = batch_size
        
    def update_alpha(self, alpha):
        alpha[:self.batch_size] = w
        return alpha
    

In [155]:
print(SingleAlphaGenerator(12).generate())

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
