# Projet de Cloud Computing

Génération des charges de travail pour les FaaS dans les grandes plateformes cloud : Cas de Microsoft Azure

Membres du groupe :
- ATANGANA Julien Patrick
- MARIA-MBOMO MBOA Marilyn
- MELIE DOUMTSOP Godsend
- NGUEN Kevina Anne
- TALLA CHENDJOU James

Sous la supervision de Pr Alain TCHANA

---

# Import libraries

In [None]:
import pandas as pd
import numpy as np

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from abc import ABC, abstractmethod
from collections import namedtuple
import statsmodels.api as sm

In [None]:
import torch
import torch.nn as nn
from typing import NamedTuple
from enum import Enum

# download data

In [None]:
!wget https://azurecloudpublicdataset2.blob.core.windows.net/azurepublicdatasetv2/azurefunctions_dataset2019/azurefunctions-dataset2019.tar.xz

In [None]:
!tar xvf azurefunctions-dataset2019.tar.xz

# Path To Data

In [None]:
path_to_dataset_folder = "/content/"
data_path =""

# Arrival model

# Arranging dataset

In [None]:
functions = []
apps = []
owners = []
number_invocations = []
X_data = []
y_data = []
hash_functions = []

for day in range(1, 3):
  inv_df_d = pd.read_csv(path_to_dataset_folder + 'invocations_per_function_md.anon.d0'+ str(day) +'.csv') if day < 10 else pd.read_csv(path_to_dataset_folder + 'invocations_per_function_md.anon.d'+ str(day) +'.csv')
  
  hash_functions.extend(inv_df_d['HashFunction'].unique().tolist())

  for minute in range(1, 1441):
    hashes = inv_df_d.loc[inv_df_d[str(minute)] > 0, ['HashOwner', 'HashApp', 'HashFunction']]
    x_data = [0]*1452
    x_data[minute-1] = 1
    x_data[1440:1440+day] = [1]*day
    X_data.append(x_data)

    app_func = hashes.groupby('HashApp')
    hash_apps = hashes['HashApp'].unique().tolist()
    hash_apps_size = len(hash_apps)
    functions.append(',|,'.join(map(str, app_func['HashFunction'].transform(lambda x: ','.join(map(str, x))).unique().tolist())) + ',|')

    if minute == 1:
      print(functions)

    apps.append(','.join(map(str, hash_apps)))
    y_data.append(hash_apps_size)




In [None]:
data = pd.DataFrame(X_data, columns=[str(i) for i in range(1, 1453)])
data['HashFunction'] = functions
data['HashApps'] = apps
data['number_of_apps'] = y_data
data.to_csv(data_path + 'data.csv', index=False)

In [None]:
data.shape

In [None]:
hash_functions = list(set(hash_functions))

In [None]:
len(hash_functions)

# Load Data

In [None]:
data = pd.read_csv(data_path + 'data.csv')

In [None]:
data.tail()

In [None]:
data

In [None]:
data['HashFunction'].unique()

In [None]:
data['HashApps'].unique()

In [None]:
DAYS = 12

In [None]:
X = data.loc[:, [str(i) for i in range(1, 1441+DAYS)]]    # for 12 days - 1452
y = data['number_of_apps']

In [None]:
X = X.to_numpy()
y = y.to_numpy()

X = X.astype('float64')
y = y.astype('int32')

In [None]:
X.shape

In [None]:
y.shape

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=10, shuffle=False)

In [None]:
y_train.dtype

In [None]:
poisson_fit = sm.GLM(y_train, X_train, family=sm.families.Poisson()).fit_regularized(alpha=1e-2)

In [None]:
poisson_fit.save(data_path + 'arr_model.pkl')

In [None]:
def get_quantiles(predict_arr, args):
    """Get prediction quantiles."""
    all_means = np.repeat(predict_arr, args.npoisson_samps,
                          axis=0).astype(np.float64)
    psamps = np.random.poisson(all_means)
    p95 = np.percentile(psamps, 95, axis=0)
    p50 = np.percentile(psamps, 50, axis=0)
    p05 = np.percentile(psamps, 5, axis=0)
    return p05, p50, p95

In [None]:
X_test[7][1441]

In [None]:
pred = poisson_fit.predict(X_test)

