This notebook generates results for JM, DP, TFIDF, BM25 tuning, SDM tuning and PLM

In [None]:
%matplotlib inline

import sys 
import os 

nb_dir = os.getcwd()
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

from plotlib.loaders import *
from plotlib.plotters import *
from phdconf.config import *
from phdconf import config 

plt.style.use('seaborn-white')

In [None]:
queries = load_queries(AUS_TOPIC_PATH)
broad, specific = load_query_types(queries)

In [None]:
index_names = ['filtered-phrasestop']
qrel_paths = [config.AUS_QREL_PATH, config.AUS_QREL_PATH]#, config.SIGIR_QREL_PATH]
rel_levels = [config.AUS_REL_LEVEL, config.AUS_REL_LEVEL]#, config.SIGIR_REL_LEVEL]
display_names = ['AUS', 'FIL_STOP']#, 'SIGIR']

In [None]:
mu = 2000
base_df = load_1d_dfs(['filtered-phrasestop'], qrel_paths, os.path.join(BASE_DIR, 'preprocessing', 'dirichlet_prior'), 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu, mu, 1)[0][0]
base_qry = load_1d_dfs(['filtered-phrasestop'], qrel_paths, os.path.join(BASE_DIR, 'preprocessing', 'dirichlet_prior'), 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu, mu, 1, per_query=True)[0][0]

In [None]:
jm_path = os.path.join(BASE_DIR, 'jelinek_mercer')

lambda_start = 0.0
lambda_end = 1.0
increment = 0.05

In [None]:
jm_dfs = load_1d_dfs(index_names, qrel_paths, jm_path, 'case-topics-{0}-unigram_jm_lambda_{1:.2f}.run', rel_levels, lambda_start, lambda_end, increment)

In [None]:
jm_fig = plot_tune_1d_comp([''], RERANK_METRICS, jm_dfs, lambda_start, lambda_end, increment)
# jm_fig.savefig('figures/jm-tune.pdf')

## Dirichlet Prior 

In [None]:
dir_path = os.path.join(BASE_DIR, 'preprocessing', 'dirichlet_prior')

mu_start = 300.0
mu_end = 3000.0
mu_increment = 50.0

In [None]:
dir_dfs = load_1d_dfs(index_names, qrel_paths, dir_path, 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu_start, mu_end, mu_increment)

In [None]:
dir_fig = plot_tune_1d_comp(['ausnl'], RERANK_METRICS, dir_dfs, mu_start, mu_end, mu_increment)
# dir_fig.savefig('figures/dir-tune.pdf')

# Load BM25 

In [None]:
bm25_path = 'bm25'

def load_bm25_dfs(index_names, qrel_paths, results_path, rel_levels, name='{0}-filtered-phrasestop-unigram_bm25_k1_{1:.2f}_b_{2:.2f}.run', per_query=False, filtered=None):
    k1_start = 1.2
    k1_end = 3.05
    b_start = 0.05
    b_end = 1.05
    increment = 0.05
    
    dfs = []
    
    for i in range(len(index_names)):
        dfs.append([])
        cnt = 0
        for k1 in np.arange(k1_start, k1_end, increment):
            dfs[i].append([])
            for b in np.arange(b_start, b_end, increment):
                dfs[i][cnt].append(to_trec_df(qrel_paths[i], os.path.join(results_path, name.format(index_names[i], k1, b)), rel_levels[i], per_query=per_query, filtered=filtered))
            cnt+=1
    
    return dfs 

In [None]:
bm25_dfs = load_bm25_dfs(['case-topics'], qrel_paths, os.path.join(BASE_DIR, bm25_path), rel_levels)

# Cross val results

In [None]:
tt_folds = read_folds(AUS_FOLDS)

