# Decoder Modification

This notebook relates to modifying the decoder used. I want to switch to a neural network that processes every single pixel bin in parallel.

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F


from tagging.paths import path_dataset
from tagging.src.datasets import ApogeeDataset
from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward
import pandas as pd

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


#### Defining the input tensors

In [None]:
n_batch =32
n_latent = 22
n_bins = 500
n_hidden = 10
latent = torch.ones(n_batch,n_latent)

In [None]:
repeated_latent = latent.repeat(1,n_bins)
repeated_latent = repeated_latent.unsqueeze(2)

In [None]:
repeated_latent.shape

#### Define the parralel network through a convolutional network

In [None]:
n_output = 1

In [None]:
conv1 = nn.Conv1d(in_channels=n_latent * n_bins, out_channels= n_hidden * n_bins, kernel_size=1, groups=n_bins)

In [None]:
conv1(repeated_latent)

#### Define the parralel decoder

In [None]:
class ParallelDecoder(nn.Module):
    def __init__(self,n_bins=100,n_hidden = 10,n_latent=22,activation=nn.LeakyReLU()):
        super(ParallelDecoder, self).__init__()
        self.n_bins = n_bins
        self.n_hidden = n_hidden
        self.n_latent = n_latent
        self.n_output = 1
        self.activation = activation
        self.conv1 = nn.Conv1d(in_channels=self.n_latent * self.n_bins, out_channels= self.n_hidden * self.n_bins, kernel_size=1, groups=self.n_bins)
        self.conv2 = nn.Conv1d(in_channels=self.n_hidden * self.n_bins, out_channels= self.n_hidden * self.n_bins, kernel_size=1, groups=self.n_bins)
        self.conv3 = nn.Conv1d(in_channels=self.n_hidden * self.n_bins, out_channels= self.n_output * self.n_bins, kernel_size=1, groups=self.n_bins)
        
        
    def forward(self, latent):
        repeated_latent = latent.repeat(1,self.n_bins)
        repeated_latent = repeated_latent.unsqueeze(2)
        #print(f"repeated_latent:{repeated_latent.shape}")

        hidden1 = self.activation(self.conv1(repeated_latent))
        #print(f"hidden1:{hidden1.shape}")

        hidden2 = self.activation(self.conv2(hidden1))
        #print(f"hidden2:{hidden2.shape}")

        output = self.conv3(hidden2)
        #print(f"output:{output.shape}")
        output = torch.squeeze(output)
        return output

In [None]:
class Autoencoder(nn.Module):
    def __init__(self,encoder,decoder,n_bins = None):
        super(Autoencoder, self).__init__()
        self.n_bins = n_bins
        self.encoder = encoder
        self.decoder = decoder

        
    def forward(self, x,train_encoder=True,train_decoder=True):
        latent=None
        output=None
        if train_encoder:
            x= self.encoder(x)
            latent = x
        if train_decoder:
            x = self.decoder(x)
            output = x
        return output,latent

In [None]:
decoder = ParallelDecoder(n_bins=n_bins,n_hidden=n_hidden,n_latent=n_latent)

## Test an autoencoder with the new decoder


In [None]:
n_bins = 5000
n_batch = 64
n_z = 20
n_cat = 30
n_hidden = 10
lr = 0.001

In [None]:
data = pd.read_pickle(path_dataset)
dataset = ApogeeDataset(data[:50000],n_bins)
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle= True,
                                     drop_last=True)

In [None]:
encoder = Feedforward([n_bins,512,128,32,n_z],activation=nn.SELU()).to(device)
decoder = ParallelDecoder(n_bins=n_bins,n_hidden=n_hidden,n_latent=n_z).to(device)
autoencoder = Autoencoder(encoder,decoder,n_bins=n_bins).to(device)


loss = nn.MSELoss()
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)


In [None]:
#SELU
for i in range(100):

    for j,(x,u,v,idx) in enumerate(loader):

        optimizer_autoencoder.zero_grad()
        x_pred,z = autoencoder(x)

        err_pred = loss(x_pred,x)  
        err_tot = err_pred

        err_tot.backward()
        optimizer_autoencoder.step()
        if j%10==0:
            print(f"epoch:{i},err:{err_tot}")


In [None]:
#ELU
for i in range(100):

    for j,(x,u,v,idx) in enumerate(loader):

        optimizer_autoencoder.zero_grad()
        x_pred,z = autoencoder(x)

        err_pred = loss(x_pred,x)  
        err_tot = err_pred

        err_tot.backward()
        optimizer_autoencoder.step()
        if j%10==0:
            print(f"epoch:{i},err:{err_tot}")


In [None]:
#Sigmoid
for i in range(100):

    for j,(x,u,v,idx) in enumerate(loader):

        optimizer_autoencoder.zero_grad()
        x_pred,z = autoencoder(x)

        err_pred = loss(x_pred,x)  
        err_tot = err_pred

        err_tot.backward()
        optimizer_autoencoder.step()
        if j%10==0:
            print(f"epoch:{i},err:{err_tot}")


In [None]:
#LeakyReLU
for i in range(100):

    for j,(x,u,v,idx) in enumerate(loader):

        optimizer_autoencoder.zero_grad()
        x_pred,z = autoencoder(x)

        err_pred = loss(x_pred,x)  
        err_tot = err_pred

        err_tot.backward()
        optimizer_autoencoder.step()
        if j%10==0:
            print(f"epoch:{i},err:{err_tot}")


In [None]:
encoder = Feedforward([n_bins,512,128,32,n_z],activation=nn.sigmoid()).to(device)
decoder = ParallelDecoder(n_bins=n_bins,n_hidden=n_hidden,n_latent=n_z).to(device)
autoencoder = Autoencoder(encoder,decoder,n_bins=n_bins).to(device)


loss = nn.MSELoss()
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)


In [None]:
autoencoder

In [None]:
ls

In [None]:
x_pred.shape
n_bins= 100

In [None]:
encoder = Feedforward([n_bins,512,128,32,n_z],activation=nn.LeakyReLU()).to(device)
decoder = Feedforward([n_z,32,128,512,n_bins],activation=nn.LeakyReLU()).to(device)

autoencoder = Autoencoder(encoder,decoder,n_bins=n_bins).to(device)


loss = nn.MSELoss()
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)


In [None]:
for i in range(20000):

    for j,(x,u,v,idx) in enumerate(loader):

        optimizer_autoencoder.zero_grad()
        x_pred,z = autoencoder(x)

        err_pred = loss(x_pred,x)  
        err_tot = err_pred

        err_tot.backward()
        optimizer_autoencoder.step()
        if j%100==0:
            print(f"err:{err_tot}")


In [None]:
%%timeit
with torch.no_grad():
    x_pred,z = autoencoder(x)


In [None]:
%%timeit
with torch.no_grad():
    x_pred,z = autoencoder(x)