In [None]:
pred

In [None]:
y_test

In [None]:
m = data.loc[0, [str(i) for i in range(1, 1447)]]

In [None]:
m[[str(i) for i in range(1, 144)]].to_list()

# Utilities

In [None]:
BOUND = "|"
IGNORE_INDEX = -100
ITEM_DATA_SEP = ','
TRACE_DATA_SEP = " "
DAYS = 2


class ExampleKeys(Enum):
    """Keys we can use to extract values from an example."""
    INPUT = "input"
    TARGET = "target"
    OUT_MASK = "mask"

In [None]:
hash_functions.append('|')

In [None]:
class FlavTensorMaker():
    """Make tensors for inputs and outputs, as requested, given flavor
    map, which is used only to get list of flavors, which we turn into
    a mapping from flavors to indexes in the features tensor.
    """

    def __init__(self, hash_functs):    # vector containing hashes of functions 
        """The tensors depend on how many codes in the flav_map.
        Args:
        flav_map_fn: String, filename with map from flavors to their codes.
        range_start/stop: Int: timestamps for start/end of training data.
        range_idx: Int: If given, tensor maker will use this index in
        the range features (modulo the number of range features).
        """
        self.flav_idxs, self.idx_flavs = self.get_flav_idxs(hash_functs)
        self.ninput = len(self.flav_idxs)
        # For timestamp feats:
        self.ninput += 1440  + DAYS
        self.noutput = len(self.flav_idxs)
    

    def get_ninput(self):
        """Getter for the ninput
        """
        return self.ninput

    def get_noutput(self):
        """Getter for the noutput
        """
        return self.noutput

    @staticmethod
    def get_flav_idxs(hash_functs):
        """Return map from flavor to index, and one with reverse mapping,
        using values in flavstr_map.
        """
        func_idx = {}
        idx_func = {}
        # hash_func.append('|')

        for idx, hash_f in enumerate(hash_functs):
            func_idx[hash_f] = idx
            idx_func[idx] = hash_f
    
        return func_idx, idx_func

    def __one_hot_flav_line(self, hash_func):
        """One-hot-encode line of flavs as a tensor of LEN(LINE) x 1 x NDIMS.
        """
        ndims = len(self.flav_idxs)
        fvals = torch.tensor([self.flav_idxs[f] for f in hash_func])
        tensor = torch.nn.functional.one_hot(fvals,
                                             num_classes=ndims) \
                                    .unsqueeze(dim=1)
        return tensor

    @staticmethod
    def one_hots_timestamp(minute_day):
        """Encode both hour-of-day (from 1 to 24) and day-of-week (from 1 to
        7) using 1-hot ecoding and return 31-dimensional tensor.
        This is public so we can re-use it in other classes
        (e.g. features for Poisson Regression in narrivals).
        """
        # We don't know what REAL day it was, but even if back at
        # start of Linux, it's fine for finding patterns:
        
        tensor = torch.from_numpy(np.array(minute_day))
        return tensor


    def encode_input(self, minute_day, line):
        """Given line of NFLAVS, output should be (NFLAVS-1) x 1 x NFEATURES
        [since last flav is not part of INPUT].
        """
        input_line = line[:-1]
        print(input_line)
        oneh_flavs = self.__one_hot_flav_line(input_line)
        # timestamp/range encoded once, then tiled:
        oneh_ts = self.one_hots_timestamp(minute_day)
        oneh_ts_line = oneh_ts.repeat(len(input_line), 1, 1)
        
        return torch.cat([oneh_flavs, oneh_ts_line], dim=2)

    def encode_target(self, hash_func):
        """For line of NFLAVS, return NFLAVS-1 targets giving indexes of the
        true flavs.
        """
        flav_idxs = [self.flav_idxs[flav] for flav in hash_func[1:]]
        return torch.LongTensor(flav_idxs)

    def replace_flav_input(self, my_input, new_flav):
        """Replace existing encoding of flavor in given input, in place, with
        encoding of 'new_flav' instead.
        """
        nhot_flav_feats = len(self.flav_idxs)
        fval = torch.tensor([self.flav_idxs[new_flav]])
        new_flav_tensor = torch.nn.functional.one_hot(
            fval, num_classes=nhot_flav_feats)[0]
        my_input[0, 0, :nhot_flav_feats] = new_flav_tensor