In [None]:
jm_qry_dfs = load_1d_dfs(index_names, qrel_paths, jm_path, 'case-topics-{0}-unigram_jm_lambda_{1:.2f}.run', rel_levels, lambda_start, lambda_end, increment, per_query=True)
dir_qry_dfs = load_1d_dfs(index_names, qrel_paths, dir_path, 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu_start, mu_end, mu_increment, per_query=True)
tf_df = load_1d_dfs(index_names, qrel_paths, os.path.join(BASE_DIR, 'tfidf'), 'case-topics-filtered-phrasestop-unigram_tfidf.run', rel_levels, 1, 1, 1, per_query=True)
tfln_df = load_1d_dfs(index_names, qrel_paths, os.path.join(BASE_DIR, 'tfidf'), 'case-topics-filtered-phrasestop-unigram_tfidfln.run', rel_levels, 1, 1, 1, per_query=True)

In [None]:
bm25_query_dfs = load_bm25_dfs(['case-topics'], qrel_paths, os.path.join(BASE_DIR, bm25_path), rel_levels, per_query=True)

In [None]:
cv_df = pd.DataFrame(columns=config.METRIC_NAMES)

for ab, runs in zip(['JM', 'DP', 'BM25', 'TFIDF', 'TFIDF\subscript{norm}'], [jm_qry_dfs[0], dir_qry_dfs[0], list(chain.from_iterable(bm25_query_dfs[0])), tf_df[0], tfln_df[0]]):
    cross = cross_validation(runs, tt_folds, config.METRIC_NAMES, base_qry)
    cv_df.loc[ab] = cross[0]

In [None]:
# write_table('tables/ausnl-traditional', bold_max(cv_df).drop('unjudged@20',axis='columns').rename(config.METRIC_NAMES, axis='columns').to_latex(escape=False))

In [None]:
# same but for sigir
jm_qry_dfs = load_1d_dfs(['sigir-stop'], SIGIR_QREL_PATH, jm_path, 'sigir-topic-topics-{0}-unigram_jm_lambda_{1:.2f}.run', rel_levels, lambda_start, lambda_end, increment, per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])
dir_qry_dfs = load_1d_dfs(['sigir-stop'], SIGIR_QREL_PATH, os.path.join(BASE_DIR, 'dirichlet_prior'), 'sigir-topic-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu_start, mu_end, mu_increment, per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])
tf_df = load_1d_dfs(['sigir-stop'], SIGIR_QREL_PATH, os.path.join(BASE_DIR, 'tfidf'), 'sigir-topic-topics-sigir-stop-unigram_tfidf.run', rel_levels, 1, 1, 1, per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])
tfln_df = load_1d_dfs(['sigir-stop'], SIGIR_QREL_PATH, os.path.join(BASE_DIR, 'tfidf'), 'sigir-topic-topics-sigir-stop-unigram_tfidfln.run', rel_levels, 1, 1, 1, per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])

In [None]:
bm25_query_dfs = load_bm25_dfs(['sigir-topic-topics'], SIGIR_QREL_PATH, os.path.join(BASE_DIR, bm25_path), rel_levels, name='{0}-sigir-stop-unigram_bm25_k1_{1:.2f}_b_{2:.2f}.run', per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])

