In [4]:
import numpy as np
import lasagne
import time
from nbfinder import NotebookFinder
import sys
sys.meta_path.append(NotebookFinder())
from matplotlib import pyplot as plt
%matplotlib inline
from matplotlib import patches
from helper_fxns import early_stop
from build_hur_classif_network import build_classif_network
from data_loader import load_classification_dataset, load_detection_dataset
from print_n_plot import print_train_results,plot_learn_curve,print_val_results, plot_ims_with_boxes
from build_hur_detection_network import build_det_network

In [16]:
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0,len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx: start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]

def train_one_epoch(x,y,batchsize, train_fn, val_fn):
    train_err = 0
    train_acc = 0
    train_batches = 0
    start_time = time.time()
    for batch in iterate_minibatches(x, y, batchsize, shuffle=True):
        inputs, targets = batch
        train_err += train_fn(inputs, targets)
        _, acc = val_fn(inputs, targets)
        train_acc += acc
        train_batches += 1
    return train_err, train_acc, train_batches

def val_one_epoch(x,y,batchsize, val_fn):
        val_err = 0
        val_acc = 0
        val_batches = 0
        for batch in iterate_minibatches(x,y, batchsize, shuffle=False):
            inputs, targets = batch
            err, acc = val_fn(inputs, targets)
            val_err += err
            val_acc += acc
            val_batches += 1
        return val_err, val_acc, val_batches
def do_one_epoch(x,y, batchsize, train_fn, val_fn):
        start_time = time.time()
        tr_err, tr_acc, tr_batches = train_one_epoch(X_train, y_train,
                                                     batchsize=batchsize,
                                                     train_fn=train_fn,
                                                     val_fn=val_fn)
                
        train_errs.append(tr_err / tr_batches)
        train_accs.append(tr_acc / tr_batches)
        print_train_results(epoch, num_epochs, start_time, tr_err / tr_batches, tr_acc / tr_batches)
        

        val_err, val_acc, val_batches = val_one_epoch(X_val, y_val,
                                                     batchsize=y_val.shape[0],
                                                      val_fn=val_fn)
        val_errs.append(val_err / val_batches)
        val_accs.append(val_acc / val_batches)
        print_val_results(val_err, val_acc / val_batches)
        
    

In [7]:
#TODO: adding logging
#TODO add special way of saving run info based on run number or date or something
#TODO add getting weights over updates
def train(datasets, num_epochs, mode='classification', save_weights=False, save_plots=False, 
          frac_of_datasets=1,batchsize=128, network_kwargs={}, inmem_class_network=None, load_path=None):
    #todo add in detect
    X_train, y_train, X_val, y_val, X_test, y_test = datasets
    
    if batchsize is None or X_train.shape[0] < batchsize:
        batchsize = X_train.shape[0]
    
    if mode=='classification':
        train_fn,val_fn,network = build_classif_network(**network_kwargs)
    elif mode == 'detection':
        if inmem_class_network:
            train_fn,val_fn,network, box_fn = build_det_network(inmem_network, **network_kwargs)
        elif load_path:
            _,_,class_network = build_classif_network(load=True, load_path=load_path)
            train_fn,val_fn,network, box_fn = build_det_network(class_network, **network_kwargs)
        else:
            raise TypeError('must specify either a inmem_classnetwork or a load path for the weights!')
        
    print "Starting training..." 
    

    train_errs, train_accs, val_errs, val_accs = [],[], [], []
    for epoch in range(num_epochs):
        do_one_epoch(x,y, batchsize, train_fn, val_fn)
        

        
        if (epoch + 1) % 10 == 0:
            plot_learn_curve(train_errs,val_errs, 'err', save_plots=save_plots, mode)
            plot_learn_curve(train_accs,val_accs, 'acc', save_plots=save_plots, mode)
            if mode == 'detection':
                pred_boxes, gt_boxes = box_fn(x_val,y_val)
                plot_ims_with_boxes(x_val, pred_bboxes, gt_bboxes, epoch=epoch,save_plots=save_plots)
                #plot weights or updates or something 
            
            
            
            
        


        if save_weights and epoch % 10 == 0:
        # Optionally, you could now dump the network weights to a file like this:
            np.savez('%s.npz'%(mode), *lasagne.layers.get_all_param_values(network))
        return network, train_errs[-1], train_accs[-1], val_errs[-1], val_accs[-1]

In [10]:

if __name__=="__main__":
    dataset = load_classification_dataset(num_ims=40)
    n = train(dataset,2,batchsize=24, network_kwargs={'num_filters':10,'num_fc_units':128, 'learning_rate': 0.001})


