In [None]:
%cd '/content/drive/MyDrive/Colab Notebooks/1/'

/content/drive/MyDrive/Colab Notebooks/1


In [None]:
!pip install tensorflow_addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras.layers as layers
import matplotlib.pyplot as plt
from datetime import datetime
import argparse
import cv2 as cv
import os
import numpy as np
import shutil

In [None]:
# generator

input_shape = (256, 256, 1)

def downsample(filters, size, apply_batchnorm=True):
	initializer = tf.random_normal_initializer(0., 0.02)

	result = tf.keras.Sequential()
	result.add(
		tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
								kernel_initializer=initializer, use_bias=False))

	if apply_batchnorm:
		result.add(tf.keras.layers.BatchNormalization())

	result.add(tf.keras.layers.LeakyReLU())

	return result

def upsample(filters, size, apply_dropout=False):
	initializer = tf.random_normal_initializer(0., 0.02)

	result = tf.keras.Sequential()
	result.add(
	tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
									padding='same',
									kernel_initializer=initializer,
									use_bias=False))

	result.add(tf.keras.layers.BatchNormalization())

	if apply_dropout:
		result.add(tf.keras.layers.Dropout(0.5))

	result.add(tf.keras.layers.ReLU())

	return result

class Generator(tf.keras.models.Model):
	def __init__(self):
		super(Generator, self).__init__()

		output_channels = 1
		inputs = layers.Input(shape=input_shape)

		down_stack = [
			downsample(64, 4, apply_batchnorm=False),  # (bs, 128, 128, 64)
			downsample(128, 4),  # (bs, 64, 64, 128)
			downsample(256, 4),  # (bs, 32, 32, 256)
			downsample(512, 4),  # (bs, 16, 16, 512)
			downsample(512, 4),  # (bs, 8, 8, 512)
			downsample(512, 4),  # (bs, 4, 4, 512)
			downsample(512, 4),  # (bs, 2, 2, 512)
			downsample(512, 4),  # (bs, 1, 1, 512)
		]

		up_stack = [
			upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
			upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
			upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
			upsample(512, 4),  # (bs, 16, 16, 1024)
			upsample(256, 4),  # (bs, 32, 32, 512)
			upsample(128, 4),  # (bs, 64, 64, 256)
			upsample(64, 4),  # (bs, 128, 128, 128)
		]

		initializer = tf.random_normal_initializer(0., 0.02)
		last = tf.keras.layers.Conv2DTranspose(output_channels, 4,
												strides=2,
												padding='same',
												kernel_initializer=initializer,
												activation='tanh')  # (bs, 256, 256, 3)

		x = inputs

		# Downsampling through the model
		skips = []
		for down in down_stack:
			x = down(x)
			skips.append(x)

		skips = reversed(skips[:-1])

		# Upsampling and establishing the skip connections
		for up, skip in zip(up_stack, skips):
			x = up(x)
			x = tf.keras.layers.Concatenate()([x, skip])

		x = last(x)


		self.model = tf.keras.Model(inputs=inputs, outputs=x)

	@tf.function
	def call(self, input, training=True):
		return self.model(input, training=training)


if False:
	gen = Generator()

	x = tf.random.uniform((32, input_shape[0], input_shape[1]), minval=-1, maxval=1)
	x = tf.ones((64, input_shape[0], input_shape[1]))


	y = gen(x)
	print(y.shape)

	print(gen.model.summary())

In [None]:
# discriminator

input_shape = (256, 256, 1)

def downsample(filters, size, apply_batchnorm=True):
	initializer = tf.random_normal_initializer(0., 0.02)

	result = tf.keras.Sequential()
	result.add(
		tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
								kernel_initializer=initializer, use_bias=False))

	if apply_batchnorm:
		result.add(tf.keras.layers.BatchNormalization())

	result.add(tf.keras.layers.LeakyReLU())

	return result

def upsample(filters, size, apply_dropout=False):
	initializer = tf.random_normal_initializer(0., 0.02)

	result = tf.keras.Sequential()
	result.add(
	tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
									padding='same',
									kernel_initializer=initializer,
									use_bias=False))

	result.add(tf.keras.layers.BatchNormalization())

	if apply_dropout:
		result.add(tf.keras.layers.Dropout(0.5))

	result.add(tf.keras.layers.ReLU())

	return result

