In [1]:
import re
import sys
import torch
import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from collections import defaultdict

HOME = Path.cwd().parent
workspace = HOME / 'workspaces'

# Import from our code base
sys.path.append(str(HOME))
from src.data import BatteryData, CyclingProtocol
from src.utils import import_config
from src.builders import (
    TRAIN_TEST_SPLITTERS,
    LABEL_ANNOTATORS,
    FEATURE_EXTRACTORS
)

Reproduce Table 1.

In [2]:
# Helper functions
def format_scores(scores, show_max_min: bool = True):
    mean = np.mean(scores)
    std = np.std(scores)
    if show_max_min:
        min_ = np.min(scores)
        max_ = np.max(scores)
        return f'{mean:.0f}±{std:.0f}'
    else:
        return f'{mean:.0f}±{std:.0f}'

def format_df(res_dict, agg_fn):
    new_res = {
        k: {
            key: agg_fn(val) for key, val in v.items()
        } for k, v in res_dict.items()
    }
    return pd.DataFrame(new_res)

def extract_scores(log_filename):
    with open(log_filename, 'r') as f:
        last_line = f.read().splitlines()[-1]
    scores = {}
    try:
        rmse, mae, mape = re.findall(r'[^:]+: (\d+\.\d+)', last_line)
    except:
        print(log_filename, last_line)
        raise
    scores['RMSE'] = float(rmse) 
    scores['MAE'] = float(mae)
    scores['MAPE'] = float(mape) * 100
    return scores

In [3]:
datasets_to_use = ['matr_1', 'matr_2', 'hust', 'mix_20', 'mix_100']
sklearn_baseline_names = [
    'dummy',
    'variance_model',
    'discharge_model',
    'full_model',
    'ridge',
    'pcr',
    'plsr',
    'svm',
    'rf'
]
nn_baseline_names = ['mlp', 'cnn', 'lstm']

def collect_results(dataset):
    sklearn_results = [
        sklearn_baselines[dataset][name] for name in sklearn_baseline_names
    ]
    nn_results = [
        nn_baselines[dataset][name] for name in nn_baseline_names
    ]
    our_results = [ours[dataset]['Ours']]
    return sklearn_results + nn_results + our_results

sklearn_baselines = defaultdict(dict)
for method_path in workspace.glob('baselines/sklearn/*'):
    method = method_path.name
    for dataset_res in method_path.glob('*'):
        dataset = dataset_res.name
        score = extract_scores(dataset_res / 'log.0')['RMSE']
        sklearn_baselines[dataset][method] = f'{score:.0f}'

nn_baselines = defaultdict(dict)
for method_path in workspace.glob('baselines/nn_models/*'):
    method = method_path.name
    for dataset_res in method_path.glob('*'):
        dataset = dataset_res.name
        scores = []
        for i in range(8):
            score = extract_scores(dataset_res / f'log.{i}')['RMSE']
            scores.append(score)
        nn_baselines[dataset][method] = format_scores(scores)
ours = defaultdict(dict)

method = 'Ours'
for dataset_res in workspace.glob('ablation/diff_branch/batlinet/*'):
    dataset = dataset_res.name
    try:
        scores = []
        for i in range(8):
            score = extract_scores(dataset_res / f'log.{i}')['RMSE']
            scores.append(score)
        ours[dataset][method] = format_scores(scores)
    except:
        ours[dataset][method] = 'none'

main_table_df = pd.DataFrame({
    dataset: collect_results(dataset) for dataset in datasets_to_use
}, index=sklearn_baseline_names+nn_baseline_names+['Ours'])

# Fill in Attia and Severson et al. results
main_table_df.loc['variance_model', ['matr_1', 'matr_2']] = ['138', '196']
main_table_df.loc['discharge_model', ['matr_1', 'matr_2']] = ['86', '173']
main_table_df.loc['full_model', ['matr_1', 'matr_2']] = ['100', '214']
main_table_df.loc['ridge', ['matr_1', 'matr_2']] = ['125', '188']
main_table_df.loc['pcr', ['matr_1', 'matr_2']] = ['100', '176']
main_table_df.loc['plsr', ['matr_1', 'matr_2']] = ['97', '193']
main_table_df.loc['rf', ['matr_1', 'matr_2']] = ['140', '202']

