In [2]:
from vocabularies import VocabType
from config import Config
from interactive_predict import InteractivePredictor
from model_base import Code2VecModelBase
from tensorflow_model import Code2VecModel as model
import tensorflow as tf
import numpy as np
import time
from typing import Dict, Optional, List, Iterable
from collections import Counter
from functools import partial

from path_context_reader import PathContextReader, ModelInputTensorsFormer, ReaderInputTensors, EstimatorAction
from common import common
from vocabularies import VocabType
from config import Config
from model_base import Code2VecModelBase, ModelEvaluationResults, ModelPredictionResults

In [8]:
log('Starting training')
start_time = time.time()

batch_num = 0
sum_loss = 0
multi_batch_start_time = time.time()
num_batches_to_save_and_eval = max(int(self.config.train_steps_per_epoch * self.config.SAVE_EVERY_EPOCHS), 1)

train_reader = PathContextReader(vocabs=self.vocabs,
                                    model_input_tensors_former=_TFTrainModelInputTensorsFormer(),
                                    config=self.config, estimator_action=EstimatorAction.Train)
input_iterator = tf.compat.v1.data.make_initializable_iterator(train_reader.get_dataset())
input_iterator_reset_op = input_iterator.initializer
input_tensors = input_iterator.get_next()

optimizer, train_loss = self._build_tf_training_graph(input_tensors)
saver = tf.compat.v1.train.Saver(max_to_keep=self.config.MAX_TO_KEEP)

self.log('Number of trainable params: {}'.format(
    np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])))
for variable in tf.compat.v1.trainable_variables():
    self.log("variable name: {} -- shape: {} -- #params: {}".format(
        variable.name, variable.get_shape(), np.prod(variable.get_shape().as_list())))

self._initialize_session_variables()

if self.config.MODEL_LOAD_PATH:
    self._load_inner_model(self.sess)

self.sess.run(input_iterator_reset_op)
time.sleep(1)
self.log('Started reader...')
# run evaluation in a loop until iterator is exhausted.
try:
    while True:
        # Each iteration = batch. We iterate as long as the tf iterator (reader) yields batches.
        batch_num += 1

        # Actual training for the current batch.
        _, batch_loss = self.sess.run([optimizer, train_loss])

        sum_loss += batch_loss
        if batch_num % self.config.NUM_BATCHES_TO_LOG_PROGRESS == 0:
            self._trace_training(sum_loss, batch_num, multi_batch_start_time)
            # Uri: the "shuffle_batch/random_shuffle_queue_Size:0" op does not exist since the migration to the new reader.
            # self.log('Number of waiting examples in queue: %d' % self.sess.run(
            #    "shuffle_batch/random_shuffle_queue_Size:0"))
            sum_loss = 0
            multi_batch_start_time = time.time()
        if batch_num % num_batches_to_save_and_eval == 0:
            epoch_num = int((batch_num / num_batches_to_save_and_eval) * self.config.SAVE_EVERY_EPOCHS)
            model_save_path = self.config.MODEL_SAVE_PATH + '_iter' + str(epoch_num)
            self.save(model_save_path)
            self.log('Saved after %d epochs in: %s' % (epoch_num, model_save_path))
            evaluation_results = self.evaluate()
            evaluation_results_str = (str(evaluation_results).replace('topk', 'top{}'.format(
                self.config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
            self.log('After {nr_epochs} epochs -- {evaluation_results}'.format(
                nr_epochs=epoch_num,
                evaluation_results=evaluation_results_str
            ))
except tf.errors.OutOfRangeError:
    pass  # The reader iterator is exhausted and have no more batches to produce.

self.log('Done training')

if self.config.MODEL_SAVE_PATH:
    self._save_inner_model(self.config.MODEL_SAVE_PATH)
    self.log('Model saved in file: %s' % self.config.MODEL_SAVE_PATH)

elapsed = int(time.time() - start_time)
self.log("Training time: %sH:%sM:%sS\n" % ((elapsed // 60 // 60), (elapsed // 60) % 60, elapsed % 60))

NameError: name 'self' is not defined