# Analyze summaries of GPT-3 prompting experiments

## Load libraries

In [62]:
import os

from typing import Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import seaborn as sns

In [2]:
RESULT_DIRECTORY = 'results/arithmetics'
SUMMARY_FILENAME = 'summary.csv'
IMAGES = 'images/'

def get_summary(directory: str) -> str:
    return pd.read_csv(os.path.join(RESULT_DIRECTORY, directory, SUMMARY_FILENAME))

base = get_summary('base')
# prompts = get_summary('prompt')
# cross_tag_prompts = get_summary('prompt_across_tags')
prompt_answers = get_summary('answer_to_prompt')
cross_tag_prompt_answers = get_summary('answer_to_prompt_across_tags')

In [3]:
base

Unnamed: 0,source,model,tag,helper_tag,n_shots,accuracy,n_correct,total
0,results/arithmetics/base/ada_2D+_0_100.jsonl,ada,2D+,2D+,0,0.01,1,100
1,results/arithmetics/base/ada_2D+_1_100.jsonl,ada,2D+,2D+,1,0.03,3,100
2,results/arithmetics/base/ada_2D+_2_100.jsonl,ada,2D+,2D+,2,0.00,0,100
3,results/arithmetics/base/ada_2D+_3_100.jsonl,ada,2D+,2D+,3,0.00,0,100
4,results/arithmetics/base/ada_2D-_0_100.jsonl,ada,2D-,2D-,0,0.00,0,100
...,...,...,...,...,...,...,...,...
132,results/arithmetics/base/davinci_5D-_1_100_4D+...,davinci,5D-,4D+,1,0.07,7,100
133,results/arithmetics/base/davinci_5D-_1_100_4D-...,davinci,5D-,4D-,1,0.13,13,100
134,results/arithmetics/base/davinci_5D-_1_100_5D+...,davinci,5D-,5D+,1,0.10,10,100
135,results/arithmetics/base/davinci_5D-_2_100.jsonl,davinci,5D-,5D-,2,0.13,13,100


In [4]:
base['std'] = np.sqrt(base['accuracy'] * (1 - base['accuracy']) / base['total'])
base['scaled_std'] = base['std'] * 100

In [5]:
base = base.query('model == "davinci"')
base

Unnamed: 0,source,model,tag,helper_tag,n_shots,accuracy,n_correct,total,std,scaled_std
7,results/arithmetics/base/davinci_1DC_0_100.jsonl,davinci,1DC,1DC,0,0.16,16,100,0.036661,3.666061
8,results/arithmetics/base/davinci_1DC_1_100.jsonl,davinci,1DC,1DC,1,0.32,32,100,0.046648,4.664762
9,results/arithmetics/base/davinci_1DC_1_100_2D+...,davinci,1DC,2D+,1,0.33,33,100,0.047021,4.702127
10,results/arithmetics/base/davinci_1DC_1_100_2D-...,davinci,1DC,2D-,1,0.24,24,100,0.042708,4.270831
11,results/arithmetics/base/davinci_1DC_1_100_2Dx...,davinci,1DC,2Dx,1,0.30,30,100,0.045826,4.582576
...,...,...,...,...,...,...,...,...,...,...
132,results/arithmetics/base/davinci_5D-_1_100_4D+...,davinci,5D-,4D+,1,0.07,7,100,0.025515,2.551470
133,results/arithmetics/base/davinci_5D-_1_100_4D-...,davinci,5D-,4D-,1,0.13,13,100,0.033630,3.363034
134,results/arithmetics/base/davinci_5D-_1_100_5D+...,davinci,5D-,5D+,1,0.10,10,100,0.030000,3.000000
135,results/arithmetics/base/davinci_5D-_2_100.jsonl,davinci,5D-,5D-,2,0.13,13,100,0.033630,3.363034


In [6]:
cols = ['1DC', '2Dx'] + [f'{n_digits}D+' for n_digits in range(2, 6)] + [f'{n_digits}D-' for n_digits in range(2, 6)]

In [59]:
plus_modes = [col for col in cols if col.endswith('+')]
minus_modes = [col for col in cols if col.endswith('-')]
normal_modes = plus_modes + minus_modes

### 0-shot

In [7]:
zero_shot_accuracies = base.query('n_shots == 0')[['tag', 'n_correct', 'scaled_std']].set_index('tag').loc[cols]

In [8]:
zero_shot_accuracies

Unnamed: 0_level_0,n_correct,scaled_std
tag,Unnamed: 1_level_1,Unnamed: 2_level_1
1DC,16,3.666061
2Dx,35,4.769696
2D+,91,2.861818
3D+,49,4.999
4D+,17,3.756328
5D+,10,3.0
2D-,48,4.995998
3D-,30,4.582576
4D-,10,3.0
5D-,5,2.179449


In [10]:
zero_shot_accuracies['mode'] = zero_shot_accuracies.apply(lambda row: row.name[-1], axis=1)
zero_shot_accuracies['n_digits'] = zero_shot_accuracies.apply(lambda row: int(row.name[0]), axis=1)

In [11]:
zero_shot_accuracies

Unnamed: 0_level_0,n_correct,scaled_std,mode,n_digits
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1DC,16,3.666061,C,1
2Dx,35,4.769696,x,2
2D+,91,2.861818,+,2
3D+,49,4.999,+,3
4D+,17,3.756328,+,4
5D+,10,3.0,+,5
2D-,48,4.995998,-,2
3D-,30,4.582576,-,3
4D-,10,3.0,-,4
5D-,5,2.179449,-,5


In [12]:
zero_shot_accuracies_plotly = zero_shot_accuracies[['n_correct', 'mode', 'n_digits', 'scaled_std']].loc[normal_modes]
zero_shot_accuracies_plotly

Unnamed: 0_level_0,n_correct,mode,n_digits,scaled_std
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2D+,91,+,2,2.861818
3D+,49,+,3,4.999
4D+,17,+,4,3.756328
5D+,10,+,5,3.0
2D-,48,-,2,4.995998
3D-,30,-,3,4.582576
4D-,10,-,4,3.0
5D-,5,-,5,2.179449


In [13]:
zero_shot_pivot_std = zero_shot_accuracies.loc[normal_modes][[
    'scaled_std', 'mode', 'n_digits']].pivot(
    index='n_digits', columns='mode', values='scaled_std')
zero_shot_pivot_std

