In [None]:
from __future__ import print_function

import os
import StringIO
import scipy.misc
import numpy as np
from glob import glob
from tqdm import trange
from itertools import chain
from collections import deque
from PIL import Image, ImageDraw


from models3 import *
from utils import save_image
import random

## some utils

In [None]:
def next(loader):
    return loader.next()[0].data.numpy()

def to_nhwc(image, data_format):
    if data_format == 'NCHW':
        new_image = nchw_to_nhwc(image)
    else:
        new_image = image
    return new_image

def to_nchw_numpy(image):
    if image.shape[3] in [1, 3]:
        new_image = image.transpose([0, 3, 1, 2])
    else:
        new_image = image
    return new_image

def norm_img(image, data_format=None):
    image = image/127.5 - 1.
    if data_format:
        image = to_nhwc(image, data_format)
    return image

def denorm_img(norm, data_format):
    return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255)

def slerp(val, low, high):
    """Code from https://github.com/soumith/dcgan.torch/issues/14"""
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

# Define model

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(3)
class Trainer(object):
    def __init__(self, config, data_loader, test_data_loader, generated_data_loader):
        self.config = config
        self.data_loader = data_loader
        self.test_data_loader = test_data_loader
        self.generated_data_loader = generated_data_loader
        self.dataset = config.dataset

        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.optimizer = config.optimizer
        self.batch_size = config.batch_size

        self.step = tf.Variable(0, name='step', trainable=False)
        
        self.gamma = config.gamma
        self.lambda_k = config.lambda_k
        
        self.ae_lr = tf.placeholder(tf.float32)
        self.d_lr = tf.placeholder(tf.float32)
        self.reg_lr = tf.placeholder(tf.float32)
        self.weight_dis = tf.placeholder(tf.float32)
        self.weight_reg = tf.placeholder(tf.float32)
        
        self.z_num = config.z_num
        self.conv_hidden_num = config.conv_hidden_num
        self.input_scale_size = config.input_scale_size

        self.model_dir = config.model_dir
        self.load_path = config.load_path

        self.use_gpu = config.use_gpu
        self.data_format = config.data_format

        height = 64
        width = 64
        self.channel = 3
       
        self.repeat_num = int(np.log2(height)) - 2

        self.start_step = 0
        self.log_step = config.log_step
        self.max_step = config.max_step
        self.save_step = config.save_step
        self.lr_update_step = config.lr_update_step

        self.is_train = config.is_train
        self.build_model()

        variables = slim.get_variables_to_restore()
        variables_to_restore = [v for v in variables if v.name.split('/')[0]=='AE' or v.name.split('/')[0]=='D']
        
        self.saver = tf.train.Saver(variables_to_restore)
        self.summary_writer = tf.summary.FileWriter(self.model_dir)

        sv = tf.train.Supervisor(logdir=self.model_dir,
                                is_chief=True,
                                saver=self.saver,
                                summary_op=None,
                                summary_writer=self.summary_writer,
                                save_model_secs=1800,
                                global_step=self.step,
                                ready_for_local_init_op=None)

        gpu_options = tf.GPUOptions(allow_growth=True)
        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                    gpu_options=gpu_options)
        self.sess = sv.prepare_or_wait_for_session(config=sess_config)
        

        if not self.is_train:
            # dirty way to bypass graph finilization error
            g = tf.get_default_graph()
            g._finalized = False
            self.build_test_model()
       

    def build_model(self):
        print("start to build training model...")

        self.x, self.label = self.data_loader 
        # self.x: [-1, 3, 64, 64], transpose(self.x): real image
        
        self.label1 = tf.reshape(self.label,[-1,1])
        self.label2 = tf.one_hot(tf.to_int32(self.label), 10)
       
        x = norm_img(self.x) # x: [-1, 3, 64, 64]
        self.pred_label, self.reg_var, self.x_tmp, self.x_tmp2 = reg(x) # this is used for training the regressor, norm image       
        d_out, self.D_z, self.AE_var = AE(x, self.label, self.batch_size, self.channel, self.z_num, self.repeat_num,self.conv_hidden_num, self.data_format)        
        # d_out: [-1,3,64,64]
       
        a = 0
        a = tf.convert_to_tensor(a, dtype=tf.float32)
        
        for i in range(10):
            changed_label = tf.to_float(tf.convert_to_tensor([i,i,i,i,i,i,i,i,i,i,i,i,i,i,i,i])) # correct
           
           
            d_out_changed, _, self.AE_var2 = AE(x, changed_label, self.batch_size, self.channel, self.z_num, self.repeat_num,self.conv_hidden_num, self.data_format, reuse=True)
            #rel = denorm_img(d_out_changed, self.data_format) # the output of the denorm img is good
            # correct    
            pred_, _, _, _ = reg(d_out_changed, reuse=True)
                 
            a = tf.add(a, tf.reduce_mean(tf.square(pred_ - tf.to_float(tf.reshape(changed_label, [-1,1])))))
         
            if i == 0:
                self.debug0 = pred_
                self.a0 = a
            if i == 4:
                self.debug4 = pred_
                self.a4 = a
           
            
        a = tf.divide(tf.to_float(a), 10) # a is another loss
        self.a = a
        
        AE_x = d_out
        self.AE_x = denorm_img(AE_x, self.data_format)
     
        self.out_label, self.D_var, self.z_vec, self.out_before_softmax = Discriminator(self.D_z, self.batch_size)
        self.label_reshaped = tf.reshape(self.out_label,[-1])

        # change the label and then run the auto-encoder, and then run the regressor to compute the difference 
        
        
        
        #---------------------------define the optimizer for 3 components in this network-----------------------##
        optimizer = tf.train.AdamOptimizer
       
        ae_optimizer = optimizer(self.ae_lr)
        d_optimizer = optimizer(self.d_lr)
        reg_optimizer = optimizer(self.reg_lr)

        ##...........................Define the loss function here.May be changed................................##
        self.ae_loss = tf.reduce_mean(tf.square(AE_x-x)) + self.weight_dis* tf.reduce_mean(tf.reduce_sum(self.label2 * tf.log(self.out_label+1e-8), 1)) + self.weight_reg * self.a
        
        print("size of ae loss")
        print(self.ae_loss.get_shape())
        self.ae_optim = ae_optimizer.minimize(self.ae_loss, var_list=self.AE_var)
    
        self.d_loss = -tf.reduce_mean(tf.reduce_sum(self.label2 * tf.log(self.out_label+1e-8), 1))
        self.d_optim = d_optimizer.minimize(self.d_loss, var_list=self.D_var)
        # label2 is the real label, out_label is the predicted label
        
        self.d_loss2 = tf.reduce_sum(self.label2 * tf.log(self.out_label+1e-8), 1)
        
       
        self.reg_loss = tf.reduce_mean(tf.square(self.pred_label-self.label1)) # not sure here is self.label or self.label1
        self.reg_optim = reg_optimizer.minimize(self.reg_loss, var_list = self.reg_var)
    
        
    def build_test_model(self):
        with tf.variable_scope("test") as vs:
            # Extra ops for interpolation
            z_optimizer = tf.train.AdamOptimizer(0.0001)
            self.z_r = tf.get_variable("z_r", [self.batch_size, self.z_num], tf.float32)

        test_variables = tf.contrib.framework.get_variables(vs)
        self.sess.run(tf.variables_initializer(test_variables))


    def autoencode(self, inputs, label, path, idx=None):
        img = inputs
        if img.shape[3] in [1, 3]:
            img = img.transpose([0, 3, 1, 2])

            #x_path = os.path.join('./aaa.jpg')
        x = self.sess.run(self.AE_x, {self.x: img, self.label: label})
        save_image(x, "aaa.jpg")
        print("[*] Samples saved")
        return x
    
    def Do_reg(self, inputs):
        
        if inputs.shape[3] in [1, 3]:
            inputs = inputs.transpose([0, 3, 1, 2])
        return self.sess.run(self.pred_label, {self.x: inputs})

    def encode(self, inputs):
        if inputs.shape[3] in [1, 3]:
            inputs = inputs.transpose([0, 3, 1, 2])
        return self.sess.run(self.D_z, {self.x: inputs})

    def decode(self, z, label):
        return self.sess.run(self.AE_x, {self.D_z: z, self.label: label})

    def interpolate_D(self, real1_batch, real2_batch, label1, label2, step=0, root_path="."):

        real1_encode = self.encode(real1_batch) #(16,512,2,2)
        real2_encode = self.encode(real2_batch) # (16,512,2,2)

        decodes = []
        changed_label = [0,1,2,3,4,5,6,7,8,9]
       
    
       
        for i in range(10):
            z_decode = self.decode(real1_encode, np.repeat(changed_label[i],16))
            decodes.append(z_decode)

        decodes = np.stack(decodes).transpose([1, 0, 2, 3, 4])
        for idx, img in enumerate(decodes):
            img = np.concatenate([[real1_batch[idx]], img, [real2_batch[idx]]], 0)
            save_image(img, "./interpolation"+str(idx)+".png", nrow=10 + 2)
        for idx in range(16):
            im = Image.open("./interpolation"+str(idx)+".png")
            draw = ImageDraw.Draw(im)
            label = label1[idx]
            draw.rectangle((label * 64,0, (label+1)*64, 64), fill=None, outline="red")
            im.save("./interpolation"+str(idx)+".png")

    def change_attributes(self, real_batch, root_path='.'):

        real_encode = self.encode(real_batch)
        decodes = []
        imgs = []
        # test batch size
        test_sample = random.sample([0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9], 16)
        decodes=self.decode(real_encode,test_sample)
        
        save_image(decodes, os.path.join("./", "test_changed_attri.png"))
        
        decodes =self.decode(real_encode,[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
        save_image(decodes, os.path.join("./", "test_changed_attri_0.png"))

    def test(self):
        root_path = "./"
        for step in range(3):
            real1_batch, label1_batch = self.get_image_from_loader()         
            save_image(real1_batch, os.path.join(root_path, 'test{}_real1.png'.format(step)))
            self.autoencode(real1_batch, label1_batch, self.model_dir, idx=os.path.join(root_path, "test{}_real1".format(step)))
            self.change_attributes(real1_batch, self.model_dir)
            
    def get_image_from_loader(self):
        tmp, label = self.data_loader
        print(tmp.get_shape())
        x, label2 = self.sess.run([tmp,label])
        if self.data_format == 'NCHW':
            x = x.transpose([0, 2, 3, 1])
            print(x)
        return x, label2
    
    def get_test_image_from_loader(self):
        tmp, label = self.test_data_loader
        x, label2 = self.sess.run([tmp,label])
        if self.data_format == 'NCHW':
            x = x.transpose([0, 2, 3, 1])
        return x, label2
    
    def get_generated_image_from_loader(self):
        tmp, label = self.generated_data_loader
        x, label2 = self.sess.run([tmp,label])
        if self.data_format == 'NCHW':
            x = x.transpose([0, 2, 3, 1])
        return x, label2

    def train_reg(self, reglr, step):
        for i in trange(0, step):
            self.sess.run(self.reg_optim, feed_dict={self.weight_dis:0.0, self.ae_lr: 0.0, self.d_lr: 0.0, self.reg_lr: reglr, self.weight_reg:0.0})
            Reg_loss, pred_label_, real_label_  = self.sess.run([self.reg_loss, self.pred_label, self.label], feed_dict={self.weight_dis:0.0, self.ae_lr: 0.0, self.d_lr: 0.0, self.reg_lr: reglr, self.weight_reg: 0.0})
            if i % 5 == 0:
                print("[train regressor] step "+str(i)+", mse loss for regression is: "+str(Reg_loss))
                print("pred label:")
                print(pred_label_)
                print("real_label:")
                print(real_label_)
                
    
    def train_AE(self, aelr, dlr, step):
        for step in trange(0, step):                
            self.sess.run(self.ae_optim, feed_dict={self.weight_dis: 0.0, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg: 0})
            AE_loss, D_loss, D_loss2, D_out, real_label, real_label_one_hot, zz, before = self.sess.run([self.ae_loss,self.d_loss,self.d_loss2, self.out_label, self.label, self.label2, self.z_vec, self.out_before_softmax], feed_dict={self.weight_dis: 0.0, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg:0.0}) 
               
            if step % 25 == 0:                
                print("D out after softmax is:")
                print(D_out)                   
                print("[train AE] step: "+str(step)+", ae_loss: "+str(AE_loss)+", d_loss = " + str(D_loss)+"\n")

    def train_D(self, aelr, dlr):
        for step in trange(5000, 10000):
            self.sess.run(self.d_optim, feed_dict={self.weight_dis: 0.0, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg:0.0})
            D_loss, D_loss2, D_out, real_label, real_label_one_hot = self.sess.run([self.d_loss, self.d_loss2, self.out_label, self.label, self.label2], feed_dict={self.weight_dis: 0.0, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg:0.0}) 
            AE_loss = self.sess.run(self.ae_loss, feed_dict={self.weight_dis: 0.0, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg:0.0})
            if step % 25 == 0:    
                print("the output of dis (after softmax):")
                print(D_out)      
                print("[train D] step: "+str(step)+", ae_loss: "+str(AE_loss)+", d_loss = " + str(D_loss)+"\n")
            
    def train_both(self, weightdis, weightreg, aelr, dlr, step):
        for step in trange(0, step):
            
            self.sess.run([self.ae_optim, self.d_optim], feed_dict={self.weight_dis: weightdis, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg: weightreg})
            #D_loss = self.sess.run(self.d_loss,feed_dict={self.weight_: weight})
            
            a_0, a_4, pred_0,pred_4, reg_loss_on_gen,D_loss,D_loss2, AE_loss,D_out, real_label, real_label_one_hot, zz, before = self.sess.run([self.a0, self.a4, self.debug0, self.debug4, self.a, self.d_loss, self.d_loss2, self.ae_loss, self.out_label, self.label, self.label2, self.z_vec, self.out_before_softmax], feed_dict={self.weight_dis: weightdis, self.ae_lr: aelr, self.d_lr: dlr, self.weight_reg: weightreg}) 
            
            if step % 25 == 0:  
                print("the output of dis (after softmax):")
                print(D_out)

                print("[train together]: step: "+str(step)+", ae_loss: "+str(AE_loss)+", d_loss = " + str(D_loss)+"\n")
            print("regression loss on transformed images is: " + str(reg_loss_on_gen))
            print("the prediction value for transformed image with conditioned label 0 is:")
            print(pred_0)
            print(a_0)
           
         

In [None]:
import numpy as np
import tensorflow as tf

#from trainer import Trainer
from config import get_config
from data_loader import get_loader
from data_loader import get_loader_test
from data_loader import get_loader_generated
from utils import prepare_dirs_and_logger, save_config


# Load data and initialize the model

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(3)
config, unparsed = get_config()
prepare_dirs_and_logger(config)

rng = np.random.RandomState(config.random_seed)
tf.set_random_seed(config.random_seed)

if config.is_train:
    data_path = config.data_path
    batch_size = config.batch_size
    do_shuffle = False
else:
    setattr(config, 'batch_size', 16)
    if config.test_data_path is None:
        data_path = config.data_path
    else:
        data_path = config.test_data_path
    batch_size = config.sample_per_image
    do_shuffle = False

data_loader = get_loader(
            data_path, config.batch_size, config.input_scale_size,
            config.data_format, config.split)
test_data_loader = get_loader_test(
            data_path, config.batch_size, config.input_scale_size,
            config.data_format, config.split)
generated_data_loader = get_loader_generated(data_path, config.batch_size, config.input_scale_size,
            config.data_format, config.split)
print("has loaded data....")
trainer = Trainer(config, data_loader, test_data_loader, generated_data_loader)
print("has initialized trainer...")

save_config(config)

# Begin training

In [None]:
g = tf.get_default_graph()
g._finalized = False
init = tf.global_variables_initializer()

trainer.sess.run(init)
chkpt_fname = tf.train.latest_checkpoint("./logs/dis")
trainer.saver.restore(trainer.sess, chkpt_fname)

In [None]:
#x, _ = trainer.data_loader
#x = norm_img(x)
#d_out_changed,_,_ = AE(x, tf.to_float([9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9]), trainer.batch_size, trainer.channel, trainer.z_num, trainer.repeat_num, trainer.conv_hidden_num, trainer.data_format, reuse=True)
#rel = denorm_img(d_out_changed, trainer.data_format) # the output of the denorm img is good


loss, x0, x4, a_0, a_4, pred_0, pred_4,reg_loss_on_gen = trainer.sess.run([trainer.a, trainer.x0, trainer.x4, trainer.a0, trainer.a4, trainer.debug0, trainer.debug4, trainer.a], feed_dict={trainer.weight_dis: 0, trainer.ae_lr: 0, trainer.d_lr: 0, trainer.weight_reg:0})
#print(pred_0)
#print(pred_1)
#print(pred_2)
#print(pred_3)
#print(pred_4)
#print(pred_8)
#print(pred_9)
#print("mse loss: "+str(loss))
save_image(x0,'./img0.jpg')
save_image(x4,'./img4.jpg')


inputs = x4
#print(inputs.shape)  #(16,64,64,3)

if inputs.shape[3] in [1, 3]:
        inputs = inputs.transpose([0, 3, 1, 2])
pred_wrong =  trainer.sess.run(trainer.debug4, {trainer.gen: inputs})
print(pred_wrong)
t1, t_ae2 = trainer.sess.run([trainer.var2, trainer.AE_var2], {trainer.gen:inputs})
#
print("--------------")
print("number of variables in reused reg: " + str(len(t1)))
print("------------")
print("number of variables in reused AE: " + str(len(t_ae2)))
pred_right = trainer.sess.run(trainer.pred_label, {trainer.x: inputs})
print(pred_right)
print("--------------")
print("--------------")
t2, t_ae = trainer.sess.run([trainer.reg_var, trainer.AE_var], {trainer.x: inputs})
print("in original reg model there are :" + str(len(t2)) +" set variables.")

print("----------")

print("in original AE model there are :" + str(len(t_ae)) +" set variables.")


from IPython.display import Image
from IPython.display import display



x0 = Image(filename='./img0.jpg')
x4 = Image(filename='./img4.jpg')

display(x0, x4)


In [None]:
from IPython.display import Image
from IPython.display import display

real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, [0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0], None)

x = Image(filename='./aaa.jpg')  
display(x)


In [None]:
trainer.train_reg(0.00005, 2000)

In [None]:
from sklearn.metrics import mean_squared_error
tot = 0
test_real_label = np.zeros((100 * 16, 1))
test_pred_label = np.zeros((100 * 16, 1))

for i in range(100):
    real_batch, label_batch = trainer.get_image_from_loader()
    print(real_batch.shape[3])
    test_pred = trainer.Do_reg(real_batch)
    test_real = np.reshape(label_batch,(-1,1))
    test_real_label[i * 16: (i + 1) * 16] = test_real
    test_pred_label[i * 16: (i + 1) * 16] = test_pred
    #loss = (test_pred - test_real)**2 * 1.0 / 16
    loss = mean_squared_error(test_pred, test_real)
    tot += loss
    #print("testing mse loss is: "+str(loss))
print("mse loss is: " + str(tot / 100))



import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

fig, ax = plt.subplots()
ax.scatter(test_pred_label, test_real_label, marker='x', s = 2)
plt.xlim(-4, 12)
ax.set_xlabel('regression value')
ax.set_ylabel('real label')
ax.set_title("predicted BTDR value distribution for real galaxies on testing set")
plt.plot([0,9], [0,9],'k-')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['bottom'].set_position(('data', 0))
ax.spines['left'].set_position(('data', 0))
ax.set_xticks([-4, -3, -2, -1,0, 1,2,3,4,5,6,7,8,9,10,11,12])
ax.set_yticks([0,1,2,3,4,5,6,7,8,9])


plt.show()

In [None]:
trainer.train_AE(0.00004,0.00004, 1000)

In [None]:
trainer.train_AE(0.00004,0.00004, 1000)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline

print("a")
real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, label_batch, trainer.model_dir, idx=os.path.join("./", "recovered"))
im = Image.open('./recovered_real.png', 'r')
#imshow(np.asarray(im))
print(label_batch)

