<a href="https://colab.research.google.com/github/mlbfalchetti/Python/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def build_generator(batch_size, seq_length, load_generator_function, n_classes=1, n_samples=None, sequence_templates=None, batch_normalize_pwm=False, anneal_pwm_logits=False, validation_sample_mode='max', supply_inputs=False) :

	sequence_class_input, sequence_class = None, None
	#Seed class input for all dense/embedding layers
	if not supply_inputs :
		sequence_class_input = Input(tensor=K.ones((batch_size, 1)), dtype='int32', name='sequence_class_seed')
		sequence_class = Lambda(lambda inp: K.cast(K.round(inp * K.random_uniform((batch_size, 1), minval=-0.4999, maxval=n_classes-0.5001)), dtype='int32'), name='lambda_rand_sequence_class')(sequence_class_input)
	else :
		sequence_class_input = Input(batch_shape=(batch_size, 1), dtype='int32', name='sequence_class_seed')
		sequence_class = Lambda(lambda inp: inp, name='lambda_rand_sequence_class')(sequence_class_input)

	#Get generated policy pwm logits (non-masked)
	generator_inputs, [raw_logits_1, raw_logits_2], extra_outputs = load_generator_function(batch_size, sequence_class, n_classes=n_classes, seq_length=seq_length, supply_inputs=supply_inputs)

	reshape_layer = Reshape((seq_length, 4, 1))
	
	onehot_template_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='zeros', name='template_dense')
	onehot_mask_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='ones', name='mask_dense')
	
	onehot_template = reshape_layer(onehot_template_dense(sequence_class))
	onehot_mask = reshape_layer(onehot_mask_dense(sequence_class))

	#Initialize Templating and Masking Lambda layer
	masking_layer = Lambda(mask_pwm, output_shape = (seq_length, 4, 1), name='masking_layer')

	#Batch Normalize PWM Logits
	if batch_normalize_pwm :
		raw_logit_batch_norm = BatchNormalization(name='policy_raw_logit_batch_norm')
		raw_logits_1 = raw_logit_batch_norm(raw_logits_1)
		raw_logits_2 = raw_logit_batch_norm(raw_logits_2)
	
	#Add Template and Multiply Mask
	pwm_logits_1 = masking_layer([raw_logits_1, onehot_template, onehot_mask])
	pwm_logits_2 = masking_layer([raw_logits_2, onehot_template, onehot_mask])
	
	#Compute PWMs (Nucleotide-wise Softmax)
	pwm_1 = Softmax(axis=-2, name='pwm_1')(pwm_logits_1)
	pwm_2 = Softmax(axis=-2, name='pwm_2')(pwm_logits_2)
	
	anneal_temp = None
	if anneal_pwm_logits :
		anneal_temp = K.variable(1.0)
		
		interpolated_pwm_1 = Lambda(lambda x: (1. - anneal_temp) * x + anneal_temp * 0.25)(pwm_1)
		interpolated_pwm_2 = Lambda(lambda x: (1. - anneal_temp) * x + anneal_temp * 0.25)(pwm_2)
		
		pwm_logits_1 = Lambda(lambda x: K.log(x / (1. - x)))(interpolated_pwm_1)
		pwm_logits_2 = Lambda(lambda x: K.log(x / (1. - x)))(interpolated_pwm_2)
	
	#Sample proper One-hot coded sequences from PWMs
	sampled_pwm_1, sampled_pwm_2, sampled_onehot_mask = None, None, None

	sample_func = sample_pwm
	if validation_sample_mode == 'sample' :
		sample_func = sample_pwm_only

	#Optionally tile each PWM to sample from and create sample axis
	if use_samples :
		pwm_logits_upsampled_1 = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(pwm_logits_1)
		pwm_logits_upsampled_2 = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(pwm_logits_2)
		sampled_onehot_mask = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(onehot_mask)

		sampled_pwm_1 = Lambda(sample_func, name='pwm_sampler_1')(pwm_logits_upsampled_1)
		#sampled_pwm_1 = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)))(sampled_pwm_1)
		sampled_pwm_1 = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_pwm_1)

		sampled_pwm_2 = Lambda(sample_func, name='pwm_sampler_2')(pwm_logits_upsampled_2)
		#sampled_pwm_2 = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)))(sampled_pwm_2)
		sampled_pwm_2 = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_pwm_2)

		
		#sampled_onehot_mask = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4))(sampled_onehot_mask)
		sampled_onehot_mask = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_onehot_mask)

	else :
		sampled_pwm_1 = Lambda(sample_func, name='pwm_sampler_1')(pwm_logits_1)
		sampled_pwm_2 = Lambda(sample_func, name='pwm_sampler_2')(pwm_logits_2)
		sampled_onehot_mask = onehot_mask
	
	
	generator_model = Model(
		inputs=[
			sequence_class_input
		] + generator_inputs,
		outputs=[
			sequence_class,
			pwm_logits_1,
			pwm_logits_2,
			pwm_1,
			pwm_2,
			sampled_pwm_1,
			sampled_pwm_2

			,onehot_mask
			,sampled_onehot_mask
		] + extra_outputs
	)

	if sequence_templates is not None :
		initialize_sequence_templates(generator_model, sequence_templates)

	#Lock all generator layers except policy layers
	for generator_layer in generator_model.layers :
		generator_layer.trainable = False
		
		if 'policy' in generator_layer.name :
			generator_layer.trainable = True

	if anneal_pwm_logits :
		return 'genesis_generator', generator_model, anneal_temp
	return 'genesis_generator', generator_model

