In [1]:
import os
import sklearn
import math

from sklearn import metrics
from scipy import stats
from collections import defaultdict

In [2]:
def breakdown_analysis(train_output, dev_output):

    train_samples = defaultdict(int)
    with open(train_output) as f:
        for line in f:
            items = line.strip().split(';')
            
            if len(items) != 8:
                continue
                
            kb, feat, value, message, question, answer = items[:6]
            kb = kb.split()
            kb[int(feat)] = value
            kb = ''.join(kb)
            
            fact = '|'.join([kb, feat, value, question, answer])

            train_samples[fact] += 1

            
    with open(dev_output) as f:
        num_questions = 0
        num_qef, num_qef_same_val, num_qef_diff_val = 0, 0, 0
        num_qnef, num_qnef_same_val, num_qnef_diff_val = 0, 0, 0

        correct_ans, correct_feat, correct_val = 0, 0, 0
        correct_qef, correct_qnef = 0, 0
        correct_qef_same_val, correct_qef_diff_val = 0, 0
        correct_qnef_same_val, correct_qnef_diff_val = 0, 0
        
        total_err = 0
        total_err_old_val = 0

        for line in f:
            items = line.strip().split(';')
            if len(items) != 8:
                continue

            kb, feat, value, message, question, gold_value, q_feat, q_value = items
            
            new_kb = kb.split()
            new_kb[int(feat)] = value
            new_kb = ''.join(new_kb)

            old_value = kb.split()[int(question)]
            num_questions += 1

            if question == q_feat:
                correct_feat += 1
            if gold_value == q_value:
                correct_val += 1

            if question == feat:
                num_qef += 1
                assert value == gold_value

                if old_value == gold_value:
                    num_qef_same_val += 1
                else:
                    num_qef_diff_val += 1

                if q_value == gold_value and q_feat == question:
                    correct_qef += 1
                    correct_ans += 1
                    if q_value == old_value:
                        correct_qef_same_val += 1
                    else:
                        correct_qef_diff_val += 1
                elif q_value != gold_value and q_feat == question:
                    total_err += 1
                    if q_value == old_value:
                        total_err_old_val += 1
                    

            # question is different from the fact
            else:
                num_qnef += 1
                assert gold_value == old_value

                if value == gold_value:
                    num_qnef_same_val += 1
                else:
                    num_qnef_diff_val += 1

                if q_value == gold_value and q_feat == question:
                    correct_qnef += 1
                    correct_ans += 1

                    if q_value == value:
                        correct_qnef_same_val += 1
                    else:
                        correct_qnef_diff_val += 1


        total_acc = round(correct_ans * 100.0 / num_questions, 1)
        feat_acc = round(correct_feat * 100.0 / num_questions, 1)
        val_acc = round(correct_val * 100.0 / num_questions, 1)
        print('Average acc:', total_acc)
        print('Feat acc:', feat_acc)
        print('Value acc:', val_acc)

        print('Question == Fact:')
        acc = round(correct_qef * 100.0 / num_qef, 1)
        percentage = round(num_qef * 100.0 / num_questions, 1)
        print('Percentage:', percentage)
        print('Avg acc:', acc)

        print(">>> old value == new_value")
        acc = round(correct_qef_same_val * 100.0 / num_qef_same_val, 1)
        percentage = round(num_qef_same_val * 100.0 / num_qef, 1)
        print("..... percentage:", percentage, num_qef_same_val, "/", num_qef)
        print("..... acc:", acc)
        print(">>> old value != new_value")
        acc = round(correct_qef_diff_val * 100.0 / num_qef_diff_val, 1)
        percentage = round(num_qef_diff_val * 100.0 / num_qef, 1)
        print("..... percentage:", percentage, num_qef_diff_val, "/", num_qef)
        print("..... acc:", acc)

        print('Question != Fact')
        acc = round(correct_qnef * 100.0 / num_qnef, 1)
        percentage = round(num_qnef * 100.0 / num_questions, 1)
        print('Percentage:', percentage)
        print('Avg acc:', acc)

        print(">>> old value == new_value")
        if num_qnef_same_val > 0:
            acc = round(correct_qnef_same_val * 100.0 / num_qnef_same_val, 1)
        else:
            acc = -1
        percentage = round(num_qnef_same_val * 100.0 / num_qnef, 1)
        print("..... percentage:", percentage, num_qnef_same_val, "/", num_qnef)
        print("..... acc:", acc)
        print(">>> old value != new_value")
        acc = round(correct_qnef_diff_val * 100.0 / num_qnef_diff_val, 1)
        percentage = round(num_qnef_diff_val * 100.0 / num_qnef, 1)
        print("..... percentage:", percentage, num_qnef_diff_val, "/", num_qnef)
        print("..... acc:", acc)

        print('\n\n')