main_table_df = main_table_df.rename(columns={
    'mix_20': 'MIX-20',
    'hust': 'HUST',
    'matr_1': 'MATR-1',
    'matr_2': 'MATR-2',
    'mix_100': 'MIX-100'
}, index={
    'dummy': 'Training Mean',
    'variance_model': '``Variance\'\' Model',
    'discharge_model': '``Discharge\'\' Model',
    'full_model': '``Full\'\' Model',
    'ridge': 'Ridge Regression',
    'pcr': 'PCR',
    'plsr': 'PLSR',
    'svm': 'SVM Regression',
    'rf': 'Random Forest',
    'mlp': 'MLP',
    'cnn': 'CNN',
    'lstm': 'LSTM',
    'Ours': 'BatLiNet (ours)'
})[['MATR-1', 'MATR-2', 'HUST', 'MIX-100', 'MIX-20']]
main_table_df

Unnamed: 0,MATR-1,MATR-2,HUST,MIX-100,MIX-20
Training Mean,399,511,420,573,594
``Variance'' Model,138,196,398,521,600
``Discharge'' Model,86,173,322,1737,988653
``Full'' Model,100,214,335,331,437
Ridge Regression,125,188,1047,395,837
PCR,100,176,435,384,707
PLSR,97,193,431,371,482
SVM Regression,140,300,344,257,461
Random Forest,140,202,345,214,290
MLP,162±7,207±4,444±3,461±30,519±25


In [1]:
# Relative improvements
def rel_improve(x, y):
    return (y - x) / x * 100

scores = [
    (63, 86),
    (162, 173),
    (268, 322),
    (168, 214),
    (207, 290)
]

for x, y in scores:
    print(f'{rel_improve(x, y):.1f}%')

36.5%
6.8%
20.1%
27.4%
40.1%


In [4]:
datasets_to_use = ['matr_1', 'matr_2', 'hust', 'mix_20', 'mix_100']
sklearn_baseline_names = [
    'dummy',
    'variance_model',
    'discharge_model',
    'full_model',
    'ridge',
    'pcr',
    'plsr',
    'svm',
    'rf'
]
nn_baseline_names = ['mlp', 'cnn', 'lstm']

def collect_results(dataset):
    sklearn_results = [
        sklearn_baselines[dataset][name] for name in sklearn_baseline_names
    ]
    nn_results = [
        nn_baselines[dataset][name] for name in nn_baseline_names
    ]
    our_results = [ours[dataset]['Ours']]
    return sklearn_results + nn_results + our_results

sklearn_baselines = defaultdict(dict)
for method_path in workspace.glob('baselines/sklearn/*'):
    method = method_path.name
    for dataset_res in method_path.glob('*'):
        dataset = dataset_res.name
        score = extract_scores(dataset_res / 'log.0')['MAPE']
        sklearn_baselines[dataset][method] = f'{score:.0f}'

nn_baselines = defaultdict(dict)
for method_path in workspace.glob('baselines/nn_models/*'):
    method = method_path.name
    for dataset_res in method_path.glob('*'):
        dataset = dataset_res.name
        scores = []
        for i in range(8):
            score = extract_scores(dataset_res / f'log.{i}')['MAPE']
            scores.append(score)
        nn_baselines[dataset][method] = format_scores(scores)
ours = defaultdict(dict)

method = 'Ours'
for dataset_res in workspace.glob('ablation/diff_branch/batlinet/*'):
    dataset = dataset_res.name
    try:
        scores = []
        for i in range(8):
            score = extract_scores(dataset_res / f'log.{i}')['MAPE']
            scores.append(score)
        ours[dataset][method] = format_scores(scores)
    except:
        ours[dataset][method] = 'none'

main_table_df = pd.DataFrame({
    dataset: collect_results(dataset) for dataset in datasets_to_use
}, index=sklearn_baseline_names+nn_baseline_names+['Ours'])