mode,+,-
n_digits,Unnamed: 1_level_1,Unnamed: 2_level_1
2,2.861818,4.995998
3,4.999,4.582576
4,3.756328,3.0
5,3.0,2.179449


In [14]:
# Reference: https://stackoverflow.com/questions/69587547/continuous-error-band-with-plotly-express-in-python
def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig

In [33]:
zero_shot_accuracies_plotly['Mode'] = zero_shot_accuracies_plotly['mode'].map({'+': 'Addition', '-': 'Subtraction'})
zero_shot_accuracies_plotly[['Mode']]

Unnamed: 0_level_0,Mode
tag,Unnamed: 1_level_1
2D+,Addition
3D+,Addition
4D+,Addition
5D+,Addition
2D-,Subtraction
3D-,Subtraction
4D-,Subtraction
5D-,Subtraction


In [110]:
fig = line(
    data_frame=zero_shot_accuracies_plotly,
    x='n_digits',
    y='n_correct',
    text='n_correct',
    color='Mode',
    error_y='scaled_std',
    error_y_mode='band',
    title='Accuracy (%) v.s. number of digits (D)',
    markers='.',
)
fig.update_traces(textposition='top right')

fig.add_trace(
    go.Scatter(
        x=[1],
        y=[zero_shot_accuracies.loc['1DC', 'n_correct']],
        text=[zero_shot_accuracies.loc['1DC', 'n_correct']],
        textposition='middle right',
        mode="lines+markers+text",
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[zero_shot_accuracies.loc['1DC', 'scaled_std']],
            visible=True),
        name='Composite',
    ))
fig.add_trace(
    go.Scatter(
        x=[2],
        y=[zero_shot_accuracies.loc['2Dx', 'n_correct']],
        text=[zero_shot_accuracies.loc['2Dx', 'n_correct']],
        textposition='middle right',
        mode="lines+markers+text",
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=[zero_shot_accuracies.loc['2Dx', 'scaled_std']],
            visible=True),
        name='Multiplication',
        fillcolor='green',
    ))


fig.update_layout(
    template='ggplot2',
    yaxis_title='Accuracy (%)',
    xaxis_title='Number of digits (D)',
    title='0-shot prompting accuracy (%) for D-digit arithmetics',
    legend_title='Arithmetics',
    hovermode="x",
    width=550,
    height=500,
)

fig.show()

In [58]:
fig.write_image(IMAGES + '/0_shot.pdf')

## K-shot (K = 0, 1, 2, 3)

In [70]:
base.query('n_shots == 1')[['tag', 'n_correct', 'scaled_std']].set_index('tag').loc[cols]

Unnamed: 0_level_0,n_correct,scaled_std
tag,Unnamed: 1_level_1,Unnamed: 2_level_1
1DC,32,4.664762
1DC,33,4.702127
1DC,24,4.270831
1DC,30,4.582576
1DC,25,4.330127
...,...,...
5D-,6,2.374868
5D-,9,2.861818
5D-,7,2.551470
5D-,13,3.363034


In [79]:
def get_plotly_data(n_shots: int) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if n_shots == 1:
        data = base.query('tag == helper_tag & n_shots == @n_shots')[[
            'tag', 'n_correct', 'scaled_std']].set_index('tag').loc[cols]
    else:
        data = base.query('n_shots == @n_shots')[[
            'tag', 'n_correct', 'scaled_std']].set_index('tag').loc[cols]
    data['mode'] = zero_shot_accuracies.apply(lambda row: row.name[-1], axis=1)
    data['n_digits'] = zero_shot_accuracies.apply(lambda row: int(row.name[0]), axis=1)
    
    data_plotly = data[['n_correct', 'mode', 'n_digits', 'scaled_std']].loc[normal_modes]
    data_plotly['Mode'] = data_plotly['mode'].map({'+': 'Addition', '-': 'Subtraction'})

    return data, data_plotly

In [80]:
k_shot_accuracies = {
    n_shots: get_plotly_data(n_shots)
    for n_shots in range(4)}

In [357]:
one_step_figs = {}

for n_shots in range(4):
    data, data_plotly = k_shot_accuracies[n_shots]
    
    one_step_figs[n_shots] = line(
        data_frame=data_plotly,
        x='n_digits',
        y='n_correct',
        text='n_correct',
        color='Mode',
        error_y='scaled_std',
        error_y_mode='band',
        title='Accuracy (%) v.s. number of digits (D)',
        markers='.',
    )
    one_step_figs[n_shots].update_traces(textposition='top right')

    one_step_figs[n_shots].add_trace(
        go.Scatter(
            x=[1],
            y=[data.loc['1DC', 'n_correct']],
            text=[data.loc['1DC', 'n_correct']],
            textposition='middle right',
            mode="lines+markers+text",
            marker=dict(color='brown'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['1DC', 'scaled_std']],
                visible=True),
            name='Composite',
        ))
    one_step_figs[n_shots].add_trace(
        go.Scatter(
            x=[2],
            y=[data.loc['2Dx', 'n_correct']],
            text=[data.loc['2Dx', 'n_correct']],
            textposition='middle left',
            mode="lines+markers+text",
            marker=dict(color='green'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['2Dx', 'scaled_std']],
                visible=True),
            name='Multiplication',
        ))


    one_step_figs[n_shots].update_layout(
        template='ggplot2',
        yaxis_title='Accuracy (%)',
        xaxis_title='Number of digits (D)',
        title=f'{n_shots}-shot prompting accuracy (%) for D-digit arithmetics',
        legend_title='Arithmetics',
        hovermode="x",
        width=550,
        height=500,
    )

In [358]:
for n_shots, fig in one_step_figs.items():
    fig.show()

In [359]:
len(one_step_figs)

4

In [361]:
grand_fig = make_subplots(rows=1, cols=len(one_step_figs), shared_yaxes=True,
                          x_title='Number of digits (D)',
                          y_title='Accuracy (%)',
                          horizontal_spacing=0.01,
                          subplot_titles=[f'{n_shots}-shot' for n_shots in range(4)]) 

for index, (n_shots, fig) in enumerate(one_step_figs.items()):
    fig.update_layout(xaxis_range=[0.7, 5.6])
    for trace in fig['data']:
        if index != 0:
            trace['showlegend'] = False
        grand_fig.append_trace(trace, row=1, col=index + 1)
        
xaxis_range = [0.7, 5.5]
        
