In [1]:
import pickle
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

from recipedataset import RecipeDataset, RECIPE_DATASET_FILENAME, NUM_GRAIN_SLOTS, NUM_ADJUNCT_SLOTS, NUM_HOP_SLOTS, NUM_MISC_SLOTS, NUM_MICROORGANISM_SLOTS, NUM_FERMENT_STAGE_SLOTS, NUM_MASH_STEPS

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

def layer_init_ortho(layer, std=np.sqrt(2)):
  nn.init.orthogonal_(layer.weight, std)
  nn.init.constant_(layer.bias, 0.0)
  return layer

def layer_init_xavier(layer, gain):
  nn.init.xavier_normal_(layer.weight, gain)
  nn.init.constant_(layer.bias, 0.0)
  return layer

def reparameterize(mu, logvar):
  std = torch.exp(0.5 * logvar)
  eps = torch.randn_like(std)
  return eps * std + mu

In [2]:
BATCH_SIZE = 256

# Load the dataset and create a dataloader for it
with open("../" + RECIPE_DATASET_FILENAME, 'rb') as f:
  dataset = pickle.load(f)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [3]:
GRAIN_TYPE_EMBED_SIZE         = 32
ADJUNCT_TYPE_EMBED_SIZE       = 32
HOP_TYPE_EMBED_SIZE           = 96
MISC_TYPE_EMBED_SIZE          = 64
MICROORGANISM_TYPE_EMBED_SIZE = 64

class RecipeNetArgs:
  def __init__(self, dataset) -> None:
    # Recipe-specific constraints ***
    self.num_mash_steps          = NUM_MASH_STEPS
    self.num_grain_slots         = NUM_GRAIN_SLOTS
    self.num_adjunct_slots       = NUM_ADJUNCT_SLOTS
    self.num_hop_slots           = NUM_HOP_SLOTS
    self.num_misc_slots          = NUM_MISC_SLOTS
    self.num_microorganism_slots = NUM_MICROORGANISM_SLOTS
    self.num_ferment_stage_slots = NUM_FERMENT_STAGE_SLOTS
    
    # NOTE: All types include a "None" (i.e., empty) category
    self.num_grain_types         = len(dataset.core_grains_idx_to_dbid) # Number of (core) grain types (rows in the DB)
    self.num_adjunct_types       = len(dataset.core_adjs_idx_to_dbid)   # Number of (core) adjunct types (rows in the DB)
    self.num_hop_types           = len(dataset.hops_idx_to_dbid)        # Number of hop types (rows in the DB)
    self.num_misc_types          = len(dataset.miscs_idx_to_dbid)       # Number of misc. types (rows in the DB)
    self.num_microorganism_types = len(dataset.mos_idx_to_dbid)         # Number of microrganism types (rows in the DB)
    
    self.num_mash_step_types  = len(dataset.mash_step_idx_to_name)  # Number of mash step types (e.g., Infusion, Decoction, Temperature)
    self.num_hop_stage_types  = len(dataset.hop_stage_idx_to_name)  # Number of hop stage types (e.g., Mash, Boil, Primary, ...)
    self.num_misc_stage_types = len(dataset.misc_stage_idx_to_name) # Number of misc stage types (e.g., Mash, Boil, Primary, ...)
    self.num_mo_stage_types   = len(dataset.mo_stage_idx_to_name)   # Number of microorganism stage types (e.g., Primary, Secondary)
    
    # Embedding sizes ***
    self.grain_type_embed_size         = GRAIN_TYPE_EMBED_SIZE
    self.adjunct_type_embed_size       = ADJUNCT_TYPE_EMBED_SIZE
    self.hop_type_embed_size           = HOP_TYPE_EMBED_SIZE
    self.misc_type_embed_size          = MISC_TYPE_EMBED_SIZE
    self.microorganism_type_embed_size = MICROORGANISM_TYPE_EMBED_SIZE
    
    # Network-specific hyperparameters/constraints ***
    self.num_hidden_layers = 1
    self.hidden_size = 1024
    self.z_size = 64 # Latent-bottleneck dimension
    self.activation_fn = nn.ELU
    self.gain = nn.init.calculate_gain('linear', None) # Make sure this corresponds to the activation function!
    self.num_inputs = self.calc_num_inputs()
    
    # VAE-specific hyperparameters
    self.beta_vae_gamma = 1000
    self.max_beta_vae_capacity = 25
    self.beta_vae_C_stop_iter = 1e5
    
  
  def calc_num_inputs(self):
    """Determine the number of inputs to the network.
    Returns:
        int: The total number of network inputs.
    """
    # (boil_time + mash_ph + sparge_temp)
    num_simple_inputs = 3 
    # Mash steps (step_type_index_size + step_time + step_temp) * (number of slots) - ordering assumed [0: step 1, 1: step 2, etc.]
    num_mash_step_inputs = self.num_mash_steps*(self.num_mash_step_types + 2)
    # Fermentation stages (step_time + step_temp) * (number of stages) - ordering assumed [0: primary, 1: secondary]
    num_ferment_step_inputs = self.num_ferment_stage_slots*(2)
    # Grain/Malt bill slots (grain_type_embed_size + amount) * (number of slots) - no ordering
    num_grain_slot_inputs = self.num_grain_slots*(self.grain_type_embed_size + 1)
    # Adjunct slots (adjunct_type_embed_size + amount) * (number of slots) - no ordering
    num_adjunct_slot_inputs = self.num_adjunct_slots*(self.adjunct_type_embed_size + 1)
    # Hop slots (hop_type_embed_size + stage_type_index_size + time + concentration) * (number of slots) - no ordering
    num_hop_slot_inputs = self.num_hop_slots*(self.hop_type_embed_size + self.num_hop_stage_types + 2)
    # Misc. slots (misc_type_embed_size + stage_type_index_size + time + amounts) * (number of slots) - no ordering
    num_misc_slot_inputs = self.num_misc_slots*(self.misc_type_embed_size + self.num_misc_stage_types + 2)
    # Microorganism slots (mo_type_embed_size + stage_type_index_size) * (number of slots) - no ordering
    num_mo_slot_inputs = self.num_microorganism_slots*(self.microorganism_type_embed_size + self.num_mo_stage_types)

    return num_simple_inputs + num_mash_step_inputs + num_ferment_step_inputs + num_grain_slot_inputs + \
      num_adjunct_slot_inputs + num_hop_slot_inputs + num_misc_slot_inputs + num_mo_slot_inputs   

