# Neural Network Binning

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
from tomo_challenge import load_data, load_redshift
from tomo_challenge.jax_metrics import ell_binning

Found classifier Random
Found classifier RandomForest
Found classifier IBandOnly


Initialize fast metric calculations:

In [3]:
from zotbin.binned import *

In [4]:
init_data = load_binned('binned_28.npz')

Load the challenge data:

In [5]:
bands='griz'
include_colors=False
include_errors=False

In [6]:
train_file='/media/data2/tomo_challenge_data/ugrizy/training.hdf5'
train_data = load_data(train_file, bands, 
                       errors=include_errors,
                       colors=include_colors, array=True)
train_z = load_redshift(train_file)
print(f'Loaded {len(train_data)} training rows.')

Loaded 8615613 training rows.




Preprocess the training data:

In [7]:
colors = np.diff(train_data, axis=1)

In [8]:
iband = bands.index('i')
data = np.concatenate((colors, train_data[:, iband:iband+1]), axis=1)

In [9]:
from sklearn.preprocessing import RobustScaler

In [10]:
preproc = RobustScaler()
features = preproc.fit_transform(train_data)

In [37]:
ndata = 10000
features = jnp.array(features[:ndata])
labels = jnp.array(train_z[:ndata])

Define a network:

In [11]:
import jax.random
import jax.numpy as jnp

In [12]:
from flax import nn, optim, serialization

In [34]:
nbins = 4

In [24]:
class NN(nn.Module):
    def apply(self, x, nbins):
        x = nn.Dense(x, 100, name='L1')
        x = nn.relu(x)
        x = nn.Dense(x, 100, name='L2')
        x = nn.relu(x)
        x = nn.Dense(x, nbins, name='L3')
        return nn.softmax(x)

In [25]:
module = NN.partial(nbins=4)

In [26]:
_, nn_init = module.init(jax.random.PRNGKey(0), features)

In [27]:
jax.tree_map(jnp.shape, nn_init)

{'L1': {'bias': (100,), 'kernel': (4, 100)},
 'L2': {'bias': (100,), 'kernel': (100, 100)},
 'L3': {'bias': (4,), 'kernel': (100, 4)}}

In [28]:
model = nn.Model(module, nn_init)

In [29]:
jax.tree_map(jnp.shape, model.params)

{'L1': {'bias': (100,), 'kernel': (4, 100)},
 'L2': {'bias': (100,), 'kernel': (100, 100)},
 'L3': {'bias': (4,), 'kernel': (100, 4)}}

In [30]:
model(features[0])

DeviceArray([0.20889738, 0.25362805, 0.26156127, 0.27591333], dtype=float32)

In [61]:
optimizer = optim.Adam(learning_rate=0.001).create(model)

In [68]:
def train_step(optimizer, batch):
    
    def loss_fn(model):
        out = model(batch['features'])
        idx = jnp.argmax(out, axis=-1)
        print(idx)
        scores = get_binned_scores(idx, batch['z'], *init_data)
        print(scores)
        return scores['FOM_DETF_3x2']
    
    loss, g = jax.value_and_grad(loss_fn)(optimizer.target)
    optimizer = optimizer.apply_gradient(g)
    return optimizer, loss

In [69]:
nbatch = 10
gen = np.random.RandomState(123)

In [70]:
def get_batch():
    idx = gen.choice(ndata, nbatch)
    return {'features': features[idx], 'z': labels[idx]}

In [71]:
def train(opt, niter=10):
    
    losses = []
    for i in range(niter):
        opt, loss = train_step(opt, get_batch())
        print(loss)
        
train(optimizer)

[3 3 3 3 3 3 2 0 3 1]
{'SNR_3x2': DeviceArray(1319.0688, dtype=float32), 'FOM_3x2': DeviceArray(2835.5964, dtype=float32), 'FOM_DETF_3x2': DeviceArray(29.425533, dtype=float32)}
29.425533
[0 3 3 3 3 0 3 3 3 3]
{'SNR_3x2': DeviceArray(nan, dtype=float32), 'FOM_3x2': DeviceArray(nan, dtype=float32), 'FOM_DETF_3x2': DeviceArray(nan, dtype=float32)}
nan
[3 2 3 3 3 3 3 3 3 3]
{'SNR_3x2': DeviceArray(nan, dtype=float32), 'FOM_3x2': DeviceArray(nan, dtype=float32), 'FOM_DETF_3x2': DeviceArray(nan, dtype=float32)}
nan
[3 2 3 3 3 3 3 3 3 2]
{'SNR_3x2': DeviceArray(nan, dtype=float32), 'FOM_3x2': DeviceArray(nan, dtype=float32), 'FOM_DETF_3x2': DeviceArray(nan, dtype=float32)}
nan
[3 2 2 2 2 3 3 3 3 3]
{'SNR_3x2': DeviceArray(nan, dtype=float32), 'FOM_3x2': DeviceArray(nan, dtype=float32), 'FOM_DETF_3x2': DeviceArray(nan, dtype=float32)}
nan
[3 3 3 3 2 2 3 3 3 2]
{'SNR_3x2': DeviceArray(nan, dtype=float32), 'FOM_3x2': DeviceArray(nan, dtype=float32), 'FOM_DETF_3x2': DeviceArray(nan, dtype=float3