save_image(real_batch, os.path.join("./", 'original.png'))
im_real = Image.open('./original.png','r')
#imshow(np.asarray(im))


from IPython.display import Image
from IPython.display import display
x = Image(filename='./recovered_real.png') 
y = Image(filename='./original.png') 
display(x, y)

In [None]:
trainer.train_AE(0.00004,0.00004, 100)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline

print("a")
real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, label_batch, trainer.model_dir, idx=os.path.join("./", "recovered"))
im = Image.open('./recovered_real.png', 'r')
#imshow(np.asarray(im))
print(label_batch)

save_image(real_batch, os.path.join("./", 'original.png'))
im_real = Image.open('./original.png','r')
#imshow(np.asarray(im))


from IPython.display import Image
from IPython.display import display
x = Image(filename='./recovered_real.png') 
y = Image(filename='./original.png') 
display(x, y)

In [None]:
trainer.train_AE(0.00004,0.00004, 500)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline

print("a")
real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, label_batch, trainer.model_dir, idx=os.path.join("./", "recovered"))
im = Image.open('./recovered_real.png', 'r')
#imshow(np.asarray(im))
print(label_batch)

save_image(real_batch, os.path.join("./", 'original.png'))
im_real = Image.open('./original.png','r')
#imshow(np.asarray(im))


