Skip to content

Commit

Permalink
changing evaluation to use a perl script consistent with what the oth…
Browse files Browse the repository at this point in the history
…er research groups are doing
  • Loading branch information
karpathy committed Jan 10, 2015
1 parent 02511eb commit 8336615
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 46 deletions.
175 changes: 175 additions & 0 deletions eval/multi-bleu.perl
@@ -0,0 +1,175 @@
#!/usr/bin/perl -w

# $Id$
use strict;

my $lowercase = 0;
if ($ARGV[0] eq "-lc") {
$lowercase = 1;
shift;
}

my $stem = $ARGV[0];
if (!defined $stem) {
print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n";
print STDERR "Reads the references from reference or reference0, reference1, ...\n";
exit(1);
}

$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0";

my @REF;
my $ref=0;
while(-e "$stem$ref") {
&add_to_ref("$stem$ref",\@REF);
$ref++;
}
&add_to_ref($stem,\@REF) if -e $stem;
die("ERROR: could not find reference file $stem") unless scalar @REF;

sub add_to_ref {
my ($file,$REF) = @_;
my $s=0;
open(REF,$file) or die "Can't read $file";
while(<REF>) {
chop;
push @{$$REF[$s++]}, $_;
}
close(REF);
}

my(@CORRECT,@TOTAL,$length_translation,$length_reference);
my $s=0;
while(<STDIN>) {
chop;
$_ = lc if $lowercase;
my @WORD = split;
my %REF_NGRAM = ();
my $length_translation_this_sentence = scalar(@WORD);
my ($closest_diff,$closest_length) = (9999,9999);
foreach my $reference (@{$REF[$s]}) {
# print "$s $_ <=> $reference\n";
$reference = lc($reference) if $lowercase;
my @WORD = split(' ',$reference);
my $length = scalar(@WORD);
my $diff = abs($length_translation_this_sentence-$length);
if ($diff < $closest_diff) {
$closest_diff = $diff;
$closest_length = $length;
# print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n";
} elsif ($diff == $closest_diff) {
$closest_length = $length if $length < $closest_length;
# from two references with the same closeness to me
# take the *shorter* into account, not the "first" one.
}
for(my $n=1;$n<=4;$n++) {
my %REF_NGRAM_N = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$REF_NGRAM_N{$ngram}++;
}
foreach my $ngram (keys %REF_NGRAM_N) {
if (!defined($REF_NGRAM{$ngram}) ||
$REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) {
$REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram};
# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}<BR>\n";
}
}
}
}
$length_translation += $length_translation_this_sentence;
$length_reference += $closest_length;
for(my $n=1;$n<=4;$n++) {
my %T_NGRAM = ();
for(my $start=0;$start<=$#WORD-($n-1);$start++) {
my $ngram = "$n";
for(my $w=0;$w<$n;$w++) {
$ngram .= " ".$WORD[$start+$w];
}
$T_NGRAM{$ngram}++;
}
foreach my $ngram (keys %T_NGRAM) {
$ngram =~ /^(\d+) /;
my $n = $1;
# my $corr = 0;
# print "$i e $ngram $T_NGRAM{$ngram}<BR>\n";
$TOTAL[$n] += $T_NGRAM{$ngram};
if (defined($REF_NGRAM{$ngram})) {
if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) {
$CORRECT[$n] += $T_NGRAM{$ngram};
# $corr = $T_NGRAM{$ngram};
# print "$i e correct1 $T_NGRAM{$ngram}<BR>\n";
}
else {
$CORRECT[$n] += $REF_NGRAM{$ngram};
# $corr = $REF_NGRAM{$ngram};
# print "$i e correct2 $REF_NGRAM{$ngram}<BR>\n";
}
}
# $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram};
# print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n"
}
}
$s++;
}
my $brevity_penalty = 1;
my $bleu = 0;

my @bleu=();

for(my $n=1;$n<=4;$n++) {
if (defined ($TOTAL[$n])){
$bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0;
# print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n";
}else{
$bleu[$n]=0;
}
}

if ($length_reference==0){
printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n";
exit(1);
}

#if ($length_translation<$length_reference) {
# $brevity_penalty = exp(1-$length_reference/$length_translation);
#}

#$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) +
# my_log( $bleu[2] ) +
# my_log( $bleu[3] ) +
# my_log( $bleu[4] ) ) / 4) ;

my $bleu_1 = $brevity_penalty * exp((my_log( $bleu[1] )));

my $bleu_2 = $brevity_penalty * exp((my_log( $bleu[1] ) +
my_log( $bleu[2] ) ) / 2) ;

my $bleu_3 = $brevity_penalty * exp((my_log( $bleu[1] ) +
my_log( $bleu[2] ) +
my_log( $bleu[3] ) ) / 3) ;

my $bleu_4 = $brevity_penalty * exp((my_log( $bleu[1] ) +
my_log( $bleu[2] ) +
my_log( $bleu[3] ) +
my_log( $bleu[4] ) ) / 4) ;

printf "BLEU = %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n",
100*$bleu_1,
100*$bleu_2,
100*$bleu_3,
100*$bleu_4,
$brevity_penalty,
$length_translation / $length_reference,
$length_translation,
$length_reference;

sub my_log {
return -9999999999 unless $_[0];
return log($_[0]);
}


67 changes: 27 additions & 40 deletions eval_sentence_predictions.py
Expand Up @@ -13,28 +13,6 @@
from imagernn.solver import Solver
from imagernn.imagernn_utils import decodeGenerator, eval_split

from nltk.align.bleu import BLEU

# UTILS needed for BLEU score evaluation
def BLEUscore(candidate, references, weights):
p_ns = [BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1)]
if all([x > 0 for x in p_ns]):
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns))
bp = BLEU.brevity_penalty(candidate, references)
return bp * math.exp(s)
else: # this is bad
return 0

