In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import uproot
import awkward as ak

In [2]:
class ParamsCVAE(nn.Module):
    def __init__(self, hidden_dim, latent_dim):
        super(ParamsCVAE, self).__init__()
        
        # Encoder layers
        self.fc_enoder = nn.Sequential(
            nn.Linear(20*20+ 3+ 2, hidden_dim, bias=True), # xbin * ybin + param_dim + one_hot
            nn.ReLu()
        )
        
        # mu and logvar
        self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias=True)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias=True)
        
        # Decoder layers
        self.fc_decoder = nn.Sequential(
            nn.Linear(latent_dim+ 3+ 2, 20*20, bias=True), # latent_dim + params_dim + one_hot
            nn.Sigmoid()
        )
        
    def encode(self, x):
        h = self.fc_encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
        
    def decode(self, z):
        h = self.fc_decoder(z)
        return h
    
    def forward(self, hist, params, labels):
        x = hist.view(hist.size(0), -1)
        # Combine image, params, labels
        x = torch.cat((x, params, labels), dim=1)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        # Combine z, params, labes -> condition
        z = torch.cat((z, params, labels), dim=1)
        hist = self.decode(z)
        return hist, mu, logvar