In [1]:
#COLAB - mount drive
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)


Mounted at /content/gdrive


In [2]:
#@title LOAD LIBRARIES AND MODELS { output-height: 10, form-width: "200px", display-mode: "form" }
#@markdown Enter your code root directory on gdrive here:


#AI PAINTER MAIN PROGRAM WITH UI
#by JON THUM
#AI MSc PROJECT CITY UNIVERSITY


#COLAB - root directory
ROOT_DIR = '/content/gdrive/My Drive/AiPainter/' #@param {type:"string"}

#CHECK CUDA MEMORY
!nvidia-smi -L
import nvidia_smi
def GPU_memory():
    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    print("Total memory:", info.total)
    print("Free memory:", info.free)
    nvidia_smi.nvmlShutdown()
GPU_memory()

#CHECK RAM
from psutil import virtual_memory
ram = virtual_memory().total / 1e9
print('RAM {:.1f} gigabytes'.format(ram))


### STYLEGAN2###

#COLAB - import .py files
import sys
SG2_DIR = ROOT_DIR + 'lib/stylegan2/'
sys.path.append(SG2_DIR)

#CHECK TF SETTINGS
%tensorflow_version 1.x
import tensorflow as tf
#print('Tensorflow version: {}'.format(tf.__version__) )

#IMPORT LIBRARIES
import dnnlib
import dnnlib.tflib as tflib
import pickle

Gmodels = []

#LIMIT TF MEMORY USE
cfg = dict()
cfg["gpu_options.per_process_gpu_memory_fraction"]=0.2
tflib.init_tf(cfg)

#DOWNLOAD CUSTOM TRAINED SG2 MODELS
print('Downloading SG2 models ..')
!gdown --id 1-4vOSwOjEn6bRkirUDm9pzPTI1mr10Rb -q
!gdown --id 1-5CurIFM-sEprbytN0bWtlZ8HxXeESXY -q
!gdown --id 1-FJoffvHQffJtlzz6t725PJnbvFQLJdR -q
!gdown --id 1-HJI2gegpAXJGi7odlxKVKvESTGCtadZ -q
!gdown --id 1-K2oUvzk3AphYq1QoymrHoDwyp_JiP5I -q

#LOAD SG2 MODELS
print('Loading SG2 models ..')
with open('networkA_Gs.pkl', 'rb') as f:
    Gs = pickle.load(f)
Gmodels.append(Gs)
print('Loaded model A')

with open('networkB_Gs.pkl', 'rb') as f:
    Gs = pickle.load(f)
Gmodels.append(Gs)
print('Loaded model B')

with open('networkC_Gs.pkl', 'rb') as f:
    Gs = pickle.load(f)
Gmodels.append(Gs)
print('Loaded model C')

with open('networkD_Gs.pkl', 'rb') as f:
    Gs = pickle.load(f)
Gmodels.append(Gs)
print('Loaded model D')

with open('networkE_Gs.pkl', 'rb') as f:
    Gs = pickle.load(f)
Gmodels.append(Gs)
print('Loaded model E')
print('SG2 models loaded')


### STYLE TRANSFER ###

#COLAB DIRECTORY
ST_DIR = ROOT_DIR + 'custom/styletransfer/'
sys.path.append(ST_DIR)

#DOWNLOAD LANDMARKS DETECTION MODEL
!gdown --id 1ye96bZvOJxiChg7cApGZ5Op2Wbhrr61B -q

#LOAD CUSTOM STYLE TRANSFER MODULES
import styletransfer
import utils
from styletransfer import *
from utils import *
 
#GET DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)
#torch.cuda.empty_cache()

#LOAD CNN MODEL   
print('Loading ST model ..')
import torchvision.models as models
cnn = models.vgg19(pretrained=True).features.to(device).eval() 
print('ST model loaded')

#TORCH SETTINGS (set determnistic=true for debug, repeatable results but SLOW)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False


### UPRES WITH ESRGAN ###

#COLAB DIRECTORY
ESR_DIR = ROOT_DIR + 'lib/esrgan/'
sys.path.append(ESR_DIR)

import RRDBNet_arch as arch