main_table_df = main_table_df.rename(columns={
    'mix_20': 'MIX-20',
    'hust': 'HUST',
    'matr_1': 'MATR-1',
    'matr_2': 'MATR-2',
    'mix_100': 'MIX-100'
}, index={
    'dummy': 'Training Mean',
    'variance_model': '``Variance\'\' Model',
    'discharge_model': '``Discharge\'\' Model',
    'full_model': '``Full\'\' Model',
    'ridge': 'Ridge Regression',
    'pcr': 'PCR',
    'plsr': 'PLSR',
    'svm': 'SVM Regression',
    'rf': 'Random Forest',
    'mlp': 'MLP',
    'cnn': 'CNN',
    'lstm': 'LSTM',
    'Ours': 'BatLiNet (ours)'
})[['MATR-1', 'MATR-2', 'HUST', 'MIX-100', 'MIX-20']]
main_table_df

Unnamed: 0,MATR-1,MATR-2,HUST,MIX-100,MIX-20
Training Mean,28,36,18,59,102
``Variance'' Model,15,12,17,39,96
``Discharge'' Model,17,10,14,47,5951
``Full'' Model,16,100,14,22,54
Ridge Regression,11,10,36,30,150
PCR,11,14,19,28,60
PLSR,11,11,18,26,75
SVM Regression,15,18,16,18,51
Random Forest,17,14,16,15,31
MLP,12±0,11±0,18±1,28±1,53±2


Reproduce Table 2.

In [5]:
class BatteryStats:
    def __init__(self,
                 material,
                 charge_protocol,
                 discharge_protocol,
                 temperature,
                 form_factor):
        self.material = material
        self.cp = charge_protocol
        self.dp = discharge_protocol
        self.T = temperature
        self.ff = form_factor

    def __repr__(self):
        return ','.join(map(str, [
            self.material,
            get_protocols(self.cp),
            get_protocols(self.dp),
            self.T,
            self.ff
        ]))

def get_protocols(protocols):
    if isinstance(protocols, CyclingProtocol):
        protocols = [protocols]
    res = []
    for x in protocols:
        if isinstance(x, CyclingProtocol):
            res.append(str(x.to_dict()))
        else:
            res.append(''.join(str(y.to_dict()) for y in x))
    return '+'.join(res)

def get_stats(cell_stats, path):
    if isinstance(path, int):
        total_num = path
    else:
        total_num = len(list(path.glob('*')))
    num_temperatures = len(set(str(x.T) for x in cell_stats))
    num_form_factors = len(set(str(x.ff) for x in cell_stats))
    num_charge_protocols = len(set(get_protocols(x.cp) for x in cell_stats))
    num_discharge_protocols = len(set(get_protocols(x.dp) for x in cell_stats))
    num_materials = len(set(str(x.material) for x in cell_stats))
    num_different_settings = len(set(str(x) for x in cell_stats))

    return (
        total_num,
        num_materials,
        num_charge_protocols,
        num_discharge_protocols,
        num_temperatures,
        num_form_factors,
        num_different_settings
    )

In [6]:
# See https://calce.umd.edu/battery-data
calce_cells = [
    # CS2_33, CS2_34
    BatteryStats(
        'LCO', CyclingProtocol(rate_in_C=0.5), CyclingProtocol(rate_in_C=0.5),
        'unknown', 'prismatic'),
    # CS2_35, CS2_36, CS2_37, CS2_38
    BatteryStats(
        'LCO', CyclingProtocol(rate_in_C=0.5), CyclingProtocol(rate_in_C=1),
        'unknown', 'prismatic'),
    # CX2_16, CX2_33, CX2_35, CX2_34, CX2_36, CX2_37, CX2_38
    BatteryStats(
        'LCO', CyclingProtocol(rate_in_C=0.5), CyclingProtocol(rate_in_C=0.5),
        'unknown', 'prismatic'),
]
calce_stats = get_stats(calce_cells, HOME / 'data/processed/CALCE')