class Discriminator(tf.keras.models.Model):
	def __init__(self):
		super(Discriminator, self).__init__()


		initializer = tf.random_normal_initializer(0., 0.02)
		inp = tf.keras.layers.Input(shape=input_shape, name='input_image')
		tar = tf.keras.layers.Input(shape=input_shape, name='target_image')

		x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

		down1 = downsample(64, 4, False)(x)  # (bs, 128, 128, 64)
		down2 = downsample(128, 4)(down1)  # (bs, 64, 64, 128)
		down3 = downsample(256, 4)(down2)  # (bs, 32, 32, 256)

		zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
		conv = tf.keras.layers.Conv2D(512, 4, strides=1,
										kernel_initializer=initializer,
										use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

		batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

		leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

		zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

		last = tf.keras.layers.Conv2D(1, 4, strides=1,
										kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

		self.model = tf.keras.Model(inputs=[inp, tar], outputs=last)

	@tf.function
	def call(self, input, training=True):
		return self.model(input, training=training)


if False:
	disc = Discriminator()

	x = tf.random.uniform((32, input_shape[0], input_shape[1]), minval=-1, maxval=1)
	x = tf.ones((64, input_shape[0], input_shape[1]))


	y = disc([x,x])
	print(y)

	print(disc.model.summary())

In [None]:
# sketch

def skeletonize(img):
	skel = np.zeros(img.shape, np.uint8)

	# Get a Cross Shaped Kernel
	element = cv.getStructuringElement(cv.MORPH_CROSS, (3,3))

	# Repeat steps 2-4
	while True:
		#Step 2: Open the image
		open = cv.morphologyEx(img, cv.MORPH_OPEN, element)
		#Step 3: Substract open from the original image
		temp = cv.subtract(img, open)
		#Step 4: Erode the original image and refine the skeleton
		eroded = cv.erode(img, element)
		skel = cv.bitwise_or(skel,temp)
		img = eroded.copy()
		# Step 5: If there are no white pixels left ie.. the image has been completely eroded, quit the loop
		if cv.countNonZero(img)==0:
			break
	return skel

def gradients(image, use_sobel=True):
	if use_sobel:
		gx = cv.Sobel(image,cv.CV_64F,1,0,ksize=3,scale=0.5)
		gy = cv.Sobel(image,cv.CV_64F,0,1,ksize=3,scale=0.5)
	else:
		gx = cv.Scharr(image,cv.CV_64F,1,0)
		gy = cv.Scharr(image,cv.CV_64F,0,1)

	g = np.sqrt(gx**2 + gy**2)
	theta = np.arctan2(gy, gx)

	return g, theta

def heightmap_sketch(src):
	#src = cv.imread(cv.samples.findFile(image))
	src = src.astype(np.uint8)

	res = src


	ksize = 31
	res = cv.GaussianBlur(res, (ksize, ksize), 0)
	res,t = gradients(res)
	#imgs.append(('orig', res))
	#res = cv.Laplacian(res, cv.CV_64F, ksize=5)
	res = res.astype(np.uint8)
	#imgs.append(('orig', res))
	#res,y = gradients(res)


	#res = erode(res, 15, cv.MORPH_ELLIPSE)
	#res = morph(res, 3, 30, cv.MORPH_ELLIPSE)

	#res = cv.bitwise_not(res)
	#res = cv.subtract(src, res)

	#ret, thresh = cv.threshold(res, 25, 255, cv.THRESH_BINARY)

	#res = cv.subtract(src, res)

	ksize = 5
	res = cv.GaussianBlur(res, (ksize, ksize), 0)
	res = cv.ximgproc.RidgeDetectionFilter_create(ksize=3).getRidgeFilteredImage(res)
	#imgs.append(('orig', res))
	ksize = 7
	res = cv.GaussianBlur(res, (ksize, ksize), 0)
	res = erode(res, 1, cv.MORPH_ELLIPSE)
	#imgs.append(('orig', res))
	ret, res = cv.threshold(res, 19, 255, cv.THRESH_BINARY)
	#imgs.append(('orig', res))
	res = cv.bitwise_and(src,res)
	#imgs.append(('orig', res))
	res = skeletonize(res)
	ret, res = cv.threshold(res, 25, 0, cv.THRESH_TOZERO)
	#imgs.append(('orig', res))
	#res = cv.Canny(res,10,220, apertureSize=5)

	#res = np.array([[1,1,1],[1,100,200],[300,1,1]])

	#res = steepestAscent(src)

	#cv.imwrite('out.png', res)


	return res



def morph_shape(val):
	if val == 0:
		return cv.MORPH_RECT
	elif val == 1:
		return cv.MORPH_CROSS
	elif val == 2:
		return cv.MORPH_ELLIPSE

def morph(target, op, size, shape):
	erosion_size = size
	erosion_shape = shape

	element = cv.getStructuringElement(erosion_shape, (2 * erosion_size + 1, 2 * erosion_size + 1),
									   (erosion_size, erosion_size))
	return cv.morphologyEx(target, op, element)

def erode(target, size, shape):
	erosion_size = size
	erosion_shape = shape

	element = cv.getStructuringElement(erosion_shape, (2 * erosion_size + 1, 2 * erosion_size + 1),
									   (erosion_size, erosion_size))
	return cv.erode(target, element)


def dilate(target, size, shape):
	dilatation_size = size
	dilation_shape = shape

	element = cv.getStructuringElement(dilation_shape, (2 * dilatation_size + 1, 2 * dilatation_size + 1),
									   (dilatation_size, dilatation_size))
	return cv.dilate(target, element)


def convert_to_labels():
	dir_src = 'out256_clip'
	dir_label = 'out256_clip_label'
	for file in os.listdir(dir_src):
		path_src = dir_src + '/' + file
		path_dst = dir_label + '/' + file

		image = Image.open(path_src)
		image = np.asarray(image)

		res = heightmap_sketch(image)

		res = Image.fromarray(res)
		res.save(path_dst)

In [None]:
'''
	casts image data from uint8 to float and maps it into -1 to 1 range
'''
def img_to_network(img):
	return (tf.cast(img, tf.float32) / 127.5) - 1.0

'''
	coverts image data from float in range -1 to 1 to uint8
'''
def restore_img(img):
	return tf.cast((img + 1.0) * 127.5, tf.uint8)

def split_image(img, tile_shape):
	img_shape = tf.shape(img)
	tile_rows = tf.reshape(img, [img_shape[0], -1, tile_shape[1], img_shape[2]])
	serial_tiles = tf.transpose(tile_rows, [1, 0, 2, 3])
	return tf.reshape(serial_tiles, [-1, tile_shape[1], tile_shape[0], img_shape[2]])

def heightmap_sketch_tensor(img):
 	return heightmap_sketch(img.numpy())

def preprocess_dataset(img_path):
	img = tf.io.read_file(img_path)
	img = tf.io.decode_png(img, channels=1, dtype=tf.uint8)

	hm = img_to_network(img)
	hm_sketch = img_to_network(tf.py_function(func=heightmap_sketch_tensor, inp=[img], Tout=tf.uint8))
	#hm_sketch = tf.random.uniform(shape=(256,256))
	return hm, tf.expand_dims(hm_sketch,axis=2)




def load_dataset(dataset_folder_paths):
	img_paths = []
	for folder_path in dataset_folder_paths:
		img_paths += [folder_path + img_name for img_name in os.listdir(folder_path) if img_name.endswith('.png')]

	AUTOTUNE = tf.data.experimental.AUTOTUNE

	ds_len = len(img_paths)
	ds = tf.data.Dataset.from_tensor_slices(img_paths)
	ds = ds.shuffle(ds_len, reshuffle_each_iteration=True)
	ds = ds.map(preprocess_dataset, AUTOTUNE)
	return ds, ds_len

In [None]:
class DataAugmentation():
	def __init__(self):
		shape = (256,256,1)
		self.prev_hm = tf.constant(-1, shape=shape, dtype=tf.float32)
		self.prev_hm_sketch = tf.constant(-1, shape=shape, dtype=tf.float32)

		self.rot_aug_prob = 0.35
		self.merge_aug_prob = 0.15
		self.identity_map_prob = 0.08


	def augment(self, hm, hm_sketch):
		p = tf.random.uniform(shape=(), minval=0, maxval=1)

		if False: # TURN OFF AUG
			if p < self.identity_map_prob:
				return self.aug_identity_map(hm, hm_sketch)

			if p < self.rot_aug_prob:
				hm, hm_sketch = self.aug_rotate_rand(hm, hm_sketch)

			if p < self.merge_aug_prob:
				hm, hm_sketch = self.aug_merge_maps(hm, hm_sketch)

			self.prev_hm = hm
			self.prev_hm_sketch = hm_sketch
		return hm, hm_sketch

	def aug_identity_map(self, hm, hm_sketch):
		value = tf.constant(-1, shape=hm.shape, dtype=tf.float32)
		return value, value

	def aug_rotate_rand(self, hm, hm_sketch):
		rand_k = tf.random.uniform(shape=[], minval=1, maxval=4, dtype=tf.int32) # k = <1,2,3>
		hm = tf.image.rot90(hm, k=rand_k)
		hm_sketch = tf.image.rot90(hm_sketch, k=rand_k)

		return hm, hm_sketch

	def aug_merge_maps(self, hm, hm_sketch):
		hm = tf.math.maximum(hm, self.prev_hm)
		hm_sketch = tf.math.maximum(hm_sketch, self.prev_hm_sketch)
		# TODO actually math.max is theoretically less correct than doing sketch on combined but results seem better
		#hm_sketch = tf.map_fn(data.heightmap_sketch_tensor_network, data.restore_img(hm), dtype=tf.float32)

		return hm, hm_sketch

In [None]:
def train(dataset, ds_len, epochs=100, preview_epochs=0, model_save_freq_epochs=30, logs_dir=None):

	train_summary_writer = tf.summary.create_file_writer(logs_dir)
	metric_gen_loss = tf.keras.metrics.Mean('gen_loss', dtype=tf.float32)
	metric_disc_loss = tf.keras.metrics.Mean('disc_loss', dtype=tf.float32)

	gen = Generator()
	disc = Discriminator()

	aug = DataAugmentation()

	learning_rate = 0.0002
	gen_optimizer = tf.keras.optimizers.Adam(learning_rate, 0.5)
	disc_optimizer = tf.keras.optimizers.Adam(learning_rate, 0.5)

	bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
	l1 = tf.keras.losses.MeanAbsoluteError()
	l2 = tf.keras.losses.MeanSquaredError()

	batch_size = 4

	dataset = dataset.repeat()
	dataset = dataset.batch(batch_size)
	dataset = dataset.as_numpy_iterator()

	train_iter = ds_len // batch_size

	@tf.function
	def train_generator(hm, hm_sketch):
		with tf.GradientTape() as tape:
			generated_hm = gen(hm_sketch)

			veracity_patch_gen = disc([hm_sketch, generated_hm])

			gen_loss = bce(tf.ones_like(veracity_patch_gen), veracity_patch_gen)
			l1_loss = l1(hm, generated_hm)

			#feature_extr_input = tf.concat([generated_hm, hm],axis=0)
			#content_loss, style_loss = sc_extr.losses(feature_extr_input)

			generator_loss = gen_loss + (100 * l1_loss)

		gradient = tape.gradient(generator_loss, gen.trainable_variables)
		gen_optimizer.apply_gradients(zip(gradient, gen.trainable_variables))

		metric_gen_loss(generator_loss)
		return generator_loss

	@tf.function
	def train_discriminator(hm, hm_sketch):
		with tf.GradientTape() as tape:
			generated_hm = gen(hm_sketch)

			veracity_patch_real = disc([hm_sketch, hm])
			veracity_patch_gen = disc([hm_sketch, generated_hm])

			real_loss = bce(tf.ones_like(veracity_patch_real), veracity_patch_real)
			gen_loss = bce(tf.ones_like(veracity_patch_gen), veracity_patch_gen)

			discriminator_loss = real_loss + gen_loss

		gradient = tape.gradient(discriminator_loss, disc.trainable_variables)
		disc_optimizer.apply_gradients(zip(gradient, disc.trainable_variables))

		metric_disc_loss(discriminator_loss)
		return discriminator_loss

	print(f'training started {datetime.now().strftime("%H:%M:%S")}')
	print(f'Iterations per epoch: {train_iter}')
	gen_loss = 0
	disc_loss = 0

	start_epoch = 1
	for epoch in range(start_epoch, epochs + 1):
		print(f'Epoch {epoch}')
		for it in range(train_iter):
			hm, hm_sketch = dataset.next()
			hm, hm_sketch = aug.augment(hm, hm_sketch)
			disc_loss = train_discriminator(hm, hm_sketch)

			gen_loss = train_generator(hm, hm_sketch)

		print(f'gloss:{metric_gen_loss.result()} dloss:{metric_disc_loss.result()}')

		with train_summary_writer.as_default():
			tf.summary.scalar('gen_loss', metric_gen_loss.result(), step=epoch)
			tf.summary.scalar('disc_loss', metric_disc_loss.result(), step=epoch)
			metric_gen_loss.reset_states()
			metric_disc_loss.reset_states()


		if epoch % model_save_freq_epochs == 0:
			print(f'Saving models to disk... time: {datetime.now().strftime("%H:%M:%S")}')
			gen.model.save((f'{logs_dir}/models/gen_{epoch}'))
			disc.model.save((f'{logs_dir}/models/disc_{epoch}'))


		if preview_epochs and epoch % preview_epochs == 0:
			tests = []
			tests.append(hm_sketch[0])

			for i in range(1, 4):
				sketch = tf.io.read_file(f'preview_input/input{i}.png')
				sketch = tf.io.decode_png(sketch, channels=1, dtype=tf.uint8)
				sketch = img_to_network(sketch)
				tests.append(sketch)

			for i, sketch in enumerate(tests):
				sketch = tf.expand_dims(sketch, axis=0)
				gan_hm = gen(sketch)[0]
				cv.imwrite(f'{logs_dir}/output_e{epoch}_{i}.png', restore_img(gan_hm).numpy())





if __name__ == '__main__':
	inputs = [
		#'/media/krzysztof/g/Earthshaper/out256/',
		#'/media/krzysztof/g/Earthshaper/out256_clip/',
		'inputs/'
	]

	experiment_label = '1'
	# save run on colab
	logs_dir=f'/content/train_logs/{experiment_label}'
	epochs=2000
	preview_epochs=5
	model_save_freq_epochs=1000


	shutil.rmtree(logs_dir, ignore_errors=True)

	try:
		os.mkdir(logs_dir, 0o777)
		os.mkdir(f"{logs_dir}/models", 0o777)
	except FileExistsError:
		pass

	print(f'Loading data from: {inputs}')
	ds, ds_len = load_dataset(inputs)

	train(ds, ds_len,
		epochs=epochs,
		preview_epochs=preview_epochs,
		model_save_freq_epochs=model_save_freq_epochs,
		logs_dir=logs_dir)


Loading data from: ['inputs/']
training started 13:13:41
Iterations per epoch: 59
Epoch 1
gloss:30.084659576416016 dloss:0.06326471269130707
Epoch 2
gloss:26.013343811035156 dloss:7.237133104354143e-05
Epoch 3
gloss:26.00662612915039 dloss:3.3560361771378666e-05
Epoch 4
gloss:26.004165649414062 dloss:1.8850974811357446e-05
Epoch 5
gloss:26.003734588623047 dloss:1.1911878573300783e-05
Epoch 6
gloss:26.003061294555664 dloss:8.118426194414496e-06
Epoch 7
gloss:26.00372886657715 dloss:5.832075203215936e-06
Epoch 8
gloss:26.003093719482422 dloss:4.3587333493633196e-06
Epoch 9
gloss:26.003725051879883 dloss:3.356716661073733e-06
Epoch 10
gloss:26.00260353088379 dloss:2.6497143608139595e-06
Epoch 11
gloss:26.00334930419922 dloss:2.134182068402879e-06
Epoch 12
gloss:26.002756118774414 dloss:1.747896931192372e-06
Epoch 13
gloss:26.003704071044922 dloss:1.4524757716571912e-06
Epoch 14
gloss:26.0020751953125 dloss:1.2222537861816818e-06
Epoch 15
gloss:26.003860473632812 dloss:1.0400545988886734e-

KeyboardInterrupt: ignored