# Results

In [1]:
from os import makedirs
from os.path import join

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.legend import Legend


plt.rcParams.update({
    'text.latex.preamble': r'\usepackage{cmbright}',
    'text.usetex': True,
    'font.family': 'serif',
})
color_palette = sns.color_palette('deep')
plt.rc('grid', linestyle="dotted", color='lightgrey')

results_dir = '../rpaper'
plots_dir = 'paper_plots'

arch_name = {
    'convnext-atto': 'ConvnextAtto',
    'convnext-tiny': 'ConvNextTiny',
    'convnextv2-atto': 'Convnextv2Atto',
    'convnextv2-nano': 'Convnextv2Nano',
    'convnextv2-tiny': 'Convnextv2Tiny',
    'densenet121': 'DenseNet121',
    'densenet161': 'DenseNet161',
    'mobilenetv3-small-075': 'MobileNetV3-Small-0.75',
    'mobilenetv3-large-100': 'MobileNetV3-Large-1.0',
    'mobilevitv2-050': 'MobileViTv2-0.5',
    'mobilevitv2-100': 'MobileViTv2-1.0',
    'mobilevitv2-200': 'MobileViTv2-2.0',

}

def format_pm(s: str) -> str:
    if isinstance(s, str) and '±' in s:
        return s.replace('±', '\\pm')
    else:
        return s

makedirs(plots_dir, exist_ok=True)

### Arch

In [2]:
def load_arch_df():
    df = pd.read_csv(join(results_dir, 'arch_batch-size-48', 'exp_mtst.csv'))
    df['Arch'] = df['run']
    df['Unseen'] = df['unseen']
    df['Seen'] = df['seen']
    df['HM'] = df['hm']
    df = df[['Arch', 'Unseen', 'Seen', 'HM']]
    df = df.set_index('Arch')

    df_params_macs = pd.DataFrame(
        [
            # Efficient
            ('mobilenetv3-small-075',  1.02,  0.12, 'conv'),
            ('mobilevitv2-050',        1.11,  1.05, 'tsfm'),
            ('mobilenetv3-large-100',  4.20,  0.63, 'conv'),
            ('mobilevitv2-100',        4.39,  4.08, 'tsfm'),
            ('convnext-atto',          3.37,  1.62, 'conv'),
            # Large
            ('densenet121',           6.95,   8.33, 'conv'),
            ('densenet161',           26.47, 22.70, 'conv'),
            ('convnext-tiny',         27.82, 28.60, 'conv'),
            ('mobilevitv2-200',       17.42, 16.11, 'tsfm'),
        ],
        columns=['Arch', 'Params', 'MACs', 'CP']
    ).set_index('Arch')


    df = pd.concat([df_params_macs, df], axis=1, join="inner")
    df = df.reset_index()
    df = df.replace(arch_name)
    df = df[['Arch', 'Seen',  'Unseen', 'HM', 'Params', 'MACs', 'CP']]
    return df


arch_df = load_arch_df()
arch_df

FileNotFoundError: [Errno 2] No such file or directory: '../rpaper/arch_batch-size-48/exp_mtst.csv'

In [None]:
def generate_arch_plot(df: pd.DataFrame):

    df = df.copy()
    df['Unseen'] = df['Unseen'].str[:5].astype(float)
    df['Seen'] = df['Seen'].str[:5].astype(float)
    df['HM'] = df['HM'].str[:5].astype(float)

    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))
    ax0: Axes = ax0
    ax1: Axes = ax1

    cp_color = {'conv': color_palette[0], 'tsfm': color_palette[1]}
    colors = [cp_color[cp] for cp in df['CP']]
    # _color = color[:len(df)]

    df.plot.scatter('Params', 'HM', ax=ax0, c=colors)
    df.plot.scatter('MACs', 'HM', ax=ax1, c=colors)

    for x, y, arch in zip(df['Params'], df['HM'], df['Arch']):
        if arch == 'MobileNetV3-Large-1.0':
            ax0.text(x+0.5, y, f'{arch}', ha='left', va='top', fontsize='small')
        else:
            ax0.text(x+0.5, y, f'{arch}', ha='left', va='bottom', fontsize='small')
    xmin, xmax = ax0.get_xlim()
    ax0.set_xlim(xmin, xmax + 0.35 * (xmax - xmin))
    ax0.tick_params(axis='both', which='major', labelsize='small', )

    for x, y, arch in zip(df['MACs'], df['HM'], df['Arch']):
        if arch == 'MobileNetV3-Large-1.0':
            ax1.text(x+0.5, y, f'{arch}', ha='left', va='top', fontsize='small')
        else:
            ax1.text(x+0.5, y, f'{arch}', ha='left', va='bottom', fontsize='small')
    xmin, xmax = ax1.get_xlim()
    ax1.set_xlim(xmin, xmax + 0.35 * (xmax - xmin))
    ax1.tick_params(axis='both', which='major', labelsize='small')

    ax0.set_ylabel('AUC-ROC HM')
    ax0.set_xlabel('Params (M)')
    ax1.set_xlabel('MACs (G)')
    ax1.set_yticklabels([])
    ax1.yaxis.label.set_visible(False)

    conv_patch = mpatches.Patch(color=cp_color['conv'], label='ConvNet')
    tsfm_patch = mpatches.Patch(color=cp_color['tsfm'], label='Transformer')
    ax1.legend(handles=[conv_patch, tsfm_patch], loc='lower right')

    # plt.xticks(fontsize='small')
    # plt.yticks(fontsize='small')

    ax0.grid()
    ax1.grid()

    plt.tight_layout()
    plt.savefig(join(plots_dir, "params-macs.pdf"))
    plt.show()