from IPython.display import Image
from IPython.display import display
x = Image(filename='./recovered_real.png') 
y = Image(filename='./original.png') 
display(x, y)

In [None]:
trainer.train_AE(0.00004,0.00004, 500)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline

print("a")
real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, label_batch, trainer.model_dir, idx=os.path.join("./", "recovered"))
#im = Image.open('./recovered_real.png', 'r')
#imshow(np.asarray(im))
print(label_batch)

save_image(real_batch, os.path.join("./", 'original.png'))
im_real = Image.open('./original.png','r')
#imshow(np.asarray(im))


from IPython.display import Image
from IPython.display import display
#x = Image(filename='./recovered_real.png') 
y = Image(filename='./original.png') 
display( y)

In [None]:
trainer.saver.save(trainer.sess,'ae_loss=1')

In [None]:
trainer.train_AE(0.00004,0.00004, 100)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline

print("a")
real_batch, label_batch = trainer.get_test_image_from_loader()
trainer.autoencode(real_batch, label_batch, trainer.model_dir, idx=os.path.join("./", "recovered"))
im = Image.open('./recovered_real.png', 'r')
#imshow(np.asarray(im))
print(label_batch)

save_image(real_batch, os.path.join("./", 'original.png'))
im_real = Image.open('./original.png','r')
#imshow(np.asarray(im))


