# Problem Statement

The `fast.ai` library has a callback to track training metrics history. However, the history is reported via console, or Jupyter widget, and there are no callbacks to store these results into CSV format. In this notebook, the author proposes his approach to implement a callback similar to [CSVLogger from Keras library](https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L1135) which will save tracked metrics into persistent file.

In [1]:
%reload_ext autoreload

In [2]:
%autoreload 2

In [47]:
from fastai import *
from fastai.torch_core import *
from fastai.vision import *
from fastai.metrics import *
from torchvision.models import resnet18

In [53]:
@dataclass
class CSVLogger(LearnerCallback):
    "A `LearnerCallback` that "
    filename:str='history.csv'

    def __post_init__(self):
        self.path = Path(self.filename)
        self.file = None

    @property
    def header(self):
        return self.learn.recorder.names
    
    def read_logged_file(self):
        return pd.read_csv(self.path)

    def on_train_begin(self, metrics_names:StrList, **kwargs:Any)->None:
        self.path.parent.mkdir(parents=True, exist_ok=True)
        self.file = self.path.open('w')
        self.file.write(','.join(self.header) + '\n')

    def on_epoch_end(self, epoch:int, smooth_loss:Tensor, last_metrics:MetricsList, **kwargs:Any)->bool:
        self.write_stats([epoch, smooth_loss] + last_metrics)

    def on_train_end(self, **kwargs:Any)->None:
        self.file.flush()
        self.file.close()

    def write_stats(self, stats:TensorOrNumList)->None:
        stats = [str(stat) if isinstance(stat, int) else f'{stat:.6f}'
                 for name,stat in zip(self.header,stats)]
        str_stats = ','.join(stats)
        self.file.write(str_stats + '\n')

## Example

Let's train MNIST classifier and track its metrics.

In [40]:
path = untar_data(URLs.MNIST_TINY)

Downloading http://files.fast.ai/data/examples/mnist_tiny


HBox(children=(IntProgress(value=0, max=316367), HTML(value='')))

In [48]:
data = ImageDataBunch.from_folder(path)

In [57]:
learn = ConvLearner(data, resnet18, metrics=[accuracy, error_rate])

In [58]:
cb = CSVLogger(learn)

In [59]:
learn.fit(3, callbacks=[cb])

VBox(children=(HBox(children=(IntProgress(value=0, max=3), HTML(value='0.00% [0/3 00:00<00:00]'))), HTML(value…

Total time: 00:02
epoch  train loss  valid loss  accuracy  error_rate
1      0.551399    0.374805    0.828326  0.171674    (00:00)
2      0.354117    0.291277    0.882690  0.117310    (00:00)
3      0.287861    0.302602    0.882690  0.117310    (00:00)



In [61]:
log_df = cb.read_logged_file()

In [None]:
assert cb.path.exists()
assert not log_df.empty
assert learn.recorder.names == log_df.columns

In [62]:
learn.recorder.names

['epoch', 'train loss', 'valid loss', 'accuracy', 'error_rate']