generate_arch_plot(arch_df)

### Resolution

In [None]:
def load_resolution_df():
    df = pd.read_csv(join(results_dir, 'resolution', 'exp_mtst.csv'))
    df[['arch', 'resolution']] = df['run'].str.split('_', expand=True)
    df['Arch'] = df['arch'].astype(str)
    df['Resolution'] = df['resolution'].astype(int)
    df['Unseen'] = df['unseen']
    df['Seen'] = df['seen']
    df['HM'] = df['hm']
    df = df[['Arch', 'Resolution', 'Seen',  'Unseen', 'HM']]
    df = df.replace(arch_name)
    return df

resolution_df = load_resolution_df()
resolution_df

In [None]:
def generate_resolution_plot(df: pd.DataFrame):
    df = df.copy()
    df['Unseen'] = df['Unseen'].str[:5].astype(float)
    df['Seen'] = df['Seen'].str[:5].astype(float)
    df['HM'] = df['HM'].str[:5].astype(float)

    fig, ax = plt.subplots()
    ax: Axes = ax
    print(plt.figure())

    metrics = {('HM', 1.0, 'solid'), ('Seen', 0.5, 'dashed'), ('Unseen', 0.5, 'dotted')}
    archs = df['Arch'].unique()
    for i, (arch, color) in enumerate(zip(archs, color_palette[:3])):
        arch_df = df[df['Arch'] == arch]
        for metric, alpha, linestyle in metrics:
            ax = arch_df.plot(
                x='Resolution', y=metric,
                ax=ax,
                marker='.', linestyle=linestyle,
                alpha=alpha, color=color,
                # label=f'{metric} {arch}'
            )
            ax.get_legend().remove()

        leg = Legend(ax, ax.get_lines()[-3:],
                     ['HM', 'Unseen', 'Seen'],
                     title=arch,
                     fontsize='x-small',
                     edgecolor=color,
                     loc='lower right')
        bb = leg.get_bbox_to_anchor().transformed(ax.transAxes.inverted())
        x_offset = i * 0.18
        bb.x0 -= x_offset
        bb.x1 -= x_offset
        leg.set_bbox_to_anchor(bb, transform = ax.transAxes)
        ax.add_artist(leg)

    # ax.legend(archs)
    ax.set_xticks(df['Resolution'].unique())
    ax.set_xlabel('X-Ray Resolution')
    ax.set_ylabel('AUC-ROC')

    ax.tick_params(axis='both', which='major', labelsize='small')
    ax.tick_params(axis='both', which='minor', labelsize='small')

    ax.grid()
    fig.tight_layout()
    fig.savefig(join(plots_dir, 'resolution.pdf'))


generate_resolution_plot(resolution_df)

### Subpop Age

In [None]:
def load_age_df():
    df = pd.read_csv(join(results_dir, 'shift_pop', 'exp_mtst.csv'))
    df = df[df.run.str.startswith('age_decade')]
    df['Decade'] = df['run'].str.split('_', expand=True)[1].str[6:].astype(int)
    df['Unseen'] = df['unseen']
    df['Seen'] = df['seen']
    df['HM'] = df['hm']
    df = df[['Decade', 'Seen',  'Unseen', 'HM']]
    return df

age_df = load_age_df()
age_df


In [None]:
def generate_age_plot(df: pd.DataFrame):
    df = df.copy()
    df['Unseen'] = df['Unseen'].str[:5].astype(float)
    df['Seen'] = df['Seen'].str[:5].astype(float)
    df['HM'] = df['HM'].str[:5].astype(float)

    fig, ax = plt.subplots()
    ax: Axes = ax
    print(plt.figure())

    df.plot(
        x='Decade', y='HM',
        ax=ax,
        marker='.', linestyle='solid',
        color=color_palette[0]
    )

    df.plot(
        x='Decade', y='Unseen',
        ax=ax,
        marker='.', linestyle='dotted',
        alpha=0.5, color=color_palette[0]
    )

    df.plot(
        x='Decade', y='Seen',
        ax=ax,
        marker='.', linestyle='dashed',
        alpha=0.5, color=color_palette[0]
    )

    labels = ['[10, 20]', '[21, 30]', '[31, 40]', '[41, 50]',
              '[51, 60]', '[61, 70]', '[71, 80]']
    ax.set_xticks(range(2, 9), labels)
    ax.set_xlabel('Decade')
    ax.set_ylabel('HM')

    ax.tick_params(axis='both', which='major', labelsize='small')
    ax.tick_params(axis='both', which='minor', labelsize='small')

    ax.grid()
    fig.tight_layout()
    fig.savefig(join(plots_dir, 'subpop_age.pdf'))