In [None]:
class TraceLSTM(nn.Module):
    """Generic LSTM for flavors or duration modeling.
    """

    def __init__(self, ninput, nhidden, noutput, nlayers):
        """Depending on the size of the input, output, and the array of hidden
        layers, add attributes for the inner LSTM and the
        fully-connected layers (including both weights and a bias
        term).
        """
        super().__init__()
        self.ninput = ninput
        self.nhidden = nhidden
        self.noutput = noutput
        self.nlayers = nlayers
        self.lstm = nn.LSTM(ninput, self.nhidden, self.nlayers)
        self.fc_out = nn.Linear(self.nhidden, noutput)
        self.hidden = self.init_hidden()

    def init_hidden(self, device=None, batch_size=1):
        """Before doing each new sequence, re-init hidden state to zeros.
        """
        hid0 = torch.zeros(self.nlayers, batch_size, self.nhidden)
        c_hid0 = torch.zeros(self.nlayers, batch_size, self.nhidden)
        if device is not None:
            hid0 = hid0.to(device)
            c_hid0 = c_hid0.to(device)
        return (hid0, c_hid0)

    def forward(self, minibatch):
        """Pass in a tensor of training examples of dimension LENGTH x
        BATCHSIZE x NINPUT, then run the forward pass. Returns tensor
        of LENGTH x BATCHSIZE x NOUTPUT.
        """
        lstm_out, self.hidden = self.lstm(minibatch, self.hidden)
        all_logits = self.fc_out(lstm_out)
        return all_logits

    def save(self, outfn):
        """Use the state-dict method of saving:
        """
        torch.save(self.state_dict(), outfn)

    @classmethod
    def create_from_path(cls, filename, device=None):
        """Factory method to return an instance of this class, given the model
        state-dict at the current filename. If device given,
        dynamically move model to device.
        """
        if device is not None:
            torch_device = torch.device(device)
            state_dict = torch.load(filename, map_location=torch_device)
        else:
            state_dict = torch.load(filename)
        nhidden = state_dict['fc_out.weight'].shape[1]
        noutput = state_dict['fc_out.weight'].shape[0]
        # LSTM layers have 4 values: ih/hh weights and ih/hh biases:
        nlayers = len(state_dict.keys()) // 4
        ninput = state_dict['lstm.weight_ih_l0'].shape[1]
        new_model = cls(ninput, nhidden, noutput, nlayers)
        new_model.load_state_dict(state_dict)
        return new_model

In [None]:
class LossStats():
    """A class to hold, and reset as needed, the loss stats, during
    training or testing.
    """
    def __init__(self):
        """Initialize all our running totals to zero."""
        self.tot_loss = 0
        self.tot_examples = 0

    def update(self, num, loss):
        """Given we've processed num examples, and observed an average loss of
        loss, update our totals.
        """
        if num == 0:
            return
        self.tot_loss += loss * num
        self.tot_examples += num

    def get_tot_examples(self):
        """Return total number of examples processed since beginning."""
        return self.tot_examples

    def overall_loss(self):
        """Calculate and return the overall loss."""
        return self.tot_loss / self.tot_examples

In [None]:

class TrainArgs(NamedTuple):
    """Arguments to be used in training."""
    learn_rate: float
    weight_decay: float
    max_iters: int


class TrainLSTM():
    """Class to handle flavor-LSTM training."""
    def __init__(self, eval_lstm, net, train_args, trainloader):
        self.eval_lstm = eval_lstm
        self.net = net
        self.train_args = train_args
        self.trainloader = trainloader

    def run_train_iteration(self, data, optimizer, criterion):
        """Run a single training step and return the number of inputs
        processed and the loss.
        """
        optimizer.zero_grad()
        num, loss = self.eval_lstm.batch_forward(data, criterion)
        loss.backward()
        optimizer.step()
        return num, loss

    def iterate_models(self, optimizer, criterion):
        """Run a single training iteration and yield the loss"""
        for epoch in range(self.train_args.max_iters):
            self.net.train()
            loss_stats = LossStats()
            for iter_num, batch in enumerate(self.trainloader, 1):
                num, loss = self.run_train_iteration(batch, optimizer,
                                                     criterion)
                loss_stats.update(num, loss)
            overall_loss = loss_stats.overall_loss()
            tot_examples = loss_stats.get_tot_examples()
            logger.info('Train loss, epoch [%d, %7d]: %.7f',
                        epoch, tot_examples, overall_loss)
            yield overall_loss

    def run(self, criterion):
        """Run training on the given neural network.
        """
        optimizer = torch.optim.Adam(self.net.parameters(),
                                     lr=self.train_args.learn_rate,
                                     weight_decay=self.train_args.weight_decay)
        logger.info("Optimizer: %s", optimizer)
        logger.info("Starting training")
        for iter_num, train_loss in enumerate(self.iterate_models(
                optimizer, criterion), 1):
            self.eval_lstm.get_test_score(iter_num, criterion)
        logger.info("Finished training")
        return self.net