grand_fig.update_layout(
    title='K-shot (K=0,1,2,3) same-task prompting accuracy (%) for D-digit arithmetics',
    legend_title='Arithmetics',
    hovermode="x",
    xaxis1_range=xaxis_range,
    xaxis2_range=xaxis_range,
    xaxis3_range=xaxis_range,
    xaxis4_range=xaxis_range,
    width=1050,
    height=500,
)
grand_fig.show()

In [165]:
grand_fig.write_image(IMAGES + '/k_shot_same_task.pdf')

## Diminishing returns from adding more shots

In [171]:
accuracy = pd.DataFrame(index=cols, columns=[1, 2, 3])
accuracy_std = pd.DataFrame(index=cols, columns=[1, 2, 3])
accuracy.index.name = 'tag'
accuracy_std.index.name = 'tag'

In [175]:
for n_shots in df.columns:
    key = int(n_shots)
    accuracy[n_shots] = k_shot_accuracies[key][0]['n_correct'] - k_shot_accuracies[key - 1][0]['n_correct']
    accuracy_std[n_shots] = np.sqrt(k_shot_accuracies[key][0]['scaled_std']**2 + k_shot_accuracies[key - 1][0]['scaled_std']**2)

In [176]:
accuracy

Unnamed: 0_level_0,1,2,3
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1DC,16,-2,3
2Dx,17,-11,5
2D+,5,-5,-3
3D+,27,-1,2
4D+,20,-9,-1
5D+,-2,-1,0
2D-,1,-1,-3
3D-,15,2,-3
4D-,15,-1,-2
5D-,7,1,-6


In [177]:
accuracy_std

Unnamed: 0_level_0,1,2,3
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1DC,5.932959,6.539113,6.565821
2Dx,6.907243,7.010706,7.002143
2D+,3.468429,3.468429,4.330127
3D+,6.574952,6.08194,6.038212
4D+,6.117189,6.593178,6.31427
5D+,4.04475,3.724245,3.608324
2D-,7.067531,7.067531,7.050532
3D-,6.763875,7.046985,7.039176
4D-,5.267827,6.08194,5.94979
5D-,3.9128,4.676537,4.221374


In [180]:
accuracy_merged = accuracy.merge(accuracy_std, left_index=True, right_index=True, suffixes=(None, '_std'))
accuracy_merged

Unnamed: 0_level_0,1,2,3,1_std,2_std,3_std
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
1DC,16,-2,3,5.932959,6.539113,6.565821
2Dx,17,-11,5,6.907243,7.010706,7.002143
2D+,5,-5,-3,3.468429,3.468429,4.330127
3D+,27,-1,2,6.574952,6.08194,6.038212
4D+,20,-9,-1,6.117189,6.593178,6.31427
5D+,-2,-1,0,4.04475,3.724245,3.608324
2D-,1,-1,-3,7.067531,7.067531,7.050532
3D-,15,2,-3,6.763875,7.046985,7.039176
4D-,15,-1,-2,5.267827,6.08194,5.94979
5D-,7,1,-6,3.9128,4.676537,4.221374


In [267]:
data = accuracy_merged.reset_index()
accuracy = pd.melt(data, id_vars=['tag'], value_vars=[1, 2, 3]).rename(
    columns={'variable': 'n_shots', 'value': 'n_correct'})
accuracy_std = pd.melt(data, id_vars=['tag'], value_vars=[f'{n_shots}_std' for n_shots in range(1, 4)]).rename(
    columns={'variable': 'n_shots', 'value': 'scaled_std'})
accuracy_std['n_shots'] = accuracy_std['n_shots'].apply(lambda elem: int(elem[0]))

display(accuracy)
display(accuracy_std)

merge_cols = ['tag', 'n_shots']
merged = accuracy.merge(accuracy_std, left_on=merge_cols, right_on=merge_cols)
display(merged)

Unnamed: 0,tag,n_shots,n_correct
0,1DC,1,16
1,2Dx,1,17
2,2D+,1,5
3,3D+,1,27
4,4D+,1,20
5,5D+,1,-2
6,2D-,1,1
7,3D-,1,15
8,4D-,1,15
9,5D-,1,7


Unnamed: 0,tag,n_shots,scaled_std
0,1DC,1,5.932959
1,2Dx,1,6.907243
2,2D+,1,3.468429
3,3D+,1,6.574952
4,4D+,1,6.117189
5,5D+,1,4.04475
6,2D-,1,7.067531
7,3D-,1,6.763875
8,4D-,1,5.267827
9,5D-,1,3.9128


Unnamed: 0,tag,n_shots,n_correct,scaled_std
0,1DC,1,16,5.932959
1,2Dx,1,17,6.907243
2,2D+,1,5,3.468429
3,3D+,1,27,6.574952
4,4D+,1,20,6.117189
5,5D+,1,-2,4.04475
6,2D-,1,1,7.067531
7,3D-,1,15,6.763875
8,4D-,1,15,5.267827
9,5D-,1,7,3.9128


In [268]:
figs = {}

data = merged
data['scaled_std'] *= 0.3

tags = {
    'Addition': [tag for tag in cols if tag.endswith('+')],
    'Subtraction': [tag for tag in cols if tag.endswith('-')],
    'Complex': [tag for tag in cols if tag.endswith('C') or tag.endswith('x')],
}

for index, mode in enumerate(tags.keys()):
    focus_tags = tags[mode]
    figs[mode] = line(
        data_frame=data.query('tag in @focus_tags'),
        x='n_shots',
        y='n_correct',
        text='n_correct',
        color='tag',
        error_y='scaled_std',
        error_y_mode='band',
        title='Marginal return (%) v.s. number of shots (K=1,2,3)',
        markers='.',
    )
    figs[mode].add_hline(y=0, line_width=2, line_dash='dash', line_color='black', opacity=0.5)
    figs[mode].update_traces(textposition='top right')

    figs[mode].update_layout(
        template='ggplot2',
        yaxis_title='Marginal return (%)',
        xaxis_title='Number of shots (K)',
        title=f'Marginal return (%) from adding the K-th shot (K=1,2,3)',
        legend_title='Arithmetics',
        hovermode="x",
        width=550,
        height=500,
        xaxis = dict(
            tickmode='linear',
            tick0=1,
            dtick=1
        )
    )

In [269]:
for mode, fig in figs.items():
    fig.show()

In [285]:
grand_fig = make_subplots(rows=1, cols=len(figs), shared_yaxes=True,
                          x_title='Number of shots (K)',
                          y_title='Marginal return on accuracy (%)',
                          horizontal_spacing=0.01,
                          subplot_titles=['Addition (2-5D+)', 'Subtraction (2-5D-)', 'Complex (1DC, 2Dx)'])

