In [None]:
import os
import sys
sys.path.append("../..")
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from dtaidistance import dtw
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from src.models.ganconv import *
from src.train.trainer.TrainerGAN import Trainer
from src.data.preprocessing.pipeline import Pipeline
from src.data.datasets.universal_dataset import CVADataset
from src.data.preprocessing.splitter import select_test_inh
from src.utils.paths import get_project_path

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
INHIBITOR_NAME = "2-mercaptobenzimidazole"

prep = Pipeline(
    num_cycle=[1, 2, 3, 4], 
    inhibitor_name="all", 
    split="all",
    norm_feat=True
)
data = prep.full_data


train, val = select_test_inh(data, INHIBITOR_NAME)

In [None]:
generator = Generator(in_dim=80, vah_dim=968, desc_dim=41).to(device)

generator.load_state_dict(
    torch.load(
        os.path.join(get_project_path(), "models", "gan_conv", INHIBITOR_NAME, "best_generator_model.pt"),
        map_location="cpu"
    )
)

In [None]:
full_ds = CVADataset(val)
infer_loader = DataLoader(full_ds, batch_size=1, shuffle=False)

gen_vah, real_cva = [], []

generator.eval()
with torch.no_grad():
    for val_batch in infer_loader:
        real_features = val_batch['features'].to(device)
        real_vah = val_batch['vah'].to(device)
        
        z = torch.randn(1, 80).to(device)
        
        gen = generator(z.unsqueeze(1), real_features.unsqueeze(1))
        gen_vah.append(gen.detach().cpu().numpy()[0])
        real_cva.append(real_vah.detach().cpu().numpy()[0])
        
df_vah = pd.DataFrame(gen_vah)
df_vah["Inhibitor"] = val.reset_index(drop=True)["Inhibitor"]

real_df_vah = pd.DataFrame(real_cva)
real_df_vah["Inhibitor"] = val.reset_index(drop=True)["Inhibitor"]

In [None]:
real_vah.detach().cpu().numpy()[0]

In [None]:
def compute_dtw(original, reconstructed):
    """
    Dynamic Time Warping расстояние между временными рядами.
    """
    return dtw.distance_fast(original, reconstructed, use_c = True)

In [None]:
np_vah = np.array(df_vah.drop(columns=["Inhibitor"]), dtype=np.float64)
np_valid = np.array(real_df_vah.drop(columns=["Inhibitor"]), dtype=np.float64)

In [None]:
dtw_m = []

for i in tqdm(range(len(df_vah))):
    dtw_m.append(compute_dtw(np_vah[i], np_valid[i]))

In [None]:
metr = pd.DataFrame(
    {
        "Inhibitor": val.reset_index(drop=True)["Inhibitor"],
        "metrics":dtw_m
    }
)

In [None]:
n = 15

plt.figure(figsize=(10, 5))
plt.plot(np_valid[n], label="Original")
plt.plot(np_vah[n], label="Reconstructed", linestyle='--', color='red')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Time, s", fontsize=16, fontweight='bold', labelpad=20)
plt.ylabel("Current, A", fontsize=16, fontweight='bold', labelpad=20)
plt.legend(fontsize=14, )
plt.show()

In [None]:
pd.DataFrame(
    {
        "Orig": np_valid[n],
        "Reconstructed": np_vah[n]
    }
).to_excel("CVA_gan.xlsx", index=False)