# Quantum NLP Challenge

In [None]:
%pdb 0

In [None]:
%matplotlib inline

In [None]:
import sys

In [None]:
# !{sys.executable} -m pip install pandas
# !{sys.executable} -m pip install scikit-learn

## Data

In [None]:
import LovelyPlots.utils as lp
lp.set_retina()

In [None]:
import pandas as pd
import numpy as np
import sklearn as sk

RANDOM_SEED = 220811
np.random.seed(RANDOM_SEED)

In [None]:
import os
import warnings

warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
from lambeq import BobcatParser, DepCCGParser
from discopy import grammar

In [None]:
from lambeq import Rewriter
from lambeq import AtomicType, IQPAnsatz, remove_cups
from lambeq import TketModel, NumpyModel
from lambeq import QuantumTrainer, SPSAOptimizer
from lambeq import Dataset

In [None]:
from pytket.circuit.display import render_circuit_jupyter

In [None]:
from pytket.extensions.qiskit import AerBackend

In [None]:
from discopy.tensor import Diagram

In [None]:
df = pd.read_csv("MC1.TXT", header=None, sep=", ", engine="python")

In [None]:
df.columns = ["s1", "s2", "label"]

In [None]:
df.head()

In [None]:
vocab = set()
lengths = set()

for i, r in df.iterrows():
    tokens = r.s1.split(" ")
    lengths.add(len(tokens))
    [vocab.add(w) for w in tokens]
    tokens = r.s2.split(" ")
    lengths.add(len(tokens))
    [vocab.add(w) for w in tokens]

In [None]:
print(vocab, len(vocab), max(lengths))

## Lambeq tutorial

### Sentence input

In [None]:
sentence = "John walks in the park"

In [None]:
parser = BobcatParser()
diagram = parser.sentence2diagram(sentence)

In [None]:
grammar.draw(diagram, figsize=(16, 4), fontsize=12)

### Diagram rewriting

In [None]:
# Prepositional phrase rewrite rule
rewriter = Rewriter(["prepositional_phrase", "determiner"])
rewritten_diagram = rewriter(diagram)

In [None]:
rewritten_diagram.draw(figsize=(16, 4), fontsize=12)

In [None]:
# Normalize
normalized_diagram = rewritten_diagram.normal_form()
normalized_diagram.draw(figsize=(16, 4), fontsize=12)

### Parametrization

In [None]:
# Atomic types
N = AtomicType.NOUN
S = AtomicType.SENTENCE

In [None]:
# Convert string diagram to qc - 1 qubit per atomic type
ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=2)
discopy_circuit = ansatz(normalized_diagram)
discopy_circuit.draw(figsize=(40, 16), fontsize=12)

In [None]:
# Convert qc from DisCoPy to pytket format
tket_circuit = discopy_circuit.to_tk()
render_circuit_jupyter(tket_circuit)

### Training: Quantum case

Inspecting `grammar.json`:

In [None]:
import json

In [None]:
grammar_json_path = "/home/jovyan/.cache/lambeq/bobcat/bert/grammar.json"

In [None]:
with open(grammar_json_path, "r") as f:
    grammar_json = json.load(f)

Based on the sentence structures, we have the following grammatical structures. This might be important later.
- Noun
- Noun phrase
- Sentence
- Verb (transitive)
- Adjective

In [None]:
parser = BobcatParser()

In [None]:
df.head()

In [None]:
df.shape

In [None]:
df[df["label"] == 1].shape

In [None]:
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold

In [None]:
df_train_val, df_test = train_test_split(df, test_size=0.2, shuffle=True)

In [None]:
df_train_val.shape

In [None]:
rskf = RepeatedStratifiedKFold(n_splits=5, n_repeats=5)

In [None]:
df_train_val.columns

In [None]:
rskf_splits = list(rskf.split(df_train_val[["s1", "s2"]], y=df_train_val["label"]))

In [None]:
qs, pn, d = (2, 3, 1)

In [None]:
ansatz = IQPAnsatz({AtomicType.NOUN: qs, AtomicType.SENTENCE: 1},
                   n_single_qubit_params=pn, n_layers=d,)

In [None]:
rskf_splits[0][0].shape, rskf_splits[0][1].shape

In [None]:
train_idx, val_idx = rskf_splits[0]
df_train, df_val = df_train_val.iloc[train_idx], df_train_val.iloc[val_idx]

In [None]:
print(df_train[df_train["label"] == 0].shape)
print(df_train[df_train["label"] == 1].shape)

