Skip to content

Commit

Permalink
standard alone script without ui
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Dec 14, 2016
1 parent 086e15d commit c078b1b
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 13 deletions.
11 changes: 10 additions & 1 deletion README.md
Expand Up @@ -143,14 +143,23 @@ THEANO_FLAGS='device=gpu0, floatX=float32, nvcc.fastmath=True' python iGAN_predi
* Check the result saved in `./pics/shoes_test_cnn_opt.png`
* We provide three methods: `opt` for optimization method; `cnn` for feed-forward network method (fastest); `cnn_opt` hybrid of the previous methods (default and best). Type `python iGAN_predict.py --help` for a complete list of the arguments.

## Script without UI
<img src='pics/script_result.png' width=1000>

We also provide a standard alone script that should work without UI. Given user constraints (i.e. a color map, a color mask and an edge map), the script generates multiple images that mostly satisfy the user constraints. See `python iGAN_script.py --help` for more details.
```bash
THEANO_FLAGS='device=gpu0, floatX=float32, nvcc.fastmath=True' python iGAN_script.py --model_name outdoor_64
```


## TODO
* ~~Support Python3.~~
* ~~Add image datasets.~~
* ~~Support average image mode.~~
* ~~Add the DCGAN model training script.~~
* ~~Support sketch models for sketching guidance.~~
* ~~Add the script for projecting an image to the latent vector `z`.~~
* Add a standard alone script without UI.
* ~~Add a standard alone script without UI.~~
* Add 128x128 models.
* Support other deep learning frameworks (e.g. Tensorflow).
* Support other deep generative models (e.g. variational autoencoder).
Expand Down
12 changes: 9 additions & 3 deletions constrained_opt.py
Expand Up @@ -5,7 +5,7 @@
import sys
from lib import utils
from PyQt4.QtCore import *

import cv2

class Constrained_OPT(QThread):
def __init__(self, opt_solver, batch_size=32, n_iters=25, topK=16, morph_steps=16, interp='linear'):
Expand Down Expand Up @@ -79,6 +79,11 @@ def update(self): # update ui

def save_constraints(self):
[im_c, mask_c, im_e, mask_e] = self.combine_constraints(self.constraints)
# write image
# im_c2 = cv2.cvtColor(im_c, cv2.COLOR_RGB2BGR)
# cv2.imwrite('input_color_image.png', im_c2)
# cv2.imwrite('input_color_mask.png', mask_c)
# cv2.imwrite('input_edge_map.png', im_e)
self.prev_im_c = im_c.copy()
self.prev_mask_c = mask_c.copy()
self.prev_im_e = im_e.copy()
Expand Down Expand Up @@ -179,6 +184,9 @@ def get_num_frames(self):
else:
return self.img_seq.shape[1]

def get_current_results(self):
return self.current_ims

def run(self): # main function
time_to_wait = 33 # 33 millisecond
while (1):
Expand All @@ -202,10 +210,8 @@ def run(self): # main function
if t_c < time_to_wait:
self.msleep(time_to_wait-t_c)


def update_invert(self, constraints):
constraints_c = self.combine_constraints(constraints)
t=time()
gx_t, z_t, cost_all = self.opt_solver.invert(constraints_c, self.z_const)

order = np.argsort(cost_all)
Expand Down
67 changes: 67 additions & 0 deletions iGAN_script.py
@@ -0,0 +1,67 @@
from __future__ import print_function
import argparse
from pydoc import locate
import constrained_opt
import cv2
import numpy as np
from pdb import set_trace as st

def parse_args():
parser = argparse.ArgumentParser(description='iGAN: Interactive Visual Synthesis Powered by GAN')
parser.add_argument('--model_name', dest='model_name', help='the model name', default='outdoor_64', type=str)
parser.add_argument('--model_type', dest='model_type', help='the generative models and its deep learning framework', default='dcgan_theano', type=str)
parser.add_argument('--framework', dest='framework', help='deep learning framework', default='theano')
parser.add_argument('--input_color', dest='input_color', help='input color image', default='./pics/input_color.png')
parser.add_argument('--input_color_mask', dest='input_color_mask', help='input color mask', default='./pics/input_color_mask.png')
parser.add_argument('--input_edge', dest='input_edge', help='input edge image', default='./pics/input_edge.png')
parser.add_argument('--output_result', dest='output_result', help='output_result', default='./pics/script_result.png')
parser.add_argument('--batch_size', dest='batch_size', help='the number of random initializations', type=int, default=64)
parser.add_argument('--n_iters', dest='n_iters', help='the number of total optimization iterations', type=int, default=100)
parser.add_argument('--top_k', dest='top_k', help='the number of the thumbnail results being displayed', type=int, default=16)
parser.add_argument('--model_file', dest='model_file', help='the file that stores the generative model', type=str, default=None)
parser.add_argument('--d_weight', dest='d_weight', help='captures the visual realism based on GAN discriminator', type=float, default=0.0)
args = parser.parse_args()
return args

