In [None]:
# import module needed
import copy
import datetime
import matplotlib.markers as plt_markers


# define helper class
class CoachTeam(object):
    
    def __init__(self, net, optimizer, loss_func, 
                 stage_names=['train', 'valid', 'test'], ind_names=['loss', 'acc']):
        # net components
        self.net_base = copy.deepcopy(net)
        self.net = copy.deepcopy(self.net_base)
        assert hasattr(self.net, 'name'), 'The nn object you want to use should have \
an attribute called `name`.'
        self.optimizer = optimizer #(self.net.parameters()) #should init a new optimizer in each iteration
        self.loss_func = loss_func

        # storage
        self.stage_names = stage_names
        self.ind_names = ind_names
        self.results = {ind:{stage:[] for stage in self.stage_names} for ind in self.ind_names}
        
    def pipeline_helper(self, n_iters, digit_rangesss, sum_epochs=15, interval_valid=50, dirname_prefix='outputs'):
        # set dirname
        self.dirname = dirname_prefix + '_' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')
        if not os.path.isdir(self.dirname):
            os.mkdir(self.dirname)
        for i in range(n_iters):
            # reset data for each experiment
            self.net = copy.deepcopy(self.net_base)
            self.results = {ind:{stage:[] for stage in self.stage_names} for ind in self.ind_names}
            
            # working pipeline
            self.train_helper(self.net, self.loss_func, self.optimizer(self.net.parameters()), 
                              digit_rangesss, sum_epochs=sum_epochs, interval_valid=interval_valid)
            self.test_helper(self.net)
            self.save(mode='file', netname=self.net.name, suffix=str(i))
            num_valid = sum_epochs * len(dataloader_train) // interval_valid
            self.plot_single(self.net.name, plot_test=False, suffix=str(i), num_valid=num_valid)
            self.plot_single(self.net.name, plot_test=True, suffix=str(i))
            print('Round {}: OK'.format(i))
        print('==================================================')
        ############## SOME BUGS EXIST IN self.plot_whole #############################
#         self.plot_whole(net.name, plot_test=False, plot_origin=True, plot_mean=True, num_valid=num_valid)
#         self.plot_whole(net.name, plot_test=False, plot_origin=True, plot_mean=False, num_valid=num_valid)
#         self.plot_whole(net.name, plot_test=False, plot_origin=False, plot_mean=True, num_valid=num_valid)
#         self.plot_whole(net.name, plot_test=True, plot_origin=True, plot_mean=True)
#         self.plot_whole(net.name, plot_test=True, plot_origin=True, plot_mean=False)
#         self.plot_whole(net.name, plot_test=True, plot_origin=False, plot_mean=True)
        print('Finish all the pipeline.')    
        
        
    def train_helper(self, net, loss_func, optimizer, digit_rangesss, sum_epochs=15, interval_valid=50):
        # train and validate, using digit_range sequences
        for idx,digit_range in enumerate(digit_rangesss):
            raw_out = {ind:[] for ind in self.ind_names}

            raw_out['loss'], raw_out['acc'], self.n = train_net(
                net=net, loss_func=loss_func, optimizer=optimizer, digit_range=digit_range,
                sum_epochs=sum_epochs//len(digit_rangesss), interval_valid=interval_valid
            )