In [7]:
batteries = list(HOME.glob('data/processed/HNEI/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading HNEI cells')]
hnei_cells = [BatteryStats(
    'NMC_LCO',
    cell.charge_protocol,
    cell.discharge_protocol,
    25, cell.form_factor
) for cell in battery_data]
hnei_stats = get_stats(hnei_cells, HOME / 'data/processed/HNEI')

Loading HNEI cells: 100%|██████████| 14/14 [00:02<00:00,  4.82it/s]


In [8]:
batteries = list(HOME.glob('data/processed/HUST/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading HUST cells')]
hust_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data]
hust_stats = get_stats(hust_cells, HOME / 'data/processed/HUST')

Loading HUST cells: 100%|██████████| 77/77 [00:34<00:00,  2.22it/s]


In [9]:
batteries = list(HOME.glob('data/processed/MATR/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading MATR cells')]
MATR_1_cells = [
    'b1c1',  'b1c3',  'b1c5',  'b1c7',  'b1c11', 'b1c15',
    'b1c17', 'b1c19', 'b1c21', 'b1c24', 'b1c26', 'b1c28',
    'b1c30', 'b1c32', 'b1c34', 'b1c36', 'b1c38', 'b1c40',
    'b1c42', 'b1c44', 'b2c0',  'b2c2',  'b2c4',  'b2c6',
    'b2c11', 'b2c13', 'b2c17', 'b2c19', 'b2c21', 'b2c23',
    'b2c25', 'b2c27', 'b2c29', 'b2c31', 'b2c33', 'b2c35',
    'b2c37', 'b2c39', 'b2c41', 'b2c43', 'b2c45',
    'b1c0',  'b1c2',  'b1c4',  'b1c6',  'b1c9',  'b1c14',
    'b1c16', 'b1c18', 'b1c20', 'b1c23', 'b1c25', 'b1c27',
    'b1c29', 'b1c31', 'b1c33', 'b1c35', 'b1c37', 'b1c39',
    'b1c41', 'b1c43', 'b1c45', 'b2c1',  'b2c3',  'b2c5',
    'b2c10', 'b2c12', 'b2c14', 'b2c18', 'b2c20', 'b2c22',
    'b2c24', 'b2c26', 'b2c28', 'b2c30', 'b2c32', 'b2c34',
    'b2c36', 'b2c38', 'b2c40', 'b2c42', 'b2c44', 'b2c46',
    'b2c47']
MATR_2_cells = [
    'b1c1',  'b1c3',  'b1c5',  'b1c7',  'b1c11', 'b1c15',
    'b1c17', 'b1c19', 'b1c21', 'b1c24', 'b1c26', 'b1c28',
    'b1c30', 'b1c32', 'b1c34', 'b1c36', 'b1c38', 'b1c40',
    'b1c42', 'b1c44', 'b2c0',  'b2c2',  'b2c4',  'b2c6',
    'b2c11', 'b2c13', 'b2c17', 'b2c19', 'b2c21', 'b2c23',
    'b2c25', 'b2c27', 'b2c29', 'b2c31', 'b2c33', 'b2c35',
    'b2c37', 'b2c39', 'b2c41', 'b2c43', 'b2c45',
    'b3c0',  'b3c1',  'b3c3',  'b3c4',  'b3c5',  'b3c6',
    'b3c7',  'b3c8',  'b3c9',  'b3c10', 'b3c11', 'b3c12',
    'b3c13', 'b3c14', 'b3c15', 'b3c16', 'b3c17', 'b3c18',
    'b3c19', 'b3c20', 'b3c21', 'b3c22', 'b3c24', 'b3c25',
    'b3c26', 'b3c27', 'b3c28', 'b3c29', 'b3c30', 'b3c31',
    'b3c33', 'b3c34', 'b3c35', 'b3c36', 'b3c38', 'b3c39',
    'b3c40', 'b3c41', 'b3c44', 'b3c45']
