<a href="https://colab.research.google.com/github/mlbfalchetti/Python/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def build_generator(batch_size, seq_length, load_generator_function, n_classes=1, n_samples=None, sequence_templates=None, batch_normalize_pwm=False, anneal_pwm_logits=False, validation_sample_mode='max', supply_inputs=False) :

	use_samples = True
	if n_samples is None :
		use_samples = False
		n_samples = 1

	sequence_class_input, sequence_class = None, None
	#Seed class input for all dense/embedding layers
	if not supply_inputs :
		sequence_class_input = Input(tensor=K.ones((batch_size, 1)), dtype='int32', name='sequence_class_seed')
		sequence_class = Lambda(lambda inp: K.cast(K.round(inp * K.random_uniform((batch_size, 1), minval=-0.4999, maxval=n_classes-0.5001)), dtype='int32'), name='lambda_rand_sequence_class')(sequence_class_input)
	else :
		sequence_class_input = Input(batch_shape=(batch_size, 1), dtype='int32', name='sequence_class_seed')
		sequence_class = Lambda(lambda inp: inp, name='lambda_rand_sequence_class')(sequence_class_input)


	#Get generated policy pwm logits (non-masked)
	generator_inputs, [raw_logits_1, raw_logits_2], extra_outputs = load_generator_function(batch_size, sequence_class, n_classes=n_classes, seq_length=seq_length, supply_inputs=supply_inputs)

	reshape_layer = Reshape((seq_length, 4, 1))
	
	onehot_template_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='zeros', name='template_dense')
	onehot_mask_dense = Embedding(n_classes, seq_length * 4, embeddings_initializer='ones', name='mask_dense')
	
	onehot_template = reshape_layer(onehot_template_dense(sequence_class))
	onehot_mask = reshape_layer(onehot_mask_dense(sequence_class))



	#Initialize Templating and Masking Lambda layer
	masking_layer = Lambda(mask_pwm, output_shape = (seq_length, 4, 1), name='masking_layer')

	#Batch Normalize PWM Logits
	if batch_normalize_pwm :
		raw_logit_batch_norm = BatchNormalization(name='policy_raw_logit_batch_norm')
		raw_logits_1 = raw_logit_batch_norm(raw_logits_1)
		raw_logits_2 = raw_logit_batch_norm(raw_logits_2)
	
	#Add Template and Multiply Mask
	pwm_logits_1 = masking_layer([raw_logits_1, onehot_template, onehot_mask])
	pwm_logits_2 = masking_layer([raw_logits_2, onehot_template, onehot_mask])
	
	#Compute PWMs (Nucleotide-wise Softmax)
	pwm_1 = Softmax(axis=-2, name='pwm_1')(pwm_logits_1)
	pwm_2 = Softmax(axis=-2, name='pwm_2')(pwm_logits_2)
	
	anneal_temp = None
	if anneal_pwm_logits :
		anneal_temp = K.variable(1.0)
		
		interpolated_pwm_1 = Lambda(lambda x: (1. - anneal_temp) * x + anneal_temp * 0.25)(pwm_1)
		interpolated_pwm_2 = Lambda(lambda x: (1. - anneal_temp) * x + anneal_temp * 0.25)(pwm_2)
		
		pwm_logits_1 = Lambda(lambda x: K.log(x / (1. - x)))(interpolated_pwm_1)
		pwm_logits_2 = Lambda(lambda x: K.log(x / (1. - x)))(interpolated_pwm_2)
	
	#Sample proper One-hot coded sequences from PWMs
	sampled_pwm_1, sampled_pwm_2, sampled_onehot_mask = None, None, None

	sample_func = sample_pwm
	if validation_sample_mode == 'sample' :
		sample_func = sample_pwm_only

	#Optionally tile each PWM to sample from and create sample axis
	if use_samples :
		pwm_logits_upsampled_1 = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(pwm_logits_1)
		pwm_logits_upsampled_2 = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(pwm_logits_2)
		sampled_onehot_mask = Lambda(lambda x: K.tile(x, [n_samples, 1, 1, 1]))(onehot_mask)

		sampled_pwm_1 = Lambda(sample_func, name='pwm_sampler_1')(pwm_logits_upsampled_1)
		#sampled_pwm_1 = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)))(sampled_pwm_1)
		sampled_pwm_1 = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_pwm_1)

		sampled_pwm_2 = Lambda(sample_func, name='pwm_sampler_2')(pwm_logits_upsampled_2)
		#sampled_pwm_2 = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)))(sampled_pwm_2)
		sampled_pwm_2 = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_pwm_2)

		
		#sampled_onehot_mask = Lambda(lambda x: K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4))(sampled_onehot_mask)
		sampled_onehot_mask = Lambda(lambda x: K.permute_dimensions(K.reshape(x, (n_samples, batch_size, seq_length, 4, 1)), (1, 0, 2, 3, 4)))(sampled_onehot_mask)

	else :
		sampled_pwm_1 = Lambda(sample_func, name='pwm_sampler_1')(pwm_logits_1)
		sampled_pwm_2 = Lambda(sample_func, name='pwm_sampler_2')(pwm_logits_2)
		sampled_onehot_mask = onehot_mask
	
	
	generator_model = Model(
		inputs=[
			sequence_class_input
		] + generator_inputs,
		outputs=[
			sequence_class,
			pwm_logits_1,
			pwm_logits_2,
			pwm_1,
			pwm_2,
			sampled_pwm_1,
			sampled_pwm_2

			,onehot_mask
			,sampled_onehot_mask
		] + extra_outputs
	)

	if sequence_templates is not None :
		initialize_sequence_templates(generator_model, sequence_templates)

	#Lock all generator layers except policy layers
	for generator_layer in generator_model.layers :
		generator_layer.trainable = False
		
		if 'policy' in generator_layer.name :
			generator_layer.trainable = True

	if anneal_pwm_logits :
		return 'genesis_generator', generator_model, anneal_temp
	return 'genesis_generator', generator_model