# Evaluate models against holdout set

Written for python 3.7, meant to run on 4GB CUDA-capable machine (iceberg)

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.text import *
import fastprogress
bs = 32 # size of minibatch
# written with fastai v1.0.48

from collections import defaultdict

In [3]:
lang_model_path = Path('/data/fastai-models/language_model')
vocab = Vocab(pickle.load(open(lang_model_path/'itos.pkl','rb')))

In [4]:
sys.path.append("..") # Adds higher directory to python modules path.
from common import Label_DbFields, Labels
# needed for class_mask for accuracy_category

In [5]:
def accuracy_category(y_pred:Tensor, y_true:Tensor, class_idxes:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
    "Gives accuracy when `y_pred` above threshold, only for the true column(s) among the `class_idxes` columns."
    # note this gives no penalty for marking other labels incorrectly among class_idxes columns
    # this is roughly comparable to accuracy for a single classifier with class_idxes' labels as the classes
    if sigmoid: y_pred = y_pred.sigmoid()
    correct_category_mask = y_true.index_select(1, class_idxes) > 0
    return (y_pred.index_select(1, class_idxes)[correct_category_mask] > thresh).float().mean()

In [6]:
# disable progress bars, see https://forums.fast.ai/t/default-to-completely-disable-progress-bar/40010
import fastai, fastprogress
fastprogress.fastprogress.NO_BAR = True
master_bar, progress_bar = fastprogress.force_console_behavior()
fastai.basic_train.master_bar, fastai.basic_train.progress_bar = master_bar, progress_bar
fastai.basic_data.master_bar, fastai.basic_data.progress_bar = master_bar, progress_bar
#dataclass.master_bar, dataclass.progress_bar = master_bar, progress_bar
fastai.text.master_bar, fastai.text.progress_bar = master_bar, progress_bar
fastai.text.data.master_bar, fastai.text.data.progress_bar = master_bar, progress_bar
fastai.core.master_bar, fastai.core.progress_bar = master_bar, progress_bar

In [7]:
# this contains the labels for each classifier
label_set = Labels.copy()
label_set['factinvestigative'] = ['investigative', 'noninvestigative', 'opinion', 'other']

In [8]:
def getTopPredPcts(preds, classes, clas_name):
    labelcounts = defaultdict(int)
    yhat = np.argmax(preds.numpy(), axis=1)
    for label in yhat:
        labelcounts[label] += 1   # leave index 0 for the overall counts to calculate next

    print("    ", preds[:5])
#        print(labelcounts[0])

    # translate from NN class index into class name, and make a full list of counts
    namedlabelcounts = defaultdict(int)
    for k,v in labelcounts.items():
        namedlabelcounts[classes[k]] = v

#    print("    ", namedlabelcounts)

    # label counts in order
    total = len(preds)
    counts_ordered = [float(namedlabelcounts[cn])/total for cn in label_set[clas_name]]
    return zip(label_set[clas_name], counts_ordered)

In [9]:
def getTopKWeightedPredPcts(preds, classes, clas_name):
    k = 3 if clas_name in ['label_category', 'station', 'supergroups'] else 2
    likelihoods, posns = preds.topk(k, dim=-1, sorted=False)

    # scale predictions so that top 3 likelihoods sum to 1
    norm_factors = 1. / likelihoods.sum(dim=-1)
    likelihoods = norm_factors * likelihoods.transpose(-1,0)
    likelihoods.transpose_(-1,0)

    # allocate their normalized likelihoods to the 3 categories for each snippet
    likelihoods_sums = defaultdict(float)
    # add up the likelihoods for each snippet for the top 3 column (class) positions
    for (snippet_lhs, snippet_posns) in zip(likelihoods, posns):
        for lh, posn in zip(snippet_lhs.tolist(), snippet_posns.tolist()):
            likelihoods_sums[posn] += lh

    # order the likelihoods for reporting, and sum up overall totals
    namedlabelsums = defaultdict(float)
    for k,v in likelihoods_sums.items():
        namedlabelsums[learn.data.train_ds.classes[k]] = v
        
    total = len(preds)
    summed_likelihoods_ordered = [namedlabelsums[cn]/total for cn in label_set[clas_name]]
    return zip(label_set[clas_name], summed_likelihoods_ordered)


## Obtain statistics on holdout set

In [15]:
# dir structure: labeled_data/holdout_set/groupname-test.tsv
# load our labeled data into a TextClasDataBunch

# Use selected learners
model_path = Path('/data/fastai-models')
modeldir = model_path/"selected"

holdout_path = Path("../labeled_data/holdout_set")
for groupfilepath in holdout_path.ls():
    groupname = str(groupfilepath.parts[-1])
    groupname = groupname[:groupname.find('-test.tsv')]
    print('\nProcessing holdout set for ' + groupname + '\n')

    test_df = pd.read_csv(holdout_path/(groupname+'-test.tsv'), header=None, delimiter='\t', names=['label','text'])
    print(test_df[:5])
    
    # settings for single vs multi-label learners
    k = 3 if groupname in ['label_category', 'station', 'supergroups'] else 2
    topkaccuracy = partial(top_k_accuracy, k=k)
    metrics = [accuracy, topkaccuracy]
    label_delim = None
    if groupname == 'multilabel':
        ### TEMP: something is wrong with this learner.
        continue
        label_delim = ','
        
    if groupname != 'label_usforeign':
        continue
        
    test_databunch = TextClasDataBunch.from_df(model_path/"fold_0", test_df, test_df, vocab=vocab, 
                                               text_cols='text', label_cols='label', bs=bs,
                                               label_delim=label_delim)

    try:
        # load exported classifier with no data
        learn = load_learner(modeldir, fname=groupname + '_clas_fine_tuned.pkl',)    
    except:
        print('  - no learner found or unable to load. skipping.')
        continue

    if groupname == 'multilabel':
        class_mask = [c in Labels['label_category'] for c in learn.data.train_ds.classes]
        class_idxes = sorted([learn.data.train_ds.classes.index(c) for c in Labels['label_category']])
        label_category_accuracy = partial(accuracy_category, class_idxes=tensor(class_idxes).cuda(), thresh=0.5)
        metrics = [accuracy_thresh, # accuracy per label at default threshold of 0.5
                   label_category_accuracy] # "accuracy" among true category labels only
        
    # print validation results: [loss function, metric 1, metric 2 ...]
    print('\n  ', learn.validate(test_databunch.valid_dl, metrics=metrics))
    
    if groupname in label_set:
        # get the actual predictions
        learn.data.valid_dl = test_databunch.valid_dl
        preds, true_y = learn.get_preds(ds_type=DatasetType.Valid) # don't need ordered = True since we don't cross reference anywhere

        # find percents of the test set based upon both top k and top 1 to compare against actual
        print(" - Out of ", len(preds), "snippets")

        pcts = getTopPredPcts(preds, learn.data.train_ds.classes, groupname)
        print("    - Top Preds %'s:")
        for pct in pcts:
            print("      %20s %f" % (pct[0], pct[1]))
        pcts = getTopKWeightedPredPcts(preds, learn.data.train_ds.classes, groupname)
        print("    - Top-k summed Preds %'s:")
        for pct in pcts:
            print("      %20s %.04f" % (pct[0], pct[1]))


Processing holdout set for label_category

            label                                               text
0             ads  To avoid long-term injury,seek immediate medic...
1  current_events  The outcry can't be ignored and thousands of p...
2             ads  Touch and be touched. Now bring the world a to...
3    science_tech  1.3 billion dollars pulled together from publi...
4      government  > As the head of security for TSA,Hogan's base...

Processing holdout set for label_investigative

              label                                               text
0     investigative  The outcry can't be ignored and thousands of p...
1  noninvestigative  1.3 billion dollars pulled together from publi...
2  noninvestigative  > As the head of security for TSA,Hogan's base...
3  noninvestigative  They are weakening or ease in a--or easing wit...
4     investigative  > It was the start of a brutal and bloody thre...

Processing holdout set for factinvestigative

              label 