In [1]:
import sys, os, time
import numpy as np
import pandas as pd
import scipy.io as sio
import pickle

In [2]:
# Accuracy metric
def acc(r_val, r_std):
    last_row = r_val[len(r_val)-1]
    last_row_std = r_std[len(r_std)-1]

    acc = round(last_row.sum() / len(r_val), 5)

    # Calculate the standard deviation for acc
    acc_std = round(np.sqrt((last_row_std**2).sum()) / len(r_val), 5)

    return acc, acc_std

# Backward Transfer metric
def bwt(r_val, r_std):
    t = len(r_val)
    tmp = []
    std_diffs = []
    for i in range(t-1):
        diff = r_val[t-1][i] - r_val[i][i]
        tmp.append(diff)
        
        # Calculate the standard deviation for each difference
        std_diff = np.sqrt(r_std[t-1][i]**2 + r_std[i][i]**2)
        std_diffs.append(std_diff)
    
    tmp_arr = np.array(tmp)
    std_diffs_arr = np.array(std_diffs)

    bwt = round(tmp_arr.sum() / (t-1), 5)
    
    # Calculate the standard deviation for bwt
    bwt_std = round(np.sqrt((std_diffs_arr**2).sum()) / (t-1), 5)
    
    return bwt, bwt_std

In [3]:
exp = 'minigrid-wallgap-doorkey-redbluedoor-crossing'
steps = '500000'
approaches = ['fine-tuning', 'ewc', 'blip', 'blip_ewc_1', 'blip_ewc_2', 'blip_spp_mask']
metrics_dir = './metrics/'
seeds = [123456, 789012, 345678]
F_prior = 5e-18
ewc_lambda = 5000.0

blip_ewc_lambda_1 = 1000.0
blip_ewc_lambda_2 = 2500.0
fisher_term = 'ft'

spp_lambda = 0.5
initial_prune_percent = 30.0
scheduler = False
prune_higher = False

task_state = 3

# Experiment details
experiments = [
    ('2023-11-13', exp, [
    (0, 'MiniGrid-WallGapS6-v0'),
    (1, 'MiniGrid-DoorKey-6x6-v0'),
    (2, 'MiniGrid-RedBlueDoors-6x6-v0'), 
    (3, 'MiniGrid-SimpleCrossingS9N1-v0')  
    ])    
    ]

for item in experiments:

    date = item[0]
    experiment = item[1]
    tasks_sequence = item[2]
    t = len(tasks_sequence)

    print('Experiment:',experiment, '\n')
    for i, approach in enumerate(approaches):

        # create name of data export file
        if approach == 'fine-tuning' or approach == 'ft-fix':
            exp_name = '{}_{}_{}_tr_{}'.format(date, experiment, approach, task_state)
        elif approach == 'ewc':
            exp_name = '{}_{}_{}_lamb_{}_tr_{}'.format(date, experiment, approach, ewc_lambda, task_state)
        elif approach == 'blip':
            exp_name = '{}_{}_{}_F_prior_{}_tr_{}'.format(date, experiment, approach, F_prior, task_state)
        elif approach == 'blip_ewc_1':
            exp_name = '{}_{}_{}_F_prior_{}_lamb_{}_F_term_{}_tr_{}'.format(date, experiment, 'blip_ewc', F_prior, blip_ewc_lambda_1 , fisher_term, task_state)
        elif approach == 'blip_ewc_2':
            exp_name = '{}_{}_{}_F_prior_{}_lamb_{}_F_term_{}_tr_{}'.format(date, experiment, 'blip_ewc', F_prior, blip_ewc_lambda_2, fisher_term, task_state)
        elif approach == 'blip_spp':
            exp_name = '{}_{}_{}_F_prior_{}_spp_lamb_{}_tr_{}'.format(date, experiment, approach, F_prior, spp_lambda, task_state)
        elif approach == 'blip_spp_mask':
            exp_name = '{}_{}_{}_F_prior_{}_spp_lamb_{}_prune_{}_scheduler_{}_prune_higher_{}_tr_{}'.format(date, experiment, approach, F_prior, spp_lambda, initial_prune_percent, scheduler, prune_higher, task_state)

        r_val_file = os.path.join(metrics_dir, exp_name + "_final_r_val.pkl")
        r_std_file = os.path.join(metrics_dir, exp_name + "_final_r_std.pkl")

        r_val_df = pd.read_pickle(r_val_file)
        r_std_df = pd.read_pickle(r_std_file)

        r_val = r_val_df
        r_std = r_std_df

        print("Approach:", approach)
        print("ACC: {}".format(acc(r_val, r_std)))
        print("BWT: {}\n".format(bwt(r_val, r_std)))

Experiment: minigrid-wallgap-doorkey-redbluedoor-crossing 

Approach: fine-tuning
ACC: (0.29201, 0.098)
BWT: (-0.56598, 0.13079)

Approach: ewc
ACC: (0.61541, 0.11525)
BWT: (0.17525, 0.16087)

Approach: blip
ACC: (0.46271, 0.16051)
BWT: (0.0, 0.21832)

Approach: blip_ewc_1
ACC: (0.44815, 0.16383)
BWT: (-4e-05, 0.21835)

Approach: blip_ewc_2
ACC: (0.51849, 0.10407)
BWT: (-3e-05, 0.06876)

Approach: blip_spp_mask
ACC: (0.54105, 0.1462)
BWT: (-2e-05, 0.1672)

