In [None]:
# Import necessary libraries
import wandb
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import time
import argparse
import os
from matplotlib.lines import Line2D
import seaborn as sns
import matplotlib.font_manager
# matplotlib.font_manager._rebuild()


# Define the extract_and_interpolate function
def extract_and_interpolate(data_info, args):
    if args.extract == 'wandb':
        start = time.time()
        data = pd.DataFrame()
        for i, run_id in enumerate(data_info['run_id_list']):
            run = wandb.Api(timeout=30).run(f"{wandb_username}/{data_info['wandb_project']}/{run_id}")
            metrics = run.scan_history(keys=['_step', data_info['metric']], page_size=10000)
            step_list=[]
            score_list=[]
            j=0
            for metric_data in metrics:
                if metric_data['_step'] <= data_info['total_timesteps']:
                    step_list.append(metric_data['_step'])
                    score_list.append(metric_data[data_info['metric']])
                    j+=1
                else:
                    break
                if j % args.print_freq == 0:
                    print(f'data_info={data_info}, scan number: {j}, current_step: {step_list[-1]}, time: {time.time() - start}')
            smooth_score = score_list
            data[run_id] = smooth_score
            print(f'Finish data_info: {data_info}, Total Time: {time.time()-start}')
            start = time.time()
        score_avg = data.mean(axis=1)
        score_var = data.var(axis=1)
        data['avg'] = score_avg
        data['var'] = np.sqrt(score_var)
        data['step'] = data.index

        return data

