In [6]:
import os
import sys
import pickle
import random

sys.path.append(os.path.dirname(os.getcwd()))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np


from sqlalchemy import create_engine
from db_scripts.brewbrain_db import Base, BREWBRAIN_DB_FILENAME, build_db_str
from file_utils import find_file_cwd_and_parent_dirs
from recipedataset import core_grain_labels, core_adjunct_labels, hop_labels, misc_labels, microorganism_labels
from recipedataset import RecipeDataset, RECIPE_DATASET_FILENAME, NUM_GRAIN_SLOTS, NUM_ADJUNCT_SLOTS, NUM_HOP_SLOTS, NUM_MISC_SLOTS, NUM_MICROORGANISM_SLOTS, NUM_FERMENT_STAGE_SLOTS, NUM_MASH_STEPS

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

def layer_init_ortho(layer, std=np.sqrt(2)):
  nn.init.orthogonal_(layer.weight, std)
  if layer.bias != None:
    nn.init.constant_(layer.bias, 0.0)
  return layer

def layer_init_xavier(layer, gain):
  nn.init.xavier_normal_(layer.weight, gain)
  if layer.bias != None:
    nn.init.constant_(layer.bias, 0.0)
  return layer

def reparameterize(mu, logvar):
  std = torch.exp(0.5 * logvar)
  eps = torch.randn_like(std)
  return eps * std + mu

In [7]:
BATCH_SIZE = 256

# Load the dataset and create a dataloader for it
with open("../" + RECIPE_DATASET_FILENAME, 'rb') as f:
  dataset = pickle.load(f)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [8]:
GRAIN_TYPE_EMBED_SIZE         = 48
ADJUNCT_TYPE_EMBED_SIZE       = 64
HOP_TYPE_EMBED_SIZE           = 256
MISC_TYPE_EMBED_SIZE          = 128
MICROORGANISM_TYPE_EMBED_SIZE = 256

# Embedding labels
db_filepath = build_db_str(find_file_cwd_and_parent_dirs(BREWBRAIN_DB_FILENAME, os.getcwd()))
engine = create_engine(db_filepath, echo=False, future=True)
Base.metadata.create_all(engine)
grain_type_embedding_labels = core_grain_labels(engine, dataset)
adjunct_type_embedding_labels = core_adjunct_labels(engine, dataset)
hop_type_embedding_labels = hop_labels(engine, dataset)
misc_type_embedding_labels = misc_labels(engine, dataset)
microorganism_type_embedding_labels = microorganism_labels(engine, dataset)

