<a href="https://colab.research.google.com/github/changhoonhahn/provabgs/blob/main/nb/nmfburst_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# decorder for `nmfburst` SPS model
Instead of using the PCA encoding a training a neural net to predict PCA coefficients, I'm going to try to train a decoder directly from the (theta, SED) data set. 

notebook has code lifted from: 
- https://github.com/stephenportillo/SDSS-VAE/blob/master/trainVAE.py
- https://github.com/stephenportillo/SDSS-VAE/blob/master/InfoVAE.py

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/My\ Drive/provabgs

/content/drive/My Drive/provabgs


In [3]:
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F

In [4]:
theta = np.load("fsps.nmfburst.theta.test.npy")

In [5]:
# whiten the spectra
mu_lnspec = np.mean(np.load('fsps.nmfburst.lnspectrum.test.npy'), axis=0)
sig_lnspec = np.std(np.load('fsps.nmfburst.lnspectrum.test.npy'), axis=0)

lnspec_white = (np.load('fsps.nmfburst.lnspectrum.test.npy') - mu_lnspec)/sig_lnspec

In [8]:
n_theta = theta.shape[1]
n_lnspec = lnspec_white.shape[1]
print('n theta = %i' % n_theta)
print('n ln(spec) = %i' % n_lnspec)

n theta = 12
n ln(spec) = 4469


In [9]:
Ntrain = int(float(theta.shape[0]) * 0.9)
Ntest = theta.shape[0] - Ntrain
print('Ntrain = %i' % Ntrain)
print('Ntest = %i' % Ntest)

Ntrain = 90000
Ntest = 10000


In [10]:
batch_size=64
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(torch.tensor(theta[:Ntrain], dtype=torch.float32), torch.tensor(lnspec_white[:Ntrain], dtype=torch.float32)),
    batch_size=batch_size,
    shuffle=True)

In [11]:
class Decoder(nn.Module): 
    def __init__(self, nfeat=1000, ncode=5, nhidden=128, nhidden2=35, dropout=0.2):
        super(Decoder, self).__init__()
        
        self.ncode = int(ncode)
        
        self.decd = nn.Linear(ncode, nhidden2)
        self.d3 = nn.Dropout(p=dropout)
        self.dec2 = nn.Linear(nhidden2, nhidden)
        self.d4 = nn.Dropout(p=dropout)
        self.outp = nn.Linear(nhidden, nfeat)
        
    def decode(self, x):
        x = self.d3(F.leaky_relu(self.decd(x)))
        x = self.d4(F.leaky_relu(self.dec2(x)))
        x = self.outp(x)
        return x
    
    def forward(self, x):
        return self.decode(x)
    
    def loss(self, x, y):
        recon_y = self.forward(x)
        MSE = torch.sum(0.5 * (y - recon_y).pow(2))
        return MSE

In [12]:
def train(): #model, optimizer, epoch, min_valid_loss, badepochs
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        tt, lns = data
        optimizer.zero_grad()
        loss = model.loss(tt, lns)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    train_loss /= len(train_loader.dataset)
    return train_loss 


class EarlyStopper:
    def __init__(self, precision=1e-3, patience=10):
        self.precision = precision
        self.patience = patience
        self.badepochs = 0
        self.min_valid_loss = float('inf')
        
    def step(self, valid_loss):
        if valid_loss < self.min_valid_loss*(1-self.precision):
            self.badepochs = 0
            self.min_valid_loss = valid_loss
        else:
            self.badepochs += 1
        return not (self.badepochs == self.patience)

In [None]:
epochs = 200
log_interval = 10
n_config = 1

for config in range(n_config):
    dropout = 0. #0.9*np.random.uniform()
    dfac = 1./(1.-dropout)
    nhidden = int(np.ceil(np.exp(np.random.uniform(np.log(dfac*n_theta+1), np.log(dfac*2*n_lnspec)))))
    nhidden2 = int(np.ceil(np.exp(np.random.uniform(np.log(dfac*n_theta+1), np.log(nhidden)))))
    print('config %i, dropout = %0.2f; 2 hidden layers with %i, %i nodes' % (config, dropout, nhidden, nhidden2))
    model = Decoder(nfeat=n_lnspec, nhidden=nhidden, nhidden2=nhidden2, ncode=n_theta, dropout=dropout)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=5)
    stopper = EarlyStopper(patience=10)

    for epoch in range(1, epochs + 1):
        train_loss = train()
        print('====> Epoch: {} TRAINING Loss: {:.2e}'.format(epoch, train_loss))
        #if epoch % log_interval == 0:
        #    print('====> Epoch: {} TRAINING Loss: {:.2e}'.format(epoch, train_loss))

        scheduler.step(train_loss)
        if (not stopper.step(train_loss)) or (epoch == epochs):
            print('Stopping')
            print('====> Epoch: {} TRAINING Loss: {:.2e}'.format(epoch, train_loss))
            #torch.save(model, tag+'/%04i.pth' % config)
            break 
        torch.save(model, 'decoder.pth')
#np.savez(tag+'/metrics.npz', MSE=mdl_MSE, KLD=mdl_KLD, MMD=mdl_MMD)

config 0, dropout = 0.00; 2 hidden layers with 508, 31 nodes
====> Epoch: 1 TRAINING Loss: 1.57e+02
====> Epoch: 2 TRAINING Loss: 5.99e+01
====> Epoch: 3 TRAINING Loss: 4.13e+01
====> Epoch: 4 TRAINING Loss: 3.08e+01
====> Epoch: 5 TRAINING Loss: 2.58e+01
====> Epoch: 6 TRAINING Loss: 2.33e+01
====> Epoch: 7 TRAINING Loss: 2.08e+01
====> Epoch: 8 TRAINING Loss: 1.97e+01
====> Epoch: 9 TRAINING Loss: 1.80e+01
====> Epoch: 10 TRAINING Loss: 1.71e+01
