# realbogus for realfast

## Machine Learning for VLA fast transient classification (using elastic search)

### By Umaa Rebbapragada and Casey Law

In [None]:
%matplotlib inline
import pylab as pl
import os.path
import numpy as np
import activegit, rflearn
from rtpipe.parsecands import read_candidates
import glob, logging
from IPython.display import Image, display

logger = logging.getLogger(__name__)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)

In [None]:
verbose = 0
random_seed = 1132014
np.random.seed(seed=random_seed)

def serveimage(imagename, baseurl='http://www.aoc.nrao.edu/~claw/plots/', width=700):
    display(Image(url=os.path.join(baseurl, imagename), width=width))

### Initialize activegit repo with classifier

In [None]:
agdir = os.path.join(os.environ['HOME'], 'code', 'realfast_al')
ag = activegit.ActiveGit(agdir)

### Read new (unlabeled) candidates and define (loc, prop)

In [None]:
datalist = rflearn.elastic.indextodatalist(unlabeled=True)
obslist, loc, prop = rflearn.elastic.restorecands(datalist)

### Active learning loop

#### Take least certain predictions and ask expert to classify. Result is then fed back in to classifier to improve predictions.

In [None]:
# set up train, test and unlabeled pools
clf = ag.classifier
train_pool, train_targets = ag.training_data
test_pool, test_targets = ag.testing_data
unlabeled_pool_stat = rflearn.features.stat_features(prop)

# set up batches
nunlabeled = len(obslist)
cands_unlabeled_pool = np.array(range(nunlabeled))
subset_threshold = 75
subset_perc = 0.666
n_jobs = 1
batch_size = 10
bi = 0 # batch index
nbatches = nunlabeled/batch_size

logger.info("Train pool size: {0}".format(len(train_pool)))
logger.info("Test pool size: {0}".format(len(test_pool)))
logger.info("Unlabeled pool size: {0}".format(nunlabeled))
logger.info("Batch size: {0}".format(batch_size))

In [None]:
# START ACTIVE LEARNING LOOP

while (nunlabeled > 0): # while still examples in unlabeled pool
    
    # choose a subset of the unlabeled pool to classify
    subset_size = int(np.floor(subset_perc * nunlabeled)) if (nunlabeled > subset_threshold) else nunlabeled
    cands_rand_subset = np.random.choice(cands_unlabeled_pool, subset_size, replace=False)
    subset_pool = unlabeled_pool_stat[cands_rand_subset,:]

    if clf:
        # get performance stats for classifier on validation set
        test_preds = clf.predict(test_pool)    
        acc, fpr, fnr = rflearn.features.calc_acc_fpr_fnr(test_targets, test_preds)
        logger.info("BATCH {0} (acc, fpr, fnr): ({1}, {2}, {3})\n".format(bi, acc, fpr, fnr))
        
        # classify that subset
        subset_pool_probs = clf.predict_proba(subset_pool) 
    
        # choose the most uncertain bunch to present to the user
        batch_subset_indices = (np.argsort(abs(subset_pool_probs[:,0] - subset_pool_probs[:,1])))[0:batch_size]
        batch_subset_probs = subset_pool_probs[batch_subset_indices,:]
    
        # ... finds the corresponding cands from the unlabeled pools
        batch_cand_indices = cands_rand_subset[batch_subset_indices]
    else:
        # need to initialize train/test/clf with first pass
        batch_cand_indices = cands_rand_subset[0:batch_size]
        
    # present cand_indices to the user
    batch_cand_targets = []
    modified_data = []
    # move those examples into the training pool, remove them from the unlabeled pool
    for ci in range(batch_size):
        candi = batch_cand_indices[ci]
        logger.info('SNR = {0}'.format(unlabeled_pool_stat[candi,0]))
        if clf:
            logger.info("RDF Probs=({0},{1})".format(batch_subset_probs[ci,0], batch_subset_probs[ci,1]) )
        serveimage(datalist[candi]['candidate_png'])
        
        while 1:
            label = int(raw_input("LABEL: Is this real? (0,1) "))
            if label not in [0,1]:
                logger.warn("Please enter in 0 or 1 only")
                continue
            else:
                break

        batch_cand_targets.append(label)
        datalist[candi]['labeled'] = 1
        modified_data.append(datalist[candi])

    # update elastic search index to show that a batch has been classified
#    rflearn.elastic.pushdata(modified_data)
    
    # augment train_pool, train_targets
    if len(train_pool) and len(test_pool):
        train_pool = np.vstack( (train_pool, unlabeled_pool_stat[batch_cand_indices,:]) )
        train_targets = np.concatenate( (train_targets, batch_cand_targets) )
    else:
        # first time through. split sample to train/test
        train_pool = unlabeled_pool_stat[batch_cand_indices[:batch_size/2],:]
        train_targets = batch_cand_targets[:batch_size/2]
        test_pool = unlabeled_pool_stat[batch_cand_indices[batch_size/2:],:]
        test_targets = batch_cand_targets[batch_size/2:]

    cands_unlabeled_pool = np.delete(cands_unlabeled_pool, batch_cand_indices, axis=0)
    nunlabeled = cands_unlabeled_pool.shape[0]
    bi += 1

    logger.info("Train pool size: {0}".format(len(train_pool)))
    logger.info("Test pool size: {0}".format(len(test_pool)))
    logger.info("Unlabeled pool size: {0}".format(nunlabeled))
    logger.info("Batch size: {0}".format(batch_size))
    
    # re-train 
    logger.info("Retraining classifier...")
    clf = rflearn.sklearn_utils.train_random_forest(train_pool, train_targets, n_jobs=n_jobs, 
                                                    verbose=verbose, n_estimators=300)

    # ask to continue
    value = raw_input("Continue? (y,n): ")
    if value ==  'n':
        logger.info("Saving train, targets, and classifier to next version name...")
        try:
            lastvers = max([int(version.lstrip('stat')) for version in ag.versions if 'stat' in version])
        except ValueError:
            lastvers = 0
        finally:
            versn = lastvers + 1
        ag.write_testing_data([tuple(tr) for tr in test_pool], test_targets)
        ag.write_training_data([tuple(tr) for tr in train_pool], train_targets)
        ag.write_classifier(clf)
        ag.commit_version('stat{0}'.format(versn))
        break
    else:
        logger.info('Continuing training...')
    

### Visualize

In [None]:
for version in ag.versions:
    ag.set_version(version)
    clf = ag.classifier
    test_pool, test_targets = ag.testing_data
    if clf:
        test_preds = clf.predict(test_pool)
        acc, fpr, fnr = rflearn.features.calc_acc_fpr_fnr(test_targets, test_preds)
        logger.info('Version {0}, {1} test cands (acc, fpr, fnr): ({1}, {2}, {3})'
                    .format(version, len(test_pool), acc, fpr, fnr))