def get_init_model(args, tmaker):
    """Return an initial LSTM model for training, given the tensor
    maker for this LSTM.
    """
    # Get ndims from the tensor_maker for this flav_map:
    ninput = tmaker.get_ninput()
    noutput = tmaker.get_noutput()
    model = TraceLSTM(ninput, args.nhidden, noutput, args.nlayers)
    return model

In [None]:
def yield_trace_lines(trace_fn):    # trace_fn: whole dataframe
  """Read and yield data from the trace line-by-line: for either
  flavors, or durations.
  """
  j = 0
  for index, row in trace_fn.iterrows():
    hash_functs = row['HashFunction'].split(',')
    # for index, num in enumerate(row['number_of_invocations'].split(',')):
    #     hash_functs.insert(index+int(num)+j, '|')
    #     j+=int(num)
    yield row[[str(i) for i in range(1, 1441+DAYS)]].to_list(), hash_functs

In [None]:
class FlavDataset(Dataset):
    """A dataset that can be used for flavor sequence modeling.
    """

    def __init__(self, seq_len, dataset_fn):    # dataset_fn = line
        """Initialize the dataset class.
        Args:
        flav_map_fn: String, filename with map from flavors to their codes.
        seq_len: Int, how long to make the sequences for each example.
        dataset_fn: String, filename where input dataset lies.
        range_start/stop: Int: timestamps for start/end of training data.
        """
        self.seq_len = seq_len
        self.tmaker = FlavTensorMaker(hash_functions)   # takes hashes of functions
        trace_data = yield_trace_lines(dataset_fn)  # takkes whole dataset
        # make one giant example, getitem() & len() will take pieces:
        self.all_inputs, self.all_targets = self.__make_example_tensor(
            trace_data)
        
        assert len(self.all_inputs) == len(self.all_targets)

    def __make_example_tensor(self, trace_data):    # trace_data = whole dataset
        """Go through lines in trace and create one big example tensor, where
        the first dimension is example number.
        """
        all_inputs, all_targets = [], []
        for idx, line in enumerate(trace_data):
            my_input, my_target = self.__make_example_from_line(line)
            all_inputs.append(my_input)
            all_targets.append(my_target)
            if idx > 1 and idx % REPORT == 0:
                logger.info("Read %s dataset lines", idx)
        logger.info("Read %s dataset lines", idx)
        # Create a single vector for each of these by reshaping:
        all_inputs = torch.cat(all_inputs)
        all_targets = torch.cat(all_targets)
        all_inputs, all_targets = self.__reshape_data(
            all_inputs, all_targets)
        return all_inputs, all_targets

    def __reshape_data(self, all_inputs, all_targets):
        """Depending on the sequence length, reshape accordingly.  Also, pad
         with targets with IGNORE_INDEX so that we divide evenly.
        """
        nflavs = len(all_inputs)
        nseqs = ceil(1.0 * nflavs / self.seq_len)
        padding_needed = nseqs * self.seq_len - nflavs
        fake_targets = (torch.ones(padding_needed, dtype=torch.long) *
                        IGNORE_INDEX)
        all_targets = torch.cat([all_targets, fake_targets])
        fake_input_shape = list(all_inputs.shape)
        fake_input_shape[0] = padding_needed
        fake_input = torch.zeros(fake_input_shape)
        all_inputs = torch.cat([all_inputs, fake_input])
        # After padding, reshape into sequences of seq_len:
        reshaped_inputs = all_inputs.reshape(-1, self.seq_len, 1,
                                             fake_input_shape[-1])
        reshaped_targets = all_targets.reshape(-1, self.seq_len)
        return reshaped_inputs, reshaped_targets

    def __make_example_from_line(self, line):
        """Unpack the line and make the example from it.
        """
        timestamp, flavs = line
        # timestamp = int(timestamp)
        # Lines don't BEGIN with BOUND, so put it on:
        flavs = [BOUND] + flavs
        ex_input = self.tmaker.encode_input(timestamp, flavs)
        ex_target = self.tmaker.encode_target(flavs)
        return ex_input, ex_target

    def __len__(self):
        """Return number of sequences of length seq_len:
        """
        return len(self.all_targets)

    def __getitem__(self, idx):
        """Return an example from all our pre-made tensors."""
        ex_input = self.all_inputs[idx]
        ex_target = self.all_targets[idx]
        sample = {ExampleKeys.INPUT: ex_input,
                  ExampleKeys.TARGET: ex_target}
        return sample