In [33]:
breakdown_analysis('../EGG/egg/zoo/common_ground/outputs/5feats_regularization/earlystop_wdecay_2.train', '../EGG/egg/zoo/common_ground/outputs/5feats_regularization/earlystop_wdecay_2.train')

Average acc: 83.4
Feat acc: 100.0
Value acc: 83.4
Question == Fact:
Percentage: 20.6
Avg acc: 19.7
>>> old value == new_value
..... percentage: 19.6 324 / 1652
..... acc: 100.0
>>> old value != new_value
..... percentage: 80.4 1328 / 1652
..... acc: 0.1
Question != Fact
Percentage: 79.3
Avg acc: 100.0
>>> old value == new_value
..... percentage: 19.7 1251 / 6348
..... acc: 100.0
>>> old value != new_value
..... percentage: 80.3 5097 / 6348
..... acc: 100.0





In [4]:
def compute_MI(output_file):
    
    with open(output_file) as f:
        facts, feats, values, messages = [], [], [], []
        num_samples = 0
        fact_dist = defaultdict(int)
        
        for line in f:
            items = line.strip().split(';')
            if len(items) != 10:
                continue
                
            kb, feat, value, message, question, answer, q_feat, q_value, f_feat, f_value = items
            num_samples += 1
            
            new_kb = kb.split()
            new_kb[int(feat)] = value
            new_kb = ' '.join(new_kb)
            
            fact = new_kb + '-' + feat + '-' + value
            fact_dist[fact] += 1
            facts.append(fact)
            feats.append(feat)
            values.append(value)
            messages.append(message)
    

    MI_message_fact = round(sklearn.metrics.mutual_info_score(messages, facts), 2)
    print("MI(message, facts):", MI_message_fact)
    
    H_facts = round(sklearn.metrics.mutual_info_score(facts, facts), 2)
    print("H(facts):", H_facts)

In [15]:
def analyze_output(output_file):
    
    fail = defaultdict(int)
    success = defaultdict(int)
    sender_symbols = defaultdict(lambda: defaultdict(int))
    receiver_symbols = defaultdict(lambda: defaultdict(int))
    correct_answer = 0

    with open(output_file) as f:        
        for line in f:
            items = line.strip().split(';')
            if len(items) != 8:
                continue

            kb, feat, value, message, question, gold_value, q_feat, q_value = items

            new_kb = kb.split()
            new_kb[int(feat)] = value
            new_kb = ''.join(new_kb)

            kb = ''.join(kb.split())

            fact = new_kb + ' ' + feat + ' ' + value
            answer = kb + '|' + new_kb + '|' + q_feat + '|' + q_value
            sender_symbols[message][fact] += 1
            receiver_symbols[message][answer] += 1

            item = [kb, '\t', new_kb, '\t', message, '\t', question, "\t", q_feat, '\t', gold_value, '\t', q_value]
            key = "".join(item)
            if q_value != gold_value or q_feat != question:
                fail[key] += 1
            else:
                success[key] += 1

            if q_value == gold_value and q_feat == question:
                correct_answer += 1

    print('FAIL')
    print("oldKB\tnewKB\tmessage\tquest\tq_feat\tval\tq_val\tfreq")
    for case in sorted(fail):
        print(case, "\t", fail[case])


#     print()
#     print("SUCCESS")
#     print("oldKB\tnewKB\tmessage\tquest\tq_feat\tval\tq_val\tfreq")
#     for case in sorted(success):
#         print(case, "\t", success[case])


    print()
    print("Messages sent by FA:")
    print("Number of symbols:", len(sender_symbols))
    for sym in sorted(sender_symbols):
        print()
        for fact in sender_symbols[sym]:
            print(sym, fact, sender_symbols[sym][fact])

#     print()
#     print("Messages received by QA:")
#     print("Number of symbols:", len(receiver_symbols))
#     for sym in sorted(receiver_symbols):
#         print()
#         for ans in receiver_symbols[sym]:
#             print(sym, ans)