MATR_full_cells = [
    'b4c43', 'b1c7', 'b3c9', 'b2c6', 'b2c21', 'b4c27',
    'b2c13', 'b1c27', 'b3c11', 'b4c39', 'b2c33', 'b4c2',
    'b4c9', 'b1c44', 'b3c7', 'b1c37', 'b3c34', 'b3c27',
    'b2c36', 'b1c34', 'b2c5', 'b2c47', 'b1c35', 'b1c29',
    'b1c38', 'b2c1', 'b4c31', 'b2c43', 'b3c13', 'b4c5',
    'b4c8', 'b4c36', 'b1c13', 'b4c23', 'b4c29', 'b3c44',
    'b3c2', 'b2c22', 'b2c42', 'b3c33', 'b1c4', 'b3c16',
    'b3c24', 'b1c10', 'b4c32', 'b1c8', 'b1c17', 'b1c14',
    'b4c12', 'b3c15', 'b1c28', 'b3c23', 'b3c0', 'b4c40',
    'b3c36', 'b2c30', 'b1c15', 'b1c24', 'b2c0', 'b1c1',
    'b4c25', 'b1c2', 'b2c18', 'b3c28', 'b3c35', 'b1c21',
    'b1c5', 'b4c20', 'b1c9', 'b2c26', 'b4c44', 'b3c18',
    'b2c17', 'b1c33', 'b3c14', 'b4c15', 'b4c35', 'b2c3',
    'b2c12', 'b3c26', 'b4c18', 'b4c17', 'b1c26', 'b3c6',
    'b3c19', 'b1c16', 'b2c23', 'b1c39', 'b4c30', 'b2c4',
    'b1c12', 'b3c12', 'b4c34', 'b4c11', 'b3c31', 'b1c36',
    'b2c35', 'b3c38', 'b2c11', 'b2c38', 'b3c4', 'b4c38',
    'b3c43', 'b2c37', 'b4c4', 'b4c42', 'b4c0', 'b1c32',
    'b3c41', 'b3c22', 'b1c6', 'b3c32', 'b4c19', 'b4c22',
    'b4c33', 'b2c32', 'b1c43', 'b1c20', 'b4c24', 'b2c28',
    'b3c3', 'b1c45', 'b4c7', 'b1c19', 'b2c20', 'b2c31',
    'b4c10', 'b1c41', 'b2c24', 'b2c2', 'b2c10', 'b1c23',
    'b2c44', 'b2c25', 'b2c46', 'b3c42', 'b3c20', 'b1c3',
    'b1c11', 'b4c13', 'b1c22', 'b4c6', 'b3c8', 'b3c30',
    'b4c41', 'b4c14', 'b2c14', 'b2c41', 'b3c29', 'b3c1',
    'b4c28', 'b4c16', 'b2c45', 'b1c25', 'b1c31', 'b2c40',
    'b1c42', 'b4c26', 'b4c1', 'b2c27', 'b4c3', 'b1c30',
    'b4c37', 'b2c34', 'b1c40', 'b3c45', 'b2c19', 'b4c21',
    'b3c10', 'b3c39', 'b2c39', 'b3c21', 'b3c40', 'b3c5',
    'b2c29', 'b1c18', 'b3c25', 'b3c17']
matr1_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data if cell.cell_id.split('MATR_')[1] in MATR_1_cells]
matr2_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data if cell.cell_id.split('MATR_')[1] in MATR_2_cells]
ne_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data if 'b4' not in cell.cell_id.split('MATR_')[1]]
clo_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data if 'b4' in cell.cell_id.split('MATR_')[1]]
matr_full_cells = [BatteryStats(
    'LFP',
    cell.charge_protocol,
    cell.discharge_protocol,
    30, cell.form_factor
) for cell in battery_data if cell.cell_id.split('MATR_')[1] in MATR_full_cells]
matr1_stats = get_stats(matr1_cells, len(MATR_1_cells))
matr2_stats = get_stats(matr2_cells, len(MATR_2_cells))
ne_stats = get_stats(ne_cells, len(ne_cells))
clo_stats = get_stats(clo_cells, len(clo_cells))
matr_full_stats = get_stats(matr_full_cells, len(MATR_full_cells))