for index, (mode, fig) in enumerate(figs.items()):
    for trace in fig['data']:
        grand_fig.append_trace(trace, row=1, col=index + 1)
        
grand_fig.add_hline(y=0, line_width=2, line_dash='dash', line_color='black', opacity=0.5)
        
xaxis_config = dict(
    tickmode='linear',
    tick0=1,
    dtick=1
)

xaxis_range = [0.8, 3.2]
        
grand_fig.update_layout(
    title='Marginal return on accuracy (%) from adding the K-th same-task prompt (K=1,2,3)',
    legend_title='Arithmetics',
    hovermode="x",
    width=900,
    height=500,
    xaxis1=xaxis_config,
    xaxis2=xaxis_config,
    xaxis3=xaxis_config,
    xaxis1_range=xaxis_range,
    xaxis2_range=xaxis_range,
    xaxis3_range=xaxis_range,
)
grand_fig.show()

In [286]:
grand_fig.write_image(IMAGES + '/k_shot_marginal_return.pdf')

## K-shot cross-task prompts

In [292]:
accuracy_pivot = base.query('n_shots == 1')[[
    'tag', 'helper_tag', 'n_correct'
]].pivot(index='tag', columns='helper_tag', values='n_correct').loc[cols][cols]
accuracy_pivot

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,32,30,33,25,26,31,24,24,29,23
2Dx,37,52,36,47,50,49,38,43,41,40
2D+,93,96,96,97,95,97,97,98,95,99
3D+,62,76,63,76,79,72,55,65,73,78
4D+,18,25,20,30,37,25,23,27,33,27
5D+,6,8,8,10,9,8,4,9,9,11
2D-,50,49,49,50,47,49,49,48,46,47
3D-,48,46,48,48,45,45,46,45,43,41
4D-,15,18,15,22,22,17,22,18,25,19
5D-,8,8,7,6,7,10,6,9,13,12


In [293]:
accuracy_std_pivot = base.query('n_shots == 1')[[
    'tag', 'helper_tag', 'scaled_std'
]].pivot(index='tag', columns='helper_tag', values='scaled_std').loc[cols][cols]
accuracy_std_pivot

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,4.664762,4.582576,4.702127,4.330127,4.386342,4.624932,4.270831,4.270831,4.537621,4.208325
2Dx,4.828043,4.995998,4.8,4.990992,5.0,4.999,4.853864,4.950758,4.918333,4.898979
2D+,2.55147,1.959592,1.959592,1.705872,2.179449,1.705872,1.705872,1.4,2.179449,0.994987
3D+,4.853864,4.270831,4.828043,4.270831,4.073082,4.489989,4.974937,4.769696,4.439595,4.142463
4D+,3.841875,4.330127,4.0,4.582576,4.828043,4.330127,4.208325,4.439595,4.702127,4.439595
5D+,2.374868,2.712932,2.712932,3.0,2.861818,2.712932,1.959592,2.861818,2.861818,3.128898
2D-,5.0,4.999,4.999,5.0,4.990992,4.999,4.999,4.995998,4.983974,4.990992
3D-,4.995998,4.983974,4.995998,4.995998,4.974937,4.974937,4.983974,4.974937,4.950758,4.918333
4D-,3.570714,3.841875,3.570714,4.142463,4.142463,3.756328,4.142463,3.841875,4.330127,3.923009
5D-,2.712932,2.712932,2.55147,2.374868,2.55147,3.0,2.374868,2.861818,3.363034,3.249615


In [297]:
accuracy_pivot.style.background_gradient(cmap='coolwarm', axis=1)

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,32,30,33,25,26,31,24,24,29,23
2Dx,37,52,36,47,50,49,38,43,41,40
2D+,93,96,96,97,95,97,97,98,95,99
3D+,62,76,63,76,79,72,55,65,73,78
4D+,18,25,20,30,37,25,23,27,33,27
5D+,6,8,8,10,9,8,4,9,9,11
2D-,50,49,49,50,47,49,49,48,46,47
3D-,48,46,48,48,45,45,46,45,43,41
4D-,15,18,15,22,22,17,22,18,25,19
5D-,8,8,7,6,7,10,6,9,13,12


In [309]:
import imgkit
html = accuracy_pivot.style.background_gradient(cmap='coolwarm', axis=1).render()
imgkit.from_string(html, IMAGES + '/cross_task_heatmap_taskwise.jpg')


this method is deprecated in favour of `Styler.to_html()`



Loading page (1/2)


True

In [303]:
fig = px.imshow(accuracy_pivot, text_auto=True)
fig.update_layout(
    xaxis_title='Helper task',
    yaxis_title='Target task',
    legend_title='Accuracy (%)',
    title='Accuracy (%) from 1-shot cross-task prompts',
    height=500,
    width=500
)

In [304]:
fig.write_image(IMAGES + '/cross_task_1_shot_heatmap.pdf')

# Results from helpful prompts

## Baseline

In [353]:
prompt_answers['std'] = prompt_answers.apply(
    lambda row: np.sqrt(row['accuracy'] * (1 - row['accuracy']) / row['total']), axis=1)
prompt_answers['scaled_std'] = prompt_answers['std'] * 100
prompt_answers['n_digits'] = prompt_answers['tag'].apply(lambda elem: int(elem[0]))
prompt_answers['mode'] = prompt_answers['tag'].apply(lambda elem: elem[-1])
prompt_answers['Mode'] = prompt_answers['mode'].map({
    '+': 'Addition',
    '-': 'Subtraction',
    'x': 'Multiplication',
    'C': 'Composite',
})
prompt_answers