In [None]:
def preprocess_df(df, ansatz):
    # Create raw diagram for both datasets.
    # We require both sentences to have a diagram
    # so it can be part of the dataset.
    df["s1_diagram"] = parser.sentences2diagrams(list(df["s1"].values), suppress_exceptions=True)
    df["s2_diagram"] = parser.sentences2diagrams(list(df["s2"].values), suppress_exceptions=True)
    df.dropna(inplace=True)

    # Convert to normal form
    df["s1_diagram"] = df["s1_diagram"].apply(lambda d: d.normal_form())
    df["s2_diagram"] = df["s2_diagram"].apply(lambda d: d.normal_form())

    # Vectorize label
    df["label_v"] = df["label"].apply(lambda l: [0, 1] if l == 0 else [1, 0])

    # Create circuits
    df["s1_circuit"] = df["s1_diagram"].apply(lambda d: ansatz(remove_cups(d)))
    df["s2_circuit"] = df["s2_diagram"].apply(lambda d: ansatz(remove_cups(d)))

In [None]:
preprocess_df(df_train, ansatz)

In [None]:
df_train.head()

In [None]:
df_train["s1_diagram"][40].draw(figsize=(4, 3), fontsize=12)

In [None]:
df_train["s2_diagram"][40].draw(figsize=(4, 3), fontsize=12)

In [None]:
df_train["s1_circuit"][40].draw(figsize=(4, 3), fontsize=12)

In [None]:
df_train["s2_circuit"][40].draw(figsize=(4, 3), fontsize=12)

In [None]:
render_circuit_jupyter(df_train["s1_circuit"][40].to_tk())

In [None]:
render_circuit_jupyter(df_train["s2_circuit"][40].to_tk())

In [None]:
preprocess_df(df_val, ansatz)

In [None]:
preprocess_df(df_test, ansatz)

In [None]:
train_circuits = list(df_train[["s1_circuit", "s2_circuit"]].values)
val_circuits = list(df_val[["s1_circuit", "s2_circuit"]].values)
test_circuits = list(df_test[["s1_circuit", "s2_circuit"]].values)

In [None]:
all_circuits = train_circuits + val_circuits + test_circuits

In [None]:
assert len(all_circuits) == len(df)

In [None]:
train_diagrams = list(df_train[["s1_diagram", "s2_diagram"]].values)
val_diagrams = list(df_val[["s1_diagram", "s2_diagram"]].values)
test_diagrams = list(df_test[["s1_diagram", "s2_diagram"]].values)

In [None]:
all_diagrams = train_diagrams + val_diagrams + test_diagrams             

In [None]:
train_labels = list(df_train["label_v"].values)
val_labels = list(df_val["label_v"].values)
test_labels = list(df_test["label_v"].values)

In [None]:
backend = AerBackend()
backend_config = {
    "backend": backend,
    "compilation": backend.default_compilation_pass(2),
    "shots": 2**13,
}

In [None]:
class CustomTketModel(TketModel):
    def forward(self, x: list[[Diagram, Diagram]]) -> np.ndarray:
        # The forward pass takes x with 2 circuits
        # for each of the sentence being compared
        s1_diagrams = []
        s2_diagrams = []
        n_rows = len(x)
        for s1d, s2d in x:
            s1_diagrams.append(s1d)
            s2_diagrams.append(s2d)
        
        s1_output = self.get_diagram_output(s1_diagrams)
        s2_output = self.get_diagram_output(s2_diagrams)
        s1_output = s1_output.reshape((n_rows, -1))[:,:2]
        s2_output = s2_output.reshape((n_rows, -1))[:,:2]
        
        s1_output_norm = np.sqrt(np.sum(s1_output * s1_output, axis=1))
        s2_output_norm = np.sqrt(np.sum(s2_output * s2_output, axis=1))
        denom = s1_output_norm * s2_output_norm
        s1_dot_s2 = np.sum(s1_output[:,:2] * s2_output[:,:2], axis=1) / denom

        complement = np.ones_like(s1_dot_s2) - s1_dot_s2
        out = np.array([s1_dot_s2,
                        complement]).T

        return out

In [None]:
class CustomNumpyModel(NumpyModel):
    def forward(self, x: list[[Diagram, Diagram]]) -> np.ndarray:
        # The forward pass takes x with 2 circuits
        # for each of the sentence being compared
        s1_diagrams = []
        s2_diagrams = []
        n_rows = len(x)
        for s1d, s2d in x:
            s1_diagrams.append(s1d)
            s2_diagrams.append(s2d)
        
        s1_output = self.get_diagram_output(s1_diagrams)
        s2_output = self.get_diagram_output(s2_diagrams)
        s1_output = s1_output.reshape((n_rows, -1))[:,:2]
        s2_output = s2_output.reshape((n_rows, -1))[:,:2]
        
        s1_output_norm = np.sqrt(np.sum(s1_output * s1_output, axis=1))
        s2_output_norm = np.sqrt(np.sum(s2_output * s2_output, axis=1))
        denom = s1_output_norm * s2_output_norm
        s1_dot_s2 = np.sum(s1_output[:,:2] * s2_output[:,:2], axis=1) / denom

        complement = np.ones_like(s1_dot_s2) - s1_dot_s2
        out = np.array([s1_dot_s2,
                        complement]).T

        return out