In [1]:
!pip install isolearn

import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import regularizers
from keras import backend as K
import keras.losses

import tensorflow as tf
from tensorflow.python.framework import ops

import isolearn.keras as iso

import numpy as np

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
#Number of PWMs to generate per objective
batch_size = 36
#Number of One-hot sequences to sample from the PWM at each grad step
n_samples = 10
#Number of epochs per objective to optimize
n_epochs = 50
#Number of steps (grad updates) per epoch
steps_per_epoch = 500

n_classes = 1

supply_inputs = False

# Then: True

sequence_class_input = Input(tensor = K.ones((batch_size, 1)), dtype = "float32", name = "sequence_class_seed")
sequence_class = Lambda(lambda inp: K.cast(K.round(inp * K.random_uniform((batch_size, 1), minval = -0.4999, maxval = n_classes - 0.5001)), dtype = "float32"), name = "lambda_random_sequence_class")(sequence_class_input)

In [3]:
sequence_class_input

<KerasTensor: shape=(36, 1) dtype=float32 (created by layer 'sequence_class_seed')>

In [4]:
sequence_class

<KerasTensor: shape=(36, 1) dtype=float32 (created by layer 'lambda_random_sequence_class')>

In [5]:
sequence_templates = [
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG',
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG',
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG',
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG',
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG'
]

library_contexts = [
    'simple',
    'simple',
    'simple',
    'simple',
    'simple'
]

target_isos = [
    0.05,
    0.25,
    0.5,
    0.75,
    1.0
]

margin_similarities = [
    0.3,
    0.3,
    0.3,
    0.3,
    0.5
]

In [87]:
sequence_templates = [
    'TCCCTACACGACGCTCTTCCGATCTNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNANTAAANNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNAATAAATTGTTCGTTGGTCGGCTTGAGTGCGTGTGTCTCGTTTAGATGCTGCGCCTAACCCTAAGCAGATTCTTCATGCAATTG',
]

library_contexts = [
    'simple',
]

target_isos = [
    1.0
]

margin_similarities = [
    0.5
]

In [88]:
import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D, LSTM, ConvLSTM2D, GRU, BatchNormalization, LocallyConnected2D, Permute
from keras.layers import Concatenate, Reshape, Softmax, Conv2DTranspose, Embedding, Multiply
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import regularizers
from keras import backend as K
import keras.losses

import tensorflow as tf

import isolearn.keras as iso

import numpy as np