In [16]:
analyze_output('../EGG/egg/zoo/common_ground/outputs/5feats_regularization/no_regularization_3.train')

FAIL
oldKB	newKB	message	quest	q_feat	val	q_val	freq
00112	02112	580	1	1	2	0 	 1
00123	00113	379	0	0	0	1 	 1
00240	00240	600	0	0	0	4 	 1
02201	00201	147	1	1	0	3 	 1
04223	03223	147	1	1	3	4 	 1
04244	04204	147	3	3	0	4 	 1
10213	10213	379	2	2	2	3 	 1
23204	23201	147	4	4	1	4 	 1
24034	34034	596	0	0	3	4 	 1
30434	34434	596	1	1	4	0 	 1
33004	31004	506	2	2	0	4 	 1
40312	10312	580	0	0	1	4 	 1
40414	40412	580	4	4	2	4 	 1
41142	41342	506	2	2	3	1 	 1
41200	41200	506	2	2	2	4 	 1
41401	41402	506	4	4	2	1 	 1
41404	21404	506	0	0	2	4 	 1
41411	11411	506	0	0	1	4 	 1
42002	22002	580	0	0	2	4 	 1
42240	32240	600	0	0	3	4 	 1
43024	23024	826	0	0	2	4 	 1
43221	43221	826	2	2	2	0 	 1

Messages sent by FA:
Number of symbols: 31

111 44111 4 1 2
111 44111 0 4 3

13 01104 4 4 1
13 20004 3 0 1
13 10004 3 0 2
13 24004 4 4 1
13 24000 1 4 1

147 14223 3 2 2
147 01123 4 3 1
147 30223 4 3 1
147 04032 4 2 2
147 34002 2 0 3
147 43043 2 0 1
147 20243 3 4 2
147 00104 3 0 1
147 03103 0 0 1
147 13204 2 2 2
147 30223 1 0 2
1

379 32333 2 3 3
379 00143 1 0 1
379 20313 4 3 1
379 40213 1 0 2
379 12203 1 2 1
379 30033 0 3 1
379 12013 0 1 3
379 22213 4 3 1
379 02343 2 3 1
379 32023 4 3 3
379 32423 0 3 1
379 32443 1 2 1
379 32113 4 3 2
379 42403 1 2 2
379 32233 4 3 1
379 02143 3 4 1
379 02313 1 2 2
379 30343 0 3 2
379 32003 0 3 2
379 22413 3 1 2
379 40303 4 3 2
379 20323 4 3 1
379 42033 4 3 2
379 00133 2 1 1
379 30313 4 3 2
379 10143 1 0 3
379 10033 2 0 1
379 00133 3 3 1
379 10323 2 3 3
379 10343 3 4 2
379 30003 4 3 1
379 30413 1 0 3
379 02423 2 4 1
379 40313 0 4 1
379 22233 1 2 1
379 12103 1 2 2
379 02133 0 0 4
379 22430 4 0 1
379 10443 3 4 2
379 12423 4 3 1
379 32103 1 2 1
379 20133 2 1 1
379 02000 1 2 2
379 00333 3 3 2
379 32113 1 2 1
379 22043 4 3 2
379 02033 0 0 1
379 42443 1 2 1
379 10123 0 1 2
379 20233 4 3 1
379 12133 4 3 1
379 44413 1 4 1
379 20333 1 0 2
379 30303 4 3 1
379 22403 0 2 1
379 10023 4 3 1
379 10233 1 0 3
379 12043 2 0 3
379 32133 2 1 1
379 10143 2 1 1
379 20213 1 0 1
379 02033 3 3 1
379 4200

512 30441 3 4 1
512 33442 4 2 1
512 00412 3 1 1
512 33112 4 2 1
512 31230 4 0 1
512 30424 2 4 3
512 30232 3 3 1
512 20434 3 3 2
512 03314 4 4 2
512 33122 4 2 1
512 30211 4 1 1
512 20404 3 0 2
512 11200 4 0 2
512 31420 4 0 2
512 14441 0 1 1
512 13234 4 4 1
512 01214 3 1 2
512 30011 3 1 2
512 30111 3 1 1
512 30400 3 0 3
512 01312 4 2 1
512 14421 4 1 1
512 40004 3 0 1
512 11042 4 2 1
512 30121 3 2 1
512 33032 4 2 1
512 01014 0 0 1
512 00432 3 3 1
512 33440 1 3 1
512 31440 4 0 1
512 01420 4 0 1
512 23444 4 4 1
512 01342 4 2 2
512 30102 4 2 1
512 11032 4 2 1
512 30420 3 2 1
512 33111 0 3 1
512 44410 0 4 1
512 01424 4 4 1
512 30430 3 3 1
512 33441 4 1 1
512 01400 4 0 1
512 21232 0 2 1
512 40444 3 4 1
512 30012 3 1 1
512 33132 4 2 1
512 44011 4 1 1
512 13232 4 2 1
512 14441 3 4 1
512 33410 4 0 1