In [None]:
np.array(all_circuits).reshape(-1).shape

In [None]:
BATCH_SIZE = 32
EPOCHS = 500

In [None]:
train_dataset = Dataset(train_circuits,
                        train_labels,
                        batch_size=BATCH_SIZE)

In [None]:
val_dataset = Dataset(val_circuits,
                      val_labels,
                      shuffle=False)

In [None]:
loss = lambda y_hat, y: -np.sum(y * np.log(y_hat)) / len(y)
acc = lambda y_hat, y: np.sum(np.round(y_hat) == y) / len(y) / 2
eval_metrics = {"acc": acc}

In [None]:
tket_model = CustomTketModel.from_diagrams(np.array(all_circuits).reshape(-1),
                                           backend_config=backend_config)

In [None]:
npy_model = CustomNumpyModel.from_diagrams(np.array(all_circuits).reshape(-1),
                                           use_jit=True)
                                           #backend_config=backend_config)

In [None]:
trainer = QuantumTrainer(
    tket_model,
    loss_function=loss,
    epochs=EPOCHS,
    optimizer=SPSAOptimizer,
    optim_hyperparams={"a": 0.05, "c": 0.06, "A": 0.01 * EPOCHS},
    evaluate_functions=eval_metrics,
    verbose="text",
    seed=RANDOM_SEED,
)

In [None]:
trainer.log_dir

In [None]:
# (qs, pn, d) = (2, 3, 1)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
# (qs, pn, d) = (2, 3, 1)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2,
                                                     sharex=True,
                                                     sharey="row",
                                                     figsize=(12, 6))

ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs[::2], color=next(colours))
ax_bl.plot(trainer.train_results['acc'][::2], color=next(colours))
ax_tr.plot(trainer.val_costs[::2], color=next(colours))
ax_br.plot(trainer.val_results['acc'][::2], color=next(colours))

# print test accuracy
test_acc = acc(tket_model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())

In [None]:
# (qs, pn, d) = (1, 3, 2)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2,
                                                     sharex=True,
                                                     sharey="row",
                                                     figsize=(12, 6))

ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(trainer.train_results['acc'], color=next(colours))
ax_tr.plot(trainer.val_costs, color=next(colours))
ax_br.plot(trainer.val_results['acc'], color=next(colours))

# print test accuracy
test_acc = acc(tket_model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())

In [None]:
# (qs, pn, d) = (1, 3, 1)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2,
                                                     sharex=True,
                                                     sharey="row",
                                                     figsize=(12, 6))

ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(trainer.train_results['acc'], color=next(colours))
ax_tr.plot(trainer.val_costs, color=next(colours))
ax_br.plot(trainer.val_results['acc'], color=next(colours))

# print test accuracy
test_acc = acc(tket_model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())

In [None]:
# (qs, pn, d) = (1, 1, 2)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2,
                                                     sharex=True,
                                                     sharey="row",
                                                     figsize=(12, 6))

ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(trainer.train_results['acc'], color=next(colours))
ax_tr.plot(trainer.val_costs, color=next(colours))
ax_br.plot(trainer.val_results['acc'], color=next(colours))

# print test accuracy
test_acc = acc(tket_model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())

In [None]:
# (qs, pn, d) = (1, 1, 1)
trainer.fit(train_dataset, val_dataset, evaluation_step=1, logging_step=1)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.style.use(["ipynb", "use_tex", "colors10-ls"])

In [None]:
fig, ((ax_tl, ax_tr), (ax_bl, ax_br)) = plt.subplots(2, 2,
                                                     sharex=True,
                                                     sharey="row",
                                                     figsize=(12, 6))

ax_tl.set_title('Training set')
ax_tr.set_title('Development set')
ax_bl.set_xlabel('Iterations')
ax_br.set_xlabel('Iterations')
ax_bl.set_ylabel('Accuracy')
ax_tl.set_ylabel('Loss')

colours = iter(plt.rcParams['axes.prop_cycle'].by_key()['color'])
ax_tl.plot(trainer.train_epoch_costs, color=next(colours))
ax_bl.plot(trainer.train_results['acc'], color=next(colours))
ax_tr.plot(trainer.val_costs, color=next(colours))
ax_br.plot(trainer.val_results['acc'], color=next(colours))

