In [1]:
import os
import json
import pickle
import random
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from processor import *
from copy import deepcopy
from collections import defaultdict, Counter
from extract import load_logs

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


#### Predict label for each annotated utterance and compute precision and recall
> 1 -> utterance has the top score \
0 -> otherwise

In [2]:
def eval(_chains):
    true_pos, false_pos, false_neg = 0, 0, 0
    true_pos_samples, false_pos_samples, false_neg_samples = defaultdict(list), defaultdict(list), defaultdict(list)
    

    for img, img_chains in _chains.items():

        best_utterance = {'Message_Text': '', 'Score': 0, 'Message_Referent': None}
        current_round = -1
        n_ref = []

        for utterance in img_chains:
            
            if current_round != utterance['Round_Nr']:
                
                # was selected as positive...
                if best_utterance['Message_Referent'] is not None:
                    
                    # ... true?
                    if best_utterance['Message_Referent'] == img:
                        true_pos += 1
                        true_pos_samples[img].append(best_utterance)
                    
                    # ... false?
                    else:   
                        false_pos += 1
                        false_pos_samples[img].append(best_utterance)
                    
                    # and were there other (gold) positives?
                    if n_ref:
                        for utt in n_ref:
#                           # yes? false negative
                            if utt['Message_Text'] != best_utterance['Message_Text']:
                                false_neg_samples[img].append(utt)
                                false_neg += 1
                        n_ref = []
    
                best_utterance = {'Message_Text': '', 'Score': 0, 'Bert_Score': 0, 'Meteor_Score': 0, 'Message_Referent': None}
                current_round = utterance['Round_Nr']

            # new top-scoring utterance?
            if utterance['In_Segment'] and utterance['Score'] >= best_utterance['Score']:
                best_utterance = utterance
            
#             if utterance['In_Segment'] and utterance['Bert_Score'] * (1 + utterance['Meteor_Score']) >= best_utterance['Bert_Score'] * (1 + best_utterance['Meteor_Score']):
#                 best_utterance = utterance
            
#             if utterance['In_Segment'] and utterance['Meteor_Score'] * (1 + utterance['Bert_Score']) >= best_utterance['Meteor_Score'] * (1 + best_utterance['Bert_Score']):
#                 best_utterance = utterance
            
            # should be a positive
            if utterance['Message_Referent'] == img:
                n_ref.append(utterance)
    
    precision = true_pos / (true_pos + false_pos)
    recall = true_pos / (true_pos + false_neg)
    
    return precision, recall, true_pos_samples, false_pos_samples, false_neg_samples

In [24]:
for subdir, dirs, files in os.walk(r'chains/'):
    for filename in files:
        filepath = subdir + os.sep + filename

        if not filepath.endswith('.dict'):
            continue
            
        with open(filepath, 'rb') as f:
            chains_ = pickle.load(file=f)
        
        print(filename.split('.')[0][4:])
        
        P, R, _, _, _ = eval(chains_)
        
        print('Precision: {:.2f}'.format(P))
        print('Recall: {:.2f}'.format(R))
        print()

re_nostopwords_keep1
Precision: 0.76
Recall: 0.53

re+vg_nostopwords
Precision: 0.74
Recall: 0.52

re_nostopwords
Precision: 0.74
Recall: 0.52

f1_keep1
Precision: 0.74
Recall: 0.52

re+vg_nostopwords_nocaptionwords
Precision: 0.64
Recall: 0.38

pr_nostopwords_keep1
Precision: 0.80
Recall: 0.56

pr+vg_keep1
Precision: 0.78
Recall: 0.54

f1+vg_nostopwords_keep1
Precision: 0.79
Recall: 0.55

re+vg
Precision: 0.74
Recall: 0.52

pr+vg_nostopwords_keep1
Precision: 0.82
Recall: 0.57

f1+vg_keep1
Precision: 0.77
Recall: 0.54

re_nostopwords_nocaptionwords
Precision: 0.47
Recall: 0.22

pr_keep1
Precision: 0.76
Recall: 0.52

re
Precision: 0.70
Recall: 0.49

re+vg_nostopwords_keep1
Precision: 0.78
Recall: 0.54

re+vg_keep1
Precision: 0.75
Recall: 0.52

f1_nostopwords_keep1
Precision: 0.78
Recall: 0.55

re_keep1
Precision: 0.72
Recall: 0.50



In [22]:
def print_errors(chains_path, out_path=None):
    
    with open(chains_path, 'rb') as f:
        _chains = pickle.load(file=f)
        
    _, _, tp, fp, fn = eval(_chains)
    
    if out_path:
        file = open(out_path, 'w')
    else:
        file = None
    
    for img in fp:
        print('target: {}\n'.format(img), file=file)
    
        print('False Positives', file=file)
        for c in fp[img]:
            print('{} {} {}     {:.02f} {} <-- {}'.format(c['Game_ID'], c['Round_Nr'], c['Message_Speaker'], c['Score'], c['Tokens'], c['Message_Text']), file=file)

        print('\nFalse Negatives', file=file)
        for c in fn[img]:
            print('{} {} {}     {:.02f} {} <-- {}'.format(c['Game_ID'], c['Round_Nr'], c['Message_Speaker'], c['Score'], c['Tokens'], c['Message_Text']), file=file)

        print('\nTrue Positives', file=file)
        for c in tp[img]:
            print('{} {} {}     {:.02f} {} <-- {}'.format(c['Game_ID'], c['Round_Nr'], c['Message_Speaker'], c['Score'], c['Tokens'], c['Message_Text']), file=file)
        print('\n', file=file)
        
    if out_path:
        file.close()

---

In [23]:
print_errors('chains/dev_pr+vg_nostopwords_keep1.dict', out_path='out_pr+vg_nostopwords_keep1.txt')