#DOWNLOAD MODEL
print('Downloading ESRGAN model ..')
!gdown --id 1nZJNl0Go87MPV88sHu6veARGfDdOZA3t -q

#LOAD MODEL
device = torch.device('cuda') 
esrgan = arch.RRDBNet(3, 3, 64, 23, gc=32)
esrgan.load_state_dict(torch.load('RRDB_ESRGAN_x4.pth'), strict=True)
esrgan.eval()
esrgan = esrgan.to(device)

#PROCESS
def upresx4(img):
    img = img * 1.0 / 255
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img = img.unsqueeze(0).to(device)
    #print('Processing ...')

    with torch.no_grad():
        upres = esrgan(img).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    upres = np.transpose(upres[[2, 1, 0], :, :], (1, 2, 0))
    upres = (upres * 255.0).round()  
    torch.cuda.empty_cache() 

    #MEMORY CHECK
    #print(float(torch.cuda.memory_allocated())/1e9, float(torch.cuda.max_memory_allocated())/1e9, \
      #float(torch.cuda.memory_reserved())/1e9, float(torch.cuda.max_memory_reserved())/1e9)
    #torch._C._cuda_resetPeakMemoryStats(torch.cuda.current_device())
    #GPU_memory()

    return upres


### GENERAL ###

#IMPORT IPYWIDGETS AND SYS LIBRARIES
import ipywidgets as widgets
from ipywidgets import interact
import os
import time
import warnings
warnings.filterwarnings('ignore')



GPU 0: Tesla V100-SXM2-16GB (UUID: GPU-a18582b0-a68c-9c70-8504-d6eb0fa612f2)
Total memory: 16914055168
Free memory: 16913989632
RAM 27.4 gigabytes
TensorFlow 1.x selected.
Downloading SG2 models ..
Loading SG2 models ..
Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Compiling... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Compiling... Loading... Done.
Loaded model A
Loaded model B
Loaded model C
Loaded model D
Loaded model E
SG2 models loaded
Device: cuda
Loading ST model ..


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))


ST model loaded
Downloading ESRGAN model ..


In [5]:
#@title AI PAINTER { vertical-output: true, output-height: 20, form-width: "200px", display-mode: "form" }

#AI PAINTER UI
#by JON THUM


#INITIAL PARAMETERS
SEED = 11             #initial painting on display
FILTER_SEED = 19      #random shuffle for filtering 

#SAVED IMAGE RESOLUTION
SIZE = ['1K', '2K', '3K', '4K']

#STYLES
STYLE = ['GENERIC', 'MODERN', 'IMPRESSIONIST', 'BAROQUE', 'SKETCH']

#LOAD IMAGE ANALYSIS FILES
datafiles = []
datafiles.append(np.load(ROOT_DIR + 'custom/imageanalysis/ImageAnalysis_A.npy'))
datafiles.append(np.load(ROOT_DIR + 'custom/imageanalysis/ImageAnalysis_B.npy'))
datafiles.append(np.load(ROOT_DIR + 'custom/imageanalysis/ImageAnalysis_C.npy'))
datafiles.append(np.load(ROOT_DIR + 'custom/imageanalysis/ImageAnalysis_D.npy'))
datafiles.append(np.load(ROOT_DIR + 'custom/imageanalysis/ImageAnalysis_E.npy'))

#MODEL ARGS
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs = dnnlib.EasyDict()
Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs.minibatch_size = 1


#GENERATE IMAGE FROM MODEL
def generate_painting(seed, artstyle, truncation, noise):
    rnd = np.random.RandomState(seed)
    latent = rnd.randn(1, Gmodels[artstyle].input_shape[1]) 
    images = Gmodels[artstyle].run(latent, None, truncation_psi=truncation, randomize_noise=noise, output_transform=fmt)
    return images[0]

#GENERATE INITIAL IMAGE
im_init = generate_painting(SEED, 0, 1.0, False)

#IMAGE DISPLAY WIDGET
im_select = widgets.Image(value=image_to_bytes(im_init), width=600, height=600, layout=widgets.Layout(margin='10px 0px 10px 20px'))

