In [1]:
import os
import sys
sys.path.append("..")
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from src.models.vaeconv import *
from src.train.TrainerVAE import Trainer
from src.data.preprocessing.pipeline import Pipeline
from src.data.datasets.universal_dataset import CVADataset
from src.data.preprocessing.splitter import select_test_inh

In [49]:
prep = Pipeline(
    num_cycle=[1, 2, 3, 4], 
    inhibitor_name="all", 
    split="all",
    norm_feat=True
)

df = prep.full_data

train, val = select_test_inh(df, "benzimidazole")

In [50]:
torch.manual_seed(42)
np.random.seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [51]:
train_dataset = CVADataset(train)
val_dataset = CVADataset(val)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=True)

In [52]:
vae = VAE_CONV(
    in_channels=1,
    latent_dim=64,
    seq_len=968,
    desc_dim=41
).to(device)

opt_vae = torch.optim.Adam(vae.parameters(), lr=5e-5)

trainer = Trainer(
    model=vae,
    loss_fn=vae_loss,
    epochs=500,
    optimizer=opt_vae,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    path_to_save_plots="/home/ccinfochem/Documents/krotkov/Electrochem_DT/reports/vae_conv/",
    path_to_save_models="/home/ccinfochem/Documents/krotkov/Electrochem_DT/models/vae_conv/best_model.pt",
    path_to_save_tables="/home/ccinfochem/Documents/krotkov/Electrochem_DT/reports/vae_conv/",
    seed=42
)

In [53]:
trainer.train_model()

Epoch 000 — Train: 1411.6191, Val: 1119.4047
Epoch 001 — Train: 800.8097, Val: 1066.9853
Epoch 002 — Train: 603.4816, Val: 930.0597
Epoch 003 — Train: 561.5359, Val: 710.1160
Epoch 004 — Train: 545.6323, Val: 614.9598
Epoch 005 — Train: 536.3527, Val: 605.1912
Epoch 006 — Train: 528.9720, Val: 601.7369
Epoch 007 — Train: 522.2543, Val: 597.7447
Epoch 008 — Train: 515.9712, Val: 597.9720
Epoch 009 — Train: 509.8378, Val: 582.8551
Epoch 010 — Train: 503.8189, Val: 580.5962
Epoch 011 — Train: 497.7583, Val: 570.5227
Epoch 012 — Train: 491.6927, Val: 562.1925
Epoch 013 — Train: 485.5428, Val: 559.9388
Epoch 014 — Train: 479.2911, Val: 550.8828
Epoch 015 — Train: 472.9406, Val: 543.3741
Epoch 016 — Train: 466.4266, Val: 539.5500
Epoch 017 — Train: 459.7910, Val: 527.8484
Epoch 018 — Train: 452.8947, Val: 515.3498
Epoch 019 — Train: 445.8513, Val: 509.0988
Epoch 020 — Train: 438.3106, Val: 495.8682
Epoch 021 — Train: 430.5961, Val: 489.9697
Epoch 022 — Train: 422.2937, Val: 477.9558
Epoch 02

In [54]:
full_ds = CVADataset(val)
infer_loader = DataLoader(full_ds, batch_size=1, shuffle=False)

vah_preds = []
embeddings = []
vae.eval()
with torch.inference_mode():
    for batch in infer_loader:
        vah, desc = batch["vah"].to(device), batch["features"].to(device)
        vah_hat, mu, logvar = vae(vah, desc)
        embed = vae.reparameterize(mu, logvar)
        
        embeddings.append(embed.squeeze().cpu().numpy())
        vah_preds.append(vah_hat.squeeze().cpu().numpy())

# объединяем и превращаем в DataFrame
vah_preds = pd.DataFrame(np.vstack(vah_preds))     # shape [N, 968]
embeddings = pd.DataFrame(np.vstack(embeddings))

In [55]:
embeddings["ppm"] = df["ppm"]
embeddings["Inhibitor"] = df["Inhibitor"]

In [56]:
gen_vah = np.array(vah_preds, dtype=np.float64)
np_val = np.array(val.reset_index(drop=True).drop(columns=prep.desc.columns).drop(columns="ppm"), dtype=np.float64)

In [57]:
from dtaidistance import dtw

def compute_dtw(original, reconstructed):
    """
    Dynamic Time Warping расстояние между временными рядами.
    """
    return dtw.distance_fast(original, reconstructed, use_c = True)

In [58]:
from tqdm import tqdm

dtw_m = []

for i in tqdm(range(len(gen_vah))):
    dtw_m.append(compute_dtw(gen_vah[i], np_val[i]))

100%|██████████| 940/940 [00:03<00:00, 285.44it/s]


In [59]:
metr = pd.DataFrame({
    "Inhibitor": val.reset_index(drop=True)["Inhibitor"],
    "metrics":dtw_m})

In [60]:
metr['metrics'].mean()

0.8731441625740993

In [61]:
metr.to_csv("vae_conv_benzimidazole.csv", index=False)

In [None]:
n = 7

plt.figure(figsize=(10, 5))
plt.plot(np_val[n], label="Original")
plt.plot(gen_vah[n], label="Reconstructed", linestyle='--', color='red')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Time, s", fontsize=16, fontweight='bold', labelpad=20)
plt.ylabel("Current, A", fontsize=16, fontweight='bold', labelpad=20)
plt.legend(fontsize=14, )
plt.show()