In [1]:
%matplotlib widget
%matplotlib widget
import os
from pathlib import Path
import time
import torch
import numpy as np
import math
import gc
from functools import partial
from dataset_test import Dataset, load_dataframes_from_folder, reverse_normalization
from torch.utils.data import DataLoader
from transformer_zerostep import GPTConfig, GPT, warmup_cosine_lr
import argparse
import warnings
import matplotlib.pyplot as plt
import glob
from itertools import compress

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams['axes.labelsize']=14
plt.rcParams['xtick.labelsize']=11
plt.rcParams['ytick.labelsize']=11
plt.rcParams['axes.grid']=True
plt.rcParams['axes.xmargin']=0

In [2]:
# Overall settings
out_dir = "out"
batch_size = 1024

checkpoint_list = glob.glob(out_dir+r"\*")

# checkpoint_list = checkpoint_list[("50pct" in checkpoint_list) and ("loss_check" not in checkpoint_list)]
mask = [("10pct" in checkpoint_list[i]) and ("loss_check" not in checkpoint_list[i]) for i in range(len(checkpoint_list))]

checkpoint_list=list(compress(checkpoint_list, mask))
print(checkpoint_list)
rmse_list = np.zeros(len(checkpoint_list))

['out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h10.pt', 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h20.pt', 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h50.pt']


In [3]:
for model_idx in range(len(checkpoint_list)):
    model_name = checkpoint_list[model_idx]

    # model_name = "ckpt_zerostep_sim_matlab_50pct_real_val_noise_h50_l8h12e12.pt"
    # folder_path = '../data/CL_experiments/test/inertia07_ki-0.0061-kp-11.8427'
    # folder_path = '../data/CL_experiments/test/inertia04_ki-0.0061-kp-11.8427'
    # folder_path = '../../../in-context-bldc/data/simulated/10_percent'
    # folder_path = '../../../in-context-bldc/data/simulated/50_percent_longer_steps'
    folder_path = '../data/CL_experiments/train/inertia13_ki-0.0061-kp-11.8427'

    # Compute settings
    cuda_device = "cuda:0"
    no_cuda = False
    threads = 10
    compile = False

    # Configure compute
    torch.set_num_threads(threads) 
    use_cuda = not no_cuda and torch.cuda.is_available()
    device_name  = cuda_device if use_cuda else "cpu"
    device = torch.device(device_name)
    device_type = 'cuda' if 'cuda' in device_name else 'cpu' # for later use in torch.autocast
    torch.set_float32_matmul_precision("high")
    # Create out dir
    exp_data = torch.load(Path(model_name), map_location=device, weights_only=False)
    seq_len = exp_data["cfg"].seq_len
    nx = exp_data["cfg"].nx
    exp_data["iter_num"]
        
    model_args = exp_data["model_args"]
    gptconf = GPTConfig(**model_args)
    model = GPT(gptconf).to(device)


    state_dict = exp_data["model"]
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        if k.startswith('module.'):
            state_dict[k[7:]] = v
            state_dict.pop(k)

    model.load_state_dict(state_dict)


    size_in_bytes = sum(param.numel() * param.element_size() for param in model.parameters())
    size_in_bytes += sum(buffer.numel() * buffer.element_size() for buffer in model.buffers())


    dfs = load_dataframes_from_folder(folder_path)
    # Log the number of DataFrames loaded

    # Create an instance of the dataset
    dataset_exp = Dataset(dfs=dfs, seq_len=1000)
    dataloader = DataLoader(dataset_exp, batch_size=batch_size, shuffle=True)



    # Example of accessing an item
    batch_u, batch_y = next(iter(dataloader))
    batch_u, batch_y = batch_u.to(device), batch_y.to(device)
    batch_y_pred = torch.empty_like(batch_y)
    cputime = np.empty(batch_y.shape[1])

    b, tt, _ = batch_y.shape 
    H = exp_data["cfg"].seq_len

    ## test each experiment
    df_len = len(dfs)
    rmse = np.zeros(df_len)
    u_full_all = []
    y_full_all = []
    # y_pred_all = []
    for i in range(df_len):
        u_full, y_full = dataset_exp.get_full_experiment(i)
        u_full, y_full = u_full.to(device), y_full.to(device)
        # u_full = torch.stack([u_full], dim=0)
        # y_full = torch.stack([y_full], dim=0)
        # cputime = np.empty(y_full.shape[1])
        # y_pred = torch.empty_like(y_full)

        u_full_all.append(u_full)
        y_full_all.append(y_full)

    u_full_all = torch.stack(u_full_all, dim=0)
    y_full_all = torch.stack(y_full_all, dim=0)
    y_pred_all = torch.empty_like(y_full_all)



    with torch.no_grad():

        for j in range(y_full_all.shape[1]):
            

            if j < H:
                pred = model(u_full_all[:,:j+1, :])
            else:
                pred = model(u_full_all[:,j-H+1:j+1, :])
            
            y_pred_all[:,j,0] = pred[:,-1,0]
        
        u_full_all, y_full_all, y_pred_all  = reverse_normalization(u_full_all, y_full_all, y_pred_all)

    for i in range(df_len):
        y_tmp = y_full_all[i,:,:].cpu().numpy()
        y_pred_tmp = y_pred_all[i,:,:].cpu().numpy()
        rmse[i] = np.sqrt(((y_tmp-y_pred_tmp)**2).mean())

    print("Average rmse: ", rmse.mean())
    rmse_list[model_idx] = rmse.mean()


    

number of parameters: 0.03M
Average rmse:  182.65465126037597
number of parameters: 0.03M
Average rmse:  169.3136015319824
number of parameters: 0.03M
Average rmse:  157.4822689819336


In [4]:
print(rmse_list)
print(checkpoint_list)

sorted_idxs = np.argsort(rmse_list)
print(sorted_idxs)
sorted_checkpoints = np.array(checkpoint_list)[sorted_idxs]
print(sorted_checkpoints[0:5])
rmse_list.sort()
print(rmse_list[0:5])


##### 50pct
# 'out\\ckpt_zerostep_sim_matlab_50pct_mix_real_val_noise_h50.pt' 134.20639965
# 'out\\ckpt_zerostep_sim_matlab_50pct_real_val_alt_h50_lr0_5.pt' 139.40439419
# 'out\\ckpt_zerostep_sim_matlab_50pct_real_val_alt_h50.pt' 139.72465393
# 'out\\ckpt_zerostep_sim_matlab_50pct_real_val_alt_h10_lr0_1_half_param.pt' 142.50420593
# 'out\\ckpt_zerostep_sim_matlab_50pct_real_val_alt_h50_lr0_01.pt' 143.06018398







[182.65465126 169.31360153 157.48226898]
['out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h10.pt', 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h20.pt', 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h50.pt']
[2 1 0]
['out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h50.pt'
 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h20.pt'
 'out\\ckpt_zerostep_sim_matlab_10pct_real_val_alt_h10.pt']
[157.48226898 169.31360153 182.65465126]