516 01124 2 1 1
516 24424 0 2 1
516 44412 2 4 1
516 14124 3 2 1
516 44402 2 4 1

547 20110 3 1 1
547 23110 4 0 2
547 22210 3 1 2
547 12010 3 1 1
547 22120 4 0 2
547 40430 2 4 1
547 00114 3 1 2
547 10

600 34124 2 1 1
600 02022 3 2 1
600 20340 0 2 2
600 24124 2 1 2
600 23210 2 2 2
600 14344 0 1 1
600 01040 3 4 2
600 04020 2 0 1
600 32140 0 3 2
600 10320 2 3 1
600 02220 4 0 1
600 43240 3 4 1
600 23140 1 3 1
600 23200 2 2 2
600 33030 2 0 2
600 04034 2 0 1
600 40020 2 0 2
600 03120 2 1 1
600 33044 0 3 1
600 04344 3 4 3
600 14200 3 0 1
600 01241 4 1 1
600 43340 4 0 3
600 14200 2 2 2
600 01210 4 0 1
600 23210 3 1 1
600 00214 2 2 1
600 00320 2 3 2
600 22020 0 2 1
600 12440 2 4 3
600 10240 3 4 2
600 33440 3 4 1
600 30010 3 1 1
600 23010 0 2 2
600 23440 2 4 1
600 14240 2 2 3
600 24310 0 2 1
600 02210 3 1 1
600 44310 0 4 1
600 01140 4 0 2
600 44044 0 4 2
600 00444 3 4 2
600 44044 4 4 2
600 10244 4 4 3
600 03040 2 0 1
600 44220 3 2 1
600 13244 2 2 2
600 31020 0 3 3
600 24322 2 3 1
600 24420 2 4 1
600 32200 0 3 1
600 30440 0 3 1
600 30241 4 1 1
600 04240 1 4 1
600 22040 0 2 1
600 32220 3 2 1
600 00140 2 1 1
600 34120 0 3 1
600 00024 3 2 1
600 03241 2 2 1
600 24220 3 2 1
600 01044 2 0 2
600 3303

916 14122 1 4 1
916 14124 1 4 1
916 14122 4 2 1
916 14422 1 4 1
916 14322 1 4 1
916 14024 1 4 1
916 14421 1 4 1
916 14441 1 4 1

938 24330 0 2 2
938 02341 0 0 2
938 00331 3 3 2
938 23330 0 2 1
938 04331 4 1 3
938 02301 3 0 1
938 00341 0 0 1
938 04331 0 0 1
938 04131 0 0 1
938 40330 0 4 1
938 24331 1 4 1
938 02031 0 0 1
938 02030 0 0 2
938 00334 2 3 1
938 24331 0 2 1
938 02330 0 0 1
938 00034 2 0 1
938 00132 3 3 2
938 14331 0 1 1
938 02130 0 0 2
938 00134 3 3 1
938 30031 0 3 1
938 20100 0 2 1
938 01330 0 0 2
938 00331 0 0 1
938 14131 0 1 2
938 02030 3 3 1
938 00331 2 3 1
938 04131 4 1 1
938 04332 4 2 1
938 02032 3 3 1
938 24302 4 2 1
938 02101 0 0 2
938 02111 0 0 1
938 00030 0 0 1
938 20131 0 2 1
938 22300 0 2 1
938 30130 0 3 1
938 02302 0 0 1
938 00131 1 0 1
938 00332 4 2 1
938 24341 0 2 3
938 20030 4 0 1
938 02301 0 0 1
938 20300 0 2 1
938 24130 0 2 2
938 20134 3 3 1
938 00031 0 0 1
938 20430 0 2 1
938 20130 0 2 3
938 04331 1 4 1
938 04311 4 1 1
938 00111 0 0 1
938 00311 0 0 1
938 003

