In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl
import seaborn as sns

import glob
import os

import numpy as np
import pandas as pd
import tensorflow as tf

mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.figsize'] = np.array((10,6))*.6

# my code
import misc
import gan


# Load galaxy image data
If you need the `.npy` files, download them from [this directory](https://www.dropbox.com/sh/raq8y8wt2rii7d6/AACA6OO3eKxfTevTJj19w_rea?dl=0). It'll be called `npy_files.tar.gz`; I'd link to it directly, but it hasn't uploaded to dropbox yet.

In [None]:
img_dir = "data/galaxy_images_training/npy_files/"
filename_formatter = os.path.join(img_dir, "{}-cutout.npy")


In [None]:
npy_files = glob.glob(filename_formatter.format("*"))

HSC_ids = np.array([int(os.path.basename(f).split("-")[0])
                    for f in npy_files])

HSC_ids


In [None]:
X_img = np.empty([len(HSC_ids), 3, 50, 50])
for i, HSC_id in enumerate(HSC_ids):
    X_img[i] = np.load(filename_formatter.format(HSC_id))

X_img = X_img.transpose([0,2,3,1])
X_img.shape


In [None]:
image_size = X_img.shape[1]
image_size


# Load targets
`HSC_ids` are in the same order as the `X_img` data. These ids are then used to cross-reference the table read into `df`.

In [None]:
df = pd.read_csv("data/2018_02_23-all_objects.csv")
# df = df[df.selected]


df = df.drop_duplicates("HSC_id") \
       .set_index("HSC_id") \
       [["photo_z", "log_mass"]]


df.head()



In [None]:
y = df.loc[HSC_ids].values
y_for_visualization_samples = np.array([.14, 8.51])


In [None]:
# values copied from output of `simple gan.ipynb`
standardizer = misc.Standardizer(means = np.array([0.21093612, 8.62739865]),
                                 std = np.array([0.30696933, 0.63783586]))
# standardizer.train(y)
print("means: ", standardizer.means)
print("std:   ", standardizer.std)
y_standard = standardizer(y)
y_for_visualization_samples_standard = standardizer(y_for_visualization_samples)

y_standard.shape


# Run GAN
Modeled after: https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/f24a27feba327a1086298a810fdf83bb30d5128a/CGAN.py

In [None]:
num_threads = 4

sess = tf.Session(config=tf.ConfigProto(
    intra_op_parallelism_threads=num_threads,
    inter_op_parallelism_threads=num_threads,
))

train = True
if train:
    num_epochs = 100
    # use a dir outside of dropbox
    checkpoint_dir = os.path.join(os.path.expanduser("~"),
                                  "tmp - models",
                                  "models/gan/checkpoints")
else:
    num_epochs = 1
    # use a dir inside the repo
    checkpoint_dir = "models/gan/checkpoints"
    
batch_size = 64
z_dim = 100
dataset_name = "galaxy_all"
result_dir = "models/gan/results"
log_dir = "models/gan/log"

model = gan.CGAN(sess, num_epochs, batch_size, z_dim, dataset_name,
                 image_size, X_img, 
                 y_standard, y_for_visualization_samples_standard,
                 checkpoint_dir, result_dir, log_dir,
                 d_learning_rate=.0001,
                 relative_learning_rate=4.,
                 loss_weighting=50.,
                )

model.build_model()


In [None]:
model.train()