In [None]:
class CollateUtils():
    """Utility collate functions for the DataLoaders."""

    @staticmethod
    def batching_collator(batch):
        """Return an example (dict) where the values are now minibatches.
        Arguments: batch: an iterable over example (dicts) in a Dataset
        Returns: collated, a single example with minibatch payload
        """
        all_inputs = []
        all_targets = []
        all_masks = []
        do_masks = batch[0].get(ExampleKeys.OUT_MASK) is not None
        for example in batch:
            all_inputs.append(example[ExampleKeys.INPUT])
            # targets are either a flat SEQ_LEN vector of targets (in
            # flavors) or 47 (in durs), so reshape accordingly:
            targets = example[ExampleKeys.TARGET]
            if len(targets.shape) == 1:
                all_targets.append(targets.reshape(-1, 1, 1))
            else:
                all_targets.append(targets.reshape(-1, 1, targets.shape[-1]))
            if do_masks:
                masks = example[ExampleKeys.OUT_MASK]
                all_masks.append(masks.reshape(-1, 1, targets.shape[-1]))
        # Join them together along the batch dimension:
        new_inputs = torch.cat(all_inputs, dim=1)
        new_targets = torch.cat(all_targets, dim=1)
        collated = {ExampleKeys.INPUT: new_inputs,
                    ExampleKeys.TARGET: new_targets}
        if do_masks:
            new_masks = torch.cat(all_masks, dim=1)
            collated[ExampleKeys.OUT_MASK] = new_masks
        return collated

In [None]:
class Evaluator(ABC):
    """Class to run the forward pass and compute test scores."""
    def __init__(self, net, device_str, testloader):
        self.net = net
        self.device = torch.device(device_str)
        self.testloader = testloader
        if self.device.type == "cuda":
            if self.device.index == 0:
                self.net = self.net.cuda(0)
            else:
                self.net = self.net.cuda(1)

    @abstractmethod
    def batch_forward(self, batch, criterion):
        """Override to pick the outputs for the batch, compute the loss."""

    def get_test_score(self, epoch, criterion):
        """Get the score of current net on test set."""
        loss_stats = LossStats()
        with torch.no_grad():
            self.net.eval()
            for iter_num, batch in enumerate(self.testloader, 1):
                num, loss = self.batch_forward(batch, criterion)
                loss_stats.update(num, loss)
        overall_loss = loss_stats.overall_loss()
        if epoch is not None:
            logger.info('Test loss, epoch [%d]: %.7f', epoch, overall_loss)
        else:
            logger.info('Test loss: %.7f', overall_loss)
        return overall_loss

In [None]:
def make_flav_dataloaders(args):

    try:
        trainset = FlavDataset(args.seq_len, args.train_flavs)
        trainloader = DataLoader(trainset, batch_size=args.batch_size,
                                 collate_fn=CollateUtils.batching_collator,
                                 shuffle=True)
        
    except AttributeError:
        # No train_flavs provided:
        trainloader = None
    testset = FlavDataset(args.seq_len, args.test_flavs)
    testloader = DataLoader(testset, batch_size=args.batch_size,
                            collate_fn=CollateUtils.batching_collator,
                            shuffle=False)
    
    return trainloader, testloader