from IPython.display import Image
from IPython.display import display
x = Image(filename='./recovered_real.png') 
y = Image(filename='./original.png') 
display(x, y)

In [None]:
trainer.saver.save(trainer.sess,'ae_loss=0.0001')

In [None]:
trainer.train_D(0.00008,0.00008)

In [None]:
trainer.saver.save(trainer.sess,'after_dis')

In [None]:
trainer.train_D(0.00008,0.00008)

In [None]:
trainer.train_both(0.003,0.0000001,0.00002,0.00008, 100)

In [None]:
trainer.saver.save(trainer.sess,'a_little')

In [None]:
trainer.train_both(0.0035,0.00001,0.00002,0.00008, 100)

In [None]:
trainer.train_both(0.0035,0.00002,0.00002,0.00008, 50)

In [None]:
trainer.saver.save(trainer.sess,'try1')

In [None]:
trainer.train_both(0.0035,0.0002,0.00002,0.00008, 50)

In [None]:
trainer.train_both(0.0035,0.0004,0.00002,0.00008, 50)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]=str(4)

In [None]:
trainer.train_both(0.0035,0.00005,0.00002,0.00008, 50)

In [None]:
trainer.train_both(0.0015,0.0001,0.00004,0.00008, 20)

In [None]:
trainer.train_both(0.002,0.0,0.00004,0.00008, 25)

