In [1]:

import random
from itertools import count
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

%matplotlib qt

In [2]:
class train_visual():
    def __init__(self, epochs, labels, loss_rng):
        self.epochs = epochs
        self.cur_trial = 1
        self.labels = labels
        self.num_labels = len(labels)
        self.loss_rng = loss_rng
        self.data = {
            'losses': [],  # trials x epochs (per trial)
            'label_pr': [] # trials x labels x 2 (p & r) x epochs
        }
    def set_data(self, inp, trial):
        self.data = inp
        self.cur_trial = trial
        
    def report(self, loss, lables_pr):
        """
        labels_pr shape = labels x [p, r]
        """
        self.data['losses'][self.cur_trial - 1].append(loss)
        for label in range(self.num_labels):
            for pr in range(2):
                self.data['label_pr'][self.cur_trial - 1][label][pr].append(lables_pr[label][pr])
                
        plt.clf()
        self.plot()
        
    def new_trial(self):
        self.data['losses'].append( [ ] )
        self.data['label_pr'].append([ [ [] , [] ] for i in range(self.num_labels) ])
        self.cur_trial += 1
        
    def start(self):
        plt.figure(figsize=(12, 15))
        self.plot()
        
    def plot(self):
        fig = plt.gcf()
        gs = fig.add_gridspec(7, 2, width_ratios=[0.8, 1], left=0.05, right=0.95, height_ratios=[2]*7, hspace=2)

        ax1 = fig.add_subplot(gs[:, 0])
        ax1.set_xlabel('Epochs', labelpad=5)
        ax1.set_ylabel('Loss', labelpad=5)
        ax1.set_xlim(0, self.epochs)
        ax1.set_ylim(0, self.loss_rng)

        for trial in range(self.cur_trial):
            color = "blue"
            if trial == self.cur_trial - 1:
                color = "red"
            y = self.data['losses'][trial]
            x = list(range(len(y)))
            ax1.plot(x, y, color=color)

        # Right plots (7 stacked on the right side)
        for label in range(7):
            ax2 = fig.add_subplot(gs[label, 1])
            ax2.set_xlabel(f'Epochs\n{self.labels[label]}', labelpad=0)
            ax2.set_ylabel('Precision\nrecall', labelpad=0)
            ax2.set_xlim(0, 10)
            ax2.set_ylim(0, 1.0)
            
            p_color = "green"
            r_color = "blue"
            for trial in range(self.cur_trial):
                if trial == self.cur_trial - 1: 
                    p_color = "yellow"
                    r_color = "purple"
                precision_y = self.data['label_pr'][trial][label][0]
                recall_y = self.data['label_pr'][trial][label][1]
                x = list(range(len(recall_y)))
                ax2.plot(x, precision_y, color=p_color)
                ax2.plot(x, recall_y, color=r_color)
                
        # plt.tight_layout()
        plt.show()
    

In [3]:
data = {
    'losses': [[0.99,0.99,0.99,0.99,0.99,0.99,0.98,0.9,0.8,0.6,0.5,0.4],
               [0.99,0.97,0.99,0.99,0.99,0.99,0.98,0.9,0.85,0.65,0.55,0.45, 0.35]
              ],  # trials x epochs (per trial)
    'label_pr': [ [ [ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ], [ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ],[ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ],[ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ],[ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ],[ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ],[ [ 0.0, 0.1 ], [ 0.1, 0.2 ] ] ],
                  [ [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ], [ [ 0.3, 0.4 ], [ 0.4, 0.5 ] ] ] 
                ] # trials x labels x 2 (p & r) x epoch count(per trial)
}
labels = [
"background",
"apo-ferritin (easy)",
"beta-amylase (impossible, NS)",
"beta-galactosidase (hard)",
"ribosome (easy)",
"thyroglobulin (hard)",
"virus-like-particle (easy)"
]

In [4]:
test_visual = train_visual(50, labels, 1.0)

In [5]:
test_visual.set_data(data, 2)

In [6]:
test_visual.start()

In [9]:
test_visual.report(0.2, [[0.8, 0.9],[0.8, 0.9],[0.8, 0.9],[0.8, 0.9],[0.8, 0.9],[0.8, 0.9],[0.8, 0.9]])


In [8]:
test_visual.new_trial()

In [64]:
labels = [
"background",
"apo-ferritin (easy)",
"beta-amylase (impossible, NS)",
"beta-galactosidase (hard)",
"ribosome (easy)",
"thyroglobulin (hard)",
"virus-like-particle (easy)"
]

trial_cnt = 2

fig = plt.figure(figsize=(12, 15))

gs = fig.add_gridspec(7, 2, width_ratios=[0.8, 1], left=0.05, right=0.95, height_ratios=[2]*7, hspace=2)

ax1 = fig.add_subplot(gs[:, 0])
ax1.set_xlabel('Epochs', labelpad=5)
ax1.set_ylabel('Loss', labelpad=5)
ax1.set_xlim(0, 25)
ax1.set_ylim(0, 1.0)

for trial in range(trial_cnt):
    color = "blue"
    if trial == trial_cnt - 1:
        color = "red"
    y = data['losses'][trial]
    x = list(range(len(y)))
    ax1.plot(x, y)

# Right plots (7 stacked on the right side)
for label in range(7):
    ax2 = fig.add_subplot(gs[label, 1])
    ax2.set_xlabel(f'Epochs\n{labels[label]}', labelpad=0)
    ax2.set_ylabel('Precision\nrecall', labelpad=0)
    ax2.set_xlim(0, 10)
    ax2.set_ylim(0, 1.0)
    
    p_color = "green"
    r_color = "blue"
    for trial in range(trial_cnt):
        if trial == trial_cnt - 1: 
            p_color = "yellow"
            r_color = "purple"
        precision_y = data['label_pr'][trial][label][0]
        recall_y = data['label_pr'][trial][label][1]
        x = list(range(len(recall_y)))
        ax2.plot(x, precision_y, color=p_color)
        ax2.plot(x, recall_y, color=r_color)
        
# plt.tight_layout()
plt.show()

In [44]:
range(10).to

range(0, 10)