#             raw_out['loss'], raw_out['acc'], net = train_net(
#                 net=net, loss_func=loss_func, optimizer=optimizer, digit_range=digit_range,
#                 sum_epochs=sum_epochs//len(digit_rangesss), interval_valid=interval_valid
#             )
            
            # save in RAM: train_data and valid_data 
            self.save(mode='dict-results', data=raw_out)
            print('Save data(train & validation): OK in RAM, for', digit_range)
        print('Train net: OK')
        
    def test_helper(self, net):
        # test(evaluate)
        net.eval()
        
        raw_out = {ind:{'test':[]} for ind in self.ind_names}
        for kdx, batch_t in enumerate(dataloader_test):
            xs_t, ys_t = batch_t
            xs_t, ys_t = (xs_t.cuda(), ys_t.cuda()) if use_cuda else (xs_t, ys_t)
            xs_t, ys_t = Variable(xs_t), Variable(ys_t)
            preds_t = net(xs_t)
            loss_test_raw = loss_func(preds_t, ys_t)
            raw_out['loss']['test'].append(loss_test_raw.cpu().data.numpy() if use_cuda \
                                           else loss_test_raw.data.numpy())
            raw_out['acc']['test'].append(evaludate_acc(preds_t, ys_t)[0])
        
        # save in RAM: test_data
        self.save(mode='dict-results', data=raw_out, stages=['test'])
        print('Save data(test): OK in RAM')
        
    def save(self, mode, data=None, netname=None, stages=['train', 'valid'], suffix=None):        
        if 'dict-results' == mode:
            assert data is not None, 'Your should specify the data you want to save in RAM.'
            for ind in self.ind_names:
                for stage in stages:
                    items = [i[0] if isinstance(i, np.ndarray) else i for i in data[ind][stage]]
                    if stage not in self.results[ind]:
                        self.results[ind][stage] = copy.deepcopy(items)
                    else:
                        self.results[ind][stage].extend(items)
        
        if 'file' == mode:
            assert netname is not None, 'You should specify the **NAME** of this neural network.'
            
            lacks = []
            for stage in self.stage_names:
                tmp_pd = {ind:copy.deepcopy(self.results[ind][stage]) for ind in self.ind_names}
                tmp_pd = pd.DataFrame(data=tmp_pd, columns=self.ind_names)
                if suffix:
                    filename = os.path.join(os.getcwd(), self.dirname, 
                                            '{}-sheet-{}-{}.csv'.format(netname, stage, suffix))
                else:
                    filename = os.path.join(os.getcwd(), self.dirname, 
                                            '{}-sheet-{}.csv'.format(netname, stage))
                    
                # write to csv
                if os.path.isfile(filename):
                    pd_source = pd.read_csv(filename)
                    tmp_pd = pd.concat([pd_source, tmp_pd])
                    tmp_pd.to_csv(filename, index=False) 
                else:
                    tmp_pd.to_csv(filename, index=False)
                        
                print('Save data in file OK:', filename)
    
    def plot_single(self, netname, suffix=None, plot_test=True, num_valid=None):        
        if plot_test:
            # plot: test
            
            x_inds = np.arange(len(dataloader_test))
            fig, axes_raw = plt.subplots(1, len(self.ind_names), figsize=(15,5))
            axes = {ind:axes_raw[i] for i,ind in enumerate(self.ind_names)}
            for ind in self.ind_names:
                axes[ind].plot(x_inds, self.results[ind]['test'], label='{}_test'.format(ind))
                axes[ind].set_xlabel('num_iter')
                axes[ind].set_ylabel(ind)
            
            if suffix:
                filename = os.path.join(os.getcwd(), self.dirname, 
                                        '{}-curves_test-{}'.format(netname, suffix))
            else:
                filename = os.path.join(os.getcwd(), self.dirname, '{}-curves_test'.format(netname))
                
            plt.savefig(filename)
            
        if not plot_test:
            # plot: train & validation
            
            fig, axes_raw = plt.subplots(1, len(self.ind_names), figsize=(15, 5))
            axes = {self.ind_names[i]:ax for i,ax in enumerate(axes_raw) }
            
            assert num_valid is not None, 'You should specify how many times \
the validation process executed each round.'
            x_inds = np.arange(num_valid)
            for ind in ['loss', 'acc']:
                for stage in ['train', 'valid']:
                    axes[ind].plot(x_inds, self.results[ind][stage], label='{}_{}'.format(ind, stage))
                
                axes[ind].set_xlabel('num_iter')
                axes[ind].set_ylabel(ind)
                axes[ind].set_title('data on: convnet_0')
                axes[ind].legend()
            
            plt.tight_layout(w_pad=10)
            
            if suffix:
                filename = os.path.join(os.getcwd(), self.dirname, 
                                        '{}-curves-train_validation-{}'.format(netname, suffix))
            else:
                filename = os.path.join(os.getcwd(), self.dirname, '{}-curves-train_validation'.format(netname))
                
            plt.savefig(filename)

            # # fig.show() 
            # comment the command above to avoid warning 
            #`UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure`
            
