In [None]:
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
import skimage.io
from segmentation import segmentation
import tensorflow.keras.backend as K
import glob
from segmentation import plotHist
import pickle

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
H = 448
W = 448
N_CH = 3
S_WEIGHT = 1e-6
TV_WEIGHT = 1e-8
ITERS = 1600
EXP_INTERS = 1200

SMALL_PATH = "../data/small/"
LARGE_APTH = "../data/large/"
FIGURE_PATH = "../data/figures/"

ORIG_SIM3_IMG_PATH = SMALL_PATH + "sim3.png"
plt.rcParams.update({'font.size':16})
PICKLE_PATH = "../data/figures/regen_imgs.p"
OUT_DIR = "../data/result/"
CLEANED_DIR = OUT_DIR + "cleaned/"
HIST_ORIG_DIR = OUT_DIR + "hist_ori/"
HIST_ORIG_RES = OUT_DIR + "hist_res/"

SAMPLE_FN = FIGURE_PATH + "sim_gen.png"
num_for_each = 3

sim_path1 = "sim1.png"
sim_path2 = "sim2.png"
sim_path3 = "sim3.png"
copper_path1 = "300.3.png"
copper_path2 = "300.9.png"
copper_path3 = "600.3.png"
copper_path4 = "600.9.png"
al_path1 = "200.3.png"
al_path2 = "600.3_Al.png"

images_for_gen = []
images_for_gen += [sim_path1]
images_for_gen += [sim_path2]
images_for_gen += [sim_path3]
images_for_gen += [copper_path1]
images_for_gen += [copper_path2]
images_for_gen += [copper_path3]
images_for_gen += [copper_path4] 
images_for_gen += [al_path1]
images_for_gen += [al_path2]
base_names = images_for_gen.copy()
result_images = [OUT_DIR + p for p in images_for_gen]
images_for_gen = [SMALL_PATH + p for p in images_for_gen]
images_for_compare = images_for_gen.copy()
images_for_gen *= num_for_each

gen_titles = ["Simulation", "Simulation", "Simulation", "Copper (300, 3)", "Copper (300, 9)", "Copper (600, 3)", "Copper (600, 9)", "Aluminium (200, 3)", "Aluminium (600, 3)"]

def readImages(filenames):
    from skimage.io import imread
    imgs = []
    for file in filenames:
        img = imread(file)
        imgs.append(img)
    return imgs


In [None]:
randimg = np.random.uniform(0, 255, (W, H, N_CH)).astype(np.uint8)
skimage.io.imwsave("../data/figures/rand_img.png", randimg)

In [None]:
from TextureSynthesistf2 import generateImageFromStyle
gen, loss, imgs = generateImageFromStyle(H, W, N_CH, S_WEIGHT, TV_WEIGHT, style_path=ORIG_SIM3_IMG_PATH, iters=ITERS, get_mid_res=True)

In [None]:
fig = plt.figure(figsize=(36, 5))
fig.set_dpi(300)
counter = 1
for i in range(len(loss)):
    fig.add_subplot(1, int(ITERS / 200) + 1, counter)
    plt.axis('off')
    buf = "n={:d}, loss={:.2f}"
    plt.title(buf.format(i * 200, loss[i].numpy()))
    plt.imshow(imgs[i])
    counter+=1

