## Calculate abundances with new decoder and FactorDis

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.autograd as autograd
from torch.autograd import Variable

import numpy as np

from tagging.paths import path_dataset
from tagging.src.datasets import ApogeeDataset
from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward,ParallelDecoder
import pandas as pd
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cuda = True if torch.cuda.is_available() else False


In [None]:
n_bins = 1000
n_batch = 64
n_z = 20
n_cat = 30
n_hidden = 10
lr = 0.0001
n_conditioned = 2
loss_ratio = 10e-4
lambda_gp = 10
n_critic = 10
lambda_fact = 10**-4

In [None]:
data = pd.read_pickle(path_dataset)
dataset = ApogeeDataset(data[:50000],n_bins)
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle= True,
                                     drop_last=True)

In [None]:
encoder = Feedforward([n_bins+n_conditioned,512,128,32,n_z],activation=nn.SELU()).to(device)
decoder = ParallelDecoder(n_bins=n_bins,n_hidden=n_hidden,n_latent=n_z+n_conditioned,activation=nn.SELU()).to(device)
conditioning_autoencoder = ConditioningAutoencoder(encoder,decoder,n_bins=n_bins).to(device)


optimizer_G = torch.optim.Adam(conditioning_autoencoder.parameters(), lr=lr)

In [None]:
discriminator = Feedforward([n_bins+n_z+n_conditioned,4096,1024,512,128,32,1],activation=nn.SELU()).to(device)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)


In [None]:
loss = nn.MSELoss()
zeros = torch.full((n_batch,2), 0.0, device=device)
ones = torch.full((n_batch,2),1.0,device=device)
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


In [None]:

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


In [None]:

batches_done = 0
for epoch in range(10000):

    for i, (x,u,v,idx) in enumerate(loader):
        optimizer_G.zero_grad()
        n_perm=torch.randperm(n_batch)
        u_perm = u[n_perm]
        n_perm2=torch.randperm(n_batch)

 
        x_pred,z = conditioning_autoencoder(x,u[:,0:2])
        x_perm,_ = conditioning_autoencoder(z,u_perm[:,0:2],train_encoder=False)



        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input

        fake = torch.cat((x_perm,z,u_perm[:,0:2]),1)
        real = torch.cat((x,z,u[:,0:2]),1)

        # Real images
        real_validity = discriminator(real)
        # Fake images
        fake_validity = discriminator(fake)
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real.data, fake.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

        d_loss.backward(retain_graph=True)
        optimizer_D.step()

        optimizer_G.zero_grad()

        # Train the generator every n_critic steps
        if i % n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake)
            real_validity = discriminator(real)

            err_pred = loss(x_pred,x)

            g_loss = err_pred-lambda_fact*(torch.mean(fake_validity)-torch.mean(real_validity))


            g_loss.backward()
            optimizer_G.step()
            if i%30*n_critic ==0:
                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D: %f] [G: %f] [R: %f]"
                    % (epoch, 100, i, len(loader), d_loss.item(), torch.mean(fake_validity).item(),err_pred.item())
                )

        
                batches_done += n_critic




In [None]:
torch.save(conditioning_autoencoder, "conditional_parallel_decoder1.p")

In [None]:
lambda_fact

In [None]:
10**-4