# gather results

we gather the training/dev/test and generated train/dev/test resutls for better analysis

## training set

In [117]:
import sacrebleu
import os
from termcolor import colored, cprint 

In [121]:

def gather_data(data_file):
    """
    return a dict of user.
    each element is a list of src/tgt pair
    """
    user_train_data={}
    
    
    with open(data_file,'r') as fin:
        lines=fin.readlines()
        for line in lines:
            src,tgt=line.strip().split('￨')
            uid=int(tgt.split(' ')[0])
            tgt_=tgt.split(' ')[1:]
            tgt=''
            for word in tgt_:
                tgt+=word
                tgt+=' '
            tgt_sent=tgt[:-1]
            
            src_=src.split(' ')[1:]
            src=''
            for word in src_:
                src+=word
                src+=' '
            src_sent=src[:-1]
            src_tgt_pair={
                'src':src_sent,
                'tgt':tgt_sent
            }
            if uid not in user_train_data:
                user_train_data[uid]=[]
                
            user_train_data[uid].append(src_tgt_pair)
    return user_train_data
            
def gather_decode_data(data_file,ref_data):
    decode_data=gather_data(data_file)
    bleu = compute_bleu(ref_data,decode_data)
    return {'data':decode_data,'bleu':bleu}
# 4. showing results
def display_one_user(user_train_data,user_ref_data,user_decode_data_list,user):
    """
    a help function to display the decoded results of one user
    
    note that we have a user_decoded_data_list, 
    this is used for comparision for multiple decoded referrence
    """
    # ['grey','red','green','yellow','blue','magenta','cyan','white']
    color_list=['green','red','blue','magenta','cyan','white']
    src_color='yellow'
#     tgt_color=''
    
    # src_color
    # 3. print the blue score
    cprint('models\' bleu: ','red',attrs=['reverse', 'blink'])    
    for j in range(len(user_decode_data_list)):
        cprint(user_decode_data_list[j]['bleu'],color_list[j])
    cprint('user {}\'s bleu: '.format(user),'red',attrs=['reverse', 'blink'])
    for j in range(len(user_decode_data_list)):
        score=compute_user_bleu(user_ref_data,user_decode_data_list[j]['data'],user)
        cprint(score,color_list[j])
    
    
    # 2. print the comparison
    cprint('user {}\'s decoded: '.format(user),'red',attrs=['reverse', 'blink'])
    
    pair_num=len(user_ref_data[user])
    for i in range(pair_num):
        # ref
        cprint(user_ref_data[user][i]['src'],src_color)
        print(user_ref_data[user][i]['tgt'])
        for j in range(len(user_decode_data_list)):
            cprint(user_decode_data_list[j]['data'][user][i]['tgt'],color_list[j])
    
    
    # 1. get the train_data
    cprint('user {}\'s training data: '.format(user),'red',attrs=['reverse', 'blink'])
    for pair in user_train_data[user]:
        cprint(pair['src'],src_color)
        print(pair['tgt'])
        
    
    
    
def gather_decode_file(output_file,mode):
    decode_file='{}.txt'.format(mode)
    ref_file='../data/src_data_full_feat_tf_resplited_review_50k/{}_ref.txt'.format(mode)
    ref_data=gather_data(ref_file)
    user_data_decode = gather_decode_data(output_file,ref_data)
    print(output_file)
    return user_data_decode

def compute_bleu(ref_data,output_data):
    ref=[]
    output=[]
    for user in ref_data:
        ref_pairs=ref_data[user]
        output_pairs=output_data[user]        
        for pair in ref_pairs:
            ref.append(pair['tgt'])
        for pair in output_pairs:
            output.append(pair['tgt'])
    
    bleu = sacrebleu.corpus_bleu(output, [ref])
    return bleu.score
def compute_user_bleu(ref_data,output_data,user):
    ref=[]
    output=[]
    ref_pairs=ref_data[user]
    output_pairs=output_data[user]        
    for pair in ref_pairs:
        ref.append(pair['tgt'])
    for pair in output_pairs:
        output.append(pair['tgt'])
    
    bleu = sacrebleu.corpus_bleu(output, [ref],force=True)
    return bleu.score


In [92]:
# 1. we gather the data in the training set for each user.
train_data='../data/src_data_full_feat_tf_resplited_review_50k/train_ref.txt'
user_train_data = gather_data(train_data)

In [99]:
# 2. we gather the data in the test/dev set for each user.
data_file='../data/src_data_full_feat_tf_resplited_review_50k/test_full_ref.txt'
user_test_full_data = gather_data(data_file)
data_file='../data/src_data_full_feat_tf_resplited_review_50k/val_full_ref.txt'
user_val_full_data = gather_data(data_file)

data_file='../data/src_data_full_feat_tf_resplited_review_50k/test_ref.txt'
user_test_data = gather_data(data_file)
data_file='../data/src_data_full_feat_tf_resplited_review_50k/val_ref.txt'
user_val_data = gather_data(data_file)

In [108]:
# 3. we gather the results in test.
folder='s_tf_512R280'
save_prefix='model.best_' # ['model.best_','model.']
num=18
mode='test_full' # mode: [test,test_full,valid,valid_full]

decode_18=gather_decode_file(folder,save_prefix,num,mode)
num=10
decode_10=gather_decode_file(folder,save_prefix,num,mode)


# 3. we gather the results in test.
folder='s_tf_512R280'
save_prefix='model.best_' # ['model.best_','model.']
num=18
mode='test_full' # mode: [test,test_full,valid,valid_full]

decode_18=gather_decode_file(folder,save_prefix,num,mode)


../../Speaker-pytorch/outputs/s_tf_512R280/test_full.txt_Speaker_model.best_18.txt
../../Speaker-pytorch/outputs/s_tf_512R280/test_full.txt_Speaker_model.best_10.txt


In [126]:
mode='test_full' # mode: [test,test_full,valid,valid_full]
output_file='../../Speaker-pytorch/outputs/s_tf_512R280/test_full.txt_Speaker_model.best_18.txt'
s_tf_512R280_best_18=gather_decode_file(output_file,mode)



../../Speaker-pytorch/outputs/s_tf_512R280/test_full.txt_Speaker_model.best_18.txt


In [None]:
mode='test_full' # mode: [test,test_full,valid,valid_full]
output_file='../../Speaker-pytorch/outputs/s_tf_512R280/test_full.txt_Speaker_model.best_18.txt'
s_tf_512R280_best_18=gather_decode_file(output_file,mode)

In [125]:
mode='test_full'
output_file='../../DialoGPT/outputs/medium_10epochs_trial_2/test_full_ref.txt_step-12500.txt'
DialoGPT_medium_10epochs_trial_2_12500=gather_decode_file(output_file,mode)

../../DialoGPT/outputs/medium_10epochs_trial_2/test_full_ref.txt_step-12500.txt


In [129]:
# display
display_one_user(user_train_data=user_train_data,
                 user_ref_data=user_test_full_data,
                 user_decode_data_list=[s_tf_512R280_best_18,DialoGPT_medium_10epochs_trial_2_12500],
                 user=100)