#################### SOME BUGS EXIST IN CODE BELOW ##############################################################
#     def plot_whole(self, netname, num_valid, plot_test=False, plot_origin=True, plot_mean=False):
#         empty_markers = ['None', None, ' ', '']
#         plt_markers_list = list([mk for mk in plt_markers.MarkerStyle.markers.keys() if mk not in empty_markers])
#         plt_markers_gen = (mk for mk in plt_markers_list)
        
#         fig, axes_raw = plt.subplots(1, 2, figsize=(15, 5))
#         axes = {ind:axes_raw[i] for i,ind in enumerate(self.ind_names)}
#         alpha_dict = {'train':0.7 ,'valid':1.0, 'test':1.0}

#         ind_data = {}
#         for ind in self.ind_names:
#             for stage in stage_names:
#                 k = '{}_{}-mean'.format(ind, stage)
#                 ind_data[k] = []
                
#         if plot_test:
#             x_inds = np.arange(len(dataloader_test))
            
#             for i in range(n_iters):
#                 for stage in stage_names[-1:]:
#                     filename = '{}-sheet-{}-{}.csv'.format(netname, stage, i)
#                     filename = os.path.join(os.getcwd(), self.dirname, filename)
#                     df = pd.read_csv(filename)
#                     for ind in ind_names:
#                         # save data for the mean curve
#                         k = '{}_{}-mean'.format(ind, stage)
#                         ind_data[k].append(df[ind].values)
                        
#                         # plot
#                         if plot_origin:
#                             axes[ind].plot(x_inds, df[ind], label='{}_{}-{}'.format(ind, stage, i), 
#                                            marker=next(plt_markers_gen), alpha=alpha_dict[stage])
#             if plot_mean:
#                 for ind in ind_names:
#                     for stage in stage_names[-1:]:
#                         k = '{}_{}-mean'.format(ind, stage)
#                         ind_data[k] = np.array(ind_data[k]).mean(axis=0)
#                         axes[ind].plot(x_inds, ind_data[k], label=k,
#                                        marker=next(plt_markers_gen), alpha=alpha_dict[stage])
        
#             for ind in ind_names:    
#                 axes[ind].legend()

#             plt.tight_layout(w_pad=10)

#             use_origin = 'Origin' if plot_origin else 'noOrigin'
#             use_mean = 'Mean' if plot_origin else 'noMean'
#             filename = '{}-curves-test-whole-{}{}'.format(netname, use_origin, use_mean)
#             filename = os.path.join(os.getcwd(), self.dirname, filename)
#             plt.savefig(filename)
                
#         if not plot_test:
#             # plot: train_validation
#             x_inds = np.arange(num_valid)
            
#             for i in range(n_iters):
#                 for stage in stage_names[:-1]:
#                     filename_base = '{}-sheet-{}-{}.csv'.format(netname, stage, i)
#                     filename = os.path.join(os.getcwd(), self.dirname, filename_base)
#                     assert os.path.isfile(filename), 'Lack file: '.format(filename_base)
#                     df = pd.read_csv(filename)
#                     for ind in self.ind_names:
#                         # save data for the mean curve
#                         k = '{}_{}-mean'.format(ind, stage)
#                         ind_data[k].append(df[ind].values)
        
#                         # plot
#                         if plot_origin:
#                             axes[ind].plot(x_inds, df[ind], label='{}_{}-{}'.format(ind, stage, i), 
#                                            marker=next(plt_markers_gen), alpha=alpha_dict[stage])

#             if plot_mean:
#                 for ind in self.ind_names:
#                     for stage in stage_names[:-1]:
#                         k = '{}_{}-mean'.format(ind, stage)
#                         ind_data[k] = np.array(ind_data[k]).mean(axis=0)
#                         axes[ind].plot(x_inds, ind_data[k], label=k,
#                                        marker=next(plt_markers_gen), alpha=alpha_dict[stage])

#             for ind in self.ind_names:    
#                 axes[ind].legend()
            
#             plt.tight_layout(w_pad=10)
            
#             use_origin = 'Origin' if plot_origin else 'noOrigin'
#             use_mean = 'Mean' if plot_origin else 'noMean'
#             filename = '{}-curves-train_validation-whole-{}{}'.format(netname, use_origin, use_mean)
#             filename = os.path.join(os.getcwd(), self.dirname, filename)
#             plt.savefig(filename)