In [None]:
import os
import glob
import sys
import numpy as np
import pickle
import tensorflow as tf
import PIL
import ipywidgets
import io

""" make sure this notebook is running from root directory """
while os.path.basename(os.getcwd()) in ('notebooks', 'src'):
    os.chdir('..')

import src.tl_gan.generate_x_ray_image as generate_image
import src.tl_gan.feature_axis as feature_axis
import src.tl_gan.get_normal_images as get_normal_images

In [None]:
""" load feature directions """
feature_layout = [
    [0, ],
    [1, ],
    [2, ],
    [3, ]
]
path_feature_direction = './asset_results/pg_gan_x_ray_integrated_norm_feature_direction_5/'
pathfile_feature_direction = glob.glob(os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1]

with open(pathfile_feature_direction, 'rb') as f:
    feature_direction_name = pickle.load(f)

feature_direction = feature_direction_name['direction']
feature_name = feature_direction_name['name']
num_feature = feature_direction.shape[1]

In [None]:
""" start tf session and load GAN model """

# path to model code and weight
path_pg_gan_code = './src/model/pggan'
path_model = './asset_model/x-ray_integrated_20190218_network-snapshot-014000.pkl'
sys.path.append(path_pg_gan_code)


""" create tf session """
yn_CPU_only = False

if yn_CPU_only:
    config = tf.ConfigProto(device_count = {'GPU': 0}, allow_soft_placement=True)
else:
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

sess = tf.InteractiveSession(config=config)

try:
    with open(path_model, 'rb') as file:
        G, D, Gs = pickle.load(file)
except FileNotFoundError:
    print('before running the code, download pre-trained model to project_root/asset_model/')
    raise

len_z = Gs.input_shapes[0][1]
#z_sample = np.random.randn(len_z)
z_sample = get_normal_images.get_single_image() # z sample as normal image
x_sample = generate_image.gen_single_img(z_sample, Gs=Gs)
z_init = z_sample.copy()

In [None]:
from io import StringIO, BytesIO
from chestPA.cam_py3 import plot_cam

def bytes_to_img(imgBytes):
    stream = BytesIO(imgBytes)
    img = PIL.Image.open(stream)
    return np.array(img)

def img_to_bytes(x_sample):
    imgObj = PIL.Image.fromarray(x_sample)
    imgByteArr = io.BytesIO()
    imgObj.save(imgByteArr, format='PNG')
    imgBytes = imgByteArr.getvalue()
    return imgBytes

w_img = ipywidgets.widgets.Image(value=img_to_bytes(x_sample), format='png', width=512, height=512)

In [None]:
class GuiCallback(object):
    #     latents = z_sample
    def __init__(self):
        self.latents = z_sample
        self.feature_direction = feature_direction
        self.feature_lock_status = np.zeros(num_feature).astype('bool')
        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
        
    def random_gen(self, event):
        #self.latents = np.random.randn(len_z)
        self.latents = get_normal_images.get_single_image() 
        self.update_img()
    
    def modify_along_feature(self, event, idx_feature, step_size=0.01):
        self.latents += self.feature_directoion_disentangled[:, idx_feature] * step_size
        self.update_img()

    def set_feature_lock(self, event, idx_feature, set_to=None):
        if set_to is None:
            self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature])
        else:
            self.feature_lock_status[idx_feature] = set_to
        self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx(
            self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
    
    def update_img(self):   
        x_sample = generate_image.gen_single_img(z=self.latents, Gs=Gs)
        x_byte = img_to_bytes(x_sample)
        w_img.value = x_byte
        
        #w_str.value = get_pred_str(x_sample)

In [None]:
guicallback = GuiCallback()
step_sizes = [1, 0.4 ,0.4, 0.4] # 0,0,0,8

In [None]:
def create_button(idx_feature, width=200, height=40):
    """ function to built button groups for one feature """
    w_name_toggle = ipywidgets.widgets.ToggleButton(
        value=False, description=feature_name[idx_feature],
        tooltip='{}, Press down to lock this feature'.format(feature_name[idx_feature]),
        layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                 width='{:.0f}px'.format(width),
                                 margin='2px 2px 2px 2px')
    )
    w_neg = ipywidgets.widgets.Button(description='-',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    w_pos = ipywidgets.widgets.Button(description='+',
                                      layout=ipywidgets.Layout(height='{:.0f}px'.format(height/2), 
                                                               width='{:.0f}px'.format(width/2),
                                                               margin='1px 1px 5px 1px'))
    
    w_name_toggle.observe(lambda event: 
                      guicallback.set_feature_lock(event, idx_feature))
    w_neg.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=-1 * step_sizes[idx_feature]))
    w_pos.on_click(lambda event: 
                     guicallback.modify_along_feature(event, idx_feature, step_size=+1 * step_sizes[idx_feature]))
    
    button_group = ipywidgets.VBox([w_name_toggle, ipywidgets.HBox([w_neg, w_pos])],
                                  layout=ipywidgets.Layout(border='1px solid gray'))
    
    return button_group

In [None]:
list_buttons = []
for idx_feature in range(num_feature):
    list_buttons.append(create_button(idx_feature))

yn_button_select = True
def arrange_buttons(list_buttons, yn_button_select=True, ncol=4):
    num = len(list_buttons)
    if yn_button_select:
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox([list_buttons[item] for item in row]) for row in feature_layout])
    else:
        layout_all_buttons = ipywidgets.VBox([ipywidgets.HBox(list_buttons[i*ncol:(i+1)*ncol]) for i in range(num//ncol+int(num%ncol>0))])
    return layout_all_buttons

In [None]:
# w_button.on_click(on_button_clicked)
guicallback.update_img()
w_button_random = ipywidgets.widgets.Button(description='random chest x-ray', button_style='success',
                                           layout=ipywidgets.Layout(height='60px', 
                                                               width='165px',
                                                               margin='1px 1px 5px 1px'))
w_button_random.on_click(guicallback.random_gen)
w_box = ipywidgets.HBox([w_img, 
                         ipywidgets.VBox([w_button_random, 
                                         arrange_buttons(list_buttons, yn_button_select=True)])
                        ], layout=ipywidgets.Layout(height='512px', width='700px')
                       )

print('press +/- to adjust feature, toggle feature name to lock the feature')
display(w_box)
#display(w_str)

In [None]:
arr_img = bytes_to_img(w_img.value)
plot_cam(dataset_test=[arr_img], plot_fig=True, show_cam=True)