Loading MATR cells: 100%|██████████| 180/180 [01:28<00:00,  2.04it/s]


In [10]:
batteries = list(HOME.glob('data/processed/RWTH/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading RWTH cells')]
rwth_cells = [BatteryStats(
    'NMC',
    cell.charge_protocol,
    cell.discharge_protocol,
    25, cell.form_factor
) for cell in battery_data]
rwth_stats = get_stats(rwth_cells, HOME / 'data/processed/RWTH')

Loading RWTH cells: 100%|██████████| 48/48 [00:28<00:00,  1.70it/s]


In [11]:
batteries = list(HOME.glob('data/processed/UL_PUR/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading UL_PUR cells')]
ul_pur_cells = [BatteryStats(
    'NCA',
    cell.charge_protocol,
    cell.discharge_protocol,
    25, cell.form_factor
) for cell in battery_data]
ul_pur_stats = get_stats(ul_pur_cells, HOME / 'data/processed/UL_PUR')

Loading UL_PUR cells: 100%|██████████| 10/10 [00:00<00:00, 69.35it/s]


In [12]:
def determine_temp(cell):
    temp = np.median(cell.cycle_data[0].temperature_in_C)
    if np.isnan(temp):
        return 25
    values = [15, 25, 35]
    return min(values, key=lambda x: abs(x - temp))

batteries = list(HOME.glob('data/processed/SNL/*'))
battery_data = [BatteryData.load(bat) for bat in tqdm(
    batteries, desc='Loading SNL cells')]
snl_cells = [BatteryStats(
    cell.cathode_material,
    cell.charge_protocol,
    cell.discharge_protocol,
    determine_temp(cell), cell.form_factor
) for cell in battery_data]
snl_stats = get_stats(snl_cells, HOME / 'data/processed/SNL')

Loading SNL cells: 100%|██████████| 61/61 [00:04<00:00, 12.61it/s]


In [13]:
from functools import reduce

total_cells = reduce(lambda x, y: x+y, [
    calce_cells,
    hnei_cells,
    hust_cells,
    matr1_cells,
    matr2_cells,
    matr_full_cells,
    rwth_cells,
    snl_cells,
    ul_pur_cells
])
total_batteries = sum([x[0] for x in [
    calce_stats,
    hnei_stats,
    hust_stats,
    matr_full_stats,
    rwth_stats,
    snl_stats,
    ul_pur_stats
]])
total_stats = get_stats(total_cells, total_batteries)

In [14]:
settings_df = pd.DataFrame([
    calce_stats,
    hnei_stats,
    hust_stats,
    matr1_stats,
    matr2_stats,
    ne_stats,
    clo_stats,
    matr_full_stats,
    rwth_stats,
    snl_stats,
    ul_pur_stats,
    total_stats,
], index=[
    'calce',
    'hnei',
    'hust',
    'matr1',
    'matr2',
    'ne',
    'clo',
    'matr_full',
    'rwth',
    'snl',
    'ul_pur',
    'total',
], columns=[
    '#cells',
    '#materials',
    '#charge_protocols',
    '#discharge_protocols',
    '#temperatures',
    '#form_factors',
    '#different_settings'
]
)
settings_df

Unnamed: 0,#cells,#materials,#charge_protocols,#discharge_protocols,#temperatures,#form_factors,#different_settings
calce,13,1,1,2,1,1,2
hnei,14,1,1,1,1,1,1
hust,77,1,1,77,1,1,77
matr1,84,1,61,1,1,1,61
matr2,81,1,47,1,1,1,47
ne,135,1,69,1,1,1,69
clo,45,1,9,1,1,1,9
matr_full,178,1,78,1,1,1,78
rwth,48,1,1,1,1,1,1
snl,61,3,1,4,3,1,22


In [15]:
settings_df.to_csv(HOME / 'data/settings.csv')

Reproduce Table 3.

In [16]:
import os
os.chdir(str(HOME))

configs = Path('configs/ablation/diff_branch/batlinet').glob('*.yaml')
fields = ['train_test_split', 'label', 'feature']

dataset_stats_dict = {}