class RecipeNetArgs:
  def __init__(self, dataset: RecipeDataset) -> None:
    # Recipe-specific constraints ***
    self.num_mash_steps          = NUM_MASH_STEPS
    self.num_grain_slots         = NUM_GRAIN_SLOTS
    self.num_adjunct_slots       = NUM_ADJUNCT_SLOTS
    self.num_hop_slots           = NUM_HOP_SLOTS
    self.num_misc_slots          = NUM_MISC_SLOTS
    self.num_microorganism_slots = NUM_MICROORGANISM_SLOTS
    self.num_ferment_stage_slots = NUM_FERMENT_STAGE_SLOTS
    
    # NOTE: All types include a "None" (i.e., empty) category
    self.num_grain_types         = len(dataset.core_grains_idx_to_dbid) # Number of (core) grain types (rows in the DB)
    self.num_adjunct_types       = len(dataset.core_adjs_idx_to_dbid)   # Number of (core) adjunct types (rows in the DB)
    self.num_hop_types           = len(dataset.hops_idx_to_dbid)        # Number of hop types (rows in the DB)
    self.num_misc_types          = len(dataset.miscs_idx_to_dbid)       # Number of misc. types (rows in the DB)
    self.num_microorganism_types = len(dataset.mos_idx_to_dbid)         # Number of microrganism types (rows in the DB)
    
    self.num_mash_step_types  = len(dataset.mash_step_idx_to_name)  # Number of mash step types (e.g., Infusion, Decoction, Temperature)
    self.num_hop_stage_types  = len(dataset.hop_stage_idx_to_name)  # Number of hop stage types (e.g., Mash, Boil, Primary, ...)
    self.num_misc_stage_types = len(dataset.misc_stage_idx_to_name) # Number of misc stage types (e.g., Mash, Boil, Primary, ...)
    self.num_mo_stage_types   = len(dataset.mo_stage_idx_to_name)   # Number of microorganism stage types (e.g., Primary, Secondary)
    
    # Embedding sizes ***
    self.grain_type_embed_size         = GRAIN_TYPE_EMBED_SIZE
    self.adjunct_type_embed_size       = ADJUNCT_TYPE_EMBED_SIZE
    self.hop_type_embed_size           = HOP_TYPE_EMBED_SIZE
    self.misc_type_embed_size          = MISC_TYPE_EMBED_SIZE
    self.microorganism_type_embed_size = MICROORGANISM_TYPE_EMBED_SIZE
    
    # Network-specific hyperparameters/constraints ***
    self.hidden_layers = [2048]
    self.z_size = 128 # Latent-bottleneck dimension
    self.activation_fn = nn.ELU
    self.gain = nn.init.calculate_gain('linear', None) # Make sure this corresponds to the activation function!

    # VAE-specific hyperparameters ***
    self.beta_vae_gamma = 1000
    self.max_beta_vae_capacity = 25
    self.beta_vae_C_stop_iter = 1e5
  
  @property
  def num_toplvl_inputs(self):
    # (boil_time + mash_ph + sparge_temp)
    return 3 
  @property
  def num_mash_step_inputs(self):
     # Mash steps (step_type_index_size + step_time + step_temp) * (number of slots) - ordering assumed [0: step 1, 1: step 2, etc.]
    return self.num_mash_steps*(self.num_mash_step_types + 2)
  @property
  def num_ferment_stage_inputs(self):
    # Fermentation stages (step_time + step_temp) * (number of stages) - ordering assumed [0: primary, 1: secondary]
    return self.num_ferment_stage_slots*(2)
  @property
  def num_grain_slot_inputs(self):
    # Grain/Malt bill slots (grain_type_embed_size + amount) * (number of slots) - no ordering
    return self.num_grain_slots*(self.grain_type_embed_size + 1)
  @property
  def num_adjunct_slot_inputs(self):
    # Adjunct slots (adjunct_type_embed_size + amount) * (number of slots) - no ordering
    return self.num_adjunct_slots*(self.adjunct_type_embed_size + 1)
  @property
  def num_hop_slot_inputs(self):
    # Hop slots (hop_type_embed_size + stage_type_index_size + time + concentration) * (number of slots) - no ordering
    return self.num_hop_slots*(self.hop_type_embed_size + self.num_hop_stage_types + 2)
  @property
  def num_misc_slot_inputs(self):
    # Misc. slots (misc_type_embed_size + stage_type_index_size + time + amounts) * (number of slots) - no ordering
    return self.num_misc_slots*(self.misc_type_embed_size + self.num_misc_stage_types + 2)
  @property
  def num_microorganism_slot_inputs(self):
    # Microorganism slots (mo_type_embed_size + stage_type_index_size) * (number of slots) - no ordering
    return self.num_microorganism_slots*(self.microorganism_type_embed_size + self.num_mo_stage_types)
  
  @property
  def num_inputs(self):
    """Determine the number of inputs to the network.
    Returns:
        int: The total number of network inputs.
    """
    return self.num_toplvl_inputs + self.num_mash_step_inputs + self.num_ferment_stage_inputs + \
           self.num_grain_slot_inputs + self.num_adjunct_slot_inputs + self.num_hop_slot_inputs + \
           self.num_misc_slot_inputs + self.num_microorganism_slot_inputs   

args = RecipeNetArgs(dataset)
args.num_inputs

14091

In [9]:
class RecipeNetData(object):
  def __init__(self) -> None:
    pass
  