# print test accuracy
test_acc = acc(tket_model(val_circuits), val_labels)
print('Validation accuracy:', test_acc.item())

In [None]:
log_dir = trainer.log_dir
log_dir

In [None]:
tket_model.make_checkpoint("./checkpoints")

### Aside: What happens during the forward pass?

In [None]:
from discopy.quantum import Circuit

In [None]:
all_circuits[0]

In [None]:
all_diagrams[0].draw(figsize=(16, 9), fontsize=12)

In [None]:
all_circuits[0].draw(figsize=(16, 9), fontsize=12)

In [None]:
render_circuit_jupyter(all_circuits[0].to_tk())

In [None]:
lambdified_diagram = model._make_lambda(all_circuits[0])

In [None]:
type(lambdified_diagram)

In [None]:
model.weights.shape

In [None]:
tensors = Circuit.eval(
    *lambdified_diagram(*model.weights),
    **model.backend_config,
    seed=model._randint(),
)

In [None]:
lambdified_diagram(*model.weights)

In [None]:
tensors

In [None]:
np.array([model._normalise_vector(t.array) for t in tensors])

## Questions/Notes

- What is getting measured?
- Domain = input, codomain = output, box = function
- Model must receive __ALL__ circuits during init - including all datasplits. This is similar to sequence models having a fixed sequence length during init.

### Debugging

In [None]:
model = CustomNumpyModel.from_diagrams(np.array(all_circuits).reshape(-1),
                                       use_jit=True,)
                                       # backend_config=backend_config)

In [None]:
# Customize QuantumTrainer
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from datetime import datetime
from math import ceil
import os
import random
import socket
import sys
from typing import Any, Optional, Union
from typing import TYPE_CHECKING

from discopy import Tensor
from tqdm.notebook import tqdm, trange

if TYPE_CHECKING:
    from torch.utils.tensorboard import SummaryWriter

from lambeq.core.globals import VerbosityLevel
from lambeq.training.checkpoint import Checkpoint
from lambeq.training.dataset import Dataset
from lambeq.training.model import Model


def _import_tensorboard_writer() -> None:
    global SummaryWriter
    try:
        from torch.utils.tensorboard import SummaryWriter
    except ImportError:  # pragma: no cover
        raise ImportError('tensorboard not found. Please install it using '
                          '`pip install tensorboard`.')


_StrPathT = Union[str, 'os.PathLike[str]']


