In [None]:

import matplotlib.pyplot as plt
import math
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score,mean_absolute_percentage_error
import numpy as np
import os
import pickle
import yaml
import torch
import matplotlib.pyplot as plt
import random
from trainer import Trainer
from dataloader import load_dataset
from dataset import *
from utils import dict2obj
# random seed
fix_seed = 2024
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)



In [None]:

config_file = 'config/calce.yaml'
with open(config_file,encoding='utf-8') as file1:
    cfg = yaml.load(file1,Loader=yaml.FullLoader)
cfg = dict2obj(cfg)
cfg.state_dict_path = f'checkpoints/batch_{cfg.batch_size}_lr_{cfg.lr}_epochs_{cfg.n_episodes}_{cfg.size[0]}_{cfg.size[2]}_{cfg.dataset_name}'
train_loader,val_loader,test_loader,scaler = load_dataset(cfg.dataset_dir,cfg.size,cfg.test,cfg.val,cfg.batch_size,cfg.batch_size,cfg.batch_size)
trainer = Trainer(cfg)
trainer.train(train_loader,val_loader)

In [None]:
dataset_dir = 'data/CALCEDataset'
pred_keys = ['CS2_35']
PRED_LEN = cfg.size[2]
stride = 1
STD = 2.0
trainer = Trainer(cfg,test=True)

In [None]:
with open(os.path.join(dataset_dir,'feature.pkl'),'rb') as f:
    data_raw = pickle.load(f)
f.close()

In [None]:
seqs =[]
preds = []
truth = []

for key in pred_keys:
    item = data_raw[key]
    item = scaler.normalize(item)
    indexs = np.arange(cfg.size[0], len(item[0])-PRED_LEN,stride)
    seq,pred = get_patchs_from_item(indexs,item,cfg.size[0],PRED_LEN)
    samples = torch.Tensor(seq).to(trainer.device)
    out,adj_matrix = trainer.predict(samples)
    out = scaler.inverse_transform(out)
    preds.append(out.cpu().squeeze().numpy()/STD)
    pred = scaler.inverse_transform(pred)
    truth.append(pred.squeeze()/STD)
    seqs.append(seq[:,-1])
l = len(pred_keys)
mse = [0]*l
mae = [0]*l
rmse = [0]*l

for i in range(l):
    mse[i] = mean_squared_error(truth[i],preds[i])
    mae[i] = mean_absolute_error(truth[i],preds[i])
    rmse[i] = math.sqrt(mse[i])

print(f'mse: {mse},{np.mean(mse)}\nrmse: {rmse},{np.mean(rmse)}\nmae: {mae},{np.mean(mae)}\n')

In [None]:
cols = len(pred_keys)
fig,axes=plt.subplots(nrows=1,ncols=cols,figsize=(5*cols,4),dpi=100)
for i in range(cols):
    key = pred_keys[i]
    # pred = preds[i]
    cell = data_raw[key][-1]
    start = cfg.size[0]
    end = len(cell)-PRED_LEN
    axes[i].plot(np.linspace(1,len(cell),len(cell)),cell/STD,
                 color=(1,0,0),label=f'True')

    axes[i].plot(np.linspace(start,end,math.ceil((end - start)/stride)),preds[i],
                     color=(0,0,1-indexs[0]*0.0008),linestyle="--",label=f'Prediction')

    axes[i].set_title(f'Battery {pred_keys[i]}')
    axes[i].set_xlabel('Number of Cycles')
    axes[i].set_ylabel('SOH')
    axes[i].legend()