# Main function
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--extract", type=str, default='wandb', choices=['wandb', 'csv'])
    parser.add_argument("--wandb_project", type=str, default='LoR_VP_tiny_imagenet')
    parser.add_argument("--print_freq", type=int, default=10000)
    args, unknown = parser.parse_known_args()

    wandb_username=os.environ.get('WANDB_USER_NAME')
    wandb_key=os.environ.get('WANDB_API_KEY')    
    wandb.login(key=wandb_key)

    # Define image information
    image_info ={
        'save_title': 'rank_deficient',
        'image_title': 'ResNet50,Tiny-ImageNet',
        'x_min': 0,
        'x_max': 8,
        'x_step': 1,
        'x_boundary_shift': 1,
        'y_min': [20],  # Different y_min for each plot
        'y_max': [83],  # Different y_max for each plot
        'y_step': [20],  # Different y_step for each plot
        'y_boundary_shift': 5,
        'xlabel_name': r'$\log_2(\text{Rank})$',
        'ylabel_name': ['Accuracy'],  # Different y_labels for each plot
        'width': 20,  # Double the width for two plots
        'height': 8,
        'fontsize': 45,
        'mark_last_point': 1,
        'use_seaborn': 1,
        'use_times_newroman': 0,
        'grid_color': "gray",
        'grid_linestyle': '-',
        'grid_linewidth': 5, # 2
        'grid_alpha': 0.2,
        'legend_alpha': 1,
        'x_axis_shift': 0.5,
        'step_denominator': 20
    }

    colors = ['blue', 'red', 'orange', 'purple', 'green',
          'olive', 'brown', 'magenta', 'cyan', 'crimson', 'gray', 'black']
    data_info = {
        'lor_vp': {
            'wandb_project': 'LoR_VP_tiny_imagenet',
            'run_id_list': ['xeij68xh'],
            'metric': 'test_acc',
            'total_timesteps': 10000,
            'window_len_smooth': 1, # 100
            'min_window_len_smooth': 1,
            'linewidth': 6,
            'linestyle': 'solid',
            'marker': '*',
            'markersize': 30,
            'markevery': 1,
            'alpha_smooth': 1,
            'fill_in_alpha': 0.2,
            'color': 'green',
            'label': 'LoR-VP'
        },
        'autovp': {
            'wandb_project': 'LoR_VP_tiny_imagenet',
            'run_id_list': ['l93wpzta', 'rf92gl2g'],
            'metric': 'test_acc',
            'total_timesteps': 40000,
            'window_len_smooth': 1, # 100
            'min_window_len_smooth': 1,
            'linewidth': 6,
            'linestyle': 'solid',
            'marker': 'o',
            'markersize': 25,
            'markevery': 1,
            'alpha_smooth': 1,
            'fill_in_alpha': 0.2,
            'color': 'red',
            'label': 'AutoVP',
        },
        'ilm_vp': {
            'wandb_project': 'LoR_VP_tiny_imagenet',
            'run_id_list': ['eo19a6p4', 'lzq07yj2'], 
            'metric': 'test_acc',
            'total_timesteps': 80000,
            'window_len_smooth': 1, # 100
            'min_window_len_smooth': 1,
            'linewidth': 6,
            'linestyle': 'solid',
            'marker': 'o',
            'markersize': 25,
            'markevery': 1,
            'alpha_smooth': 1,
            'fill_in_alpha': 0.2,
            'color': 'orange',
            'label': 'ILM-VP',
        },
    }

    # Extract and process data
    lor_vp_data = pd.DataFrame()
    lor_vp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    lor_vp_data['avg'] = [76.82, 77.34, 78.25, 78.57, 78.41, 78.36, 78.12, 78.38]
    lor_vp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]
    
    autovp_data = pd.DataFrame()
    autovp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    autovp_data['avg'] = [73.68, 74.82, 75.46, 75.52, 75.51, 75.47, 75.79, 75.44]
    autovp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]

    ilm_vp_data = pd.DataFrame()
    ilm_vp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    ilm_vp_data['avg'] = [19.52, 24.12, 26.46, 27.52, 27.93, 27.77, 27.88, 27.95]
    ilm_vp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]

    # Create subplots
    fig, axes = plt.subplots(1, 2, figsize=(image_info['width'], image_info['height']))

    if image_info['use_seaborn']:
        sns.set_theme()
    if image_info['use_times_newroman']:
        plt.rcParams['font.family'] = 'serif'
        plt.rcParams['font.serif'] = 'Times New Roman'

    legend_handles = []

    # Plot for Train Loss
    ax = axes[0]
    ax.grid(visible=True, which='major', linestyle=image_info['grid_linestyle'], linewidth=image_info['grid_linewidth'])
    ax.grid(visible=False, which='minor')
    ax.minorticks_on()

    # ax2 = ax.twinx()

    info = data_info['autovp']
    data = autovp_data
    legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])

    info = data_info['ilm_vp']
    data = ilm_vp_data
    legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])
    # ax2.set_ylim(20, 30)
    # ax2.set_yticks(np.arange(20, 31, 10))
    # ax2.tick_params(axis='y', labelsize=image_info['fontsize'])

    info = data_info['lor_vp']
    data = lor_vp_data
    legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])

    ax.set_xscale('log',base=2)
    ax.set_xlabel(image_info['xlabel_name'], fontsize=image_info['fontsize'])
    ax.set_ylabel(image_info['ylabel_name'][0], fontsize=image_info['fontsize'])
    ax.set_xlim(image_info['x_min'] - image_info['x_boundary_shift'] / 2, image_info['x_max'] + image_info['x_boundary_shift'])
    ax.set_ylim(image_info['y_min'][0] - image_info['y_boundary_shift'] / 2, image_info['y_max'][0] + image_info['y_boundary_shift'])
    ax.set_ylim(73, 79)
    ax.set_yticks(np.arange(73, 79, 5))
    ax.set_xticks(ticks=[1, 2, 4, 8, 16, 32, 64, 128, 256], labels=['0', '1', '2', '3', '4', '5', '6', '7', '8'])
    ax.set_yticks(np.arange(image_info['y_min'][0], image_info['y_max'][0] + 0.1, image_info['y_step'][0]))
    ax.tick_params(axis='both', which='major', labelsize=image_info['fontsize'])
    ax.set_title('ResNet50-P,TinyImageNet', fontsize=image_info['fontsize'])

    # # Add broken axis indicator
    # d = .015  # how big to make the diagonal lines in axes coordinates
    # kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
    # ax.plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal
    # ax.plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

    # kwargs.update(transform=ax2.transAxes)  # switch to the right y-axis
    # ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
    # ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

    
    lor_vp_data = pd.DataFrame()
    lor_vp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    lor_vp_data['avg'] = [81.57, 82.17, 82.59, 82.78, 82.76, 82.91, 82.84, 82.78]
    lor_vp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]
    
    autovp_data = pd.DataFrame()
    autovp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    autovp_data['avg'] = [80.51, 80.72, 80.89, 81.52, 81.52, 81.64, 81.59, 81.59]
    autovp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]

    ilm_vp_data = pd.DataFrame()
    ilm_vp_data['step'] = [1, 3, 5, 10, 20, 50, 100, 224]
    ilm_vp_data['avg'] = [20.68, 24.82, 25.46, 26.42, 26.51, 26.47, 26.79, 26.44]
    ilm_vp_data['var'] = [0, 0, 0, 0, 0, 0, 0, 0]

    ax = axes[1]
    ax.grid(visible=True, which='major', linestyle=image_info['grid_linestyle'], linewidth=image_info['grid_linewidth'])
    ax.grid(visible=False, which='minor')
    ax.minorticks_on()

    info = data_info['autovp']
    data = autovp_data
    # legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])

    info = data_info['ilm_vp']
    data = ilm_vp_data
    # legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])

    info = data_info['lor_vp']
    data = lor_vp_data
    # legend_handles.append(Line2D([0], [0], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'], label=info['label'], linestyle=info['linestyle']))
    ax.plot(data['step'], data['avg'], marker=info['marker'], markersize=info['markersize'], linestyle=info['linestyle'], color=info['color'], linewidth=info['linewidth'], alpha=info['alpha_smooth'])
    ax.fill_between(data['step'], data['avg'] - data['var'], data['avg'] + data['var'], color=info['color'], alpha=info['fill_in_alpha'])

    ax.set_xscale('log',base=2)
    ax.set_xlabel(image_info['xlabel_name'], fontsize=image_info['fontsize'])
    ax.set_ylabel(image_info['ylabel_name'][0], fontsize=image_info['fontsize'])
    ax.set_xlim(image_info['x_min'] - image_info['x_boundary_shift'] / 2, image_info['x_max'] + image_info['x_boundary_shift'])
    ax.set_ylim(image_info['y_min'][0] - image_info['y_boundary_shift'] / 2, image_info['y_max'][0] + image_info['y_boundary_shift'])
    ax.set_xticks(ticks=[1, 2, 4, 8, 16, 32, 64, 128, 256], labels=['0', '1', '2', '3', '4', '5', '6', '7', '8'])
    ax.set_yticks(np.arange(image_info['y_min'][0], image_info['y_max'][0] + 0.1, image_info['y_step'][0]))
    ax.tick_params(axis='both', which='major', labelsize=image_info['fontsize'])
    ax.set_title('ViT-B/16,Tiny-ImageNet', fontsize=image_info['fontsize'])

    legend=fig.legend(handles=legend_handles, fontsize=image_info['fontsize'], loc='lower center', bbox_to_anchor=(0.5, -0.2), ncol=3, frameon=False)

    # plt.title(image_info['image_title'], fontsize=image_info['fontsize'])
    plt.tight_layout()
    plt.savefig(f"ablations/rank_deficient/{image_info['save_title']}.pdf", bbox_inches='tight')
    plt.show()
    plt.clf()
    print(f"Finished plot, the figure is saved in {image_info['save_title']}.pdf")