In [1]:
from datetime import datetime
import tftables
import tensorflow as tf
import numpy as np

In [2]:
def input_transform(tbl_batch):
    data = tbl_batch['spectrum']
    metals = tbl_batch['MH_ratio']

    data_float = tf.to_float(data)
    metals_float = tf.to_float(metals)

    data_max = tf.reduce_max(data_float, axis=1)
    normalized_data = tf.divide(data_float,
                                tf.expand_dims(data_max, axis=1))
    
    return normalized_data, metals_float

In [None]:
NUM_SPECTRA = 8000
BATCH_SIZE = 100
NUM_FEATURES = 40

LEARNING_RATE = 0.01

NUM_ITERATIONS = 20000
PRINT_FREQ = 10
a_star = np.zeros((NUM_SPECTRA, NUM_FEATURES))
b_star = np.zeros((NUM_FEATURES, 1569128))

ses = []

tf.reset_default_graph()

with tf.device('/cpu:0'):
    loader = tftables.load_dataset(filename='sample_8k.h5',
                                   dataset_path='/spectra',
                                   input_transform=input_transform,
                                   batch_size=BATCH_SIZE,
                                   cyclic=True,
                                   ordered=True)
    data_batch, metals_batch = loader.dequeue()

    a = tf.get_variable('a',
                        shape=(NUM_SPECTRA, NUM_FEATURES),
                        dtype=tf.float32,
                        initializer=tf.orthogonal_initializer)
    b = tf.get_variable('b',
                        shape=(NUM_FEATURES, 1569128), 
                        dtype=tf.float32,
                        initializer=tf.orthogonal_initializer)
    
    a_indices = tf.placeholder(tf.int32, shape=(BATCH_SIZE,))
    a_block = tf.gather(a, a_indices)
    
    product = tf.matmul(a_block, b)
    squared_error = tf.reduce_mean(tf.abs(tf.subtract(data_batch, product)))

    global_step = tf.Variable(0, trainable=False)
    optimize = tf.train.AdamOptimizer(LEARNING_RATE).minimize(squared_error, global_step=global_step)

    with tf.Session() as sess:
        with loader.begin(sess):
            sess.run(tf.global_variables_initializer())
            print('iter', '\t', 'a_index', '\t', 'standard error (per el)')
            for i in range(NUM_ITERATIONS):
                a_block_index = i % int(NUM_SPECTRA/BATCH_SIZE)
                a_start_index, a_stop_index = a_block_index*BATCH_SIZE, (a_block_index+1)*BATCH_SIZE
                _, se, a_star[a_start_index:a_stop_index], b_star = sess.run(
                    [optimize, squared_error, a_block, b],
                    feed_dict={a_indices: range(a_start_index, a_stop_index)}
                )
                ses.append(se)
                if i % PRINT_FREQ == 0:
                    print(i, '\t', a_block_index, '\t', np.mean(ses))
                    ses = []

iter 	 a_index 	 standard error (per el)
0 	 0 	 0.37081346
10 	 10 	 0.36100164
20 	 20 	 0.36558855


In [4]:
# save info
outfile_name = 'output-{}'.format(datetime.now().strftime('%y%m%d-%H%M%S'))
metadata = np.array((NUM_SPECTRA, BATCH_SIZE, NUM_FEATURES, LEARNING_RATE))
np.savez_compressed(outfile_name, metadata=metadata, a_star=a_star, b_star=b_star)
print('Saved {}.npz with metadata: {}'.format(outfile_name, metadata))

Saved sgd-output-190430-220702.npz with metadata: [2.e+03 1.e+02 4.e+01 1.e-02]