for config_path in configs:
    config = import_config(config_path, fields)
    splitter = TRAIN_TEST_SPLITTERS.build(config['train_test_split'])
    feature_extractor = FEATURE_EXTRACTORS.build(config['feature'])
    label_annotator = LABEL_ANNOTATORS.build(config['label'])

    train_list, test_list = splitter.split()
    pbar = tqdm(train_list, desc=f'Reading train data of {config_path.stem}')
    train_cells = [BatteryData.load(path) for path in pbar]
    pbar = tqdm(test_list, desc=f'Reading test data of {config_path.stem}')
    test_cells = [BatteryData.load(path) for path in pbar]

    train_features = feature_extractor(train_cells)
    test_features = feature_extractor(test_cells)
    train_labels = label_annotator(train_cells)
    test_labels = label_annotator(test_cells)

    train_mask = ~torch.isnan(train_labels)
    test_mask = ~torch.isnan(test_labels)
    train_features = train_features[train_mask]
    test_features = test_features[test_mask]
    train_labels = train_labels[train_mask]
    test_labels = test_labels[test_mask]

    dataset_stats_dict[config_path.stem] = {
        '#Cells': len(test_labels) + len(train_labels),
        '#Train Cells': len(train_labels),
        '#Test Cells': len(test_labels),
        'EoL Percentage (%)': config['label'].get('eol_soh', 0.8) * 100,
        '#Early Cycles': train_features.size(2),
        'Max Cycle Life': max(test_labels.max(), train_labels.max()),
        'Min Cycle Life': min(test_labels.min(), train_labels.min()),
        'Max Train Cycle Life': train_labels.max(),
        'Min Train Cycle Life': train_labels.min(),
        'Max Test Cycle Life': test_labels.max(),
        'Min Test Cycle Life': test_labels.min(),
    }

os.chdir(HOME / 'notebooks')

Reading train data of matr_1: 100%|██████████| 41/41 [00:13<00:00,  3.11it/s]
Reading test data of matr_1: 100%|██████████| 42/42 [00:16<00:00,  2.56it/s]
Extracting features: 100%|██████████| 41/41 [00:02<00:00, 14.64it/s]
Extracting features: 100%|██████████| 42/42 [00:02<00:00, 14.92it/s]
Reading train data of matr_2: 100%|██████████| 41/41 [00:11<00:00,  3.58it/s]
Reading test data of matr_2: 100%|██████████| 40/40 [00:23<00:00,  1.73it/s]
Extracting features: 100%|██████████| 41/41 [00:02<00:00, 15.02it/s]
Extracting features: 100%|██████████| 40/40 [00:02<00:00, 18.10it/s]
Reading train data of mix_20: 100%|██████████| 256/256 [01:38<00:00,  2.60it/s]
Reading test data of mix_20: 100%|██████████| 147/147 [01:03<00:00,  2.31it/s]
Extracting features: 100%|██████████| 256/256 [00:08<00:00, 31.67it/s]
Extracting features: 100%|██████████| 147/147 [00:04<00:00, 31.85it/s]
Reading train data of mix_100: 100%|██████████| 205/205 [02:03<00:00,  1.66it/s]
Reading test data of mix_100: 10

In [17]:
columns = [
    '#Cells', '#Train Cells', '#Test Cells', 'EoL Percentage (%)',
    '#Early Cycles', 'Max Cycle Life', 'Min Cycle Life',
]
dataset_stats_df = pd.DataFrame({
    dataset: {col: int(val[col]) for col in columns}
    for dataset, val in dataset_stats_dict.items()
})
dataset_stats_df.rename(index={'#Cells': '#Cells after cleaning'}).T

Unnamed: 0,#Cells after cleaning,#Train Cells,#Test Cells,EoL Percentage (%),#Early Cycles,Max Cycle Life,Min Cycle Life
matr_1,83,41,42,80,100,2237,300
matr_2,81,41,40,80,100,2160,300
mix_20,354,207,147,90,20,2323,104
mix_100,342,205,137,80,100,2691,148
hust,77,55,22,80,100,2691,1144
