In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl
import seaborn as sns

import glob
import os

import numpy as np
import pandas as pd

mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['figure.figsize'] = np.array((10,6))*.6

# Load galaxy image data
If you need the `.npy` files, download them from [this directory](https://www.dropbox.com/sh/raq8y8wt2rii7d6/AACA6OO3eKxfTevTJj19w_rea?dl=0). It'll be called `npy_files.tar.gz`; I'd link to it directly, but it hasn't uploaded to dropbox yet.

In [2]:
img_dir = "data/galaxy_images_training/npy_files/"
filename_formatter = os.path.join(img_dir, "{}-cutout.npy")

In [3]:
npy_files = glob.glob(filename_formatter.format("*"))

HSC_ids = np.array([int(os.path.basename(f).split("-")[0])
                    for f in npy_files])

HSC_ids

array([43158322471266302, 43159013960976629, 43158335356159832, ...,
       43159001076101367, 43159134220058785, 43158459910224182])

In [None]:
X_img = np.empty([len(HSC_ids), 3, 50, 50])
for i, HSC_id in enumerate(HSC_ids):
    X_img[i] = np.load(filename_formatter.format(HSC_id))

X_img = X_img.transpose([0,2,3,1])
X_img.shape

In [None]:
image_size = X_img.shape[1]
image_size

# Load targets
`HSC_ids` are in the same order as the `X_img` data. These ids are then used to cross-reference the table read into `df`.

In [None]:
df = pd.read_csv("data/2018_02_23-all_objects.csv")
# df = df[df.selected]


df = df.drop_duplicates("HSC_id") \
       .set_index("HSC_id") \
       [["photo_z", "log_mass"]]


df.head()


In [None]:
y = df.loc[HSC_ids].values

# Run GAN
Modeled after: https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/f24a27feba327a1086298a810fdf83bb30d5128a/CGAN.py

In [None]:
import tensorflow as tf

import gan

In [None]:
num_threads = 3

sess = tf.Session(config=tf.ConfigProto(
    intra_op_parallelism_threads=num_threads,
    inter_op_parallelism_threads=num_threads,
))

num_epochs = 10
batch_size = 64
z_dim = 100
dataset_name = "galaxy_all"
checkpoint_dir = "models/gan/checkpoints"
result_dir = "models/gan/results"
log_dir = "models/gan/log"

model = gan.CGAN(sess, num_epochs, batch_size, z_dim, dataset_name,
                 image_size, X_img, y,
                 checkpoint_dir, result_dir, log_dir, )

model.build_model()


In [None]:
model.train()