args = RecipeNetArgs(dataset)
args.num_inputs

5899

In [10]:
class RecipeVAE(nn.Module):

  def __init__(self, args) -> None:
    super().__init__()
    
    hidden_size = args.hidden_size
    z_size = args.z_size
    activation_fn = args.activation_fn
    gain = args.gain
    
    assert args.num_inputs >= 1
    assert args.num_hidden_layers >= 1
    assert hidden_size >= 1
    assert z_size >= 1 and z_size < args.num_inputs

    # Encoder and decoder networks
    self.encoder = nn.Sequential()
    self.encoder.append(layer_init_xavier(nn.Linear(args.num_inputs, hidden_size), gain))
    self.encoder.append(activation_fn())
    for _ in range(1, args.num_hidden_layers):
      self.encoder.append(layer_init_xavier(nn.Linear(hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
    self.encoder.append(layer_init_xavier(nn.Linear(hidden_size, z_size*2), gain))
    self.encoder.append(activation_fn())
    self.encoder.append(nn.BatchNorm1d(z_size*2))

    self.decoder = nn.Sequential()
    self.decoder.append(layer_init_xavier(nn.Linear(z_size, hidden_size), gain))
    self.decoder.append(activation_fn())
    for _ in range(1, args.num_hidden_layers):
      self.decoder.append(layer_init_xavier(nn.Linear(hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
    self.decoder.append(layer_init_xavier(nn.Linear(hidden_size, args.num_inputs), gain))
    
    # Embeddings (NOTE: Any categoricals that don't have embeddings will be one-hot encoded)
    self.grain_type_embedding         = nn.Embedding(args.num_grain_types, args.grain_type_embed_size)
    self.adjunct_type_embedding       = nn.Embedding(args.num_adjunct_types, args.adjunct_type_embed_size) 
    self.hop_type_embedding           = nn.Embedding(args.num_hop_types, args.hop_type_embed_size)
    self.misc_type_embedding          = nn.Embedding(args.num_misc_types, args.misc_type_embed_size)
    self.microorganism_type_embedding = nn.Embedding(args.num_microorganism_types, args.microorganism_type_embed_size)
    
    # Post-network decoders (these are basically learned inverse embeddings)
    self.grain_type_decoder         = layer_init_xavier(nn.Linear(args.grain_type_embed_size, args.num_grain_types), gain)
    self.adjunct_type_decoder       = layer_init_xavier(nn.Linear(args.adjunct_type_embed_size, args.num_adjunct_types), gain)
    self.hop_type_decoder           = layer_init_xavier(nn.Linear(args.hop_type_embed_size, args.num_hop_types), gain)
    self.misc_type_decoder          = layer_init_xavier(nn.Linear(args.misc_type_embed_size, args.num_misc_types), gain)
    self.microorganism_type_decoder = layer_init_xavier(nn.Linear(args.microorganism_type_embed_size, args.num_microorganism_types), gain)
    
    self.gamma = args.beta_vae_gamma
    self.C_stop_iter = args.beta_vae_C_stop_iter
    self.C_max = torch.Tensor([args.max_beta_vae_capacity])
    
    self.args = args
  
  def forward(self, x, use_mean=False):
    # Simple top-level heads (high-level recipe parameters)
    x_toplvl = torch.cat((x['boil_time'].unsqueeze(1), x['mash_ph'].unsqueeze(1), x['sparge_temp'].unsqueeze(1)), dim=1) # (B, 3)
    
    # Mash step heads
    # NOTE: Data shape is (B, S=number_of_mash_step_slots) for the
    # following recipe tensors: {'mash_step_type_inds', 'mash_step_times', 'mash_step_avg_temps'}
    num_mash_step_types = self.args.num_mash_step_types
    enc_mash_step_type_onehot = F.one_hot(x['mash_step_type_inds'].long(), num_mash_step_types).float().flatten(1) # (B, S, num_mash_step_types) -> (B, S*num_mash_step_types) = [B, 24]
    x_mash_steps = torch.cat((enc_mash_step_type_onehot, x['mash_step_times'], x['mash_step_avg_temps']), dim=1) # (B, num_mash_step_types*S+S+S) = [B, 36=(24+6+6)]
    
    # Ferment stage heads
    # NOTE: Data shape is (B, S=2) for the following recipe tensors: {'ferment_stage_times', 'ferment_stage_temps'}
    x_ferment_stages = torch.cat((x['ferment_stage_times'], x['ferment_stage_temps']), dim=1) # (B, S+S)

    # Grain (malt bill) heads
    # NOTE: Data shape is (B, S=num_grain_slots) for the following recipe tensors: {'grain_core_type_inds', 'grain_amts'}
    enc_grain_type_embed = self.grain_type_embedding(x['grain_core_type_inds']).flatten(1) # (B, S, grain_type_embed_size) -> (B, S*grain_type_embed_size)
    x_grains = torch.cat((enc_grain_type_embed, x['grain_amts']), dim=1) # (B, S*grain_type_embed_size+S)
    
    # Adjunct heads
    # NOTE: Data shape is (B, S=num_adjunct_slots) for the following recipe tensors: {'adjunct_core_type_inds', 'adjunct_amts'}
    enc_adjunct_type_embed = self.adjunct_type_embedding(x['adjunct_core_type_inds']).flatten(1) # (B, S, adjunct_type_embed_size) -> (B, S*adjunct_type_embed_size)
    x_adjuncts = torch.cat((enc_adjunct_type_embed, x['adjunct_amts']), dim=1) # (B, S*adjunct_type_embed_size+S)
    
    # Hop heads
    # NOTE: Data shape is (B, S=num_hop_slots) for the following recipe tensors: 
    # {'hop_type_inds', 'hop_stage_type_inds', 'hop_times', 'hop_concentrations'}
    num_hop_stage_types = self.args.num_hop_stage_types
    enc_hop_type_embed = self.hop_type_embedding(x['hop_type_inds']).flatten(1) # (B, S, hop_type_embed_size)
    enc_hop_stage_type_onehot = F.one_hot(x['hop_stage_type_inds'].long(), num_hop_stage_types).float().flatten(1) # (B, S, num_hop_stage_types)
    x_hops = torch.cat((enc_hop_type_embed, enc_hop_stage_type_onehot, x['hop_times'], x['hop_concentrations']), dim=1) # (B, S*hop_type_embed_size + S*num_hop_stage_types + S + S)
    
    # Misc. heads
    # NOTE: Data shape is (B, S=num_misc_slots) for the following recipe tensors:
    # {'misc_type_inds', 'misc_stage_inds', 'misc_times', 'misc_amts'}
    num_misc_stage_types = self.args.num_misc_stage_types
    enc_misc_type_embed = self.misc_type_embedding(x['misc_type_inds']).flatten(1) # (B, S, misc_type_embed_size)
    enc_misc_stage_type_onehot = F.one_hot(x['misc_stage_inds'].long(), num_misc_stage_types).float().flatten(1) # (B, S, num_misc_stage_types)
    x_miscs = torch.cat((enc_misc_type_embed, enc_misc_stage_type_onehot, x['misc_times'], x['misc_amts']), dim=1) # (B, S*misc_type_embed_size + S*num_misc_stage_types + S + S)
    
    # Microorganism heads
    # NOTE: Data shape is (B, S=num_microorganism_slots) for the following recipe tensors:
    # {'mo_type_inds', 'mo_stage_inds'}
    num_mo_stage_types = self.args.num_mo_stage_types
    enc_mo_type_embed = self.microorganism_type_embedding(x['mo_type_inds']).flatten(1) # (B, S, microorganism_type_embed_size)
    enc_mo_stage_type_onehot = F.one_hot(x['mo_stage_inds'].long(), num_mo_stage_types).float().flatten(1) # (B, S, num_mo_stage_types)
    x_mos = torch.cat((enc_mo_type_embed, enc_mo_stage_type_onehot), dim=1) # (B, S*microorganism_type_embed_size + S*num_mo_stage_types)
    
    # Put all the recipe data together into a flattened tensor
    x_recipe = x
    x = torch.cat((x_toplvl, x_mash_steps, x_ferment_stages, x_grains, x_adjuncts, x_hops, x_miscs, x_mos), dim=1) # (B, num_inputs)
    
    # Encode to the latent distribution, sample (reparameterize trick) then decode
    mean, logvar = torch.chunk(self.encoder(x), 2, dim=-1) 
    z = mean if use_mean else reparameterize(mean, logvar)
    x_hat = self.decoder(z)
    
    # TODO: Move the loss calculations into its own function...
    # Maybe use a class/dict for holding intermediate values during the pipeline for easy passing between functions?
    
    # The decoded tensor is flat with a shape of (B, num_inputs), we'll need to break it apart to calculate the appropriate losses
    x_hat_toplvl, x_hat_mash_steps, x_hat_ferment_stages, x_hat_grains, x_hat_adjuncts, x_hat_hops, x_hat_miscs, x_hat_mos = torch.split(
      x_hat, [3, x_mash_steps.shape[1], x_ferment_stages.shape[1], x_grains.shape[1], x_adjuncts.shape[1], x_hops.shape[1], x_miscs.shape[1], x_mos.shape[1]], dim=1
    )
    
    # TODO: Simplify all this stuff into fewer losses: 
    # Group together all BCELogit and MSE losses into singluar tensors in both x and x_hat
    loss_toplvl = F.mse_loss(x_hat_toplvl, x_toplvl, reduction='sum')
    
    # Mash step decode and loss calculation
    num_mash_steps = self.args.num_mash_steps
    dec_mash_step_type_onehot, dec_mash_step_times, dec_mash_step_avg_temps = torch.split(
      x_hat_mash_steps, [enc_mash_step_type_onehot.shape[1], num_mash_steps, num_mash_steps], dim=1
    )
    loss_mash_steps = F.binary_cross_entropy_with_logits(dec_mash_step_type_onehot, enc_mash_step_type_onehot, reduction='sum') + \
      F.mse_loss(dec_mash_step_times, x_recipe['mash_step_times'], reduction='sum') + F.mse_loss(dec_mash_step_avg_temps, x_recipe['mash_step_avg_temps'], reduction='sum')
    
    # Ferment stages loss calculation
    loss_ferment_stages = F.mse_loss(x_hat_ferment_stages, x_ferment_stages, reduction='sum')
    
    # Grain slots decode and loss calculation
    num_grain_slots = self.args.num_grain_slots
    num_grain_types = self.args.num_grain_types
    grain_type_embed_size = self.args.grain_type_embed_size
    dec_grain_type_embed, dec_grain_amts = torch.split(x_hat_grains, [enc_grain_type_embed.shape[1], num_grain_slots], dim=1)
    dec_grain_type_logits = self.grain_type_decoder(dec_grain_type_embed.view(-1, num_grain_slots, grain_type_embed_size)) # (B, num_grain_slots, num_grain_types)
    enc_grain_type_onehot = F.one_hot(x_recipe['grain_core_type_inds'].long(), num_grain_types).float() # (B, num_grain_slots, num_grain_types)
    loss_grains = F.binary_cross_entropy_with_logits(dec_grain_type_logits, enc_grain_type_onehot, reduction='sum') + F.mse_loss(dec_grain_amts, x_recipe['grain_amts'], reduction='sum')
    
    # Adjunct slots decode and loss calculation
    num_adjunct_slots = self.args.num_adjunct_slots
    num_adjunct_types = self.args.num_adjunct_types
    adjunct_type_embed_size = self.args.adjunct_type_embed_size
    dec_adjunct_type_embed, dec_adjunct_amts = torch.split(x_hat_adjuncts, [enc_adjunct_type_embed.shape[1], num_adjunct_slots], dim=1)
    dec_adjunct_type_logits = self.adjunct_type_decoder(dec_adjunct_type_embed.view(-1, num_adjunct_slots, adjunct_type_embed_size)) # (B, num_adjunct_slots, num_adjunct_types)
    enc_grain_type_onehot = F.one_hot(x_recipe['adjunct_core_type_inds'].long(), num_adjunct_types).float() # (B, num_adjunct_slots, num_adjunct_types)
    loss_adjuncts = F.binary_cross_entropy_with_logits(dec_adjunct_type_logits, enc_grain_type_onehot, reduction='sum') + F.mse_loss(dec_adjunct_amts, x_recipe['adjunct_amts'], reduction='sum')
    
    # Hop slots decode and loss calculation
    num_hop_slots = self.args.num_hop_slots
    num_hop_types = self.args.num_hop_types
    hop_type_embed_size = self.args.hop_type_embed_size
    dec_hop_type_embed, dec_hop_stage_type_onehot, dec_hop_times, dec_hop_concentrations = torch.split(
      x_hat_hops, [enc_hop_type_embed.shape[1], enc_hop_stage_type_onehot.shape[1], num_hop_slots, num_hop_slots], dim=1
    )
    dec_hop_type_logits = self.hop_type_decoder(dec_hop_type_embed.view(-1, num_hop_slots, hop_type_embed_size)) # (B, num_hop_slots, num_hop_types)
    enc_hop_type_onehot = F.one_hot(x_recipe['hop_type_inds'].long(), num_hop_types).float() # (B, num_hop_slots, num_hop_types)
    loss_hops = F.binary_cross_entropy_with_logits(dec_hop_type_logits, enc_hop_type_onehot, reduction='sum') + \
      F.binary_cross_entropy_with_logits(dec_hop_stage_type_onehot, enc_hop_stage_type_onehot, reduction='sum') + \
      F.mse_loss(dec_hop_times, x_recipe['hop_times'], reduction='sum') + \
      F.mse_loss(dec_hop_concentrations, x_recipe['hop_concentrations'], reduction='sum')
    
    # Miscellaneous slots decode and loss calculation
    num_misc_slots = self.args.num_misc_slots
    num_misc_types = self.args.num_misc_types
    misc_type_embed_size = self.args.misc_type_embed_size
    dec_misc_type_embed, dec_misc_stage_type_onehot, dec_misc_times, dec_misc_amts = torch.split(
      x_hat_miscs, [enc_misc_type_embed.shape[1], enc_misc_stage_type_onehot.shape[1], num_misc_slots, num_misc_slots], dim=1
    )
    dec_misc_type_logits = self.misc_type_decoder(dec_misc_type_embed.view(-1, num_misc_slots, misc_type_embed_size)) # (B, num_misc_slots, num_misc_types)
    enc_misc_type_onehot = F.one_hot(x_recipe['misc_type_inds'].long(), num_misc_types).float() # (B, num_misc_slots, num_misc_types)
    loss_miscs = F.binary_cross_entropy_with_logits(dec_misc_type_logits, enc_misc_type_onehot, reduction='sum') + \
      F.binary_cross_entropy_with_logits(dec_misc_stage_type_onehot, enc_misc_stage_type_onehot, reduction='sum') + \
      F.mse_loss(dec_misc_times, x_recipe['misc_times'], reduction='sum') + \
      F.mse_loss(dec_misc_amts, x_recipe['misc_amts'], reduction='sum')
    
    # Microorganism slots decode and loss calculation
    num_mo_slots = self.args.num_microorganism_slots
    num_mo_types = self.args.num_microorganism_types
    mo_type_embed_size = self.args.microorganism_type_embed_size
    dec_mo_type_embed, dec_mo_stage_type_onehot = torch.split(
      x_hat_mos, [enc_mo_type_embed.shape[1], enc_mo_stage_type_onehot.shape[1]], dim=1
    )
    dec_mo_type_logits = self.microorganism_type_decoder(dec_mo_type_embed.view(-1, num_mo_slots, mo_type_embed_size)) # (B, num_mo_slots, num_mo_types)
    enc_mo_type_onehot = F.one_hot(x_recipe['mo_type_inds'].long(), num_mo_types).float() # (B, num_mo_slots, num_mo_types)
    loss_mos = F.binary_cross_entropy_with_logits(dec_mo_type_logits, enc_mo_type_onehot, reduction='sum') + \
      F.binary_cross_entropy_with_logits(dec_mo_stage_type_onehot, enc_mo_stage_type_onehot, reduction='sum')
    
    # Add up all our losses for reconstruction of the recipe
    reconst_loss = loss_toplvl + loss_mash_steps + loss_ferment_stages + loss_grains + loss_hops + loss_miscs + loss_mos
    
    
    #return x_hat, mean, logvar
    
    
  '''
  def calc_loss(self, x, x_hat, mean, logvar, num_iter, kl_weight=1.0):
    
    #one_hot_types = nn.functional.one_hot(x.long(), NUM_CORE_GRAIN_TYPES).float()
    #recons_loss = nn.functional.binary_cross_entropy_with_logits(x_hat, one_hot_types, reduction='sum')
    
    # Beta-VAE KL calculation is based on https://arxiv.org/pdf/1804.03599.pdf
    kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean ** 2 - logvar.exp(), dim=1), dim=0)
    C = torch.clamp(self.C_max/self.C_stop_iter * num_iter, 0, self.C_max.data[0])
    loss = recons_loss + kl_weight * self.gamma * (kl_loss - C).abs()
    return loss
  '''
  
recipe_net = RecipeVAE(args)

In [11]:
for batch_idx, batch in enumerate(dataloader):
  if batch_idx == 1:
    recipe_net(batch)
    break

RuntimeError: result type Float can't be cast to the desired output type Long