class RecipeNetHeadEncoder(nn.Module):
  def __init__(self, args) -> None:
    super().__init__()
    # Embeddings (NOTE: Any categoricals that don't have embeddings will be one-hot encoded)
    self.grain_type_embedding         = nn.Embedding(args.num_grain_types, args.grain_type_embed_size)
    self.adjunct_type_embedding       = nn.Embedding(args.num_adjunct_types, args.adjunct_type_embed_size) 
    self.hop_type_embedding           = nn.Embedding(args.num_hop_types, args.hop_type_embed_size)
    self.misc_type_embedding          = nn.Embedding(args.num_misc_types, args.misc_type_embed_size)
    self.microorganism_type_embedding = nn.Embedding(args.num_microorganism_types, args.microorganism_type_embed_size)
    self.args = args
    
  def forward(self, x):
    heads = RecipeNetData()
    # Simple top-level heads (high-level recipe parameters)
    heads.x_toplvl = torch.cat((x['boil_time'].unsqueeze(1), x['mash_ph'].unsqueeze(1), x['sparge_temp'].unsqueeze(1)), dim=1) # (B, 3)
    
    # Mash step heads
    # NOTE: Data shape is (B, S=number_of_mash_steps) for the
    # following recipe tensors: {'mash_step_type_inds', 'mash_step_times', 'mash_step_avg_temps'}
    num_mash_step_types = self.args.num_mash_step_types
    heads.enc_mash_step_type_onehot = F.one_hot(x['mash_step_type_inds'].long(), num_mash_step_types).float().flatten(1) # (B, S, num_mash_step_types) -> (B, S*num_mash_step_types) = [B, 24]
    heads.x_mash_steps = torch.cat((heads.enc_mash_step_type_onehot, x['mash_step_times'], x['mash_step_avg_temps']), dim=1) # (B, num_mash_step_types*S+S+S) = [B, 36=(24+6+6)]
    
    # Ferment stage heads
    # NOTE: Data shape is (B, S=2) for the following recipe tensors: {'ferment_stage_times', 'ferment_stage_temps'}
    heads.x_ferment_stages = torch.cat((x['ferment_stage_times'], x['ferment_stage_temps']), dim=1) # (B, S+S)

    # Grain (malt bill) heads
    # NOTE: Data shape is (B, S=num_grain_slots) for the following recipe tensors: {'grain_core_type_inds', 'grain_amts'}
    num_grain_types = self.args.num_grain_types
    heads.enc_grain_type_embed = self.grain_type_embedding(x['grain_core_type_inds']).flatten(1) # (B, S, grain_type_embed_size) -> (B, S*grain_type_embed_size)
    heads.enc_grain_type_onehot = F.one_hot(x['grain_core_type_inds'].long(), num_grain_types).float() # (B, num_grain_slots, num_grain_types)
    heads.x_grains = torch.cat((heads.enc_grain_type_embed, x['grain_amts']), dim=1) # (B, S*grain_type_embed_size+S)
    
    # Adjunct heads
    # NOTE: Data shape is (B, S=num_adjunct_slots) for the following recipe tensors: {'adjunct_core_type_inds', 'adjunct_amts'}
    num_adjunct_types = self.args.num_adjunct_types
    heads.enc_adjunct_type_embed = self.adjunct_type_embedding(x['adjunct_core_type_inds']).flatten(1) # (B, S, adjunct_type_embed_size) -> (B, S*adjunct_type_embed_size)
    heads.enc_adjunct_type_onehot = F.one_hot(x['adjunct_core_type_inds'].long(), num_adjunct_types).float() # (B, num_adjunct_slots, num_adjunct_types)
    heads.x_adjuncts = torch.cat((heads.enc_adjunct_type_embed, x['adjunct_amts']), dim=1) # (B, S*adjunct_type_embed_size+S)
    
    # Hop heads
    # NOTE: Data shape is (B, S=num_hop_slots) for the following recipe tensors: 
    # {'hop_type_inds', 'hop_stage_type_inds', 'hop_times', 'hop_concentrations'}
    num_hop_types = self.args.num_hop_types
    num_hop_stage_types = self.args.num_hop_stage_types
    heads.enc_hop_type_embed = self.hop_type_embedding(x['hop_type_inds']).flatten(1) # (B, S, hop_type_embed_size)
    heads.enc_hop_type_onehot = F.one_hot(x['hop_type_inds'].long(), num_hop_types).float() # (B, num_hop_slots, num_hop_types)
    heads.enc_hop_stage_type_onehot = F.one_hot(x['hop_stage_type_inds'].long(), num_hop_stage_types).float().flatten(1) # (B, S, num_hop_stage_types)
    heads.x_hops = torch.cat((heads.enc_hop_type_embed, heads.enc_hop_stage_type_onehot, x['hop_times'], x['hop_concentrations']), dim=1) # (B, S*hop_type_embed_size + S*num_hop_stage_types + S + S)
    
    # Misc. heads
    # NOTE: Data shape is (B, S=num_misc_slots) for the following recipe tensors:
    # {'misc_type_inds', 'misc_stage_inds', 'misc_times', 'misc_amts'}
    num_misc_types = self.args.num_misc_types
    num_misc_stage_types = self.args.num_misc_stage_types
    heads.enc_misc_type_embed = self.misc_type_embedding(x['misc_type_inds']).flatten(1) # (B, S, misc_type_embed_size)
    heads.enc_misc_type_onehot = F.one_hot(x['misc_type_inds'].long(), num_misc_types).float() # (B, num_misc_slots, num_misc_types)
    heads.enc_misc_stage_type_onehot = F.one_hot(x['misc_stage_inds'].long(), num_misc_stage_types).float().flatten(1) # (B, S, num_misc_stage_types)
    heads.x_miscs = torch.cat((heads.enc_misc_type_embed, heads.enc_misc_stage_type_onehot, x['misc_times'], x['misc_amts']), dim=1) # (B, S*misc_type_embed_size + S*num_misc_stage_types + S + S)
    
    # Microorganism heads
    # NOTE: Data shape is (B, S=num_microorganism_slots) for the following recipe tensors:
    # {'mo_type_inds', 'mo_stage_inds'}
    num_mo_types = self.args.num_microorganism_types
    num_mo_stage_types = self.args.num_mo_stage_types
    heads.enc_mo_type_embed = self.microorganism_type_embedding(x['mo_type_inds']).flatten(1) # (B, S, microorganism_type_embed_size)
    heads.enc_mo_type_onehot = F.one_hot(x['mo_type_inds'].long(), num_mo_types).float() # (B, num_mo_slots, num_mo_types)
    heads.enc_mo_stage_type_onehot = F.one_hot(x['mo_stage_inds'].long(), num_mo_stage_types).float().flatten(1) # (B, S, num_mo_stage_types)
    heads.x_mos = torch.cat((heads.enc_mo_type_embed, heads.enc_mo_stage_type_onehot), dim=1) # (B, S*microorganism_type_embed_size + S*num_mo_stage_types)
    
    # Put all the recipe data together into a flattened tensor
    x = torch.cat((heads.x_toplvl, heads.x_mash_steps, heads.x_ferment_stages, heads.x_grains, heads.x_adjuncts, heads.x_hops, heads.x_miscs, heads.x_mos), dim=1) # (B, num_inputs)
    return x, heads