class CustomQuantumTrainer(QuantumTrainer):
    def fit(self,
            train_dataset: Dataset,
            val_dataset: Optional[Dataset] = None,
            evaluation_step: int = 1,
            logging_step: int = 1) -> None:
        """Fit the model on the training data and, optionally,
        evaluate it on the validation data.

        Parameters
        ----------
        train_dataset : :py:class:`Dataset`
            Dataset used for training.
        val_dataset : :py:class:`Dataset`, optional
            Validation dataset.
        evaluation_step : int, default: 1
            Sets the intervals at which the metrics are evaluated on the
            validation dataset.
        logging_step : int, default: 1
            Sets the intervals at which the training statistics are
            printed if `verbose = 'text'` (otherwise ignored).

        """
        if self.from_checkpoint:
            self._load_extra_chkpoint_info(self.checkpoint)

        def writer_helper(*args: Any) -> None:
            if self.use_tensorboard:
                self.writer.add_scalar(*args)
            else:
                print(*args)

        # initialise progress bar
        step = self.start_step
        batches_per_epoch = ceil(len(train_dataset)/train_dataset.batch_size)
        status_bar = tqdm(total=float('inf'),
                          bar_format='{desc}',
                          desc=self._generate_stat_report(),
                          disable=(
                                self.verbose != VerbosityLevel.PROGRESS.value),
                          leave=True,
                          position=0)

        # start training loop
        for epoch in trange(self.start_epoch,
                            self.epochs,
                            desc='Epoch',
                            disable=(
                                self.verbose != VerbosityLevel.PROGRESS.value),
                            leave=False,
                            position=1):
            train_loss = 0.0
            with Tensor.backend(self.backend):
                for batch in tqdm(train_dataset,
                                  desc='Batch',
                                  total=batches_per_epoch,
                                  disable=(self.verbose
                                           != VerbosityLevel.PROGRESS.value),
                                  leave=False,
                                  position=2):
                    step += 1
                    x, y_label = batch
                    y_hat, loss = self.training_step(batch)
                    if (self.evaluate_on_train
                            and self.evaluate_functions is not None):
                        for metr, func in self.evaluate_functions.items():
                            res = func(y_hat, y_label)
                            metric = self._train_results_epoch[metr]
                            metric.append(len(x) * res)
                    train_loss += len(batch[0]) * loss
                    writer_helper('train/step_loss', loss, step)
                    status_bar.set_description(
                            self._generate_stat_report(
                                train_loss=loss,
                                val_loss=(self.val_costs[-1] if self.val_costs
                                          else None)))
            train_loss /= len(train_dataset)
            self.train_epoch_costs.append(train_loss)
            writer_helper('train/epoch_loss', train_loss, epoch + 1)

            # evaluate on train
            if (self.evaluate_on_train
                    and self.evaluate_functions is not None):
                for name in self._train_results_epoch:
                    self.train_results[name].append(
                        sum(self._train_results_epoch[name])/len(train_dataset)
                    )
                    self._train_results_epoch[name] = []  # reset
                    writer_helper(
                        f'train/{name}', self.train_results[name][-1],
                        epoch+1)
                    if self.verbose == VerbosityLevel.PROGRESS.value:
                        status_bar.set_description(
                                self._generate_stat_report(
                                    train_loss=train_loss,
                                    val_loss=(self.val_costs[-1]
                                              if self.val_costs else None)))

            # evaluate metrics on validation data
            if val_dataset is not None:
                if epoch % evaluation_step == 0:
                    val_loss = 0.0
                    batches_per_validation = ceil(len(val_dataset)
                                                  / val_dataset.batch_size)
                    writer_helper('batches_per_validation', batches_per_validation, len(val_dataset), val_dataset.batch_size)
                    with Tensor.backend(self.backend):
                        disable_tqdm = (self.verbose
                                        != VerbosityLevel.PROGRESS.value)
                        for v_batch in tqdm(val_dataset,
                                            desc='Validation batch',
                                            total=batches_per_validation,
                                            disable=disable_tqdm,
                                            leave=False,
                                            position=2):
                            writer_helper("***", v_batch)
                            x_val, y_label_val = v_batch
                            writer_helper("***", x_val, y_label_val)
                            y_hat_val, cur_loss = self.validation_step(v_batch)
                            writer_helper("***", y_hat_val, cur_loss)
                            val_loss += cur_loss * len(x_val)
                            if self.evaluate_functions is not None:
                                for metr, func in (
                                        self.evaluate_functions.items()):
                                    res = func(y_hat_val, y_label_val)
                                    self._val_results_epoch[metr].append(
                                        len(x_val)*res)
                            status_bar.set_description(
                                    self._generate_stat_report(
                                        train_loss=train_loss,
                                        val_loss=val_loss))
                        val_loss /= len(val_dataset)
                        self.val_costs.append(val_loss)
                        status_bar.set_description(
                                self._generate_stat_report(
                                    train_loss=train_loss,
                                    val_loss=val_loss))
                        writer_helper('val/loss', val_loss, epoch+1)

                    if self.evaluate_functions is not None:
                        for name in self._val_results_epoch:
                            self.val_results[name].append(
                                sum(self._val_results_epoch[name])
                                / len(val_dataset))
                            self._val_results_epoch[name] = []  # reset
                            writer_helper(
                                f'val/{name}', self.val_results[name][-1],
                                epoch + 1)
                            status_bar.set_description(
                                    self._generate_stat_report(
                                        train_loss=train_loss,
                                        val_loss=val_loss))
            # save checkpoint info
            save_dict = {'epoch': epoch+1,
                         'model_weights': self.model.weights,
                         'model_symbols': self.model.symbols,
                        'train_costs': self.train_costs,
                         'train_epoch_costs': self.train_epoch_costs,
                         'train_results': self.train_results,
                         'val_costs': self.val_costs,
                         'val_results': self.val_results,
                         'random_state': random.getstate(),
                         'step': step}
            print(f"save_dict: {save_dict}")
            self.save_checkpoint(save_dict, self.log_dir)
            if self.verbose == VerbosityLevel.TEXT.value:
                if epoch == 0 or (epoch+1) % logging_step == 0:
                    space = (len(str(self.epochs))-len(str(epoch+1)) + 2) * ' '
                    prefix = f'Epoch {epoch+1}:' + space
                    print(prefix + self._generate_stat_report(
                            train_loss=train_loss,
                            val_loss=(self.val_costs[-1] if self.val_costs
                                      else None)),
                          file=sys.stderr)
        status_bar.close()
        if self.verbose == VerbosityLevel.TEXT.value:
            print('\nTraining completed!', file=sys.stderr)