In [17]:
def mi(x, y):
    return round(sklearn.metrics.mutual_info_score(x, y) / np.log(2), 2)

In [35]:
# compute mutual information
import numpy as np

output_dir = '../EGG/egg/zoo/common_ground/outputs/5feats_regularization'
num_feats = 5
num_runs = 5

for i in range(0, num_runs):
    output_file = 'earlystop_' + str(i+1) + '.train'
    output_file = os.path.join(output_dir, output_file)
    
    print('Run:', i+1, output_file)
    with open(output_file) as f:
        messages, new_kbs = [], []
        fact_feats, fact_values = [], []
        facts = []
        kb_feats = []
        kb_facts = []
        feat_values = defaultdict(list)
        
        for line in f:
            items = line.strip().split(';')
            if len(items) != 8:
                continue
                
            kb, feat, value, message = items[:4]
            
            new_kb = kb.split()
            new_kb[int(feat)] = value
            fact = feat + '-' + value
            kb_fact = ''.join(new_kb) + '-' + fact
            kb_feat = ''.join(new_kb) + '-' + feat
            
            new_kbs.append(''.join(new_kb))
            messages.append(message)
            fact_feats.append(feat)
            fact_values.append(value)
            facts.append(fact)
            kb_facts.append(kb_fact)
            kb_feats.append(kb_feat)
            for j in range(num_feats):
                feat_values[j].append(new_kb[j])
    
    print('Number of symbols:', len(set(messages)))
    print(set(messages))
    print()
    
    MI_message_updated_feat = mi(messages, fact_feats)
    H_updated_feat = mi(fact_feats, fact_feats)
    print("MI(message, updated_feat):", MI_message_updated_feat, 'entropy:', H_updated_feat)
    
    MI_message_updated_value = mi(messages, fact_values)
    H_updated_value = mi(fact_values, fact_values)
    print("MI(message, updated_value):", MI_message_updated_value, 'entropy:', H_updated_value)
    
    MI_message_new_kb = mi(messages, new_kbs)
    print("MI(message, new KB):", MI_message_new_kb)
    
    MI_message_fact = mi(messages, facts)
    print("MI(message, facts):", MI_message_fact)
    
    MI_message_kb_feat = mi(messages, kb_feats)
    print("MI(message, KB-feats):", MI_message_kb_feat)
    
    MI_message_kb_fact = mi(messages, kb_facts)
    print("MI(message, KB-facts):", MI_message_kb_fact)
    
    H_message = mi(messages, messages)
    print("H(message):", H_message)
    H_new_kb = mi(new_kbs, new_kbs)
    print("H(new_kb):", H_new_kb)
    H_facts = mi(facts, facts)
    print("H(facts):", H_facts)
        
    H_kb_feats = mi(kb_feats, kb_feats)
    print("H(kb_feats):", H_kb_feats)
    H_kb_facts = mi(kb_facts, kb_facts)
    print("H(kb_facts):", H_kb_facts)
    
    
    print()
            

Run: 1 ../EGG/egg/zoo/common_ground/outputs/5feats_regularization/earlystop_1.train
Number of symbols: 1
{'640'}

MI(message, updated_feat): -0.0 entropy: 2.32
MI(message, updated_value): -0.0 entropy: 2.32
MI(message, new KB): 0.0
MI(message, facts): 0.0
MI(message, KB-feats): 0.0
MI(message, KB-facts): 0.0
H(message): 0.0
H(new_kb): 11.3
H(facts): 4.64
H(kb_feats): 12.5
H(kb_facts): 12.5

Run: 2 ../EGG/egg/zoo/common_ground/outputs/5feats_regularization/earlystop_2.train
Number of symbols: 2
{'867', '685'}

MI(message, updated_feat): 0.01 entropy: 2.32
MI(message, updated_value): 0.02 entropy: 2.32
MI(message, new KB): 0.88
MI(message, facts): 0.11
MI(message, KB-feats): 0.97
MI(message, KB-facts): 0.97
H(message): 0.97
H(new_kb): 11.3
H(facts): 4.64
H(kb_feats): 12.5
H(kb_facts): 12.5

Run: 3 ../EGG/egg/zoo/common_ground/outputs/5feats_regularization/earlystop_3.train
Number of symbols: 2
{'897', '90'}

MI(message, updated_feat): 0.01 entropy: 2.32
MI(message, updated_value): 0.01 e