class RecipeNetFootDecoder(nn.Module):
  def __init__(self, args: RecipeNetArgs) -> None:
    super().__init__()
    gain = args.gain
    self.grain_type_decoder         = layer_init_xavier(nn.Linear(args.grain_type_embed_size, args.num_grain_types, bias=False), gain)
    self.adjunct_type_decoder       = layer_init_xavier(nn.Linear(args.adjunct_type_embed_size, args.num_adjunct_types, bias=False), gain)
    self.hop_type_decoder           = layer_init_xavier(nn.Linear(args.hop_type_embed_size, args.num_hop_types, bias=False), gain)
    self.misc_type_decoder          = layer_init_xavier(nn.Linear(args.misc_type_embed_size, args.num_misc_types, bias=False), gain)
    self.microorganism_type_decoder = layer_init_xavier(nn.Linear(args.microorganism_type_embed_size, args.num_microorganism_types, bias=False), gain)
    
    # [Top-level recipe attributes, Mash steps, Fermentation stages, Grains, Adjuncts, Hops, Misc, Microorganisms]
    self.split_sizes = [
      args.num_toplvl_inputs, args.num_mash_step_inputs, args.num_ferment_stage_inputs, 
      args.num_grain_slot_inputs, args.num_adjunct_slot_inputs, args.num_hop_slot_inputs,
      args.num_misc_slot_inputs, args.num_microorganism_slot_inputs
    ]
    #assert np.sum(se)
    self.args = args
    
  def forward(self, x_hat):
    foots = RecipeNetData()
    
    # The decoded tensor is flat with a shape of (B, num_inputs), we'll need to break it apart
    # so that we can eventually calculate losses appropriately for each head of original data fed to the encoder
    foots.x_hat_toplvl, foots.x_hat_mash_steps, foots.x_hat_ferment_stages, foots.x_hat_grains, foots.x_hat_adjuncts, foots.x_hat_hops, foots.x_hat_miscs, foots.x_hat_mos = torch.split(x_hat, self.split_sizes, dim=1)

    # Mash steps
    num_mash_steps = self.args.num_mash_steps
    enc_mash_step_type_onehot_size = num_mash_steps * self.args.num_mash_step_types
    foots.dec_mash_step_type_onehot, foots.dec_mash_step_times, foots.dec_mash_step_avg_temps = torch.split(
      foots.x_hat_mash_steps, [enc_mash_step_type_onehot_size, num_mash_steps, num_mash_steps], dim=1
    )

    # Grain slots
    num_grain_slots = self.args.num_grain_slots
    grain_type_embed_size = self.args.grain_type_embed_size
    enc_grain_type_embed_size = num_grain_slots * grain_type_embed_size
    foots.dec_grain_type_embed, foots.dec_grain_amts = torch.split(foots.x_hat_grains, [enc_grain_type_embed_size, num_grain_slots], dim=1)
    foots.dec_grain_type_logits = self.grain_type_decoder(foots.dec_grain_type_embed.view(-1, num_grain_slots, grain_type_embed_size)) # (B, num_grain_slots, num_grain_types)

    # Adjunct slots
    num_adjunct_slots = self.args.num_adjunct_slots
    adjunct_type_embed_size = self.args.adjunct_type_embed_size
    enc_adjunct_type_embed_size = num_adjunct_slots * adjunct_type_embed_size
    dec_adjunct_type_embed, foots.dec_adjunct_amts = torch.split(foots.x_hat_adjuncts, [enc_adjunct_type_embed_size, num_adjunct_slots], dim=1)
    foots.dec_adjunct_type_logits = self.adjunct_type_decoder(dec_adjunct_type_embed.view(-1, num_adjunct_slots, adjunct_type_embed_size)) # (B, num_adjunct_slots, num_adjunct_types)
    
    # Hop slots
    num_hop_slots = self.args.num_hop_slots
    hop_type_embed_size = self.args.hop_type_embed_size
    enc_hop_type_embed_size = num_hop_slots * hop_type_embed_size
    enc_hop_stage_type_onehot_size = num_hop_slots * self.args.num_hop_stage_types
    dec_hop_type_embed, foots.dec_hop_stage_type_onehot, foots.dec_hop_times, foots.dec_hop_concentrations = torch.split(
      foots.x_hat_hops, [enc_hop_type_embed_size, enc_hop_stage_type_onehot_size, num_hop_slots, num_hop_slots], dim=1
    )
    foots.dec_hop_type_logits = self.hop_type_decoder(dec_hop_type_embed.view(-1, num_hop_slots, hop_type_embed_size)) # (B, num_hop_slots, num_hop_types)
    
    # Miscellaneous slots
    num_misc_slots = self.args.num_misc_slots
    misc_type_embed_size = self.args.misc_type_embed_size
    enc_misc_type_embed_size = num_misc_slots * misc_type_embed_size
    enc_misc_stage_type_onehot_size = num_misc_slots * self.args.num_misc_stage_types
    dec_misc_type_embed, foots.dec_misc_stage_type_onehot, foots.dec_misc_times, foots.dec_misc_amts = torch.split(
      foots.x_hat_miscs, [enc_misc_type_embed_size, enc_misc_stage_type_onehot_size, num_misc_slots, num_misc_slots], dim=1
    )
    foots.dec_misc_type_logits = self.misc_type_decoder(dec_misc_type_embed.view(-1, num_misc_slots, misc_type_embed_size)) # (B, num_misc_slots, num_misc_types)
    
    # Microorganism slots
    num_mo_slots = self.args.num_microorganism_slots
    mo_type_embed_size = self.args.microorganism_type_embed_size
    enc_mo_type_embed_size = num_mo_slots * mo_type_embed_size
    enc_mo_stage_type_onehot_size = num_mo_slots * self.args.num_mo_stage_types
    dec_mo_type_embed, foots.dec_mo_stage_type_onehot = torch.split(
      foots.x_hat_mos, [enc_mo_type_embed_size, enc_mo_stage_type_onehot_size], dim=1
    )
    foots.dec_mo_type_logits = self.microorganism_type_decoder(dec_mo_type_embed.view(-1, num_mo_slots, mo_type_embed_size)) # (B, num_mo_slots, num_mo_types)
    
    return foots
    