Unnamed: 0,source,model,tag,helper_tag,n_shots,accuracy,n_correct,total,std,scaled_std,n_digits,mode,Mode
0,results/arithmetics/answer_to_prompt/davinci_1...,davinci,1DC,1DC,0,0.52,52,100,0.04996,4.995998,1,C,Composite
1,results/arithmetics/answer_to_prompt/davinci_1...,davinci,1DC,1DC,1,0.46,46,100,0.04984,4.983974,1,C,Composite
2,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2D+,2D+,0,0.99,99,100,0.00995,0.994987,2,+,Addition
3,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2D+,2D+,1,0.95,95,100,0.021794,2.179449,2,+,Addition
4,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2D-,2D-,0,0.48,48,100,0.04996,4.995998,2,-,Subtraction
5,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2D-,2D-,1,0.46,46,100,0.04984,4.983974,2,-,Subtraction
6,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2Dx,2Dx,0,0.41,41,100,0.049183,4.918333,2,x,Multiplication
7,results/arithmetics/answer_to_prompt/davinci_2...,davinci,2Dx,2Dx,1,0.46,46,100,0.04984,4.983974,2,x,Multiplication
8,results/arithmetics/answer_to_prompt/davinci_3...,davinci,3D+,3D+,0,0.64,64,100,0.048,4.8,3,+,Addition
9,results/arithmetics/answer_to_prompt/davinci_3...,davinci,3D+,3D+,1,0.75,75,100,0.043301,4.330127,3,+,Addition


In [363]:
two_step_figs = {}

for n_shots in range(2):
    data = prompt_answers.query('n_shots == @n_shots')
    
    two_step_figs[n_shots] = line(
        data_frame=data[data.tag.str.endswith('+') | data.tag.str.endswith('-')],
        x='n_digits',
        y='n_correct',
        text='n_correct',
        color='Mode',
        error_y='scaled_std',
        error_y_mode='band',
        title='Accuracy (%) v.s. number of digits (D)',
        markers='.',
    )
    two_step_figs[n_shots].update_traces(textposition='top right')
    
    data = data.set_index('tag')

    two_step_figs[n_shots].add_trace(
        go.Scatter(
            x=[1],
            y=[data.loc['1DC', 'n_correct']],
            text=[data.loc['1DC', 'n_correct']],
            textposition='middle right',
            mode="lines+markers+text",
            marker=dict(color='brown'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['1DC', 'scaled_std']],
                visible=True),
            name='Composite',
        ))
    two_step_figs[n_shots].add_trace(
        go.Scatter(
            x=[2],
            y=[data.loc['2Dx', 'n_correct']],
            text=[data.loc['2Dx', 'n_correct']],
            textposition='middle left',
            mode="lines+markers+text",
            marker=dict(color='green'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['2Dx', 'scaled_std']],
                visible=True),
            name='Multiplication',
        ))


    two_step_figs[n_shots].update_layout(
        template='ggplot2',
        yaxis_title='Accuracy (%)',
        xaxis_title='Number of digits (D)',
        title=f'{n_shots}-shot 2-step prompting accuracy (%) for D-digit arithmetics',
        legend_title='Arithmetics',
        hovermode="x",
        width=550,
        height=500,
    )

In [364]:
for n_shots, fig in two_step_figs.items():
    fig.show()

In [365]:
grand_fig = make_subplots(rows=1, cols=len(two_step_figs), shared_yaxes=True,
                          x_title='Number of digits (D)',
                          y_title='Accuracy (%)',
                          horizontal_spacing=0.01,
                          subplot_titles=[f'{n_shots}-shot' for n_shots in range(2)]) 

for index, (n_shots, fig) in enumerate(two_step_figs.items()):
    for trace in fig['data']:
        if index != 0:
            trace['showlegend'] = False
        grand_fig.append_trace(trace, row=1, col=index + 1)
        
xaxis_range = [0.7, 5.5]
        
grand_fig.update_layout(
    title='K-shot (K=0,1,2,3) same-task 2-step prompting accuracy (%)<br><sup>for D-digit arithmetics</sup>',
    legend_title='Arithmetics',
    hovermode="x",
    xaxis1_range=xaxis_range,
    xaxis2_range=xaxis_range,
    width=700,
    height=500,
)
grand_fig.show()

In [332]:
grand_fig.write_image(IMAGES + '/2_step_same_task.pdf')

In [333]:
# TODO: Add a top row showing 0-shot and 1-shot perf for 1-step prompting

## Ablation studies: Impact of 2-step same-task prompting over 1-step prompting (0,1-shot)

In [362]:
for n_shots, fig in one_step_figs.items():
    fig.show()

In [366]:
for n_shots, fig in two_step_figs.items():
    fig.show()

In [371]:
base

Unnamed: 0,source,model,tag,helper_tag,n_shots,accuracy,n_correct,total,std,scaled_std
7,results/arithmetics/base/davinci_1DC_0_100.jsonl,davinci,1DC,1DC,0,0.16,16,100,0.036661,3.666061
8,results/arithmetics/base/davinci_1DC_1_100.jsonl,davinci,1DC,1DC,1,0.32,32,100,0.046648,4.664762
9,results/arithmetics/base/davinci_1DC_1_100_2D+...,davinci,1DC,2D+,1,0.33,33,100,0.047021,4.702127
10,results/arithmetics/base/davinci_1DC_1_100_2D-...,davinci,1DC,2D-,1,0.24,24,100,0.042708,4.270831
11,results/arithmetics/base/davinci_1DC_1_100_2Dx...,davinci,1DC,2Dx,1,0.30,30,100,0.045826,4.582576
...,...,...,...,...,...,...,...,...,...,...
132,results/arithmetics/base/davinci_5D-_1_100_4D+...,davinci,5D-,4D+,1,0.07,7,100,0.025515,2.551470
133,results/arithmetics/base/davinci_5D-_1_100_4D-...,davinci,5D-,4D-,1,0.13,13,100,0.033630,3.363034
134,results/arithmetics/base/davinci_5D-_1_100_5D+...,davinci,5D-,5D+,1,0.10,10,100,0.030000,3.000000
135,results/arithmetics/base/davinci_5D-_2_100.jsonl,davinci,5D-,5D-,2,0.13,13,100,0.033630,3.363034


In [375]:
target_n_shots = [0, 1]

one_step = base.query('tag == helper_tag & n_shots in @target_n_shots')[[
    'tag', 'n_shots', 'n_correct', 'scaled_std', ]]
one_step['n_digits'] = one_step['tag'].apply(lambda elem: int(elem[0]))
one_step['mode'] = one_step['tag'].apply(lambda elem: elem[-1])
one_step['Mode'] = one_step['mode'].map({
    '+': 'Addition',
    '-': 'Subtraction',
    'x': 'Multiplication',
    'C': 'Composite',
})
display(one_step)

two_step = prompt_answers.query('tag == helper_tag & n_shots in @target_n_shots')[[
    'tag', 'n_shots', 'n_correct', 'scaled_std', 'n_digits', 'mode', 'Mode']]
display(two_step)

