In [42]:
import argparse
import csv
import wandb
import numpy as np
import pandas as pd
from scipy.stats import rankdata
import glob

In [43]:
# These are needed if accessing a wandb table
run_name = 'run-7b7fxaaf-AMFMconvai2grade_en:v0'
project = 'metric/chateval/'
table_name = 'AM-FM-convai2-grade_en.table.json'

# This is needed if accessing a stored csv file in the exp folder
run_id = '6x05lyco'

# This is always needed
gold_name = 'annotations.relevance'
metric_name = 'am_fm_scores'
dataset = 'convai'

In [44]:
def get_wandb_table(run_name, table_name):
    api = wandb.Api()
    art = api.artifact(run_name)
    table = art.get(table_name)
    return pd.DataFrame(data=table.data, columns=table.columns)

In [45]:
def get_table_from_exp(run_id):
    csv_path = '../chateval-exp/{}/*csv'.format(run_id)
    csv_path = glob.glob(csv_path)
    if len(csv_path) < 1:
        print('Run not found, check the symlink')
        return None
    return pd.read_csv(csv_path[0])
    

In [46]:
table = get_wandb_table(project + run_name, table_name)
#table = get_table_from_exp(run_id)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [48]:
gold_scores = table[gold_name]
gold_ranks = rankdata(gold_scores)

In [50]:
metric_scores = table[metric_name]
metric_ranks = rankdata(metric_scores)
metric_vs_gold = gold_ranks - metric_ranks

In [52]:
df = pd.DataFrame(list(zip(table['dialogue_id'],
                          gold_scores, gold_ranks,
                          metric_scores, metric_ranks,
                          np.absolute(metric_vs_gold))),
                 columns = ['dialogue_id', 'gold_score', 'gold_rank',
                            metric_name, '{}_rank'.format(metric_name), '{}_vs_gold'.format(metric_name)])
df = df.sort_values(by=['{}_vs_gold'.format(metric_name)], ascending=False)

In [53]:
df.to_csv('{}_{}_ranks_sorted.csv'.format(dataset, metric_name))