generate_age_plot(age_df)

### From Generalized FSL to Standard FSL

In [None]:
results_dir = '../../Meta-CXR-dev-run/rpaper'

def load_gfsl_df(

    ):
    df = pd.read_csv(join(results_dir, 'gfsl', 'exp_mtst.csv'))

    df[['n-way', 'n-unseen', 'k-shot']] = df['run'].str.split('_', expand=True)
    df['n-way'] = df['n-way'].str[5:].astype(int)
    df['n-unseen'] = df['n-unseen'].str[7:].astype(int)
    df['k-shot'] = df['k-shot'].str[6:].astype(int)
    df['Unseen'] = df['unseen']
    df['Seen'] = df['seen']
    df['HM'] = df['hm']

    df = df[['n-way', 'n-unseen', 'k-shot', 'Seen',  'Unseen', 'HM']]
    return df

gfsl_df = load_gfsl_df()
with open('gfsl.md', 'w') as f:
    f.write(gfsl_df.to_markdown(index=False) + '\n')
with open('gfsl.tex', 'w') as f:
    formatters = [format_pm] * gfsl_df.shape[1]
    f.write(gfsl_df.to_latex(index=False, formatters=formatters))
gfsl_df

In [None]:
def generate_gfsl_k_shot(df, k_shot, ax: Axes):

    n_ways = (3, 4, 5)
    num_n_ways = len(n_ways)
    colors = color_palette[:num_n_ways]

    for n_way, color in zip(n_ways, colors):
        subdf = df[df['n-way'] == n_way]

        subdf.plot(x='n-unseen', y='HM',
                   ax=ax,
                   marker='.', linestyle='solid',
                   color=color)
        subdf.plot(x='n-unseen', y='Unseen',
                   ax=ax,
                   marker='.', linestyle='dotted',
                   color=color, alpha=0.5)
        subdf.plot(x='n-unseen', y='Seen',
                   ax=ax,
                   marker='.', linestyle='dashed',
                   color=color, alpha=0.5)

        ax.get_legend().remove()


        if k_shot != 1:
            ax.set_yticklabels([])
        else:
            ax.set_ylabel('AUC-ROC')

        if k_shot != 5:
            ax.set_xlabel('')
        else:
            ax.set_xlabel('$n$-unseen')

        if k_shot == 15:
            leg = Legend(ax, ax.get_lines()[-3:],
                        ['HM', 'Unseen', 'Seen'],
                        title=f'{n_way}-way',
                        fontsize='x-small',
                        edgecolor=color,
                        loc='lower right')
            bb = leg.get_bbox_to_anchor().transformed(ax.transAxes.inverted())
            x_offset = (n_way - 5) * 0.25
            bb.x0 += x_offset
            bb.x1 += x_offset
            leg.set_bbox_to_anchor(bb, transform = ax.transAxes)
            ax.add_artist(leg)

        ax.set_title(f'at least {k_shot}-shot')
        ax.tick_params(axis='both', which='major', labelsize='small')
        ax.tick_params(axis='both', which='minor', labelsize='small')
        ax.grid()
        ax.set_ylim(55, 90)


def generate_gfsl_group_by_k_shot(df):
    df = df.copy()

    fig, axs = plt.subplots(1, 3, figsize=(10, 5))

    for (k_shot, k_shot_df), ax in zip(df.groupby('k-shot'), axs):
        k_shot_df = k_shot_df.drop(columns=['k-shot'])
        with open(f'gfsl_{k_shot:02d}-shot.tex', 'w') as f:
            formatters = [format_pm] * k_shot_df.shape[1]
            f.write(k_shot_df.fillna(' '*12).to_latex(index=False, formatters=formatters))
        print(k_shot_df)

        k_shot_df['Unseen'] = k_shot_df['Unseen'].str[:5].astype(float)
        k_shot_df['Seen'] = k_shot_df['Seen'].str[:5].astype(float)
        k_shot_df['HM'] = k_shot_df['HM'].str[:5].astype(float)
        generate_gfsl_k_shot(k_shot_df, k_shot, ax)


    # fig.suptitle('From Generalized to Standard FSL', fontsize=16)
    plt.tight_layout()
    plt.savefig(join(plots_dir, "task_complexity.pdf"))
    plt.show()

generate_gfsl_group_by_k_shot(gfsl_df)