class RecipeNet(nn.Module):

  def __init__(self, args) -> None:
    super().__init__()
    
    hidden_layers = args.hidden_layers
    z_size = args.z_size
    activation_fn = args.activation_fn
    gain = args.gain
    
    assert all([num_hidden > 0 for num_hidden in hidden_layers])
    assert args.num_inputs >= 1
    assert len(hidden_layers) >= 1
    assert z_size >= 1 and z_size < args.num_inputs

    # Encoder and decoder networks
    self.encoder = nn.Sequential()
    self.encoder.append(layer_init_xavier(nn.Linear(args.num_inputs, hidden_layers[0]), gain))
    self.encoder.append(activation_fn())
    prev_hidden_size = hidden_layers[0]
    for hidden_size in hidden_layers[1:]:
      self.encoder.append(layer_init_xavier(nn.Linear(prev_hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
      prev_hidden_size = hidden_size
    self.encoder.append(layer_init_xavier(nn.Linear(prev_hidden_size, z_size*2), gain))
    self.encoder.append(activation_fn())
    self.encoder.append(nn.BatchNorm1d(z_size*2))

    self.decoder = nn.Sequential()
    self.decoder.append(layer_init_xavier(nn.Linear(z_size, hidden_layers[-1]), gain))
    self.decoder.append(activation_fn())
    prev_hidden_size = hidden_layers[-1]
    for hidden_size in reversed(hidden_layers[:-1]):
      self.decoder.append(layer_init_xavier(nn.Linear(prev_hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
      prev_hidden_size = hidden_size
    self.decoder.append(layer_init_xavier(nn.Linear(hidden_layers[0], args.num_inputs), gain))
    
    # Pre-net Encoder (Network 'Heads')
    self.head_encoder = RecipeNetHeadEncoder(args)
    # Post-net Decoder (Network 'Foots')
    self.foot_decoder = RecipeNetFootDecoder(args)

    self.gamma = args.beta_vae_gamma
    self.C_stop_iter = args.beta_vae_C_stop_iter
    self.C_max = torch.Tensor([args.max_beta_vae_capacity])
    
    self.args = args
  
  def encode(self, input: torch.Tensor):
    # Start by breaking the given x apart into all the various heads/embeddings 
    # and concatenate them into a value that can be fed to the encoder network
    x, heads = self.head_encoder(input)
    # Encode to the latent distribution mean and std dev.
    mean, logvar = torch.chunk(self.encoder(x), 2, dim=-1) 
    return heads, mean, logvar
  
  def decode(self, z: torch.Tensor):
    # Decode to the flattened output
    x_hat = self.decoder(z)
    # We need to perform the reverse process on the output from the decoder network:
    # Break apart the output into matching segments similar to the heads (foots!) for use in later loss calculations
    foots = self.foot_decoder(x_hat)
    return foots
    
  def forward(self, input: torch.Tensor, use_mean=False):
    heads, mean, logvar = self.encode(input)
    # Sample (reparameterize trick) the final latent vector (z)
    z = mean if use_mean else reparameterize(mean, logvar)
    foots = self.decode(z)

    return heads, foots, mean, logvar
  
  def loss_fn(self, input, heads, foots, mean, logvar, num_iter, kl_weight=1.0):
    REDUCTION = 'sum'
    # TODO: Simplify all this stuff into fewer losses: 
    # Group together all BCELogit and MSE losses into singluar tensors in both x and x_hat
    loss_toplvl = F.mse_loss(foots.x_hat_toplvl, heads.x_toplvl, reduction=REDUCTION)
    loss_mash_steps = F.binary_cross_entropy_with_logits(foots.dec_mash_step_type_onehot, heads.enc_mash_step_type_onehot, reduction=REDUCTION) + \
      F.mse_loss(foots.dec_mash_step_times, input['mash_step_times'], reduction=REDUCTION) + \
      F.mse_loss(foots.dec_mash_step_avg_temps, input['mash_step_avg_temps'], reduction=REDUCTION)
    loss_ferment_stages = F.mse_loss(foots.x_hat_ferment_stages, heads.x_ferment_stages, reduction=REDUCTION)
    loss_grains = F.binary_cross_entropy_with_logits(foots.dec_grain_type_logits, heads.enc_grain_type_onehot, reduction=REDUCTION) + \
      F.mse_loss(foots.dec_grain_amts, input['grain_amts'], reduction=REDUCTION)
    loss_adjuncts = F.binary_cross_entropy_with_logits(foots.dec_adjunct_type_logits, heads.enc_adjunct_type_onehot, reduction=REDUCTION) + \
      F.mse_loss(foots.dec_adjunct_amts, input['adjunct_amts'], reduction=REDUCTION)
    loss_hops = F.binary_cross_entropy_with_logits(foots.dec_hop_type_logits, heads.enc_hop_type_onehot, reduction=REDUCTION) + \
      F.binary_cross_entropy_with_logits(foots.dec_hop_stage_type_onehot, heads.enc_hop_stage_type_onehot, reduction=REDUCTION) + \
      F.mse_loss(foots.dec_hop_times, input['hop_times'], reduction=REDUCTION) + \
      F.mse_loss(foots.dec_hop_concentrations, input['hop_concentrations'], reduction=REDUCTION)
    loss_miscs = F.binary_cross_entropy_with_logits(foots.dec_misc_type_logits, heads.enc_misc_type_onehot, reduction=REDUCTION) + \
      F.binary_cross_entropy_with_logits(foots.dec_misc_stage_type_onehot, heads.enc_misc_stage_type_onehot, reduction=REDUCTION) + \
      F.mse_loss(foots.dec_misc_times, input['misc_times'], reduction=REDUCTION) + \
      F.mse_loss(foots.dec_misc_amts, input['misc_amts'], reduction=REDUCTION)
    loss_mos = F.binary_cross_entropy_with_logits(foots.dec_mo_type_logits, heads.enc_mo_type_onehot, reduction=REDUCTION) + \
      F.binary_cross_entropy_with_logits(foots.dec_mo_stage_type_onehot, heads.enc_mo_stage_type_onehot, reduction=REDUCTION)

    # Add up all our losses for reconstruction of the recipe
    reconst_loss = loss_toplvl + loss_mash_steps + loss_ferment_stages + loss_grains + loss_adjuncts + loss_hops + loss_miscs + loss_mos

    # Beta-VAE KL calculation is based on https://arxiv.org/pdf/1804.03599.pdf
    kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mean ** 2 - logvar.exp(), dim=1), dim=0)
    C = torch.clamp(self.C_max/self.C_stop_iter * num_iter, 0, self.C_max.data[0])
    loss = reconst_loss + kl_weight * self.gamma * (kl_loss - C).abs()
    return loss, C
    
  
recipe_net = RecipeNet(args)
optimizer  = torch.optim.Adam(recipe_net.parameters(), lr=1e-3, betas=(0.9, 0.999))

In [10]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3000, eps=1e-5)
optimizer.param_groups[0]['lr'] = 1e-3 # Learning Rate
global_step = 1

In [12]:
import time
from torch.utils.tensorboard import SummaryWriter

run_dir = os.path.join("runs", f"recipe_vae_{int(time.time())}")
os.makedirs(run_dir, exist_ok=True)
writer = SummaryWriter(run_dir)
writer.add_text(
  "hyperparameters",
  "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()]))
)

# Monitor the recipe network using hooks and tensorboard
MONITOR_UPDATE_STEPS = 50
for name, layer in recipe_net.named_children():
  if name == 'encoder':
    encoder_children = list(layer.named_children())
    # Distributions of outputs after the first layer+activation
    first_actfn = encoder_children[1][1]
    first_actfn.register_forward_hook(
      lambda layer, input, output:
        writer.add_histogram("dists/outputs/encoder_first_actfn", output.flatten(), global_step, bins='auto') if global_step % MONITOR_UPDATE_STEPS == 0 else None
    )
    # Distribution of outputs after the last layer+activation (before batchnorm)
    num_hidden_layers = len(args.hidden_layers)
    last_actfn = encoder_children[1+num_hidden_layers*2][1]
    last_actfn.register_forward_hook(
      lambda layer, input, output:
        writer.add_histogram("dists/outputs/encoder_last_actfn", output.flatten(), global_step, bins='auto') if global_step % MONITOR_UPDATE_STEPS == 0 else None
    )
    # Distribution of outputs after the encoder (last layer is a batchnorm1D)
    batchnorm = encoder_children[2+num_hidden_layers*2][1]
    batchnorm.register_forward_hook(
      lambda layer, input, output:
        writer.add_histogram("dists/outputs/encoder_batchnorm", output.flatten(), global_step, bins='auto') if global_step % MONITOR_UPDATE_STEPS == 0 else None
    )
    
    # Distributions of weights of the first layer
    first_layer = encoder_children[0][1]
    first_layer.register_forward_hook(
      lambda layer, input, output:
        writer.add_histogram("dists/weights/encoder_first_layer", layer.weight.flatten(), global_step, bins='auto') if global_step % MONITOR_UPDATE_STEPS == 0 else None
    )
    last_layer = encoder_children[num_hidden_layers*2][1]
    last_layer.register_forward_hook(
      lambda layer, input, output:
        writer.add_histogram("dists/weights/encoder_last_layer", layer.weight.flatten(), global_step, bins='auto') if global_step % MONITOR_UPDATE_STEPS == 0 else None
    )
    
  elif name == 'decoder':
    pass
  elif name == 'head_encoder':
    pass
  else: # name == 'foot_decoder':
    pass

KL_WEIGHT  = 1.0
NUM_EPOCHS = 10
outlier_ids = []
for i in range(NUM_EPOCHS):
  epoch_loss = 0.0
  for batch_idx, batch in enumerate(dataloader):
    heads, foots, mean, logvar = recipe_net(batch)
    loss, C = recipe_net.loss_fn(batch, heads, foots, mean, logvar, global_step, KL_WEIGHT)
    
    epoch_loss += loss.item()
    writer.add_scalar("charts/total_loss", loss.item(), global_step)
    if global_step % 100 == 0:
      # Send the head encoder's embeddings to tensorboard
      writer.add_embedding(recipe_net.head_encoder.grain_type_embedding.weight, grain_type_embedding_labels, tag="embedding/grain_type")
      writer.add_embedding(recipe_net.head_encoder.adjunct_type_embedding.weight, adjunct_type_embedding_labels, tag="embeddings/adjunct_type")
      writer.add_embedding(recipe_net.head_encoder.hop_type_embedding.weight, hop_type_embedding_labels, tag="embeddings/hop_type")
      writer.add_embedding(recipe_net.head_encoder.misc_type_embedding.weight, misc_type_embedding_labels, tag="embeddings/misc_type")
      writer.add_embedding(recipe_net.head_encoder.microorganism_type_embedding.weight, microorganism_type_embedding_labels, tag="embeddings/microorganism_type")        

    #if loss.item() > 3.5e4:
    #  print(batch['dbid'])
    
    optimizer.zero_grad() 
    loss.backward()
    nn.utils.clip_grad_norm_(recipe_net.parameters(), 100.0)
    optimizer.step()
    scheduler.step(loss)
    global_step += 1
    
    print('\r', "Global Step:", global_step, "Loss:", np.around(loss.item(), 5), "lr:", optimizer.param_groups[0]['lr'], "C:", np.around(C.item(), 1), "\t\t", end='')
  
  avg_epoch_loss = epoch_loss / (batch_idx+1)
  print("\r\n", f"Avg Epoch #{i+1} loss: {np.around(avg_epoch_loss,5)}\t\t\t\t")
  

 Global Step: 1119 Loss: 12388.30859 lr: 0.001 C: 0.3 		
 Avg Epoch #1 loss: 152734.90841				
 Global Step: 2237 Loss: 11778.50391 lr: 0.001 C: 0.6 		
 Avg Epoch #2 loss: 25335.70338				
 Global Step: 3355 Loss: 9916.01465 lr: 0.001 C: 0.8 			
 Avg Epoch #3 loss: 23874.46143				
 Global Step: 3706 Loss: 20651.43555 lr: 0.001 C: 0.9 		

KeyboardInterrupt: 