#IMAGE SAVE WIDGETS
b_save_p = widgets.Button(description='Save', layout=widgets.Layout(margin='10px 0px 0px 230px'))
w_size_p = widgets.Dropdown(options=[('1K', 0), ('2K', 1), ('3K', 2), ('4K', 3)], value=0, description='', 
                                  layout=widgets.Layout(width='10%', height='27px', margin='10px 0px 0px 5px'))

#INTERACTIVE SELECTION WIDGETS 
w_filt = widgets.Dropdown(options=[('SELECT BY ID', False), ('SELECT BY FILTERING', True)], value=False, description='MODE: ',
                              layout=widgets.Layout(margin='0px 0px 40px 0px'))
w_artstyle = widgets.Dropdown(options=[(STYLE[0], 0), (STYLE[1], 1), (STYLE[2], 2), (STYLE[3], 3), (STYLE[4], 4)], value=1, description='STYLE: ', 
                              layout=widgets.Layout(margin='0px 0px 10px 0px'))
w_seed = widgets.IntSlider(min=0, max=5000, step=1, value=SEED, description='ID #: ', continuous_update=False, 
                              layout=widgets.Layout(width='95%', margin='0px 0px 0px 5px'))
w_sample = widgets.IntSlider(value=1, min=1, max=5001, step=1, description='FILTER #:', continuous_update=False,
                             layout=widgets.Layout(width='95%', margin='0px 0px 0px 5px'))

#INTERACTIVE FINE-TUNING WIDGETS 
info4 = widgets.HTML(value="<b>&nbsp; &nbsp; &nbsp; &nbsp; FINE-TUNING</b>", placeholder=' ', description=' ',
                                                                    layout=widgets.Layout(margin='25px 0px 0px 0px'))
w_truncation = widgets.FloatSlider(min=-1, max=3, step=0.01, value=1, description='Experimental: ', continuous_update=False)
w_add_noise = widgets.ToggleButton(description='Randomise', layout=widgets.Layout(width='50%', height='30px', margin='5px 0px 10px 90px'))
w_noise = widgets.Dropdown(options=[('ON', True), ('OFF', False)], value=False, description='Random: ')

#INTERACTIVE FILTER WIDGETS 
info1 = widgets.HTML(value='<b>&nbsp; &nbsp; &nbsp; FILTERING</b>', placeholder=' ', description=' ',
                                                                layout=widgets.Layout(margin='10px 0px 0px 0px'))
w_det = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Detail:', continuous_update=False)
w_bri = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Brightness:', continuous_update=False)
w_con = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Contrast:', continuous_update=False)
w_sat = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Saturation:', continuous_update=False)
w_var = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Col Variety:', continuous_update=False)
w_R = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Red:', continuous_update=False)
w_Y = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Yellow:', continuous_update=False)
w_G = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Green:', continuous_update=False)
w_C = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Cyan:', continuous_update=False)
w_B = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Blue:', continuous_update=False)
w_M = widgets.IntRangeSlider(value=[0, 100], min=0, max=100, step=1, description='Magenta:', continuous_update=False)
w_genre = widgets.Dropdown(options=[('ALL', 0), ('NON-PORTRAIT', 1), ('PORTRAIT', 2)], value=0, description='Genre: ')


#BUTTON COLOURS
w_sample.style.handle_color = 'lightblue'
w_seed.style.handle_color = 'white'
w_truncation.style.handle_color = 'lightblue'
b_save_p.style.button_color = 'lightblue'

#MENU LAYOUT
box = widgets.HBox([b_save_p, w_size_p])
box1 = widgets.VBox([w_filt, w_artstyle, info1, w_det, \
                     w_var, w_sat, w_R, w_Y, w_G, w_C, w_B, w_M, w_bri, w_con, \
                     w_genre, info4, w_truncation, w_noise, w_add_noise])
box2 = widgets.VBox([w_seed, w_sample])
box3 = widgets.VBox([box2, im_select, box])
ui1 = widgets.HBox([box1, box3])


