In [3]:
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import numpy as np
import itertools
from random import shuffle
from tensorflow_probability import distributions as tfd
# tfd = tfp.distributions

In [4]:
# sess = tf.InteractiveSession()

In [5]:
class CategoricalMixture(object):
    
    def __init__(self, K, L, m):
        self.K = K  # number of components
        self.L = L  # sequence length
        self.m = m  # number of categories
        
        self.z_logits = tf.Variable(tf.random.uniform((K,)), dtype=tf.float32)
        self.p_logits = [tf.Variable(tf.random.uniform((L, m,)), dtype=tf.float32) for _ in range(K)]
        self.z = tf.nn.softmax(self.z_logits)
        self.ps = [tf.nn.softmax(self.p_logits[i]) for i in range(K)]
        self.components = [tfd.Independent(tfd.Categorical(probs=self.ps[i]), reinterpreted_batch_ndims=1) for i in range(K)]
        self.model = tfd.Mixture(
            cat = tfd.Categorical(probs=self.z),
            components=self.components           
        )
        
        self.X_train = tf.placeholder(name="X_train",shape=[None, L], dtype=tf.float32)
        self.weights = tf.placeholder(name="loss_weights", shape=[None], dtype=tf.float32)
        self.loss = -tf.reduce_mean(self.weights * self.model.log_prob(self.X_train))
        self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        
    def _iterate_minibatches(self, inputs1, inputs2, batch_size, shuffle=True):
        if shuffle:
            indices = np.arange(inputs1.shape[0])
            np.random.shuffle(indices)
        for start_idx in range(0, inputs1.shape[0] - batch_size + 1, batch_size):
            if shuffle:
                excerpt = indices[start_idx:start_idx + batch_size]
            else:
                excerpt = slice(start_idx, start_idx + batch_size)
            yield inputs1[excerpt], inputs2[excerpt]
        
    def train(self, X, W, epochs=100, batch_size=10, shuffle=True, verbose=False, one_hot=False, print_every=100):
        if one_hot:
            X = np.argmax(X, axis=-1)
        for t in range(epochs):
            e_loss = 0
            n_batches = 0
            for batch in self._iterate_minibatches(X, W, batch_size, shuffle=shuffle):
                xi, wi = batch
                _, np_loss = self.sess.run([self.train_op, self.loss], feed_dict={self.X_train: xi, self.weights:wi})
                e_loss += np_loss
                n_batches += 1
            if verbose:
                if t % print_every == 0 or t == epochs-1:
                    print("Training loss at %i/%i: %.3f" % (t, epochs, e_loss/n_batches))
                
    def sample(self, n, one_hot=False):
        samples = self.model.sample(n).eval(session=self.sess)
        if one_hot:
            samples_one_hot = np.zeros((n, self.L, self.m))
            samples_one_hot[np.arange(n).reshape(n, 1), np.arange(self.L_), samples] = 1
            samples = samples_one_hot
        return samples

In [6]:
def get_all_sequences(length, arr=False):
    N = 4**length
    if arr:
        base = [np.array([[1, 0, 0, 0]]), np.array([[0, 1, 0, 0]]), 
                np.array([[0, 0, 1, 0]]), np.array([[0, 0, 0, 1]])]
    else:
        base = 'ATCG'
    seq_lists = list(itertools.product(base, repeat=length))
    if arr:
        all_seq = np.zeros((N, length, 4))
    else:
        all_seq = ["A" * length] * N
    
    for i in range(N):
        if i % int(10**6) == 0 and i > 0:
            print("Sequences constructed: %i / %i" % (i, N))
#         elif idx:
#             all_seq[i] = np.concqtenate(np.argmax(seq_lists[i], axis=-1), axis=0)
        if arr:
            all_seq[i] = np.concatenate(seq_lists[i], axis=0)
        else:
            all_seq[i] = "".join(seq_lists[i])
    return all_seq

In [7]:
X_all = get_all_sequences(10, arr=True)

Sequences constructed: 1000000 / 1048576


In [8]:
mix = CategoricalMixture(4, 10, 4)  

In [9]:
NUM_STEPS = 10000
N = 100

xt = X_all[200:200+int(N/4)]
xt = np.concatenate([xt, X_all[100000:100000+int(N/4)]])
xt = np.concatenate([xt, X_all[500000:500000+int(N/4)]])
xt = np.concatenate([xt, X_all[1000000:1000000+int(N/4)]])
w = np.random.rand(N)

mix.train(xt, w, epochs=500, batch_size=10, shuffle=True, one_hot=True, verbose=True, print_every=100)

Training loss at 0/500: 7.271
Training loss at 100/500: 4.344
Training loss at 200/500: 3.252
Training loss at 300/500: 2.890
Training loss at 400/500: 2.735
Training loss at 499/500: 2.657


In [11]:
xs = mix.sample(100, one_hot=False)
mix.model.log_prob(xs).eval(session=mix.sess)

array([ -4.72657  , -14.576713 ,  -9.229977 ,  -4.6774077, -10.106009 ,
        -4.4184985,  -5.7726083,  -4.44662  ,  -4.7952538,  -4.74765  ,
        -5.192702 ,  -5.8997326,  -4.2965436,  -5.4116755,  -5.487743 ,
        -5.1537056,  -5.2462425,  -5.080471 ,  -4.67406  ,  -5.569378 ,
        -4.9103756,  -4.408775 ,  -4.4054265,  -5.274855 ,  -5.2894278,
        -5.034159 ,  -4.6996064,  -8.545227 ,  -4.44662  ,  -4.8215613,
        -4.72657  ,  -6.099427 ,  -4.673725 ,  -5.2894278, -10.921806 ,
        -5.3745747,  -5.1537056,  -4.4054265, -10.177364 ,  -4.989058 ,
       -13.264703 ,  -4.8118377,  -4.5983047,  -4.537695 ,  -4.8678236,
        -4.8678236,  -5.511746 ,  -5.125141 ,  -5.6109385,  -5.872328 ,
        -4.408775 ,  -4.74765  ,  -9.093495 ,  -4.408775 ,  -5.3752794,
        -9.996938 ,  -5.1674185,  -5.040514 ,  -9.483656 , -10.359264 ,
       -10.35176  , -14.3752985,  -4.789639 ,  -4.41515  ,  -5.030811 ,
       -10.279336 ,  -4.6696653, -14.439966 ,  -4.649926 ,  -9.7