In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
import seaborn as sns

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette(['#9e0059', '#6da7de', '#ee266d', '#dee000', '#eb861e'])
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
train_all = pd.read_csv('train_all.log')
train_no_ref = pd.read_csv('train_no_ref.log')
train_no_fragment = pd.read_csv('train_no_fragment.log')

In [None]:
width = 7
height = width / 1.618
fig, ax = plt.subplots(figsize=(width, height))

for data, label, color in zip(
        [train_all, train_no_ref, train_no_fragment],
        ['GLEAMS', 'GLEAMS minus ref spectra features',
         'GLEAMS minus fragment features'],
        [('#9e0059', '#e8cad3'), ('#6da7de', '#dbe5f1'),
         ('#ee266d', '#fad2d8')]):
    ax.scatter(data['epoch'] + 1, data['val_loss_2'], marker='o', c=color[1],
               clip_on=False)
    ax.plot(data['epoch'] + 1, data['val_loss_2'].rolling(5, 1, True).mean(),
            label=label, c=color[0])

ax.set_xlim(0, 20)
ax.set_xlabel('Iteration')
ax.set_ylabel('Validation loss')

ax.xaxis.set_major_locator(mticker.MultipleLocator(5))

ax.legend()

sns.despine()

plt.savefig('ablation.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()