In [None]:
trainer.train_both(0.0025,0.0,0.00004,0.00008, 25)

In [None]:
trainer.train_both(0.003,0.0,0.00004,0.00008, 25)

In [None]:
from sklearn.metrics import mean_squared_error
from IPython.display import Image
from IPython.display import display
from matplotlib.backends.backend_pdf import PdfPages

import matplotlib.pyplot as plt

pp = PdfPages('evaluation.pdf')

tot = 0
fake_label = np.zeros((100 * 16, 1))
fake_pred_label = np.zeros((100 * 16, 1))
for i in range(100):
    fake_batch, label = trainer.get_generated_image_from_loader()
    fake_pred = trainer.Do_reg(fake_batch)
    label = np.reshape(label,(-1,1))
    fake_label[i * 16: (i + 1) * 16] = label
    fake_pred_label[i * 16: (i + 1) * 16] = fake_pred
    #loss = (test_pred - test_real)**2 * 1.0 / 16
    loss = mean_squared_error(label, fake_pred)
    tot += loss
    if i == 0:
        print(label[0:16])
        print(fake_pred[0:16])
        save_image(fake_batch, os.path.join("./", 'generated.png'))
        
mse = tot / 100
    #print("testing mse loss is: "+str(loss))
print("mse loss is: " + str(mse))
x = Image(filename='./generated.png')
display(x)



