In [21]:
import pickle
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

from recipedataset import RecipeDataset, RECIPE_DATASET_FILENAME, NUM_GRAIN_SLOTS, NUM_ADJUNCT_SLOTS, NUM_HOP_SLOTS, NUM_MISC_SLOTS, NUM_MICROORGANISM_SLOTS, NUM_FERMENT_STAGE_SLOTS, NUM_MASH_STEPS

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

def layer_init_ortho(layer, std=np.sqrt(2)):
  nn.init.orthogonal_(layer.weight, std)
  nn.init.constant_(layer.bias, 0.0)
  return layer

def layer_init_xavier(layer, gain):
  nn.init.xavier_normal_(layer.weight, gain)
  nn.init.constant_(layer.bias, 0.0)
  return layer

def reparameterize(mu, logvar):
  std = torch.exp(0.5 * logvar)
  eps = torch.randn_like(std)
  return eps * std + mu

In [4]:
BATCH_SIZE = 256

# Load the dataset and create a dataloader for it
with open("../" + RECIPE_DATASET_FILENAME, 'rb') as f:
  dataset = pickle.load(f)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [22]:
for batch_idx, batch in enumerate(dataloader):
  if batch_idx == 1:
    print(batch.keys())
    break

In [16]:
GRAIN_TYPE_EMBED_SIZE         = 32
ADJUNCT_TYPE_EMBED_SIZE       = 32
HOP_TYPE_EMBED_SIZE           = 96
MISC_TYPE_EMBED_SIZE          = 64
MICROORGANISM_TYPE_EMBED_SIZE = 64

class RecipeNetArgs:
  def __init__(self, dataset) -> None:
    # Recipe-specific constraints ***
    self.num_mash_steps          = NUM_MASH_STEPS
    self.num_grain_slots         = NUM_GRAIN_SLOTS
    self.num_adjunct_slots       = NUM_ADJUNCT_SLOTS
    self.num_hop_slots           = NUM_HOP_SLOTS
    self.num_misc_slots          = NUM_MISC_SLOTS
    self.num_microorganism_slots = NUM_MICROORGANISM_SLOTS
    self.num_ferment_stage_slots = NUM_FERMENT_STAGE_SLOTS
    
    # NOTE: All types include a "None" (i.e., empty) category
    self.num_grain_types         = len(dataset.core_grains_idx_to_dbid) # Number of (core) grain types (rows in the DB)
    self.num_adjunct_types       = len(dataset.core_adjs_idx_to_dbid)   # Number of (core) adjunct types (rows in the DB)
    self.num_hop_types           = len(dataset.hops_idx_to_dbid)        # Number of hop types (rows in the DB)
    self.num_misc_types          = len(dataset.miscs_idx_to_dbid)       # Number of misc. types (rows in the DB)
    self.num_microorganism_types = len(dataset.mos_idx_to_dbid)         # Number of microrganism types (rows in the DB)
    
    self.num_mash_step_types  = len(dataset.mash_step_idx_to_name)  # Number of mash step types (e.g., Infusion, Decoction, Temperature)
    self.num_hop_stage_types  = len(dataset.hop_stage_idx_to_name)  # Number of hop stage types (e.g., Mash, Boil, Primary, ...)
    self.num_misc_stage_types = len(dataset.misc_stage_idx_to_name) # Number of misc stage types (e.g., Mash, Boil, Primary, ...)
    self.num_mo_stage_types   = len(dataset.mo_stage_idx_to_name)   # Number of microorganism stage types (e.g., Primary, Secondary)
    
    # Embedding sizes ***
    self.grain_type_embed_size         = GRAIN_TYPE_EMBED_SIZE
    self.adjunct_type_embed_size       = ADJUNCT_TYPE_EMBED_SIZE
    self.hop_type_embed_size           = HOP_TYPE_EMBED_SIZE
    self.misc_type_embed_size          = MISC_TYPE_EMBED_SIZE
    self.microorganism_type_embed_size = MICROORGANISM_TYPE_EMBED_SIZE
    
    # Network-specific hyperparameters/constraints ***
    self.num_hidden_layers = 1
    self.hidden_size = 1024
    self.z_size = 64 # Latent-bottleneck dimension
    self.activation_fn = nn.ELU
    self.gain = nn.init.calculate_gain('linear', None) # Make sure this corresponds to the activation function!
    self.num_inputs = self.calc_num_inputs()
  
  def calc_num_inputs(self):
    """Determine the number of inputs to the network.
    Returns:
        int: The total number of network inputs.
    """
    # (boil_time + mash_ph + sparge_temp)
    num_simple_inputs = 3 
    # Mash steps (step_type_index_size + step_time + step_temp) * (number of slots) - ordering assumed [0: step 1, 1: step 2, etc.]
    num_mash_step_inputs = self.num_mash_steps*(self.num_mash_step_types + 2)
    # Fermentation stages (step_time + step_temp) * (number of stages) - ordering assumed [0: primary, 1: secondary]
    num_ferment_step_inputs = self.num_ferment_stage_slots*(2)
    # Grain/Malt bill slots (grain_type_embed_size + amount) * (number of slots) - no ordering
    num_grain_slot_inputs = self.num_grain_slots*(self.grain_type_embed_size + 1)
    # Adjunct slots (adjunct_type_embed_size + amount) * (number of slots) - no ordering
    num_adjunct_slot_inputs = self.num_adjunct_slots*(self.adjunct_type_embed_size + 1)
    # Hop slots (hop_type_embed_size + stage_type_index_size + time + concentration) * (number of slots) - no ordering
    num_hop_slot_inputs = self.num_hop_slots*(self.hop_type_embed_size + self.num_hop_stage_types + 2)
    # Misc. slots (misc_type_embed_size + stage_type_index_size + time + amounts) * (number of slots) - no ordering
    num_misc_slot_inputs = self.num_misc_slots*(self.misc_type_embed_size + self.num_misc_stage_types + 2)
    # Microorganism slots (mo_type_embed_size + stage_type_index_size) * (number of slots) - no ordering
    num_mo_slot_inputs = self.num_microorganism_slots*(self.microorganism_type_embed_size + self.num_mo_stage_types)

    return num_simple_inputs + num_mash_step_inputs + num_ferment_step_inputs + num_grain_slot_inputs + \
      num_adjunct_slot_inputs + num_hop_slot_inputs + num_misc_slot_inputs + num_mo_slot_inputs   