Unnamed: 0,tag,n_shots,n_correct,scaled_std,n_digits,mode,Mode
7,1DC,0,16,3.666061,1,C,Composite
8,1DC,1,32,4.664762,1,C,Composite
20,2D+,0,91,2.861818,2,+,Addition
21,2D+,1,96,1.959592,2,+,Addition
33,2D-,0,48,4.995998,2,-,Subtraction
34,2D-,1,49,4.999,2,-,Subtraction
46,2Dx,0,35,4.769696,2,x,Multiplication
47,2Dx,1,52,4.995998,2,x,Multiplication
59,3D+,0,49,4.999,3,+,Addition
60,3D+,1,76,4.270831,3,+,Addition


Unnamed: 0,tag,n_shots,n_correct,scaled_std,n_digits,mode,Mode
0,1DC,0,52,4.995998,1,C,Composite
1,1DC,1,46,4.983974,1,C,Composite
2,2D+,0,99,0.994987,2,+,Addition
3,2D+,1,95,2.179449,2,+,Addition
4,2D-,0,48,4.995998,2,-,Subtraction
5,2D-,1,46,4.983974,2,-,Subtraction
6,2Dx,0,41,4.918333,2,x,Multiplication
7,2Dx,1,46,4.983974,2,x,Multiplication
8,3D+,0,64,4.8,3,+,Addition
9,3D+,1,75,4.330127,3,+,Addition


In [378]:
diffs = {}
content_cols = ['tag', 'n_correct', 'scaled_std']
helper_cols = ['tag', 'n_digits', 'mode', 'Mode']

for n_shots in target_n_shots:
    diffs[n_shots] = one_step.query('n_shots == @n_shots')[content_cols].merge(
        two_step.query('n_shots == @n_shots')[content_cols],
        left_on='tag', right_on='tag', suffixes=('_one_step', '_two_step'))
    diffs[n_shots]['marginal_return'] = diffs[n_shots]['n_correct_two_step'] - diffs[n_shots]['n_correct_one_step']
    diffs[n_shots]['marginal_return_std'] = np.sqrt(diffs[n_shots]['scaled_std_two_step']**2 + diffs[n_shots]['scaled_std_one_step']**2)
    diffs[n_shots] = diffs[n_shots].merge(
        one_step.query('n_shots == @n_shots')[helper_cols],
        left_on='tag', right_on='tag')

for n_shots, diff in diffs.items():
    display(diff)

Unnamed: 0,tag,n_correct_one_step,scaled_std_one_step,n_correct_two_step,scaled_std_two_step,marginal_return,marginal_return_std,n_digits,mode,Mode
0,1DC,16,3.666061,52,4.995998,36,6.196773,1,C,Composite
1,2D+,91,2.861818,99,0.994987,8,3.029851,2,+,Addition
2,2D-,48,4.995998,48,4.995998,0,7.065409,2,-,Subtraction
3,2Dx,35,4.769696,41,4.918333,6,6.851277,2,x,Multiplication
4,3D+,49,4.999,64,4.8,15,6.930368,3,+,Addition
5,3D-,30,4.582576,41,4.918333,11,6.722351,3,-,Subtraction
6,4D+,17,3.756328,27,4.439595,10,5.815497,4,+,Addition
7,4D-,10,3.0,15,3.570714,5,4.66369,4,-,Subtraction
8,5D+,10,3.0,12,3.249615,2,4.422669,5,+,Addition
9,5D-,5,2.179449,6,2.374868,1,3.223352,5,-,Subtraction


Unnamed: 0,tag,n_correct_one_step,scaled_std_one_step,n_correct_two_step,scaled_std_two_step,marginal_return,marginal_return_std,n_digits,mode,Mode
0,1DC,32,4.664762,46,4.983974,14,6.826419,1,C,Composite
1,2D+,96,1.959592,95,2.179449,-1,2.93087,2,+,Addition
2,2D-,49,4.999,46,4.983974,-3,7.059037,2,-,Subtraction
3,2Dx,52,4.995998,46,4.983974,-6,7.056912,2,x,Multiplication
4,3D+,76,4.270831,75,4.330127,-1,6.08194,3,+,Addition
5,3D-,45,4.974937,36,4.8,-9,6.913031,3,-,Subtraction
6,4D+,37,4.828043,26,4.386342,-11,6.523036,4,+,Addition
7,4D-,25,4.330127,19,3.923009,-6,5.842944,4,-,Subtraction
8,5D+,8,2.712932,11,3.128898,3,4.141256,5,+,Addition
9,5D-,12,3.249615,9,2.861818,-3,4.330127,5,-,Subtraction


In [399]:
diff_figs = {}


for n_shots, diff in diffs.items():
    data = diff.set_index('tag')
    data['marginal_return_std'] *= 0.3
    display(data)
    
    diff_figs[n_shots] = line(
        data_frame=data.loc[normal_modes],
        x='n_digits',
        y='marginal_return',
        text='marginal_return',
        color='Mode',
        error_y='marginal_return_std',
        error_y_mode='band',
        title='Marginal return on 0-shot accuracy (%) from two-step same-task prompting',
        markers='.',
    )
    diff_figs[n_shots].add_hline(y=0, line_width=2, line_dash='dash', line_color='black', opacity=0.5)
    diff_figs[n_shots].update_traces(textposition='top right')

    diff_figs[n_shots].add_trace(
        go.Scatter(
            x=[1],
            y=[data.loc['1DC', 'marginal_return']],
            text=[data.loc['1DC', 'marginal_return']],
            textposition='middle right',
            mode="lines+markers+text",
            marker=dict(color='brown'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['1DC', 'marginal_return_std']],
                visible=True),
            name='Composite',
        ))
    diff_figs[n_shots].add_trace(
        go.Scatter(
            x=[2],
            y=[data.loc['2Dx', 'marginal_return']],
            text=[data.loc['2Dx', 'marginal_return']],
            textposition='middle left',
            mode="lines+markers+text",
            marker=dict(color='green'),
            error_y=dict(
                type='data', # value of error bar given in data coordinates
                array=[data.loc['2Dx', 'marginal_return_std']],
                visible=True),
            name='Multiplication',
        ))


    diff_figs[n_shots].update_layout(
        template='ggplot2',
        yaxis_title='Accuracy (%)',
        xaxis_title='Number of digits (D)',
        title=f'Marginal return on {n_shots}-shot acuracy (%) from two-step same-task prompting',
        legend_title='Arithmetics',
        hovermode="x",
        width=550,
        height=500,
    )