In [None]:
base_sig = load_1d_dfs(['sigir-stop'], SIGIR_QREL_PATH, os.path.join(BASE_DIR, 'dirichlet_prior'), 'sigir-topic-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, 1500, 1500, 1, per_query=True, filtered=[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97])[0][0]

#sig_folds = [[[1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97], [1, 3, 9, 19, 23, 25, 31, 37, 39, 45, 51, 97]]]
sig_folds = read_folds('sigir-folds.txt')

In [None]:
cv_df = pd.DataFrame(columns=config.METRIC_NAMES)
for ab, runs in zip(['JM', 'DP', 'BM25', 'TFIDF', 'TFIDF\subscript{norm}'], [jm_qry_dfs[0], dir_qry_dfs[0], list(chain.from_iterable(bm25_query_dfs[0])), tf_df[0], tfln_df[0]]):
    cross = cross_validation(runs, sig_folds, config.METRIC_NAMES, base_sig)
    cv_df.loc[ab] = cross[0]
    
# write_table('tables/sigir-traditional', bold_max(cv_df).drop('unjudged@20',axis='columns').rename(config.METRIC_NAMES, axis='columns').to_latex(escape=False))

In [None]:
jm_qry_dfs = None
dir_qry_dfs = None
bm25_qry_dfs = None

# BM25 max results

In [None]:
# get max run 

def create_max_df(bm25_dfs): 
    k1_start = 1.2
    k1_end = 3.05
    b_start = 0.05
    b_end = 1.05
    increment = 0.05
    
    num_x = len(bm25_dfs[0][0])
    num_y = len(bm25_dfs[0])
    
    x = np.array([[i]*num_x for i in np.arange(k1_start, k1_end, increment)])
    y = np.array([list(np.arange(b_start, b_end, increment))] * num_y)
    
    max_for = {}
    for i in range(len(bm25_dfs)):
        for j in range(len(bm25_dfs[i])):
            inds = bm25_dfs[i][j].index
            for ind in inds:
                if ind not in max_for:
                    max_for[ind] = (bm25_dfs[i][j][ind], i, j)
                else:
                    if bm25_dfs[i][j][ind] > max_for[ind][0]:
                        max_for[ind] = (bm25_dfs[i][j][ind], i, j)
              
    for measure in max_for:
        max_for[measure] = {'b': y[0][max_for[measure][2]], 'k1': x[max_for[measure][1]][0], 'score': max_for[measure][0]}
        
    max_df = pd.DataFrame.from_dict(max_for, orient='index')
    return max_df

In [None]:
bm25_max_df = create_max_df(bm25_dfs[0])

In [None]:
bm25_max_df

In [None]:
print(bm25_max_df[bm25_max_df.index.isin(config.METRIC_NAMES)].rename(index=config.METRIC_NAMES).round(4).to_latex())

In [None]:
def plot_bm25(metric_names, dfs, ylims=[]): 
    r = int(len(metric_names)/2)
    c = r
    if len(metric_names)%2 != 0:
        c += 1 
    
    fig, axs = plt.subplots(r, c, subplot_kw=dict(projection='3d'))
    fig.set_size_inches(16, 8)
    
    k1_start = 1.2
    k1_end = 3.05
    b_start = 0.05
    b_end = 1.05
    increment = 0.05
    
    num_x = len(dfs[0])
    num_y = len(dfs)
    
    x = np.array([[i]*num_x for i in np.arange(k1_start, k1_end, increment)])
    y = np.array([list(np.arange(b_start, b_end, increment))] * num_y)
            
    cnt = 0 
    row = 0
    for m in metric_names:

        z = [[y[m] for y in x] for x in dfs]
        z = np.array(z)

        axs[row, cnt].plot_surface(x, y, z, cmap=cm.gray)

        axs[row, cnt].tick_params(labelsize=10)
        axs[row, cnt].zaxis.set_major_formatter(FormatStrFormatter('%.3f'))

        print(row, cnt, m)
        axs[row, cnt].set_zlabel(metric_names[m], fontsize=14, rotation='vertical')
        
#         axs[row, cnt].view_init(30,20)

        cnt += 1 
        if cnt >= c:
            cnt = 0
            row += 1
                
    if len(metric_names) % 2 != 0: 
        fig.delaxes(axs[row, -1])
        
#     for i in range(len(ylims)):
#         plt.gcf().get_axes()[i].set_ylim(ymax=ylims[i])
                
#     fig.subplots_adjust(wspace=0.07, hspace=0.05, left=0.01, right=1.7)
    return fig 

bm25_fig = plot_bm25(RERANK_METRICS, bm25_dfs[0], ylims=ALL_YLIMS)

In [None]:
# bm25_fig.savefig('figures/ausnl-bm25-tune.pdf')

In [None]:
def plot_single_bm25(dfs, metric):
    fig, axs = plt.subplots(1, 1, subplot_kw=dict(projection='3d'))
    fig.set_size_inches(16, 8)
    
    k1_start = 1.2
    k1_end = 3.05
    b_start = 0.05
    b_end = 1.05
    increment = 0.05
    
    num_x = len(dfs[0])
    num_y = len(dfs)
    
    x = np.array([[i]*num_x for i in np.arange(k1_start, k1_end, increment)])
    y = np.array([list(np.arange(b_start, b_end, increment))] * num_y)
    
    z = [[y[list(metric.keys())[0]] for y in x] for x in dfs]
    z = np.array(z)
    
    axs.plot_surface(x, y, z, cmap=cm.gray)

    axs.tick_params(axis='x',labelsize=25, pad=3)
    axs.tick_params(axis='y',labelsize=25, pad=3)
    axs.tick_params(axis='z',labelsize=25, pad=13)
    axs.set_ylabel('b', fontsize=35, rotation='horizontal', labelpad=25)
    axs.set_xlabel('k1', fontsize=35, rotation='horizontal', labelpad=25)
    axs.zaxis.set_major_formatter(FormatStrFormatter('%.3f'))

    axs.set_zlabel(list(metric.values())[0], fontsize=30, rotation='vertical', labelpad=27)
    
    return fig

bm25_fig = plot_single_bm25(bm25_dfs[0], {'rbp@0.80': 'RBP'})
# bm25_fig.savefig('figures/ausnl-bm25-tune.pdf')

In [None]:
bm25_dfs = load_bm25_dfs(['sigir-topic-topics'], [config.SIGIR_QREL_PATH], 
    os.path.join(BASE_DIR, bm25_path), name='{0}-sigir-stop-unigram_bm25_k1_{1:.2f}_b_{2:.2f}.run', 
    rel_levels=[config.SIGIR_REL_LEVEL])

In [None]:
bm25_fig = plot_single_bm25(bm25_dfs[0], {'rbp@0.80': 'RBP'})
# bm25_fig.savefig('figures/sigir-bm25-tune.pdf')

# SDM 

In [None]:
def all_combinations(n, k, increment: float):
    ret = []
    cnt = 0
    inp = [0] * (k + 1)
    inc = int(increment * 100)
    mtc = False 
    t = n
    h = 0
    while True:
        if mtc:
            if t > 1:
                h = 0
            h+=1
            t = inp[h]
            inp[h] = 0
            inp[1] = t - 1
            inp[h + 1]+=1
        else:
            inp[1] = n
            for i in range(2, k+1):
                inp[i] = 0
        
        allVal = True
        for i in range(1, k+1):
            if inp[i] == n:
                break
            
            if inp[i] % inc != 0:
                allVal = False
                break
        
        if allVal:
            add = []
            for i in range(1, k+1):
                add.append(float(inp[i]) / 100.0)
            
            ret.append(add)
            cnt+=1
        

        mtc = inp[k] != n
        if not mtc:
            break
    
    return ret 

In [None]:
sdm_measures = []
smooth_measures = []
metric_names = ['ERR', 'RBP']
for m in metric_names:
    sdm_measures.append(load_dfs(config.AUS_QREL_PATH, config.AUS_REL_LEVEL, '', [os.path.join(BASE_DIR, 'grid-search', 'sdm-'+str(x)+"-window-"+m+"-combine-max.run") for x in range(1, 21)]))
    smooth_measures.append(load_dfs(config.AUS_QREL_PATH, config.AUS_REL_LEVEL, '', [os.path.join(BASE_DIR, 'grid-search', 'sdm-'+str(x)+"-smooth-window-"+m+"-combine-max.run") for x in range(1, 21)]))

In [None]:
def plot_different(dfs, smooth_dfs, base, metric_names, start, end, increment, names=[], styles=[], y_lims=[], legend_x=0.96, legend_y=0.46):
    if len(metric_names) == 2:
        r = 1
        c = 2
    else:
        r = int(len(metric_names)/2)
        c = r
        if len(metric_names)%2 != 0:
            c += 1 
        
    fig, axs = plt.subplots(r, c)
    fig.set_size_inches(16, 4)
    x = np.arange(start, end+increment, increment)
    cnt = 0 
    row = 0
    for i, (df, smooth_df, m) in enumerate(zip(dfs, smooth_dfs, metric_names)):
        s = None 
        if i < len(styles): 
            s = styles[i]

        axs[row, cnt].plot(x, [base[m] for x in range(len(df))], linestyle=s)
        axs[row, cnt].plot(x, [y[m] for y in df], linestyle=s)
        axs[row, cnt].plot(x, [y[m] for y in smooth_df], linestyle=s)
        if m.startswith('rbp@'):
            es = 'rbp-res@'+m[4:]
            axs[row, cnt].fill_between(x, [base[m] for y in range(len(df))], [base[es]+base[m] for y in range(len(df))], alpha=0.3)
            axs[row, cnt].fill_between(x, [y[m] for y in df], [y[es]+y[m] for y in df], alpha=0.3)
            axs[row, cnt].fill_between(x, [y[m] for y in smooth_df], [y[es]+y[m] for y in smooth_df], alpha=0.3)

        axs[row, cnt].set_ylabel(metric_names[m],fontsize=30)
        axs[row, cnt].tick_params(labelsize=15)
        axs[row, cnt].yaxis.set_major_formatter(FormatStrFormatter('%.4f'))
        cnt += 1
        if cnt >= c: 
            cnt = 0 
            row += 1 
    
    for i in range(len(y_lims)):
        plt.gcf().get_axes()[i].set_ylim(ymax=y_lims[i])
    
    if len(metric_names) % 2 != 0: 
        fig.delaxes(axs[row, -1])

    if len(names) > 0:
        fig.legend(names, bbox_to_anchor=[legend_x, legend_y], frameon=True, ncol=2, prop={"size": 15}).get_frame().set_edgecolor('black')
        
    fig.tight_layout()
    return fig
    

In [None]:
sdm_max_fig = plot_different(sdm_measures, smooth_measures, base_df, RERANK_METRICS, 1, 20, 1, names=['$R$', 'SDM', 'SDM$_{\mathrm{smooth}}$'], legend_y=0.9, legend_x=0.99)
# sdm_max_fig.savefig('figures/sdm-window.pdf')

In [None]:
sdm_index = ['filtered-phrasestop']#, 'sigir']
sdm_path = 'sdm_rerank'
sdm_dir_mu = [1050]#, 1350]

def load_sdm(index_names, qrel_paths, str_format, results_path, sdm_dir_mu, increment, window, per_query=False, clip=True):
    dfs = []
    combs = all_combinations(100, 3, increment)
    for i, ind in enumerate(index_names):
        temp = []
        for comb in combs:
            if clip and comb[0] > 0.6 or not clip:
                temp.append((comb, to_trec_df(qrel_paths[i], 
                  os.path.join(results_path, str_format.format(ind, sdm_dir_mu[i], comb[0], comb[1], comb[2], window)), rel_levels[i], per_query=per_query)))
        
        dfs.append(temp)
    
    return dfs 

# sdm_format = '{0}-sdm-dir-mu-{1:.2f}-weights-{2:.2f}-{3:.2f}-{4:.2f}-window-{5}.run'

# sdm_dfs = load_sdm(sdm_index, [config.AUS_QREL_PATH], 'case-topics-{0}-sdm_rerank-dir-mu-{1:.2f}-weights-{2:.2f}-{3:.2f}-{4:.2f}-window-{5}.run', os.path.join(BASE_DIR, sdm_path), sdm_dir_mu, 0.05, 11)

In [None]:
mu = 1050
base_df = load_1d_dfs(['filtered-phrasestop'], qrel_paths, os.path.join(BASE_DIR, 'preprocessing', 'dirichlet_prior'), 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu, mu, 1)[0][0]
base_qry = load_1d_dfs(['filtered-phrasestop'], qrel_paths, os.path.join(BASE_DIR, 'preprocessing', 'dirichlet_prior'), 'case-topics-{0}-unigram_dir_mu_{1:.2f}.run', rel_levels, mu, mu, 1, per_query=True)[0][0]

In [None]:
sdm_measures = []
smooth_measures = []
for m in metric_names:
    sdm_measures.append(load_dfs(config.AUS_QREL_PATH, config.AUS_REL_LEVEL, '', [os.path.join(BASE_DIR, 'grid-search', 'sdm-'+str(x)+"-window-"+m+"-combine-max.run") for x in range(1, 21)], per_query=True))
    smooth_measures.append(load_dfs(config.AUS_QREL_PATH, config.AUS_REL_LEVEL, '', [os.path.join(BASE_DIR, 'grid-search', 'sdm-'+str(x)+"-smooth-window-"+m+"-combine-max.run") for x in range(1, 21)], per_query=True))

In [None]:
# insert plm into df
def load_plm(ind, qrel_path, str_format, results_path, rel_level, per_query=False):
    dfs = []
    # sigma
    for s in range(10, 510, 10):
        temp = []
        for mu in range(300, 2050, 50):
            temp.append((s, mu, to_trec_df(qrel_path, 
              os.path.join(results_path, str_format.format(ind, mu, s)), rel_level, per_query=per_query)))        
        dfs.append(temp)
    
    return dfs 

plm_qry_dfs = load_plm(index_names[0], qrel_paths[0], 'case-topics-{0}-plm-dir-mu-{1:.2f}-sigma-{2:.2f}.run', os.path.join(BASE_DIR, 'plm'), '1', per_query=True)
flattened = [x[2] for y in plm_qry_dfs for x in y]

In [None]:
for ab, runs in zip(['PLM'], [flattened]):
    cross = cross_validation(runs, tt_folds, RERANK_METRICS, base_qry)
    cv_df.loc[ab] = cross[0]

In [None]:
cv_df.loc['$R$'] = ['{:.4f}'.format(base_df[m]) for m in RERANK_METRICS]
cv_df = cv_df.reindex(['$R$', 'PLM', 'SDM', 'SDM\subscript{smooth}'])
# write_table('tables/ausnl-bbow', bold_max(cv_df).rename(config.METRIC_NAMES, axis='columns').to_latex(escape=False))

In [None]:
def plot_qry_diff(runs, qry_df, metrics):
    _max = [0.0] * len(metrics)
    max_inds = [0] * len(metrics)
    qry_res = pd.DataFrame()

    print(len(runs), len(metrics))

    # print(runs[0][0].mean)
    
    for i, (metric_runs, m) in enumerate(zip(runs, metrics)):
        for j, run in enumerate(metric_runs):            
            v = run[m].mean()
            if v > _max[i]:
                _max[i] = v
                max_inds[i] = j

    for i, m in enumerate(metrics):
        qry_res[m] = runs[i][max_inds[i]][m]
                        
    qry_res.sort_index(inplace=True)
    qry_comp_df = qry_res-qry_df
    qry_comp_fig = qry_comp_df[metrics.keys()].rename(metrics, axis='columns').plot.box(fontsize=15, boxprops=dict(linestyle='-', linewidth=2), medianprops=dict(linestyle='-', linewidth=2), color=dict(boxes='black', whiskers='black', medians='b', caps='r'),figsize=(16, 4)).axhline(y=0, xmin=0.0, xmax=1.0, linestyle='--', linewidth=1.0, color='grey')
    return qry_comp_fig

metrics = copy.copy(config.METRIC_NAMES)
del metrics['unjudged@20']
sdm_qry_diff = plot_qry_diff(smooth_measures, base_qry, {'err@20': 'ERR', 'rbp@0.80': 'RBP'})
# sdm_qry_diff.get_figure().savefig('figures/ausnl-sdm-qry-diff.pdf')

In [None]:
sdm_dfs = load_sdm(sdm_index, [config.AUS_QREL_PATH], 'case-topics-{0}-sdm_rerank-smooth-dir-mu-{1:.2f}-weights-{2:.2f}-{3:.2f}-{4:.2f}-window-{5}.run', os.path.join(BASE_DIR, sdm_path), sdm_dir_mu, 0.05, 8, clip=False)

In [None]:
def plot_sdm_1d(dfs, index_names, metric_names):
    r = int(len(metric_names)/2)
    c = r
    r -= 1
#     if len(metric_names)%2 != 0:
#         c += 1 
    
    fig, axs = plt.subplots(r, c, subplot_kw=dict(projection='3d'))
    fig.set_size_inches(16, 8)
    
    xs = np.array([x[0][0] for x in dfs])
    ys = np.array([x[0][1] for x in dfs])

    verts = list(zip(xs, ys))
    cnt = 0 
    row = 0 
    for m in metric_names:
        print(row, cnt)
        zs = np.array([x[1][m] for x in dfs])
#         print(zs.shape)
#                 print(sorted(list(zip(zs, verts))))

        axs[row, cnt].scatter(xs, ys, zs)

        axs[row, cnt].tick_params(labelsize=10)
        axs[row, cnt].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        axs[row, cnt].xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        axs[row, cnt].zaxis.set_major_formatter(FormatStrFormatter('%.4f'))

        axs[row, cnt].set_zlabel(metric_names[m],fontsize=20, rotation=90)

        axs[row, cnt].view_init(20, 120)
        cnt += 1 
        if cnt >= c: 
            cnt = 0 
            row += 1 
            
    if len(metric_names) % 2 != 0: 
        fig.delaxes(axs[row, -1])
                
#     fig.subplots_adjust(wspace=0.07, hspace=0.05, left=0.01, right=1.7)

    fig.tight_layout()
    return fig 

# def plot_sdm(dfs, index_names, metric_names):
    
#     fig, axs = plt.subplots(len(metric_names), len(index_names), subplot_kw=dict(projection='3d'))
#     fig.set_size_inches(16, 8)
#     xs = np.array([x[0][0] for x in dfs])
#     ys = np.array([x[0][1] for x in dfs])

#     verts = list(zip(xs, ys))
#     cnt = 0 
#     for j, m in enumerate(dfs[i][0][1].index):
#         if m in metric_names:
#             zs = np.array([x[1][m] for x in dfs])
# #                 print(sorted(list(zip(zs, verts))))

#             axs[cnt, i].scatter(xs, ys, zs)

#             axs[cnt, i].tick_params(labelsize=10)
#             axs[cnt, i].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
#             axs[cnt, i].xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
#             axs[cnt, i].zaxis.set_major_formatter(FormatStrFormatter('%.2f'))

#             if i == 0:
#                 axs[cnt, i].set_zlabel(metric_names[m],fontsize=20)

#             axs[cnt, i].view_init(30, 120)
#             cnt += 1 
            
#     fig.tight_layout()

rerank_metrics = copy.copy(config.METRIC_NAMES)
del rerank_metrics['recall_100']
    
sdm_fig = plot_sdm_1d(sdm_dfs[0], display_names[0], rerank_metrics)
# sdm_fig.savefig('figures/aus-sdm-tune.pdf')

In [None]:
sdm_1d_dfs = load_1d_dfs(index_names[:1], qrel_paths, os.path.join(BASE_DIR, sdm_path), 'case-topics-{0}-sdm_rerank-dir-mu-1050.00-weights-0.80-0.10-0.10-window-{1}.run', rel_levels, 6, 20, 1)

In [None]:
sdm_1d_fig = plot_tune_1d_comp(display_names, RERANK_METRICS, sdm_1d_dfs, 6, 20, 1)

# PLM

In [None]:
sigma_start = 25
sigma_end = 300
sigma_inc = 25

plm_dfs = load_1d_dfs(index_names[:1], qrel_paths, os.path.join(BASE_DIR, 'plm'), 'case-topics-{0}-plm-dir-mu-2400.00-sigma-{1:.2f}.run', rel_levels, sigma_start, sigma_end, sigma_inc)
plm_b_dfs = load_1d_dfs(index_names[:1], qrel_paths, os.path.join(BASE_DIR, 'plm'), 'case-topics-{0}-plm-dir-mu-2400.00-sigma-{1:.2f}.run', rel_levels, sigma_start, sigma_end, sigma_inc, filtered=broad)
plm_s_dfs = load_1d_dfs(index_names[:1], qrel_paths, os.path.join(BASE_DIR, 'plm'), 'case-topics-{0}-plm-dir-mu-2400.00-sigma-{1:.2f}.run', rel_levels, sigma_start, sigma_end, sigma_inc, filtered=specific)

In [None]:
plm_fig = plot_tune_1d_comp(display_names, RERANK_METRICS, plm_dfs, sigma_start, sigma_end, sigma_inc, ylims=YLIMS)
# plm_fig.savefig('figures/plm-tuning.pdf')