In [1]:
import torch
import pickle
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

from collections import defaultdict

HOME = Path('/pvc4/battery/nmi_configs')

def mape(x, y):
    return torch.mean((x - y).abs() / y)

In [2]:
workspace =  HOME / 'workspaces/transfer'
preds = defaultdict(dict)
for path in workspace.glob('*'):
    mat, count = path.name.split('_')
    preds_ = []
    for pred_file in path.glob('preds*.pkl'):
        with open(pred_file, 'rb') as f:
            pred = pickle.load(f)
        preds_.append(pred)
    preds[mat][count] = preds_


In [22]:
for file in Path('/pvc4/battery/nmi_configs/workspaces/transfer_pos/NMC_1').glob('*.pkl'):
    with open(file, 'rb') as f:
        y = pickle.load(f)
        for x in ['intercd_preds', 'intracd_preds', 'finetune_preds']:
            score = mape(y[x], y['labels'])
            print(x, score.item(), end=' ')
        print(mape((y['intercd_preds'] + y['intracd_preds']) / 2, y['labels']))
        print('')


intercd_preds 0.37164250016212463 intracd_preds 0.3420901596546173 finetune_preds 0.4568028450012207 tensor(0.3155)

intercd_preds 0.32916706800460815 intracd_preds 0.39218273758888245 finetune_preds 2.6607401371002197 tensor(0.3539)

intercd_preds 0.2726101279258728 intracd_preds 0.24864502251148224 finetune_preds 0.5202293992042542 tensor(0.2391)

intercd_preds 0.9147775173187256 intracd_preds 0.5264322757720947 finetune_preds 0.6676366925239563 tensor(0.7204)

intercd_preds 0.5200692415237427 intracd_preds 0.2866385877132416 finetune_preds 0.744026243686676 tensor(0.3890)

intercd_preds 0.47455620765686035 intracd_preds 0.2832162380218506 finetune_preds 0.8278024196624756 tensor(0.3719)

intercd_preds 0.33706405758857727 intracd_preds 0.3518770635128021 finetune_preds 0.37586989998817444 tensor(0.3417)

intercd_preds 0.3246169686317444 intracd_preds 0.3028452694416046 finetune_preds 0.5715575218200684 tensor(0.2998)



In [4]:
preds

defaultdict(dict,
            {'LCO': {'16': [{'intercd_preds': tensor([1266.1785, 1165.8464,  769.0946,  496.0692,  420.4820,  810.2390,
                         877.5687,  444.2342,  271.3041,  271.3532,  267.4570,  295.4696,
                         261.0262,  236.5822,  279.5909,  257.1804,  204.7635,  211.1976,
                         206.2965,  214.4874,  326.5588]),
                'intracd_preds': tensor([551.5922, 578.0704, 589.7775, 540.2132, 449.1240, 741.3839, 611.7087,
                        553.2409, 268.4343, 251.2568, 261.4382, 291.9220, 239.9379, 250.1857,
                        260.8864, 228.4672, 141.0196, 172.8393, 181.1322, 201.2227, 416.9851]),
                'finetune_preds': tensor([506.0107, 512.0403, 430.8390, 535.8584, 376.6316, 806.2074, 576.5170,
                        564.4268, 263.1222, 282.2976, 267.2672, 229.6733, 240.5640, 214.5345,
                        267.0035, 231.0332, 188.3853, 209.0527, 203.7384, 194.2435, 333.4554]),
                'lab

In [5]:
mape(preds['LCO']['16'][0]['intercd_preds'], preds['LCO']['16'][0]['labels'])
mape(preds['LCO']['16'][0]['intracd_preds'], preds['LCO']['16'][0]['labels'])

KeyError: 'intercd_preds'

In [3]:
colors = ['#A9BED9', '#E3F0F6', '#E8A19D']
train_counts = {
    'LCO': [16],
    # 'NMC': [1, 2, 4, 8, 16],
    # 'NCA': [1, 2, 4, 8]
}
for material, counts in train_counts.items():
    means, stds, mins, maxs = [], [], [], []
    for count in counts:
        scores = {
            'BatLiNet': [],
            'Finetune': [],
            'Vanilla': []
        }
        for seed in range(8):
            seed_preds = preds[material][str(count)][seed]
            # scores['BatLiNet'].append(mape(
            #     (seed_preds['intercd_preds'] + seed_preds['intracd_preds']) / 2,
            #     seed_preds['labels']
            # ))
            scores['BatLiNet'].append(mape(seed_preds['intercd_preds'], seed_preds['labels']))
            scores['Finetune'].append(mape(seed_preds['finetune_preds'], seed_preds['labels']))
            scores['Vanilla'].append(mape(seed_preds['intracd_preds'], seed_preds['labels']))
        means.append([np.mean(scores[key]) for key in ['Vanilla', 'Finetune', 'BatLiNet']])
        stds.append([np.std(scores[key]) for key in ['Vanilla', 'Finetune', 'BatLiNet']])
        mins.append([np.min(scores[key]) for key in ['Vanilla', 'Finetune', 'BatLiNet']])
        maxs.append([np.max(scores[key]) for key in ['Vanilla', 'Finetune', 'BatLiNet']])
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    x = np.arange(len(counts))
    ax.bar(
        x,
        [i[0] for i in means],
        width=0.2,
        color=colors[0],
        yerr=[i[0] for i in stds],
        error_kw=dict(capsize=3, capthick=1, lw=1),
        # error bar color is darker than grey
        ecolor='black')
    ax.bar(
        x+0.2,
        [i[1] for i in means],
        width=0.2,
        color=colors[1],
        yerr=[i[1] for i in stds],
        error_kw=dict(capsize=3, capthick=1, lw=1),
        # error bar color is darker than grey
        ecolor='black')
    ax.bar(
        x+0.4,
        [i[2] for i in means],
        width=0.2,
        color=colors[2],
        yerr=[i[2] for i in stds],
        error_kw=dict(capsize=3, capthick=1, lw=1),
        # error bar color is darker than grey
        ecolor='black')
    plt.ylim([0, 1.2])
    plt.show()
    # fig.savefig(
    #     HOME / f'workspaces/transfer/{material}.svg',
    #     bbox_inches='tight',
    #     pad_inches=0.1
    # )

IndexError: list index out of range