plt.savefig(FIGURE_PATH + 'iters_adam.eps', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
from TextureSynthesis import generateImageFromStyle
gen_2, loss_2, imgs_2 = generateImageFromStyle(H, W, N_CH, S_WEIGHT, TV_WEIGHT, style_path=ORIG_SIM3_IMG_PATH, iters=ITERS, get_mid_res=True)
skimage.io.imsave(SAMPLE_FN, gen_2)

In [None]:
fig = plt.figure(figsize=(36, 5))
fig.set_dpi(300)
counter = 1
for i in range(len(loss_2)):
    fig.add_subplot(1, int(ITERS / 200) + 1, counter)
    plt.axis('off')
    buf = "n={:d}, loss={:.2f}"
    plt.title(buf.format(i * 200, loss_2[i]))
    plt.imshow(imgs_2[i])
    counter+=1

plt.savefig(FIGURE_PATH + '/iters_lbfgsb.eps', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
cleaned, mid_res, __ = segmentation(SAMPLE_FN, get_mid_res=True)
titles = ["Generated Microstructure", "Gradient Magnitude", "Thresholding", "Skeletonization", "Overlay Visualization"]
fig = plt.figure(figsize=(23, 5))
fig.set_dpi(300)
counter = 1
for i in range(len(mid_res)):
    fig.add_subplot(1, len(mid_res), counter)
    plt.axis('off')
    plt.title(titles[i])
    plt.imshow(mid_res[i], cmap=None)
    counter+=1

plt.savefig(FIGURE_PATH + 'seg.eps', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
from TextureSynthesis import generateMultiImagesFromStyle
generated_imgs = generateMultiImagesFromStyle(images_for_gen, H=H, W=W, C=N_CH, S_weight=S_WEIGHT, V_weight=TV_WEIGHT, iters=EXP_INTERS)
pickle.dump(generated_imgs, open(PICKLE_PATH, "wb" ) )


In [None]:
generated_imgs = pickle.load( open(PICKLE_PATH, "rb" ) )
orig_imgs = readImages(images_for_compare)
images2plot =  orig_imgs + generated_imgs

fig = plt.figure(figsize=(40, 18))
fig.set_dpi(300)
counter = 1
for i in range(len(images2plot)):
    fig.add_subplot(4, int(len(images2plot) / 4), counter)
    if counter == 1:
        plt.text(-60,200, "Original", size=18, rotation="vertical", va="center")
    if counter == int(len(images2plot) / 4) + 1:
        plt.text(-60,200, "Generated", size=18, rotation="vertical", va="center")
    if i < len(gen_titles):
        plt.title(gen_titles[i])
    plt.axis('off')
    plt.imshow(images2plot[i], cmap=None)
    counter+=1

plt.savefig(FIGURE_PATH + 'F5.pdf', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
from TextureSynthesis import generateMultiImagesFromStyle
generated = generateMultiImagesFromStyle(images_for_compare, H=H, W=W, C=N_CH, S_weight=S_WEIGHT, V_weight=TV_WEIGHT * 5, iters=EXP_INTERS, output_dir=OUT_DIR)

In [None]:
filenames = glob.glob(OUT_DIR + "*.png")
filenames_str = " ".join(filenames)
!python segmentation.py -i {filenames_str} -o {CLEANED_DIR}

In [None]:
vis_orig = []
vis_res = []

fig = plt.figure(figsize=(60, 5))
fig.set_dpi(300)
counter = 1
for fn in base_names:
    fig.add_subplot(1, len(base_names), counter)
    if counter == 1:
        plt.xlabel("Grain size")
        plt.ylabel("Density")
        # plt.text(-1800, 0.075, "Original", rotation="vertical", size=18, va="center")
    plt.title(gen_titles[counter - 1])
    cleaned, mid_res, hist = segmentation(SMALL_PATH + fn, get_mid_res=True)
    vis_orig.append(mid_res[4])
    plotHist(hist[0], hist[1], True)
    cleaned, mid_res, hist = segmentation(OUT_DIR + fn, get_mid_res=True)
    plotHist(hist[0], hist[1], False)
    vis_res.append(mid_res[4])
    if counter==1:
        plt.legend(["Orignal", "Generated"])
    counter+=1

plt.savefig('../data/figures/hist.eps', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
fig = plt.figure(figsize=(36, 8))
fig.set_dpi(300)
counter = 1
for i in range(len(vis_orig)):
    fig.add_subplot(2, len(vis_orig), counter)
    if i == 0:
        plt.text(-60, 200, "Original", size=18, rotation="vertical", va="center")
    plt.axis('off')
    plt.title(gen_titles[i])
    plt.imshow(vis_orig[i], cmap=None)
    counter+=1

for i in range(len(vis_res)):
    fig.add_subplot(2, len(vis_orig), counter)
    plt.axis('off')
    plt.imshow(vis_res[i], cmap=None)
    if i == 0:
        plt.text(-60, 200, "Generated", size=18, rotation="vertical", va="center")
    counter+=1
plt.savefig('../data/figures/seg_vis.eps', transparent=True, bbox_inches='tight', pad_inches=0)
plt.show()