#GENESIS Generator Model definitions
def load_generator_network(batch_size, sequence_class, n_classes=1, seq_length=205, supply_inputs=False) :

	sequence_class_onehots = np.eye(n_classes)

	#Generator network parameters
	latent_size = 100
	
	#Generator inputs
	latent_input_1, latent_input_2, latent_input_1_out, latent_input_2_out = None, None, None, None
	if not supply_inputs :
		latent_input_1 = Input(tensor=K.ones((batch_size, latent_size)), name='noise_input_1')
		latent_input_2 = Input(tensor=K.ones((batch_size, latent_size)), name='noise_input_2')
		latent_input_1_out = Lambda(lambda inp: inp * K.random_uniform((batch_size, latent_size), minval=-1.0, maxval=1.0), name='lambda_rand_input_1')(latent_input_1)
		latent_input_2_out = Lambda(lambda inp: inp * K.random_uniform((batch_size, latent_size), minval=-1.0, maxval=1.0), name='lambda_rand_input_2')(latent_input_2)
	else :
		latent_input_1 = Input(batch_shape=K.ones(batch_size, latent_size), name='noise_input_1')
		latent_input_2 = Input(batch_shape=K.ones(batch_size, latent_size), name='noise_input_2')
		latent_input_1_out = Lambda(lambda inp: inp, name='lambda_rand_input_1')(latent_input_1)
		latent_input_2_out = Lambda(lambda inp: inp, name='lambda_rand_input_2')(latent_input_2)
	
	class_embedding = Lambda(lambda x: K.gather(K.constant(sequence_class_onehots), K.cast(x[:, 0], dtype='int32')))(sequence_class)

	seed_input_1 = Concatenate(axis=-1)([latent_input_1_out, class_embedding])
	seed_input_2 = Concatenate(axis=-1)([latent_input_2_out, class_embedding])
	
	
	#Policy network definition
	policy_dense_1 = Dense(21 * 384, activation='relu', kernel_initializer='glorot_uniform', name='policy_dense_1')
	
	policy_dense_1_reshape = Reshape((21, 1, 384))
	
	policy_deconv_0 = Conv2DTranspose(256, (7, 1), strides=(2, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='policy_deconv_0')
	
	policy_deconv_1 = Conv2DTranspose(192, (8, 1), strides=(2, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='policy_deconv_1')
	
	policy_deconv_2 = Conv2DTranspose(128, (7, 1), strides=(2, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='policy_deconv_2')
	
	policy_conv_3 = Conv2D(128, (8, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_conv_3')

	policy_conv_4 = Conv2D(64, (8, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_conv_4')

	policy_conv_5 = Conv2D(4, (8, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='policy_conv_5')

	#policy_deconv_3 = Conv2DTranspose(4, (7, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='policy_deconv_3')
	
	batch_norm_0 = BatchNormalization(name='policy_batch_norm_0')
	relu_0 = Lambda(lambda x: K.relu(x))
	batch_norm_1 = BatchNormalization(name='policy_batch_norm_1')
	relu_1 = Lambda(lambda x: K.relu(x))
	batch_norm_2 = BatchNormalization(name='policy_batch_norm_2')
	relu_2 = Lambda(lambda x: K.relu(x))

	batch_norm_3 = BatchNormalization(name='policy_batch_norm_3')
	relu_3 = Lambda(lambda x: K.relu(x))

	batch_norm_4 = BatchNormalization(name='policy_batch_norm_4')
	relu_4 = Lambda(lambda x: K.relu(x))

	policy_out_1 = Reshape((seq_length, 4, 1))(policy_conv_5(relu_4(batch_norm_4(policy_conv_4(relu_3(batch_norm_3(policy_conv_3(relu_2(batch_norm_2(policy_deconv_2(relu_1(batch_norm_1(policy_deconv_1(relu_0(batch_norm_0(policy_deconv_0(policy_dense_1_reshape(policy_dense_1(seed_input_1)))))))))))))))))))
	policy_out_2 = Reshape((seq_length, 4, 1))(policy_conv_5(relu_4(batch_norm_4(policy_conv_4(relu_3(batch_norm_3(policy_conv_3(relu_2(batch_norm_2(policy_deconv_2(relu_1(batch_norm_1(policy_deconv_1(relu_0(batch_norm_0(policy_deconv_0(policy_dense_1_reshape(policy_dense_1(seed_input_2)))))))))))))))))))
	
	return [latent_input_1, latent_input_2], [policy_out_1, policy_out_2], []

In [89]:
seq_length = len(sequence_templates[0])

load_generator_function = load_generator_network

#Get generated policy pwm logits (non-masked)
generator_inputs, [raw_logits_1, raw_logits_2], extra_outputs = load_generator_function(batch_size, sequence_class, n_classes = n_classes, seq_length = seq_length, supply_inputs = supply_inputs)

In [90]:
generator_inputs

[<KerasTensor: shape=(36, 100) dtype=float32 (created by layer 'noise_input_1')>,
 <KerasTensor: shape=(36, 100) dtype=float32 (created by layer 'noise_input_2')>]

In [91]:
raw_logits_1

<KerasTensor: shape=(36, 205, 4, 1) dtype=float32 (created by layer 'reshape_5')>

In [92]:
raw_logits_2

<KerasTensor: shape=(36, 205, 4, 1) dtype=float32 (created by layer 'reshape_6')>

In [93]:
extra_outputs

[]

In [94]:
reshape_layer = Reshape((seq_length, 4, 1))
	
onehot_template_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer = 'zeros', name = 'template_dense')
onehot_mask_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer = 'ones', name = 'mask_dense')
	
onehot_template = reshape_layer(onehot_template_dense(sequence_class))
onehot_mask = reshape_layer(onehot_mask_dense(sequence_class))

In [95]:
onehot_template

<KerasTensor: shape=(36, 205, 4, 1) dtype=float32 (created by layer 'reshape_7')>

In [96]:
onehot_mask

<KerasTensor: shape=(36, 205, 4, 1) dtype=float32 (created by layer 'reshape_7')>

In [97]:
def mask_pwm(inputs) :
	pwm, onehot_template, onehot_mask = inputs

	return pwm * onehot_mask + onehot_template

# Initialize templating and masking lambda layer
masking_layer = Lambda(mask_pwm, output_shape = (seq_length, 4, 1), name = "masking_layer")

In [98]:
batch_normalize_pwm = False

# Batch Normalize PWM Logits
if batch_normalize_pwm :
  raw_logit_batch_norm = BatchNormalization(name = "policy_raw_logit_batch_norm")
  raw_logits_1 = raw_logit_batch_norm(raw_logits_1)
  raw_logits_2 = raw_logit_batch_norm(raw_logits_2)

In [99]:
# Add Template and Multiply Mask
pwm_logits_1 = masking_layer([raw_logits_1, onehot_template, onehot_mask])
pwm_logits_2 = masking_layer([raw_logits_2, onehot_template, onehot_mask])

In [100]:
#Compute PWMs (Nucleotide-wise Softmax)
pwm_1 = Softmax(axis = -2, name = 'pwm_1')(pwm_logits_1)
pwm_2 = Softmax(axis = -2, name = 'pwm_2')(pwm_logits_2)	

In [101]:
#Sample proper One-hot coded sequences from PWMs
sampled_pwm_1, sampled_pwm_2, sampled_onehot_mask = None, None, None

def st_sampled_softmax(logits):
	with ops.name_scope("STSampledSoftmax") as namescope :
		nt_probs = tf.nn.softmax(logits)
		onehot_dim = logits.get_shape().as_list()[1]
		sampled_onehot = tf.one_hot(tf.squeeze(tf.random.categorical(logits, 1), 1), onehot_dim, 1.0, 0.0)
		with tf.Graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
			return tf.math.ceil(sampled_onehot * nt_probs)

def st_hardmax_softmax(logits):
	with ops.name_scope("STHardmaxSoftmax") as namescope :
		nt_probs = tf.nn.softmax(logits)
		onehot_dim = logits.get_shape().as_list()[1]
		sampled_onehot = tf.one_hot(tf.argmax(nt_probs, 1), onehot_dim, 1.0, 0.0)
		with tf.Graph().gradient_override_map({'Ceil': 'Identity', 'Mul': 'STMul'}):
			return tf.math.ceil(sampled_onehot * nt_probs)

#@ops.RegisterGradient("STMul")
#def st_mul(op, grad):
#	return [grad, grad]

def sample_pwm(pwm_logits) :
	n_sequences = K.shape(pwm_logits)[0]
	seq_length = K.shape(pwm_logits)[1]
	
	flat_pwm = K.reshape(pwm_logits, (n_sequences * seq_length, 4))
	sampled_pwm = K.switch(tf.convert_to_tensor(K.learning_phase(), dtype = tf.int32), st_sampled_softmax(flat_pwm), st_hardmax_softmax(flat_pwm))
 
sample_func = sample_pwm

validation_sample_mode = 'max'

#if validation_sample_mode == 'sample' :
#  sample_func = sample_pwm_only

sampled_pwm_1 = Lambda(sample_func, name='pwm_sampler_1')(pwm_logits_1)
sampled_pwm_2 = Lambda(sample_func, name='pwm_sampler_2')(pwm_logits_2)
sampled_onehot_mask = onehot_mask

In [50]:
nt_probs = tf.nn.softmax(flat_pwm)
onehot_dim = flat_pwm.get_shape().as_list()[1]

#tf.compat.v1.distributions.Multinomial(flat_pwm, 1)

#tf.multinomial(flat_pwm, 1)
#tf.random.categorical(flat_pwm, 1)
#tf.squeeze(tf.random.categorical(flat_pwm, 1), 1)
tf.one_hot(tf.squeeze(tf.random.categorical(flat_pwm, 1), 1), onehot_dim, 1.0, 0.0)

#sampled_onehot = tf.one_hot(tf.squeeze(tf.compat.v1.distributions.Multinomial(flat_pwm, 1), 1), onehot_dim, 1.0, 0.0)

<KerasTensor: shape=(7380, 4) dtype=float32 (created by layer 'tf.one_hot_1')>

In [69]:
n_sequences = K.shape(pwm_logits_1)[0]
seq_length = K.shape(pwm_logits_1)[1]
	
flat_pwm = K.reshape(pwm_logits_1, (n_sequences * seq_length, 4))
#sampled_pwm = K.switch(K.learning_phase(), st_sampled_softmax(flat_pwm), st_hardmax_softmax(flat_pwm))
sampled_pwm = K.switch(tf.convert_to_tensor(K.learning_phase(), dtype = tf.int32), st_sampled_softmax(flat_pwm), st_hardmax_softmax(flat_pwm))

In [102]:
sampled_onehot_mask

<KerasTensor: shape=(36, 205, 4, 1) dtype=float32 (created by layer 'reshape_7')>

In [103]:
generator_model = Model(
	inputs=[
		sequence_class_input
	] + generator_inputs,
	outputs=[
		sequence_class,
		pwm_logits_1,
		pwm_logits_2,
		pwm_1,
		pwm_2,
		sampled_pwm_1,
		sampled_pwm_2,
    onehot_mask,
    sampled_onehot_mask
	] + extra_outputs
)

In [104]:
def initialize_sequence_templates(generator, sequence_templates) :

	embedding_templates = []
	embedding_masks = []

	for k in range(len(sequence_templates)) :
		sequence_template = sequence_templates[k]
		onehot_template = iso.OneHotEncoder(seq_length=len(sequence_template))(sequence_template).reshape((len(sequence_template), 4, 1))
		
		for j in range(len(sequence_template)) :
			if sequence_template[j] not in ['N', 'X'] :
				nt_ix = np.argmax(onehot_template[j, :, 0])
				onehot_template[j, :, :] = -4.0
				onehot_template[j, nt_ix, :] = 10.0
			elif sequence_template[j] == 'X' :
				onehot_template[j, :, :] = -1.0

		onehot_mask = np.zeros((len(sequence_template), 4, 1))
		for j in range(len(sequence_template)) :
			if sequence_template[j] == 'N' :
				onehot_mask[j, :, :] = 1.0
		
		embedding_templates.append(onehot_template.reshape(1, -1))
		embedding_masks.append(onehot_mask.reshape(1, -1))

	embedding_templates = np.concatenate(embedding_templates, axis=0)
	embedding_masks = np.concatenate(embedding_masks, axis=0)

	generator.get_layer('template_dense').set_weights([embedding_templates])
	generator.get_layer('template_dense').trainable = False

	generator.get_layer('mask_dense').set_weights([embedding_masks])
	generator.get_layer('mask_dense').trainable = False

initialize_sequence_templates(generator_model, sequence_templates)

In [111]:
for generator_layer in generator_model.layers :
  generator_layer.trainable = False
  if 'policy' in generator_layer.name :
    generator_layer.trainable = True

In [114]:
anneal_pwm_logits=False

if anneal_pwm_logits :
	return 'genesis_generator', generator_model, anneal_temp
return 'genesis_generator', generator_model

SyntaxError: ignored