In [0]:
# use github data
!git clone https://github.com/caibojun/psfnet.git
!mv ./psfnet/* ./
!rm -r ./psfnet
!rm -r sample_data

In [0]:
#use your own data
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

In [0]:
#create new folders
!mkdir ./input
!mkdir ./input/trainA
!mkdir ./input/trainB
!mkdir ./input/valid
!mkdir ./input/valid_gt
!rm -r sample_data

In [0]:
from __future__ import division
from mobilenet import *
from network import *
import tensorflow as tf
import numpy as np
from astropy.io import fits
import os


In [0]:
class PSFNET(object):
    def __init__(self,sess,img_sz,checkpoint_dir,save_folder,graph_dir,use_mobilenet):
        def checkdir(dirpath):
            if os.path.isdir(dirpath) is False:
                os.mkdir(dirpath)
        self.sess = sess
        self.sz = img_sz
        if use_mobilenet:
            self.deconvnet = mobilenet
        else:
            self.deconvnet = deconv_resnet
        self.save_folder = save_folder
        self.graph_dir = graph_dir
        self.checkpoint_dir=checkpoint_dir
        checkdir(self.save_folder)
        checkdir(self.checkpoint_dir)
        checkdir(self.graph_dir)
        self.build_net()
        self.saver = tf.train.Saver()



    def build_net(self,Lambda=10.0):
        def loss(gt,restore,method='mse'):
            if method=='mse':
                return tf.reduce_mean((gt-restore)**2)
            else:
                return tf.reduce_mean(tf.abs(gt-restore))

        self.blur_data = tf.placeholder(tf.float32,
                                        [None, self.sz, self.sz, 1],
                                        name='blur_image')
        self.gt_data = tf.placeholder(tf.float32,
                                        [None, self.sz, self.sz, 1],
                                        name='gt_image')
        self.restore_data = self.deconvnet(self.blur_data,scope='restore')
        self.fake_blur = self.deconvnet(self.gt_data,scope='blur')
        self.cycle_restore = self.deconvnet(self.fake_blur,reuse=True,scope='restore')
        self.cycle_blur = self.deconvnet(self.restore_data,reuse=True,scope='blur')
        self.id_blur = self.deconvnet(self.blur_data,reuse=True,scope='blur')
        self.id_restore = self.deconvnet(self.gt_data,reuse=True,scope='restore')

        
        self.loss = loss(self.blur_data,self.fake_blur,'mae') + loss(self.gt_data,self.restore_data,'mae') + \
                    Lambda*(loss(self.cycle_restore,self.gt_data,'mae') + loss(self.cycle_restore,self.gt_data,'mae')) + \
                    (loss(self.blur_data,self.id_blur,'mae') + loss(self.gt_data,self.id_restore,'mae'))/2

    def train_optimizer(self):
        self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
        self.optim = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(self.loss,var_list = tf.trainable_variables())

    def train(self,dir_path,EPOCH=500,Continue=False):
        if Continue:
            init_point = int(raw_input("Please input your initiate point:"))
            self.saver.restore(self.sess, os.path.join(self.checkpoint_dir, 'psfnet'))
        else:
            init_point = 1
        self.train_optimizer()
        init_op = tf.global_variables_initializer()
        self.sess.run(init_op)
        summary_writer = tf.summary.FileWriter(self.graph_dir, self.sess.graph)
        print('Training...')
        def read(img_path):
            image = np.float32(fits.open(img_path)[0].data)
            image = np.expand_dims(image, axis=0)
            image = np.expand_dims(image, axis=-1)
            image = 2 * image / np.max(image) - 1.0
            return image
        def saveimg(savepath,img):
            if os.path.isfile(savepath):
                os.remove(savepath)
            fits.HDUList([fits.PrimaryHDU(img)]).writeto(savepath)
        def exponential_decay(t, init=1e-3, m=1000, finish=1e-6): 
            alpha = np.log(init / finish) / m 
            l = - np.log(init) / alpha 
            decay = np.exp(-alpha * (t + l)) 
            return decay
        files = os.listdir(dir_path)
#         files.remove('.ipynb_checkpoints')

        valid_path = os.path.join(os.path.dirname(dir_path),'valid')
        valid_files = os.listdir(valid_path)
#         valid_files.remove('.ipynb_checkpoints')
        
        file_nums=len(files)
        for epoch in range(init_point,EPOCH):
            Loss = 0
            VLoss = 0
            lr = 1e-3 if epoch < 30 else exponential_decay(epoch)
            for file in files:

                img_path = os.path.join(dir_path, file).replace('\\','/')
                image =read(img_path)
                gt_path = os.path.join('input/trainA',file).replace('\\','/')
                gt = read(gt_path)
                _, restored, loss = self.sess.run([self.optim,self.restore_data,self.loss],
                                         feed_dict={self.blur_data: image,self.gt_data:gt,self.lr:lr})
                restored = restored.reshape([restored.shape[1], restored.shape[2]])
                savepath = os.path.join(self.save_folder, file).replace('\\','/')
                Loss += loss
                saveimg(savepath, restored)

            for valid in valid_files:
                img_path = os.path.join(valid_path, valid).replace('\\','/')
                image =read(img_path)
                valid_gt = os.path.join('input/valid_gt',valid).replace('\\','/')
                gt = read(valid_gt)
                valid_loss = self.sess.run(self.loss,
                                         feed_dict={self.blur_data: image,self.gt_data:gt})
                VLoss+=valid_loss
            print('Epoch {}, mean loss {}, valid_loss {}'.format(epoch,Loss/file_nums,VLoss/len(valid_files)))
            self.saver.save(self.sess, os.path.join(self.checkpoint_dir, 'psfnet'))

    def predict(self,dir_path):
        model = tf.train.latest_checkpoint(self.checkpoint_dir)
        self.saver.restore(self.sess, model)

        files = os.listdir(dir_path)
        for file in files:
            print(file)
            img_path = os.path.join(dir_path,file)
            image = np.float32(fits.open(img_path)[0].data)
            image = np.expand_dims(image,axis=0)
            image = np.expand_dims(image,axis=-1)
            image = 2*image/np.max(image)-1.0
            restored = self.sess.run(self.restore_data,feed_dict={self.blur_data:image})
            restored = restored.reshape([restored.shape[1],restored.shape[2]])
            savepath = os.path.join(self.save_folder,file)
            if os.path.isfile(savepath):
                os.remove(savepath)
            fits.HDUList([fits.PrimaryHDU(restored)]).writeto(savepath)
        self.sess.close()

In [0]:
!nvidia-smi

In [0]:
if __name__ == '__main__':
    tfconfig = tf.ConfigProto(allow_soft_placement=True)
    tfconfig.gpu_options.allow_growth = True
    with tf.Session(config=tfconfig) as sess:
        psfnet = PSFNET(sess,200,'checkpoint/','test', 'log',False)
        psfnet.train('input/trainB',Continue=False)

In [0]:
#down result from google
from google.colab import files
download_list = os.listdir('test')
for file in download_list:
    files.download(os.path.join('test',file))