class EvaluateFlavLSTM(Evaluator):
    """Class to help with testing of a flavor LSTM."""
    def batch_forward(self, batch, criterion):
        """Run the forward pass and get the number of examples and the loss.
        """
        inputs = batch[ExampleKeys.INPUT]
        targets = batch[ExampleKeys.TARGET]
        num = targets[targets != IGNORE_INDEX].numel()
        inputs, targets = (inputs.to(self.device),
                           targets.to(self.device))
        batch_size = inputs.shape[1]
        self.net.hidden = self.net.init_hidden(self.device, batch_size)
        outputs = self.net(inputs)
        outputs = outputs.reshape(-1, outputs.shape[-1])
        targets = targets.reshape(-1)
        loss = criterion(outputs, targets)
        return num, loss

# Flav Main

In [None]:
X = data.loc[:, [str(i) for i in range(1, 1441+DAYS)] + ['HashFunction']]

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, random_state=10, shuffle=False)

In [None]:
args={
    "train_flavs" : X_train,
    "test_flavs" : X_test,
    "device" : "cuda:0" , 
    "seq_len" : 500,
    "batch_size" : 100,
    "range_start" : 0,
    "max_iters" : 10,
    "lr" : 5e-3,
    "weight_decay" : 1e-5, 
    "nlayers" : 2,
    "nhidden" :  200,
    "model_save_fn" : data_path + 'func_model.pth'
}

# Convert Dict to object
args = namedtuple("Args", args.keys())(*args.values())

In [None]:
tmaker = FlavTensorMaker(hash_functions)
net = get_init_model(args, tmaker)
trainloader, testloader = make_flav_dataloaders(args)
train_args = TrainArgs(args.lr, args.weight_decay, args.max_iters)
eval_lstm = EvaluateFlavLSTM(net, args.device, testloader)
train_run = TrainLSTM(eval_lstm, net, train_args, trainloader)
criterion = torch.nn.CrossEntropyLoss()
trained_net = train_run.run(criterion)

torch.save(trained_net, args.model_save_fn)

In [None]:
 X_test[[str(i) for i in range(1, 1441+DAYS)]].to_numpy()[0]

# Generator

In [None]:
args={
    "arrival_model_pkl" : data_path + "arr_model.pkl",
    "device" : "cpu" ,
    "flav_model" : data_path + "func_model.pth",
    "out_flavs_fn" : "tmp.flavs",
    "minute_day": X_test[[str(i) for i in range(1, 1441+DAYS)]].to_numpy()
}

# Convert Dict to object
args = namedtuple("Args", args.keys())(*args.values())

In [None]:
class GenLSTM():
    """An evaluator that only does forward pass (no loss calculation)."""
    def __init__(self, net, device):
        self.net = net
        self.device = device

    def init_hidden(self):
        self.net.hidden = self.net.init_hidden(self.device)

    def forward(self, my_input):
        outputs = self.net(my_input)
        return outputs

In [None]:
def make_arrival_vector(x):
    return x

def encode_trace_line(timestamp, item_lst):
    itemstr = ITEM_DATA_SEP.join(item_lst)
    out_str = "{}{}{}".format(timestamp, TRACE_DATA_SEP, itemstr)
    return out_str