Unnamed: 0_level_0,n_correct_one_step,scaled_std_one_step,n_correct_two_step,scaled_std_two_step,marginal_return,marginal_return_std,n_digits,mode,Mode
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1DC,16,3.666061,52,4.995998,36,1.859032,1,C,Composite
2D+,91,2.861818,99,0.994987,8,0.908955,2,+,Addition
2D-,48,4.995998,48,4.995998,0,2.119623,2,-,Subtraction
2Dx,35,4.769696,41,4.918333,6,2.055383,2,x,Multiplication
3D+,49,4.999,64,4.8,15,2.07911,3,+,Addition
3D-,30,4.582576,41,4.918333,11,2.016705,3,-,Subtraction
4D+,17,3.756328,27,4.439595,10,1.744649,4,+,Addition
4D-,10,3.0,15,3.570714,5,1.399107,4,-,Subtraction
5D+,10,3.0,12,3.249615,2,1.326801,5,+,Addition
5D-,5,2.179449,6,2.374868,1,0.967006,5,-,Subtraction


Unnamed: 0_level_0,n_correct_one_step,scaled_std_one_step,n_correct_two_step,scaled_std_two_step,marginal_return,marginal_return_std,n_digits,mode,Mode
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1DC,32,4.664762,46,4.983974,14,2.047926,1,C,Composite
2D+,96,1.959592,95,2.179449,-1,0.879261,2,+,Addition
2D-,49,4.999,46,4.983974,-3,2.117711,2,-,Subtraction
2Dx,52,4.995998,46,4.983974,-6,2.117073,2,x,Multiplication
3D+,76,4.270831,75,4.330127,-1,1.824582,3,+,Addition
3D-,45,4.974937,36,4.8,-9,2.073909,3,-,Subtraction
4D+,37,4.828043,26,4.386342,-11,1.956911,4,+,Addition
4D-,25,4.330127,19,3.923009,-6,1.752883,4,-,Subtraction
5D+,8,2.712932,11,3.128898,3,1.242377,5,+,Addition
5D-,12,3.249615,9,2.861818,-3,1.299038,5,-,Subtraction


In [400]:
for n_shots, fig in diff_figs.items():
    fig.show()

In [415]:
grand_fig = make_subplots(rows=2, cols=3, shared_yaxes=True,
                          x_title='Number of digits (D)',
                          horizontal_spacing=0.01,
                          vertical_spacing=0.08,
                          subplot_titles=['0-shot 1-step', '0-shot 2-step', '0-shot 2nd-step marginal return',
                                          '1-shot 1-step', '1-shot 2-step', '1-shot 2nd-step marginal return'])

for index, (n_shots, fig) in enumerate(one_step_figs.items()):
    if n_shots not in target_n_shots:
        continue
    for trace in fig['data']:
        trace['showlegend'] = False
        grand_fig.append_trace(trace, row=index + 1, col=1)

for index, (n_shots, fig) in enumerate(two_step_figs.items()):
    if n_shots not in target_n_shots:
        continue
    for trace in fig['data']:
        trace['showlegend'] = False
        grand_fig.append_trace(trace, row=index + 1, col=2)
        
for index, (n_shots, fig) in enumerate(diff_figs.items()):
    for trace in fig['data']:
        if index != 0:
            trace['showlegend'] = False
        grand_fig.append_trace(trace, row=index + 1, col=3)
        grand_fig.add_hline(y=0, line_width=2, line_dash='dash', line_color='black', opacity=0.5,
                            row=index+1, col=3)
        
grand_fig.update_yaxes(title_text='0-shot accuracy (%)', row=1, col=1)
grand_fig.update_yaxes(title_text='1-shot accuracy (%)', row=2, col=1)
        
xaxis_range = [0.7, 5.5]
        
grand_fig.update_layout(
    title='Marginal return on 0/1-shot accuracy (%) of 2nd step of 2-step same-task prompting<br><sup>for D-digit arithmetics</sup>',
    legend_title='Arithmetics',
    hovermode="x",
    xaxis1_range=xaxis_range,
    xaxis2_range=xaxis_range,
    xaxis3_range=xaxis_range,
    xaxis4_range=xaxis_range,
    xaxis5_range=xaxis_range,
    xaxis6_range=xaxis_range,
    width=1000,
    height=900,
)
grand_fig.show()

In [416]:
grand_fig.write_image(IMAGES + '/2_step_marginal_return.pdf')

## Cross-task 2-step prompting

In [423]:
merge_cols = ['tag', 'helper_tag', 'n_correct']
index_cols = ['tag', 'helper_tag']
merged = pd.concat([
    cross_tag_prompt_answers[merge_cols],
    prompt_answers.query('tag == helper_tag & n_shots == 1')[merge_cols]
], axis=0).sort_values(by=index_cols)
merged['scaled_std'] = np.sqrt((merged['n_correct'] / 100) * (1 - merged['n_correct'] / 100) / 100) * 100
merged

Unnamed: 0,tag,helper_tag,n_correct,scaled_std
1,1DC,1DC,46,4.983974
0,1DC,2D+,45,4.974937
1,1DC,2D-,57,4.950758
2,1DC,2Dx,49,4.999000
3,1DC,3D+,51,4.999000
...,...,...,...,...
86,5D-,3D-,9,2.861818
87,5D-,4D+,8,2.712932
88,5D-,4D-,6,2.374868
89,5D-,5D+,6,2.374868


In [424]:
pivot = merged.pivot(index='tag', columns='helper_tag', values='n_correct')[cols].loc[cols]
pivot

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,46,49,45,51,45,44,57,45,43,46
2Dx,32,46,32,45,44,42,31,33,47,45
2D+,96,96,95,92,92,92,97,98,97,97
3D+,71,75,74,75,73,67,70,65,78,79
4D+,31,28,30,32,26,29,26,28,35,30
5D+,11,7,11,10,11,11,10,13,14,12
2D-,43,48,49,47,48,49,46,50,49,47
3D-,39,40,43,43,42,42,38,36,40,41
4D-,20,16,19,17,18,16,16,12,19,14
5D-,7,5,7,5,8,6,8,9,6,9


In [346]:
pivot.style.background_gradient(cmap='coolwarm', axis=1)

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,46,49,45,51,45,44,57,45,43,46
2Dx,32,46,32,45,44,42,31,33,47,45
2D+,96,96,95,92,92,92,97,98,97,97
3D+,71,75,74,75,73,67,70,65,78,79
4D+,31,28,30,32,26,29,26,28,35,30
5D+,11,7,11,10,11,11,10,13,14,12
2D-,43,48,49,47,48,49,46,50,49,47
3D-,39,40,43,43,42,42,38,36,40,41
4D-,20,16,19,17,18,16,16,12,19,14
5D-,7,5,7,5,8,6,8,9,6,9