def evalCandidate(candidate, references):
"""
candidate is a single list of words, references is a list of lists of words
written by humans.
"""
b1 = BLEUscore(candidate, references, [1.0])
b2 = BLEUscore(candidate, references, [0.5, 0.5])
b3 = BLEUscore(candidate, references, [1/3.0, 1/3.0, 1/3.0])
return [b1,b2,b3]

def main(params):

# load the checkpoint
Expand All @@ -61,13 +39,13 @@ def main(params):

# iterate over all images in test set and predict sentences
BatchGenerator = decodeGenerator(checkpoint_params)
all_bleu_scores = []
n = 0
#for img in dp.iterImages(split = 'test', shuffle = True, max_images = max_images):
all_references = []
all_candidates = []
for img in dp.iterImages(split = 'test', max_images = max_images):
n+=1
print 'image %d/%d:' % (n, max_images)
references = [x['tokens'] for x in img['sentences']] # as list of lists of tokens
references = [' '.join(x['tokens']) for x in img['sentences']] # as list of lists of tokens
kwparams = { 'beam_size' : params['beam_size'] }
Ys = BatchGenerator.predict([{'image':img}], model, checkpoint_params, **kwparams)

Expand All @@ -77,30 +55,39 @@ def main(params):

# encode the human-provided references
img_blob['references'] = []
for gtwords in references:
print 'GT: ' + ' '.join(gtwords)
img_blob['references'].append({'text': ' '.join(gtwords)})
for gtsent in references:
print 'GT: ' + gtsent
img_blob['references'].append({'text': gtsent})

# now evaluate and encode the top prediction
top_predictions = Ys[0] # take predictions for the first (and only) image we passed in
top_prediction = top_predictions[0] # these are sorted with highest on top
candidate = [ixtoword[ix] for ix in top_prediction[1]]
print 'PRED: (%f) %s' % (top_prediction[0], ' '.join(candidate))
bleu_scores = evalCandidate(candidate, references)
print 'BLEU: B-1: %f B-2: %f B-3: %f' % tuple(bleu_scores)
img_blob['candidate'] = {'text': ' '.join(candidate), 'logprob': top_prediction[0], 'bleu': bleu_scores}
candidate = ' '.join([ixtoword[ix] for ix in top_prediction[1] if ix > 0]) # ix 0 is the END token, skip that
print 'PRED: (%f) %s' % (top_prediction[0], candidate)

# save for later eval
all_references.append(references)
all_candidates.append(candidate)

all_bleu_scores.append(bleu_scores)
img_blob['candidate'] = {'text': candidate, 'logprob': top_prediction[0]}
blob['imgblobs'].append(img_blob)

print 'final average bleu scores:'
bleu_averages = [sum(x[i] for x in all_bleu_scores)*1.0/len(all_bleu_scores) for i in xrange(3)]
blob['final_result'] = { 'bleu' : bleu_averages }
print 'FINAL BLEU: B-1: %f B-2: %f B-3: %f' % tuple(bleu_averages)

# use perl script to eval BLEU score for fair comparison to other research work
# first write intermediate files
print 'writing intermediate files into eval/'
open('eval/output', 'w').write('\n'.join(all_candidates))
for q in xrange(5):
open('eval/reference'+`q`, 'w').write('\n'.join([x[q] for x in all_references]))
# invoke the perl script to get BLEU scores
print 'invoking eval/multi-bleu.perl script...'
owd = os.getcwd()
os.chdir('eval')
os.system('./multi-bleu.perl reference < output')
os.chdir(owd)

# now also evaluate test split perplexity
gtppl = eval_split('test', dp, model, checkpoint_params, misc, eval_max_images = max_images)
print 'perplexity of ground truth words: %f' % (gtppl, )
print 'perplexity of ground truth words based on dictionary of %d words: %f' % (len(ixtoword), gtppl)
blob['gtppl'] = gtppl

# dump result struct to file
Expand Down
7 changes: 1 addition & 6 deletions visualize_result_struct.html
Expand Up @@ -59,9 +59,6 @@
border-bottom: 1px solid #555;
box-shadow: 0px 0px 4px 2px #555;
}
.bleu {
font-family: Courier, monospace;
}
.logprob {
font-family: Courier, monospace;
}
Expand All @@ -75,13 +72,12 @@
var current_img_i = 0;

function start() {
loadDataset('result_struct.json');
loadDataset('result_struct_b1.json');
}

function writeHeader() {
html = '<h2>Showing results for ' + db.checkpoint_params.dataset + ' on ' + db.imgblobs.length + ' images</h2>';
html += '<br>Eval params were: ' + JSON.stringify(db.params);
html += '<br>Final average BLEU was: ' + '(' + db.final_result.bleu[0].toFixed(2) + ',' + db.final_result.bleu[1].toFixed(2) + ',' + db.final_result.bleu[2].toFixed(2) + ')'
html += '<br>Final average perplexity of ground truth words: ' + db.gtppl.toFixed(2);
$("#blobsheader").html(html);
}
Expand Down Expand Up @@ -141,7 +137,6 @@

function insertAnnot(annot, dnew) {
dnew.append('div').attr('class', 'atxt').text(annot.text);
dnew.append('div').attr('class', 'bleu').text('BLEU: (' + annot.bleu[0].toFixed(2) + ',' + annot.bleu[1].toFixed(2) + ',' + annot.bleu[2].toFixed(2) + ')');
dnew.append('div').attr('class', 'logprob').text('logprob: ' + annot.logprob.toFixed(2));
}

Expand Down

0 comments on commit 8336615

Please sign in to comment.