args = RecipeNetArgs(dataset)
args.num_inputs

5899

In [None]:
class RecipeVAE(nn.Module):
  
  def __init__(self, args) -> None:
    super().__init__()
    
    hidden_size = args.hidden_size
    z_size = args.z_size
    activation_fn = args.activation_fn
    gain = args.gain
    
    assert args.num_inputs >= 1
    assert args.num_hidden_layers >= 1
    assert hidden_size >= 1
    assert z_size >= 1 and z_size < args.num_inputs

    # Encoder and decoder networks
    self.encoder = nn.Sequential()
    self.encoder.append(layer_init_xavier(nn.Linear(args.num_inputs, hidden_size), gain))
    self.encoder.append(activation_fn())
    for _ in range(1, args.num_hidden_layers):
      self.encoder.append(layer_init_xavier(nn.Linear(hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
    self.encoder.append(layer_init_xavier(nn.Linear(hidden_size, z_size*2), gain))
    self.encoder.append(activation_fn())
    self.encoder.append(nn.BatchNorm1d(z_size*2))

    self.decoder = nn.Sequential()
    self.decoder.append(layer_init_xavier(nn.Linear(z_size, hidden_size)), gain)
    self.decoder.append(activation_fn())
    for _ in range(1, args.num_hidden_layers):
      self.decoder.append(layer_init_xavier(nn.Linear(hidden_size, hidden_size), gain))
      self.encoder.append(activation_fn())
    self.decoder.append(layer_init_xavier(nn.Linear(hidden_size, args.num_inputs), gain))
    
    # Embeddings (NOTE: Any categoricals that don't have embeddings will be one-hot encoded)
    self.grain_type_embedding         = nn.Embedding(args.num_grain_types, args.grain_type_embed_size)
    self.adjunct_type_embedding       = nn.Embedding(args.num_adjunct_types, args.adjunct_type_embed_size) 
    self.hop_type_embedding           = nn.Embedding(args.num_hop_types, args.hop_type_embed_size)
    self.misc_type_embedding          = nn.Embedding(args.num_misc_types, args.misc_type_embed_size)
    self.microorganism_type_embedding = nn.Embedding(args.num_microorganism_types, args.microorganism_type_embed_size)
    
    # Post-network decoders (these are basically learned inverse embeddings)
    self.grain_type_decoder         = layer_init_xavier(nn.Linear(args.grain_type_embed_size, args.num_grain_types), gain)
    self.adjunct_type_decoder       = layer_init_xavier(nn.Linear(args.adjunct_type_embed_size, args.num_adjunct_types), gain)
    self.hop_type_decoder           = layer_init_xavier(nn.Linear(args.hop_type_embed_size, args.num_hop_types), gain)
    self.misc_type_decoder          = layer_init_xavier(nn.Linear(args.misc_type_embed_size, args.num_misc_types), gain)
    self.microorganism_type_decoder = layer_init_xavier(nn.Linear(args.microorganism_type_embed_size, args.num_microorganism_types), gain)
    
    self.num_mash_step_types  = args.num_mash_step_types
    self.num_hop_stage_types  = args.num_hop_stage_types
    self.num_misc_stage_types = args.num_misc_stage_types
    self.num_mo_stage_types   = args.num_mo_stage_types
  
  def forward(self, x):
    # Simple top-level heads (high-level recipe parameters)
    x_simple_toplvl = torch.cat((x['boil_time'].unsqueeze(1), x['mash_ph'].unsqueeze(1), x['sparge_temp'].unsqueeze(1)), dim=1)
    
    # Mash step heads
    # NOTE: Data shape is (B, S=number_of_mash_step_slots) for the
    # following recipe tensors: {'mash_step_type_inds', 'mash_step_times', 'mash_step_avg_temps'}
    mash_step_type_onehot = F.one_hot(x['mash_step_type_inds'].long(), self.num_mash_step_types)                        # (B, S, num_mash_step_types)
    x_mash_steps = torch.cat((mash_step_type_onehot.flatten(1), x['mash_step_times'], x['mash_step_avg_temps']), dim=1) # (B, num_mash_step_types*S+S+S)
    
    # Ferment stage heads
    # NOTE: Data shape is (B, S=2) for the following recipe tensors: {'ferment_stage_times', 'ferment_stage_temps'}
    x_ferment_stages = torch.cat((x['ferment_stage_times'], x['ferment_stage_temps']), dim=1) # (B, 4=(S+S))

    # Grain (malt bill) heads
    # NOTE: Data shape is (B, S=num_grain_slots) for the following recipe tensors: {'grain_core_type_inds', 'grain_amts'}
    grain_type_embed = self.grain_type_embedding(x['grain_core_type_inds'])     # (B, S, grain_type_embed_size)
    x_grains = torch.cat((grain_type_embed.flatten(1), x['grain_amts']), dim=1) # (B, S*grain_type_embed_size+S)
    
    # Adjunct heads
    # NOTE: Data shape is (B, S=num_adjunct_slots) for the following recipe tensors: {'adjunct_core_type_inds', 'adjunct_amts'}
    adjunct_type_embed = self.adjunct_type_embedding(x['adjunct_core_type_inds']) # (B, S, adjunct_type_embed_size)
    x_adjuncts = torch.cat((adjunct_type_embed.flatten(1), x['adjunct_amts']), dim=1)
    
    # Hop heads
    # NOTE: Data shape is (B, S=num_hop_slots) for the following recipe tensors: 
    # {'hop_type_inds', 'hop_stage_type_inds', 'hop_times', 'hop_concentrations'}
    hop_type_embed = self.hop_type_embedding(x['hop_type_inds']) # (B, S, hop_type_embed_size)
    hop_stage_type_onehot = F.one_hot(x['hop_stage_type_inds'].long, self.num_hop_stage_types) # (B, S, num_hop_stage_types)
    x_hops = torch.cat((hop_type_embed.flatten(1), hop_stage_type_onehot.flatten(1), x['hop_times'], x['hop_concentrations']), dim=1)
    
    # Misc. heads
    # NOTE: Data shape is (B, S=num_misc_slots) for the following recipe tensors:
    # {'misc_type_inds', 'misc_stage_inds', 'misc_times', 'misc_amts'}
    misc_type_embed = self.misc_type_embedding(x['misc_type_inds'])                     # (B, S, misc_type_embed_size)
    misc_stage_type_onehot = F.one_hot(x['misc_stage_inds'], self.num_misc_stage_types) # (B, S, num_misc_stage_types)
    x_miscs = torch.cat((misc_type_embed.flatten(1), misc_stage_type_onehot.flatten(1), x['misc_times'], x['misc_amts']), dim=1)
    
    # Microorganism heads
    # NOTE: Data shape is (B, S=num_microorganism_slots) for the following recipe tensors:
    # {'mo_type_inds', 'mo_stage_inds'}
    mo_type_embed = self.microorganism_type_embedding(x['mo_type_inds'])     # (B, S, microorganism_type_embed_size)
    mo_stage_onehot = F.one_hot(x['mo_stage_inds'], self.num_mo_stage_types) # (B, S, num_mo_stage_types)
    x_microorganisms = torch.cat((mo_type_embed.flatten(1), mo_stage_onehot.flatten(1)), dim=1)
    
    # Put all the recipe data together into a flattened tensor
    x = torch.cat((x_simple_toplvl, x_mash_steps, x_ferment_stages, x_grains, x_adjuncts, x_hops, x_miscs, x_microorganisms), dim=1) # (B, num_inputs)
    
    # Encode
    mean, logvar = torch.chunk(self.encoder(x), 2, dim=-1)
    
    # Decode
    # TODO
  