#GENERATE SAMPLE
def display_sample(seed, truncation, artstyle, noise, add_noise, det, bri, con, \
                   sat, var, R, Y, G, C, B, M, sample, filt, genre):
    global painting

    if (filt):
        info1.value = '<b>&nbsp; &nbsp; &nbsp; FILTERING ON</b>'
        w_seed.style.handle_color = 'white'
        w_sample.style.handle_color = 'lightblue'

        #FILTER VALUES
        f = ([det[0], det[1]], [bri[0], bri[1]], [con[0], con[1]], [sat[0], sat[1]], [var[0], var[1]],
            [R[0], R[1]], [Y[0], Y[1]], [G[0], G[1]], [C[0], C[1]], [B[0], B[1]], [M[0], M[1]])

        #APPLY GENRE FILTER
        df = datafiles[artstyle].copy() 
        if (genre):
            if (genre==1):
                df = df[ df[:,11]<1 ]
            else:
                df = df[ df[:,11]>0 ]

        #APPLY REMAINING FILTERS
        for i in range(0,11):
            #MAX VAL NEEDS TO INCLUDE VALUES OVER 100
            if f[i][1]==100:
                f[i][1] = 999
            #FILTER ALGORITHM
            df = df[np.logical_and( df[:,i]>=f[i][0], df[:,i]<=f[i][1] )]
          
        #NUMBER OF PICKS IN RANGE
        num = df.shape[0]    
        info1.value = '<b>&nbsp; &nbsp; &nbsp; FILTERING {}</b>'.format(num)

        if (num):
            #IDX OF PICKS
            pick = df[:,-1]

            #SHUFFLE
            shuffle = pick.copy()
            np.random.seed(FILTER_SEED)
            np.random.shuffle(shuffle)

            #GET SAMPLE
            if(sample<num):
                seed_id = shuffle[sample-1]
            else:
                seed_id = shuffle[num-1]

            #FEEDBACK      
            print('\n  ID: {}   Style: {}   Filtering: ON   Samples: {}\n'.format(seed_id, STYLE[artstyle], num))
            w_seed.value = seed_id
        else:
            seed_id = seed
            print('\n  ID: {}   Style: {}   NO SAMPLES IN RANGE\n'.format(seed_id, STYLE[artstyle]))

        w_sample.max = max(1, num)

    else:
        info1.value = '<b>&nbsp; &nbsp; &nbsp; FILTERING OFF</b>'
        seed_id = seed
        print('\n  ID: {}   Style: {}   Filtering: OFF\n'.format(seed_id, STYLE[artstyle]))   
        w_seed.style.handle_color = 'lightblue'
        w_sample.style.handle_color = 'white'    

    #GENERATE IMAGE FROM SEED
    painting = generate_painting(seed_id, artstyle, truncation, noise)
    im_select.value = image_to_bytes(painting)


#MENU DISPLAY
interactive_output = widgets.interactive_output(display_sample, {'seed': w_seed, 'truncation': w_truncation, 'artstyle': w_artstyle,
                                                  'noise': w_noise, 'add_noise': w_add_noise,  
                                                  'det': w_det, 'bri': w_bri, 'con': w_con, 'sat': w_sat, 'var': w_var,
                                                  'R': w_R, 'Y': w_Y, 'G': w_G, 'C': w_C, 'B': w_B, 'M': w_M, 
                                                  'sample': w_sample, 'filt': w_filt, 'genre': w_genre})
#out.layout.height = '600px'
display(ui1, interactive_output)


#SAVE IMAGE
def save_p(b):
    b_save_p.style.button_color = 'red'

    size = SIZE[w_size_p.value]
    name = 'P_{}_{}-{}_{}.jpg'.format(STYLE[w_artstyle.value][:3], w_seed.value, int(w_truncation.value*100), size)
    path = ROOT_DIR + 'images/results/' + name

    #UPRES IF 2K/3K/4K REQUIRED
    if (w_size_p.value):
        b_save_p.description = 'Saving ..'
        
        #output = upresx4(painting)
        quarters, x, y, halfx, halfy = make_quarters(painting, 10)
        for i in range(4):
            quarters[i] = upresx4(quarters[i])
        output = combine_quarters(quarters, 4*x, 4*y, 4*halfx, 4*halfy, 40)

        output = Image.fromarray(output.astype(np.uint8))
        if(w_size_p.value == 1):
            output = output.resize((2048, 2048))
        elif(w_size_p.value == 2):
            output = output.resize((3072, 3072))
    else:
        time.sleep(0.8)
        output = Image.fromarray(painting)

    output.save(path)  
    print('SAVED {}'.format(name))
    b_save_p.description = 'Save'

    b_save_p.style.button_color = 'lightblue'     
    