def preprocess_image(img_path, npx):
im = cv2.imread(img_path, 1)
if im.shape[0] != npx or im.shape[1] != npx:
out = cv2.resize(im, (npx, npx))
else:
out = np.copy(im)

out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
return out
if __name__ == '__main__':
args = parse_args()
if not args.model_file: #if the model_file is not specified
args.model_file = './models/%s.%s' % (args.model_name, args.model_type)

for arg in vars(args):
print('[%s] =' % arg, getattr(args, arg))

# initialize model and constrained optimization problem
model_class = locate('model_def.%s' % args.model_type)
model = model_class.Model(model_name=args.model_name, model_file=args.model_file)
opt_class = locate('constrained_opt_%s' % args.framework)
opt_solver = opt_class.OPT_Solver(model, batch_size=args.batch_size, d_weight=args.d_weight)
img_size = opt_solver.get_image_size()
opt_engine = constrained_opt.Constrained_OPT(opt_solver, batch_size=args.batch_size, n_iters=args.n_iters, topK=args.top_k)
# load user inputs
npx = model.npx
im_color = preprocess_image(args.input_color, npx)
im_color_mask = preprocess_image(args.input_color_mask, npx)
im_edge = preprocess_image(args.input_edge, npx)
# run the optimization
opt_engine.init_z()
constraints = [im_color, im_color_mask[... ,[0]], im_edge, im_edge[...,[0]]]
for n in range(args.n_iters):
opt_engine.update_invert(constraints=constraints)
results = opt_engine.get_current_results()
final_result= np.concatenate(results, 1)
# combine input and output
final_vis = np. hstack([im_color, im_color_mask, im_edge, final_result])
final_vis = cv2.cvtColor(final_vis, cv2.COLOR_RGB2BGR)
final_vis = cv2.resize(final_vis, (0, 0), fx=2.0, fy=2.0)
# save
cv2.imwrite(args.output_result, final_vis)
Binary file added pics/input_color.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pics/input_color_mask.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pics/input_edge.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pics/script_result.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 1 addition & 9 deletions ui/gui_draw.py
Expand Up @@ -132,6 +132,7 @@ def get_image_id(self):

def get_frame_id(self):
return self.frame_id

def get_z(self):
print('get z from image %d, frame %d'%(self.get_image_id(), self.get_frame_id()))
return self.opt_engine.get_z(self.get_image_id(), self.get_frame_id())
Expand Down Expand Up @@ -205,14 +206,11 @@ def paintEvent(self, event):
painter.drawEllipse(pnt, w, w)

if self.type is 'warp' and self.show_ui:
# print 'paint warp brush'
color = Qt.green
w = 10
painter.setPen(QPen(color, w, Qt.DotLine, cap=Qt.RoundCap, join=Qt.RoundJoin)) # ,)
pnt1 = self.uiWarp.StartPoint()
# print 'start_point', pnt1
if pnt1 is not None:
# print 'paint warp brush 2'
pnt1f = QPointF(pnt1[0]*self.scale, pnt1[1]*self.scale)
pnt2f = QPointF(self.pos.x(), self.pos.y())
painter.drawLine(pnt1f, pnt2f)
Expand Down Expand Up @@ -279,17 +277,13 @@ def mousePressEvent(self, event):
self.update()

def mouseMoveEvent(self, event):
# print('mouse move', self.pos)
self.pos = self.round_point(event.pos())
if self.isPressed:
# point = event.pos()
if self.type in ['color','edge']:
self.points.append(self.pos)
self.update_ui()
self.update_opt_engine()
# self.update()
self.update()
# print(point)

def mouseReleaseEvent(self, event):
if event.button() == Qt.LeftButton and self.isPressed:
Expand All @@ -311,12 +305,10 @@ def sizeHint(self):


def update_frame(self, dif):
# frame_id_old = self.frame_id
num_frames = self.opt_engine.get_num_frames()
if num_frames > 0:
self.frame_id = (self.frame_id+dif) % num_frames
print("show frame id = %d"%self.frame_id)
# self.update()

def fix_z(self):
self.opt_engine.init_z(self.get_frame_id(), self.get_image_id())
Expand Down
3 changes: 3 additions & 0 deletions ui/ui_sketch.py
Expand Up @@ -16,6 +16,9 @@ def __init__(self, img_size, scale, accu=True, nc=3):
def update(self, points, color):
num_pnts = len(points)
c = 255 - int(color.red())
if c > 0:
c = 255

for i in range(0, num_pnts - 1):
pnt1 = (int(points[i].x()/self.scale), int(points[i].y()/self.scale))
pnt2 = (int(points[i + 1].x()/self.scale), int(points[i + 1].y()/self.scale))
Expand Down

0 comments on commit c078b1b

Please sign in to comment.