fig, ax = plt.subplots()
ax.scatter(fake_pred_label, fake_label, marker='x', s = 2)
plt.xlim(-4, 12)
ax.set_xlabel('regression value')
ax.set_ylabel('real label')
ax.set_title("predicted BTDR value distribution for generated galaxies from testing set")
plt.plot([0,9], [0,9],'k-')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['bottom'].set_position(('data', 0))
ax.text(5, 1, r'mse = ' + str(mse), fontsize=15)
ax.spines['left'].set_position(('data', 0))
ax.set_xticks([-4, -3, -2, -1,0, 1,2,3,4,5,6,7,8,9,10,11,12])
ax.set_yticks([0,1,2,3,4,5,6,7,8,9])
pp.savefig(fig)

#---------

tot = 0
fake_label = np.zeros((100 * 16, 1))
fake_pred_label = np.zeros((100 * 16, 1))
for i in range(100):
    fake_batch, label = trainer.get_image_from_loader()
    fake_pred = trainer.Do_reg(fake_batch)
    label = np.reshape(label,(-1,1))
    fake_label[i * 16: (i + 1) * 16] = label
    fake_pred_label[i * 16: (i + 1) * 16] = fake_pred
    #loss = (test_pred - test_real)**2 * 1.0 / 16
    loss = mean_squared_error(label, fake_pred)
    tot += loss
    if i == 0:
        print(label[0:16])
        print(fake_pred[0:16])
        save_image(fake_batch, os.path.join("./", 'generated.png'))
        