b_save_p.on_click(save_p)





HBox(children=(VBox(children=(Dropdown(description='MODE: ', layout=Layout(margin='0px 0px 40px 0px'), options…

Output()

SAVED P_MOD_11-100_1K.jpg


In [6]:
#@title AI STYLE TRANSFER { vertical-output: true, output-height: 20, form-width: "200px", display-mode: "form" }

#STYLE TRANSFER UI
#by JON THUM


#PARAMETERS
NORM = (None, None, 1, 2)   #Normalisation = (style type, content type, gamma, sharpness) 
                            #normalisation types = 'MATCH' 'CONTRAST' 'HISTO' 'ADAPT' 
HI_RES = True               #Process hi res output
PLOTS = True                #Generate loss plots

#DIRECTORIES
CONTENT_DIR = ROOT_DIR + 'images/content/' 
OUT_DIR = ROOT_DIR + 'images/results/' 
PLOT_DIR = ROOT_DIR + 'images/plots/'

#LOW RES NETWORK PARAMETERS
net1 = {                     #Network1 parameters
    'EPOCHS':                120,                           #Max number of epochs
    'SIZE':                  512,                           #Image output width (keeping content aspect ratio)
    'TILE':                  0.5,                           #Adds detail to output, for no tiling set to 1
    'NOISE':                 [0.0, 0.0, 1.0],               #Background noise, foreground noise, noise range 0->1
    'FG_CONTENT_WEIGHT':     2,                            #Additional content weight through mask
    'GLOBAL_STYLE_WEIGHT':   1e6,                           #Global style weight
    'STYLE_WEIGHTS':         [1, 1, 1, 1, 1],               #Relative style weights layers 1->5
    'STYLE_LAYERS_SL1':      ['conv1_1','conv2_1','conv3_1','conv4_1','conv5_1'], #Style layers with SL1 loss 
    'STYLE_LAYERS_MSE':      [], #Style layers with MSE loss 
    'CONTENT_LAYERS':        ['conv4_2'],                   #Content layer with MSE loss
    'GLOBAL_CONTENT_WEIGHT': 1,                             #Global content weight
    'CONTENT_WEIGHTS':       [1],                           #Relative content weights
    'AVE_POOLING':           True,                          #AVE_POOLING fast, MAX POOLING better quality
    'CONV':                  0.99                          #Convergence factor 0.95->0.99     
    }

#HI RES NETWORK PARAMETERS
net2 = {                     #Network2 parameters
    'EPOCHS':                60,
    'SIZE':                  1024, # max 1560(allow) 1710(0.4) 1860(0.3) 1990(0.2) 2110(0.1)
    'TILE':                  1.0,
    'NOISE':                 [0.0, 0.0, 1.0],
    'FG_CONTENT_WEIGHT':     0,
    'GLOBAL_STYLE_WEIGHT':   1e6,
    'STYLE_WEIGHTS':         [1, 1, 1, 1, 1],
    'STYLE_LAYERS_SL1':      ['conv1_1','conv2_1','conv3_1','conv4_1','conv5_1'],
    'STYLE_LAYERS_MSE':      [],
    'CONTENT_LAYERS':        ['conv4_2'],
    'GLOBAL_CONTENT_WEIGHT': 1,
    'CONTENT_WEIGHTS':       [1], 
    'AVE_POOLING':           True,
    'CONV':                  0.97                           
    }
    
#FEATURES DICTIONARY
FEATURES = {0:'Face', 1:'Features', 2:'Eyes'} 


#WIDGETS

#IMAGE DISPLAY WIDGETS
im_init = Image.open(CONTENT_DIR + 'avril.jpg')
im_black = Image.new('L', (256,256))
im_content = widgets.Image(value=PIL_to_bytes(im_init), width=256, height=256, 
                          layout=widgets.Layout(margin='10px 0px 0px 45px'))
im_mask = widgets.Image(value=PIL_to_bytes(im_black), width=256, height=256, 
                          layout=widgets.Layout(margin='30px 0px 0px 45px'))
im_transfer = widgets.Image(value=PIL_to_bytes(im_black), width=768, height=768, 
                          layout=widgets.Layout(margin='10px 0px 0px 20px'))

#IMAGE FILE BROWSER
contentlist=sorted(os.listdir(CONTENT_DIR))
w_file = widgets.Dropdown(options=contentlist, value=contentlist[16], ensure_option=True, description='File:', disabled=False, 
                          layout=widgets.Layout(margin='0px 0px 0px 0px'))

#INTERACTIVE WIDGETS
info9 = widgets.HTML(value='<b>CONTENT</b>', placeholder=' ', description=' ', 
                          layout=widgets.Layout(margin='10px 120px 0px 40px'))
b_process = widgets.Button(description='PROCESS', 
                          layout=widgets.Layout(width='50%', height='30px', margin='0px 0px 0px 50px'))
info10 = widgets.HTML(value='process time', placeholder=' ', description=' ', 
                          layout=widgets.Layout(width='50%', height='30px', margin='0px 0px 0px 0px'))
info11 = widgets.HTML(value="<b>&nbsp; &nbsp; &nbsp; &nbsp; FINE-TUNING</b>", placeholder=' ', description=' ',
                          layout=widgets.Layout(margin='30px 0px 0px 0px'))
w_abstraction = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.75, description='Abstraction: ', continuous_update=False)
w_tex_scale = widgets.FloatSlider(min=0.25, max=3, step=0.01, value=1.0, description='Tex Scale: ', continuous_update=False)
w_strength =  widgets.FloatSlider(min=0.01, max=2, step=0.01, value=1.0, description='Tex Strength: ', continuous_update=False)
info12 = widgets.HTML(value="<b>&nbsp; &nbsp; FEATURE CONTROL</b>", placeholder=' ', description=' ',
                          layout=widgets.Layout(margin='30px 0px 0px 0px'))
w_preservation = widgets.FloatSlider(min=0, max=10, step=0.1, value=1.0, description='Preservation: ', continuous_update=False)
w_feature = widgets.Dropdown(options=[('FACE', 0), ('EYES/NOSE/MOUTH', 1), ('EYES', 2)], value=0, description='Feature: ')

#ORIGINAL COLOUR RESTORATION
info13 = widgets.HTML(value="<b>&nbsp; &nbsp; &nbsp; &nbsp; INTERACTIVE</b>", placeholder=' ', description=' ',
                          layout=widgets.Layout(margin='30px 0px 0px 0px'))
w_colswap = widgets.FloatSlider(min=0, max=1, step=0.01, value=0.0, description='Original Col: ', continuous_update=False,
                                layout=widgets.Layout(margin='0px 0px 20px 0px'))

#IMAGE SAVE WIDGETS
b_save_st = widgets.Button(description='Save', layout=widgets.Layout(margin='0px 0px 0px 20px'))
w_size_st = widgets.Dropdown(options=[('1K', 0), ('2K', 1), ('3K', 2), ('4K', 3)], value=0, description='', 
                                  layout=widgets.Layout(width='10%', height='27px', margin='0px 100px 0px 5px'))

#COLOURS
b_process.style.button_color = 'lightblue'
b_save_st.style.button_color = 'lightblue'
w_abstraction.style.handle_color = 'lightblue'
w_tex_scale.style.handle_color = 'lightblue'
w_strength.style.handle_color = 'lightblue'
w_preservation.style.handle_color = 'lightblue'
w_colswap.style.handle_color = 'lightblue'

#MENU LAYOUT
box5 = widgets.HBox([b_process, info10, b_save_st, w_size_st])
box6 = widgets.VBox([info9, im_content, info11, w_abstraction, w_tex_scale, 
                     w_strength, info12, w_preservation, w_feature, im_mask, info13])
box7 = widgets.VBox([box5, im_transfer])
ui2 = widgets.HBox([box6, box7])

#FIX LAYOUT SIZE
#ui2.layout.height = '1100px'
#ui2.layout.width = '1350px'


#RUN STYLE TRANSFER
def style_transfer(abstraction, tex_scale, strength, preservation, feature):
    global transfer

    start = time.process_time()

    #WIDGET UPDATES
    b_process.style.button_color = 'red'  
    b_process.description = 'processing ..'
    info10.value='Time: {:.1f} secs'.format(time.process_time() - start)#'time ..'
    w_colswap.value = 0
    net1['NOISE'] = [abstraction, abstraction/(1+preservation/5), 0.05]
    net1['TILE'] = tex_scale 
    net2['TILE'] = (3 + tex_scale)/4 
    net1['GLOBAL_STYLE_WEIGHT'] = 1e4*10**(strength*2)
    net2['GLOBAL_STYLE_WEIGHT'] = (3e5 + 1e4*10**(strength*2))/4
    net1['FG_CONTENT_WEIGHT'] = 1 + preservation
    net2['FG_CONTENT_WEIGHT'] = 1 + preservation

    #FIX RANDOMNESS FOR CONSISTENCY
    np.random.seed(999) 
    torch.cuda.manual_seed(999)
    
    #GET STYLE IMAGE
    style = Image.fromarray(painting)

    #GET CONTENT IMAGE 
    content = ui_content

    #NORMALISE IMAGES
    style, content = image_normalise(style, content, S_TYPE=NORM[0], \
          C_TYPE=NORM[1], gamma=NORM[2], sharpen=NORM[3])

    #GET NET1 PROCESSING SIZE
    im_size = aspect(content, net1['SIZE'])

    #GET MASK: if no alpha supplied then extract feature (face/eyes) into mask
    content_s = content.resize(im_size)
    mask = mask_extract(content_s, FEATURES[feature])
    im_mask.value = PIL_to_bytes(mask)
    mask = mask.resize(content.size)

    #LOAD IMAGES TO TENSORS, INPUT=CONTENT/NOISE MIX
    content_img, content_mask = content_loader(content, mask, net1, im_size)
    style_img = style_loader(style, net1['TILE'], im_size, keep_aspect=True)
    input_img = input_loader(content, mask, net1['NOISE'], im_size)

    #RUN ALGORITHM
    output, style_plots1, content_plots1, converged1  = run_style_transfer(net1, cnn, cnn_normalization_mean, 
                    cnn_normalization_std, content_img, style_img, input_img, content_mask)
    if(PLOTS):
        gen_plots(PLOT_DIR, 'NET1', net1, style_plots1, content_plots1)
      
    #SECOND RUN FOR HI RES OUTPUT
    if (HI_RES):
        if (converged1):
            b_process.description = 'net1 converged at {} ..'.format(converged1)
        else :
            b_process.description = 'net1 NOT CONVERGED ..'
        info10.value='Time: {:.1f} secs'.format(time.process_time() - start)#'time .. ..'

        #GET NET2 PROCESSING SIZE
        im_size = aspect(content, net2['SIZE'])

        #NEW INPUT = PREVIOUS OUTPUT
        output = output.cpu().squeeze(0) 
        unloader = transforms.ToPILImage()   
        input = unloader(output)

        #MAKE INPUT AND MASK SAME SIZE
        input = input.resize(im_size, Image.LANCZOS)
        mask = mask.resize(im_size, Image.BICUBIC)
        
        #LOAD IMAGES TO TENSORS
        content_img, content_mask = content_loader(content, mask, net2, im_size)
        style_img = style_loader(style, net2['TILE'], im_size, keep_aspect=True)
        input_img = input_loader(input, mask, net2['NOISE'], im_size)

        #RUN ALGORITHM AGAIN
        output, style_plots2, content_plots2, converged2  = run_style_transfer(net2, cnn, cnn_normalization_mean, 
                        cnn_normalization_std, content_img, style_img, input_img, content_mask)
        if(PLOTS):
            gen_plots(PLOT_DIR, 'NET2', net2, style_plots2, content_plots2)
        
        #print('\r Processing .. .. ..', end = "", flush=True)
        if (converged2):
            b_process.description = 'net2 converged at {} ..'.format(converged2)
        else :
            b_process.description = 'net2 NOT CONVERGED ..'
        info10.value='Time: {:.1f} secs'.format(time.process_time() - start)

    #CONVERT FINAL IMAGE
    transfer = tensor_to_PIL(output)
    im_transfer.value = PIL_to_bytes(transfer)  

    #WIDGET UPDATES
    im_transfer.width = im_size[0]*0.9
    im_transfer.height = im_size[1]*0.9
    b_process.description = 'PROCESS'
    b_process.style.button_color = 'lightblue'
    #info10.value='Time: {:.1f} secs'.format(time.process_time() - start)
        
    torch.cuda.empty_cache()
    #MEMORY CHECK
    #print(float(torch.cuda.memory_allocated())/1e9, float(torch.cuda.max_memory_allocated())/1e9, \
      #float(torch.cuda.memory_reserved())/1e9, float(torch.cuda.max_memory_reserved())/1e9)
    #torch._C._cuda_resetPeakMemoryStats(torch.cuda.current_device())
    #GPU_memory()


#PICK CONTENT IMAGE
@interact
def show_images(file=w_file):
    global ui_content
    ui_content = Image.open(CONTENT_DIR + file)
    if (len(ui_content.getbands())==4):
        im_content.value = PIL_to_bytes(ui_content.convert('RGB'))
    else:
        im_content.value = PIL_to_bytes(ui_content)

#PROCESS STYLE TRANSFER
def process(b):
    style_transfer(w_abstraction.value, w_tex_scale.value, w_strength.value, 
                 w_preservation.value, w_feature.value)
b_process.on_click(process)

#DISPLAY UI
display(ui2)

#INITIAL RUN
style_transfer(w_abstraction.value, w_tex_scale.value, w_strength.value, 
                 w_preservation.value, w_feature.value)

#ORIGINAL COLOUR RESTORATION
@interact
def orig_colour(colswap=w_colswap):
    global col_swap
    content = ui_content.resize(transfer.size)
    col_swap = colour_swap(content, transfer, colswap)
    im_transfer.value = PIL_to_bytes(col_swap)

#SAVE IMAGE
def save_st(b):
    b_save_st.style.button_color = 'red'

    size = SIZE[w_size_st.value]
    name = 'P_{}_{}-{}_ST_{}_{}-{}-{}-{}-{}-{}_{}.jpg'.format(STYLE[w_artstyle.value][:3], 
            w_seed.value, int(w_truncation.value*100), w_file.value[:-4], int(w_abstraction.value*100),
            int(w_tex_scale.value*100), int(w_strength.value*100), int(w_preservation.value*100),
            w_feature.value, int(w_colswap.value*100), size)
    path = OUT_DIR + name

    if (w_colswap.value):
        image = col_swap
    else:
        image = transfer

    #UPRES IF 2K/3K/4K REQUIRED
    if (w_size_st.value):
        b_save_st.description = 'Saving ..'

        res = image.size
        input = np.asarray(image)

        quarters, x, y, halfx, halfy = make_quarters(input, 10)
        for i in range(4):
            quarters[i] = upresx4(quarters[i])
        output = combine_quarters(quarters, 4*x, 4*y, 4*halfx, 4*halfy, 40)

        output = Image.fromarray(output.astype(np.uint8))
        if(w_size_st.value == 1):
            output = output.resize((2*res[0], 2*res[1]))
        elif(w_size_st.value == 2):
            output = output.resize((3*res[0], 3*res[1]))
    else:
        time.sleep(0.8)
        output = image

    output.save(path)       
    print('SAVED {}'.format(name))
    b_save_st.description = 'Save'

    b_save_st.style.button_color = 'lightblue'

b_save_st.on_click(save_st)



interactive(children=(Dropdown(description='File:', index=16, layout=Layout(margin='0px 0px 0px 0px'), options…

HBox(children=(VBox(children=(HTML(value='<b>CONTENT</b>', description=' ', layout=Layout(margin='10px 120px 0…

interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='Original Col: ', layout=Lay…

SAVED P_MOD_11-100_ST_00813_75-100-100-100-0-0_1K.jpg