In [None]:
class Generator():
    """Creates an object that generates a trace (flavs and durs) according
    to our batching baseline trace generator process.
    """
    def __init__(self, device, arrival_mdl, flav_lstm):
        """arrival_mdl: the Poisson GLM mdl from statsmodels
        flav_lstm/dur_lstm: LSTM evaluators to run forward passes
        flav_map_fn: String, filename with map from flavors to their codes.
        interval_map_fn: String, filename where mapping from durations
        to intervals stored. Used here to get list of intervals.
        bsize_map_fn: String, mapping from integers to bsize codes
        (e.g. 11-15, or 26-50).  This is batches of flavors, not to be
        confused with "batches" of examples for ML.
        range_start/stop: Int: timestamps for start/end of training data.
        range_idx: Int: If given, tensor maker will use this index in
        the range features (modulo the number of range features).
        """
        self.device = device
        self.arrival_mdl = arrival_mdl
        self.flav_to_idxs, self.idx_to_flavs = FlavTensorMaker.get_flav_idxs(hash_functions)
        
        self.flav_tmaker = FlavTensorMaker(hash_functions)
        
        self.flav_lstm = flav_lstm

    def __get_narrivals(self, timestamp):
        pred_mean = self.arrival_mdl.predict([timestamp])[0]
        narrivals = np.random.poisson(pred_mean)
        return narrivals

    def __init_flav_input(self, timestamp):
        """Initialize input tensor to a BOUND at given timestamp."""
        flav_lst = [BOUND, BOUND]  # second BOUND ignored by tmaker
        my_input = self.flav_tmaker.encode_input(timestamp, flav_lst)
        my_input = my_input.to(self.device)
        return my_input

    def __adjust_flav_input(self, my_input, prev_flav):
        """Replace the flavor part of the input only."""
        self.flav_tmaker.replace_flav_input(my_input, prev_flav)
        return my_input

    def __sample_flav(self, output):
        """Sample flavor from output of flavor LSTM."""
        probs = torch.softmax(output.reshape(-1), dim=0)
        flav_idx = torch.multinomial(probs, 1).item()
        # Also get flavor string itself:
        flav_flav = self.idx_to_flavs[flav_idx]
        return flav_flav, flav_idx

    def __generate_flavs(self, timestamp, target_nbatches):
        """Auto-regressively generate target_nbatches batches of flavors, at
        given timestamp.
        """
        my_input = self.__init_flav_input(timestamp)
        flav_flavs = []
        flav_idxs = []
        nseen_batches = 0
        prev_flav = BOUND
        while True:
            output = self.flav_lstm.forward(my_input)
            next_flav, next_idx = self.__sample_flav(output)
            # Shouldn't happen, but skip if it does:
            if next_flav == BOUND and prev_flav == BOUND:
                continue
            flav_flavs.append(next_flav)
            flav_idxs.append(next_idx)
            # Increment number batches seen on each bound:
            if next_flav == BOUND:
                nseen_batches += 1
                if nseen_batches == target_nbatches:
                    break
            # Otherwise, get next input and continue:
            my_input = self.__adjust_flav_input(my_input, next_flav)
            prev_flav = next_flav
        return flav_flavs, flav_idxs


    def __generate_batches(self, timestamp, nbatches):
        """Use flav/dur LSTMs to generate trace for a single row/timestamp."""
        flav_flavs, flav_idxs = self.__generate_flavs(timestamp, nbatches)
        return flav_idxs

    def __output_flavs(self, timestamp, flav_idxs, out_flavs_file):
        """Make and output the flavor line."""
        flav_lst = [self.idx_to_flavs[i] for i in flav_idxs]
        flav_out = encode_trace_line(timestamp, flav_lst)
        print(flav_out, file=out_flavs_file)

    def __output_zero_arrivals(self, timestamp, out_flavs_file):
        flav_out = encode_trace_line(timestamp, [])
        print(flav_out, file=out_flavs_file)


    def __call__(self, out_flavs_file):
        """Generate trace for timestamps from start_s to stop_s inclusive."""
        self.flav_lstm.init_hidden()
        for ntimestamps, timestamp in enumerate(args.minute_day):
            nbatches = self.__get_narrivals(timestamp)
            if nbatches == 0:
                self.__output_zero_arrivals(timestamp, out_flavs_file)
                continue
            flav_idxs = self.__generate_batches(timestamp, nbatches)
            self.__output_flavs(timestamp, flav_idxs, out_flavs_file)
            if ntimestamps > 0 and ntimestamps % REPORT == 0:
                logger.info("Generated %d output lines (now on %d)", ntimestamps, timestamp)


def get_lstm_eval(device, model_fn):
    """Initialize the generation LSTM evaluator."""
    net = trained_net # TraceLSTM.create_from_path(model_fn, device)
    return GenLSTM(net, device)


def main(args):
    logger.info("Reading models")
    arrival_mdl = sm.load(args.arrival_model_pkl)
    flav_gen = get_lstm_eval(args.device, args.flav_model)
    lstm_generator = Generator(args.device, arrival_mdl, flav_gen)
    with open(args.out_flavs_fn, "w") as flavs_file:
        logger.info("Running generation")
        lstm_generator(flavs_file)

In [None]:
main(args)