mse = tot / 100
    #print("testing mse loss is: "+str(loss))
print("mse loss is: " + str(mse))
x = Image(filename='./generated.png')
display(x)


fig, ax = plt.subplots()
ax.scatter(fake_pred_label, fake_label, marker='x', s = 2)
plt.xlim(-4, 12)
ax.set_xlabel('regression value')
ax.set_ylabel('real label')
ax.set_title("predicted BTDR value distribution for real galaxies on training set")
plt.plot([0,9], [0,9],'k-')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['bottom'].set_position(('data', 0))
ax.text(5, 1, r'mse = ' + str(mse), fontsize=15)
ax.spines['left'].set_position(('data', 0))
ax.set_xticks([-4, -3, -2, -1,0, 1,2,3,4,5,6,7,8,9,10,11,12])
ax.set_yticks([0,1,2,3,4,5,6,7,8,9])

pp.savefig(fig)
#----------------------

tot = 0
fake_label = np.zeros((100 * 16, 1))
fake_pred_label = np.zeros((100 * 16, 1))
for i in range(100):
    fake_batch, label = trainer.get_test_image_from_loader()
    fake_pred = trainer.Do_reg(fake_batch)
    label = np.reshape(label,(-1,1))
    fake_label[i * 16: (i + 1) * 16] = label
    fake_pred_label[i * 16: (i + 1) * 16] = fake_pred
    #loss = (test_pred - test_real)**2 * 1.0 / 16
    loss = mean_squared_error(label, fake_pred)
    tot += loss
    if i == 0:
        print(label[0:16])
        print(fake_pred[0:16])
        save_image(fake_batch, os.path.join("./", 'generated.png'))
        
mse = tot / 100
    #print("testing mse loss is: "+str(loss))
