<a href="https://colab.research.google.com/github/jemelike/recipe-personalization/blob/setup_improvements/recipe_gen/notebooks/example_train.py" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import torch
import numpy as np
import pickle
import torch.utils.data as data
import torch.nn as nn

from functools import partial
from tqdm import tqdm
from itertools import chain
from datetime import datetime
import torch
import torch.nn.init as init

In [None]:

def run_epoch(device, model, sampler, loss_compute, print_every, max_len,
              clip=None, teacher_forcing=False, max_name_len=15, **tensor_kwargs):
    """
    Run a single epoch

    Arguments:
        device {torch.device} -- Torch device on which to store/process data
        model {nn.Module} -- Model to be trained/run
        sampler {BatchSampler} -- Data sampler
        loss_compute {funct} -- Function to compute loss for each batch
        print_every {int} -- Log loss every k iterations
        max_len {int} -- Maximum length / number of steps to unroll and predict

    Keyword Arguments:
        clip {float} -- Clip gradients to a maximum (default: {None})
        teacher_forcing {bool} -- Whether to do teacher-forcing in training (default: {False})
        max_name_len {int} -- Maximum # timesteps to unroll to predict name (default: {15})
        **tensor_kwargs {torch.Tensor} -- Assorted tensors for fun and profit

    Returns:
        float -- Average loss across the epoch
    """
    start = datetime.now()
    total_tokens = 0
    total_name_tokens = 0
    total_loss = 0.0
    total_name_loss = 0.0
    print_tokens = 0

    # Extract into tuples and list
    tensor_names, base_tensors = zip(*tensor_kwargs.items())

    # Iterate through batches in the epoch
    for i, batch in enumerate(tqdm(sampler.epoch_batches(), total=sampler.n_batches), 1):
        batch_users, items = [t.to(device) for t in batch]

        # Fill out batch information
        batch_map = dict(zip(
            tensor_names,
            get_batch_information_general(items, *base_tensors)
        ))

        # Logistics
        this_batch_size = batch_map['steps_tensor'].size(0)
        this_batch_num_tokens = (batch_map['steps_tensor'] != PAD_INDEX).data.sum().item()
        this_batch_num_name_tokens = 0
        this_batch_num_name_tokens = (batch_map['name_tensor'] != PAD_INDEX).data.sum().item()
        name_targets = batch_map['name_tensor']

        # Batch first
        # Comparing out(token[t-1]) to token[t]
        (log_probs, _), (name_log_probs, _) = model.forward(
            device=device, inputs=(
                batch_map['calorie_level_tensor'],
                batch_map['name_tensor'],
                batch_map['ingr_tensor']
            ),
            ingr_masks=batch_map['ingr_mask_tensor'],
            targets=batch_map['steps_tensor'][:, :-1],
            max_len=max_len-1,
            start_token=START_INDEX,
            teacher_forcing=teacher_forcing,
            name_targets=name_targets[:, :-1],
            max_name_len=max_name_len-1,
            visualize=False
        )
        loss, name_loss = loss_compute(
            log_probs, batch_map['steps_tensor'][:, 1:],
            name_outputs=name_log_probs,
            name_targets=name_targets[:, 1:],
            norm=this_batch_size,
            model=model,
            clip=clip
        )

        total_loss += loss
        total_name_loss += name_loss

        # Logging
        total_tokens += this_batch_num_tokens
        total_name_tokens += this_batch_num_name_tokens
        print_tokens += this_batch_num_tokens

        if model.training and i % print_every == 0:
            elapsed = datetime.now() - start
            print("Epoch Step: {} LM Loss: {:.5f}; {}; Tokens/s: {:.3f}".format(
                i,
                loss / this_batch_size,
                'Name Loss: {:.5f}'.format(name_loss / this_batch_size) if name_loss else '',
                print_tokens / elapsed.seconds
            ))
            start = datetime.now()
            print_tokens = 0

        del log_probs, name_log_probs

    # Reshuffle the sampler
    sampler.renew_indices()

    if total_name_tokens > 0:
        print('\nName Perplexity: {}'.format(np.exp(total_name_loss / float(total_name_tokens))))

    return np.exp(total_loss / float(total_tokens))


In [None]:

start = datetime.now()
USE_CUDA, DEVICE = get_device()

# Filters
MAX_NAME = 15
MAX_INGR = 5
MAX_INGR_TOK = 20
MAX_STEP_TOK = 256

# Reproducibility
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Args
data_dir = ""
batch_size = 25
vocab_emb_dim = 300
calorie_emb_dim = 5
ingr_emb_dim = args.ingr_emb_size
hidden_size = 256
n_layers = 2
dropout = .2
num_epochs = 50
lr = "1e-3"
print_every = 500
exp_name = "baseline"
save_folder = "modle/baseline"
lr_annealing_rate = 0.9
clip = None
ingr_gru = False
ingr_emb = True
decode_name = False
shared_proj = False
n_teacher_forcing = None
checkpoint_loc = None