In [1]:
#export 
from pathlib import Path

import pandas as pd
import psutil

from loop.callbacks import Callback, Order

In [2]:
#export
MB = 1024 ** 2
GB = 1024 * MB

In [3]:
#export
class MemoryUsage(Callback):
    """A debugging callback used to track the amount of CPU memory used during training."""
    
    order = Order.Internal()
    
    def __init__(self, filename: str='memory.csv', units: int=GB):
        self.filename = filename
        self.units = units
    
    def training_started(self, **kwargs):
        self.iter = 0
        self._stream = self.open()
        self.write('index,mem_percent,mem_free,mem_available,mem_used')
        
    def training_ended(self, **kwargs):
        self.close()
    
    def batch_ended(self, **kwargs):
        self.iter += 1
        mem = psutil.virtual_memory()
        record = [self.iter, mem.percent, mem.available, mem.used]
        self.write(','.join([str(x) for x in record]))
        
    def plot(self, **fig_kwargs):
        mem = pd.read_csv(self.filename)
        index = mem.columns.str.startswith('mem')
        mem[mem.columns[index]] /= self.units
        f, ax = plt.subplots(2, 1, **fig_kwargs)
        ax1, ax2 = ax.flat
        unit_name = ('GB' if self.units == GB else 
                     'MB' if self.units == MB else 
                     '')
        self.plot_memory_percentage(ax1, mem)
        self.plot_memory_usage(ax2, mem, unit_name)
        
    @staticmethod
    def plot_memory_percentage(ax, mem):
        mem.plot(x='index', y='mem_percent', ax=ax)
        ax.set_title('Memory usage during training', fontsize=20)
        ax.set_xlabel('Batch Index', fontsize=16)
        ax.set_ylabel('Percentage', fontsize=16)

    @staticmethod
    def plot_memory_usage(ax, mem, y_label):
        mem.plot(x='index', y=['mem_available', 'mem_used'], ax=ax)
        ax.set_xlabel('Batch Index', fontsize=16)
        ax.set_ylabel(y_label, fontsize=16)
    
    def open(self):
        self.close()
        return Path(self.filename).open('w')
        
    def write(self, msg):
        self._stream.write(msg + '\n')
        self._stream.flush()
    
    def close(self):
        if hasattr(self, '_stream') and self._stream is not None:
            self._stream.close()
            self._stream.flush()
            self._stream = None

In [4]:
from pathlib import Path

from loop.training import Loop
from loop.metrics import accuracy
from loop.testing import get_mnist
from loop.modules import fc_network


trn_ds, val_ds = get_mnist(flat=True)
loop = Loop(fc_network(784, [100, 10]), cbs=[MemoryUsage()])
loop.fit_datasets(trn_ds, val_ds, epochs=3, batch_size=100)
loop.cb['memory_usage'].plot()

Epoch:    1 | train_loss=-52784.4358, valid_loss=-65107.9847
Epoch:    2 | train_loss=-228083.7342, valid_loss=-234273.1725
Epoch:    3 | train_loss=-503976.5155, valid_loss=-510168.7209


Error! Training loop was interupted with un-expected exception
--------------------------------------------------------------
  File "/home/ck/code/loop/dev/loop/training.py", line 80, in train
    self.cb.training_ended(phases=phases)

  File "/home/ck/code/loop/dev/loop/callbacks.py", line 197, in training_ended
    def training_ended(self, **kwargs): self('training_ended', **kwargs)

  File "/home/ck/code/loop/dev/loop/callbacks.py", line 231, in __call__
    method(**kwargs)

  File "<ipython-input-3-fcca88589cf8>", line 17, in training_ended
    self.close()

  File "<ipython-input-3-fcca88589cf8>", line 61, in close
    self._stream.flush()

I/O operation on closed file.
--------------------------------------------------------------


NameError: name 'plt' is not defined