print("mse loss is: " + str(mse))
x = Image(filename='./generated.png')
display(x)


fig, ax = plt.subplots()
ax.scatter(fake_pred_label, fake_label, marker='x', s = 2)
plt.xlim(-4, 12)
ax.set_xlabel('regression value')
ax.set_ylabel('real label')
ax.set_title("predicted BTDR value distribution for real galaxies on testing set")
plt.plot([0,9], [0,9],'k-')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
ax.spines['bottom'].set_position(('data', 0))
ax.text(5, 1, r'mse = ' + str(mse), fontsize=15)
ax.spines['left'].set_position(('data', 0))
ax.set_xticks([-4, -3, -2, -1,0, 1,2,3,4,5,6,7,8,9,10,11,12])
ax.set_yticks([0,1,2,3,4,5,6,7,8,9])


plt.show()
pp.savefig(fig)
pp.close()

In [None]:

# now, begin to test

from PIL import Image

real_batch, label_batch = trainer.get_test_image_from_loader()
real1_encode = trainer.encode(real_batch) #(16,512,2,2)
       

decodes=[]
changed_label = [0,1,2,3,4,5,6,7,8,9]
       
    
for i in range(10):
    z_decode = trainer.decode(real1_encode, np.repeat(changed_label[i],16))
    decodes.append(z_decode)

decodes = np.stack(decodes).transpose([1, 0, 2, 3, 4])
for idx, img in enumerate(decodes):
    img = np.concatenate([[real_batch[idx]], img], 0)
    save_image(img, "./interpolation"+str(idx)+".png", nrow=10 + 1)
for idx in range(16):
    im = Image.open("./interpolation"+str(idx)+".png")
    draw = ImageDraw.Draw(im)
    label = label_batch[idx]
    draw.rectangle(((label+1) * 66,0, (label+2)*66, 68), fill=None, outline="red")
    im.save("./interpolation"+str(idx)+".png")
    

from matplotlib.pyplot import imshow
from PIL import Image
%matplotlib inline


from IPython.display import Image
from IPython.display import display
#x = Image(filename='./original.png') 
#y = Image(filename='./recovered_real.png') 
t0 = Image(filename='./interpolation0.png')
t1 = Image(filename='./interpolation1.png')
t2 = Image(filename='./interpolation2.png')
t3 = Image(filename='./interpolation3.png')
t4 = Image(filename='./interpolation4.png')
t5 = Image(filename='./interpolation5.png')
t6 = Image(filename='./interpolation6.png')
t7 = Image(filename='./interpolation7.png')
t8 = Image(filename='./interpolation8.png')
t9 = Image(filename='./interpolation9.png')
t10 = Image(filename='./interpolation10.png')
t11 = Image(filename='./interpolation11.png')
t12 = Image(filename='./interpolation12.png')
t13 = Image(filename='./interpolation13.png')
t14 = Image(filename='./interpolation14.png')
t15 = Image(filename='./interpolation15.png')
display(t0, t1,t2,t3,t4,t5,t6,t7,t8,t9,t10,t11,t12,t13,t14,t15)

In [None]:
trainer.train_both(0.003, 0.00008, 0.00008)

In [None]:
from matplotlib.pyplot import imshow
from PIL import Image

%matplotlib inline

for i in range(50):
    real_batch, label_batch = trainer.get_test_image_from_loader()
    real_encode = trainer.encode(real_batch)
    
    print("encoding has finished")
  
    test_sample = random.sample([0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9], 16)
    decodes=trainer.decode(real_encode,test_sample)
    
    print("decoding has finished")
    
    j = 0
    for k in test_sample:
        im = Image.fromarray(decodes[j].astype(np.uint8))
        im.save("/mnt/ds3lab/litian/AE_BEGAN/data/galaxy_64_bdtr/generated2/"+str(k)+"/"+str(i)+"_"+str(j)+".jpg")
        print("have saved the iamge in:"+"/mnt/ds3lab/litian/AE_BEGAN/data/galaxy_64_bdtr/generated2/"+str(k)+"/"+str(i)+"_"+str(j)+".jpg")
        j = j + 1



In [None]:
trainer.saver.save(trainer.sess,'try4')