In [348]:
html = pivot.style.background_gradient(cmap='coolwarm', axis=1).render()
pdfkit.from_string(html, IMAGES + '/2_step_cross_task_heatmap_taskwise.pdf')


this method is deprecated in favour of `Styler.to_html()`



True

In [350]:
fig = px.imshow(pivot, text_auto=True)
fig.update_layout(
    xaxis_title='Helper task',
    yaxis_title='Target task',
    legend_title='Accuracy (%)',
    title='Accuracy (%) from 2-step 1-shot cross-task prompts',
    height=500,
    width=500
)

In [351]:
fig.write_image(IMAGES + '/2_step_cross_task_1_shot_heatmap.pdf')

## Impact of different sampling strategies

In [425]:
pivot

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,46,49,45,51,45,44,57,45,43,46
2Dx,32,46,32,45,44,42,31,33,47,45
2D+,96,96,95,92,92,92,97,98,97,97
3D+,71,75,74,75,73,67,70,65,78,79
4D+,31,28,30,32,26,29,26,28,35,30
5D+,11,7,11,10,11,11,10,13,14,12
2D-,43,48,49,47,48,49,46,50,49,47
3D-,39,40,43,43,42,42,38,36,40,41
4D-,20,16,19,17,18,16,16,12,19,14
5D-,7,5,7,5,8,6,8,9,6,9


In [426]:
pivot_std = merged.pivot(index='tag', columns='helper_tag', values='scaled_std')[cols].loc[cols]
pivot_std

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,4.983974,4.999,4.974937,4.999,4.974937,4.963869,4.950758,4.974937,4.950758,4.983974
2Dx,4.664762,4.983974,4.664762,4.974937,4.963869,4.935585,4.624932,4.702127,4.990992,4.974937
2D+,1.959592,1.959592,2.179449,2.712932,2.712932,2.712932,1.705872,1.4,1.705872,1.705872
3D+,4.537621,4.330127,4.386342,4.330127,4.439595,4.702127,4.582576,4.769696,4.142463,4.073082
4D+,4.624932,4.489989,4.582576,4.664762,4.386342,4.537621,4.386342,4.489989,4.769696,4.582576
5D+,3.128898,2.55147,3.128898,3.0,3.128898,3.128898,3.0,3.363034,3.46987,3.249615
2D-,4.950758,4.995998,4.999,4.990992,4.995998,4.999,4.983974,5.0,4.999,4.990992
3D-,4.877499,4.898979,4.950758,4.950758,4.935585,4.935585,4.853864,4.8,4.898979,4.918333
4D-,4.0,3.666061,3.923009,3.756328,3.841875,3.666061,3.666061,3.249615,3.923009,3.46987
5D-,2.55147,2.179449,2.55147,2.179449,2.712932,2.374868,2.712932,2.861818,2.374868,2.861818


In [428]:
accuracy_pivot  # From 1-step cross-task prompting analysis

helper_tag,1DC,2Dx,2D+,3D+,4D+,5D+,2D-,3D-,4D-,5D-
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1DC,32,30,33,25,26,31,24,24,29,23
2Dx,37,52,36,47,50,49,38,43,41,40
2D+,93,96,96,97,95,97,97,98,95,99
3D+,62,76,63,76,79,72,55,65,73,78
4D+,18,25,20,30,37,25,23,27,33,27
5D+,6,8,8,10,9,8,4,9,9,11
2D-,50,49,49,50,47,49,49,48,46,47
3D-,48,46,48,48,45,45,46,45,43,41
4D-,15,18,15,22,22,17,22,18,25,19
5D-,8,8,7,6,7,10,6,9,13,12


In [431]:
cross_task_results = {}
for tag in cols:
    cross_task_results[tag] = sorted(
        (accuracy_pivot.loc[tag, helper_tag], helper_tag)
        for helper_tag in cols)

In [434]:
def get_helpers(tag, mode, top_k=1):
    if mode == 'best':
        return [helper_tag for _, helper_tag in cross_task_results[tag][-top_k:]]
    else:
        return [helper_tag for _, helper_tag in cross_task_results[tag][:top_k]]

In [432]:
results = pd.DataFrame(index=pivot.index)

In [440]:
results['best'] = pivot.apply(lambda row: row[get_helpers(row.name, 'best')[0]], axis=1)
results['worst'] = pivot.apply(lambda row: row[get_helpers(row.name, 'worst')[0]], axis=1)
results['top-3'] = pivot.apply(lambda row: np.mean(row[get_helpers(row.name, 'best', 3)]), axis=1)
results['top-5'] = pivot.apply(lambda row: np.mean(row[get_helpers(row.name, 'best', 5)]), axis=1)
results['bottom-3'] = pivot.apply(lambda row: np.mean(row[get_helpers(row.name, 'worst', 3)]), axis=1)
results['bottom-5'] = pivot.apply(lambda row: np.mean(row[get_helpers(row.name, 'worst', 5)]), axis=1)
results['uniform'] = pivot.apply(lambda row: np.mean(row[get_helpers(row.name, 'best', 10)]), axis=1)
results['same'] = pivot.apply(lambda row: row[row.name], axis=1)

In [442]:
results.style.background_gradient(cmap='coolwarm', axis=1)

Unnamed: 0_level_0,best,worst,top-3,top-5,bottom-3,bottom-5,uniform,same
tag,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
1DC,45,46,45.0,45.4,49.333333,48.8,47.1,46
2Dx,46,32,44.0,42.0,31.666667,37.4,39.7,46
2D+,97,96,95.666667,95.2,95.0,95.2,95.2,95
3D+,73,70,75.666667,76.0,71.666667,69.4,72.7,75
4D+,26,31,31.0,30.2,29.0,28.8,29.5,26
5D+,12,10,12.0,12.0,10.666667,10.0,11.0,11
2D-,47,49,46.333333,46.6,48.0,48.6,47.6,46
3D-,43,41,41.666667,40.6,39.0,40.2,40.4,36
4D-,19,20,18.0,16.8,18.333333,16.6,16.7,19
5D-,6,8,7.0,7.0,6.666667,7.0,7.0,9
