In [5]:
import os
import sys
sys.path.append("..")
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from src.models.vaeconv import *
from src.train.TrainerVAE import Trainer
from src.data.preprocessing.pipeline import Pipeline
from src.data.datasets.universal_dataset import CVADataset
from src.data.preprocessing.splitter import select_test_inh
from src.utils.paths import get_project_path

In [2]:
prep = Pipeline(
    num_cycle=[1, 2, 3, 4], 
    inhibitor_name="all", 
    split="all",
    norm_feat=True
)

df = prep.full_data

train, val = select_test_inh(df, "2-mercaptobenzimidazole")

In [3]:
torch.manual_seed(42)
np.random.seed(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
train_dataset = CVADataset(train)
val_dataset = CVADataset(val)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=True)

In [8]:
vae = VAE_CONV(
    in_channels=1,
    latent_dim=64,
    seq_len=968,
    desc_dim=41
).to(device)

opt_vae = torch.optim.Adam(vae.parameters(), lr=5e-5)

trainer = Trainer(
    model=vae,
    loss_fn=vae_loss,
    epochs=500,
    optimizer=opt_vae,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    path_to_save_plots=os.path.join(get_project_path(), "reports", "vae_conv", "2_mercaptobenzimidazole"),
    path_to_save_models=os.path.join(get_project_path(), "models", "vae_conv", "2_mercaptobenzimidazole", "best_model.pt"),
    path_to_save_tables=os.path.join(get_project_path(), "reports", "vae_conv", "2_mercaptobenzimidazole"),
    seed=42
)

In [9]:
trainer.train_model()

Epoch 000 — Train: 1480.0884, Val: 727.0574
Epoch 001 — Train: 746.9249, Val: 691.7502
Epoch 002 — Train: 441.8069, Val: 566.4323
Epoch 003 — Train: 325.1297, Val: 373.3966
Epoch 004 — Train: 273.3269, Val: 269.0974
Epoch 005 — Train: 247.7955, Val: 237.9243
Epoch 006 — Train: 233.9219, Val: 227.9016
Epoch 007 — Train: 226.1829, Val: 219.1330
Epoch 008 — Train: 221.4119, Val: 214.4221
Epoch 009 — Train: 217.6687, Val: 209.9887
Epoch 010 — Train: 214.6186, Val: 207.4826
Epoch 011 — Train: 211.6667, Val: 205.4667
Epoch 012 — Train: 209.2100, Val: 199.7524
Epoch 013 — Train: 205.7629, Val: 196.8253
Epoch 014 — Train: 202.9560, Val: 193.8811
Epoch 015 — Train: 200.0998, Val: 190.4793
Epoch 016 — Train: 197.0527, Val: 187.7634
Epoch 017 — Train: 194.6665, Val: 183.7951
Epoch 018 — Train: 191.0149, Val: 179.4144
Epoch 019 — Train: 188.0360, Val: 177.3799
Epoch 020 — Train: 184.1075, Val: 171.7623
Epoch 021 — Train: 180.7029, Val: 168.0047
Epoch 022 — Train: 175.9408, Val: 166.5780
Epoch 023 

KeyboardInterrupt: 

In [54]:
full_ds = CVADataset(val)
infer_loader = DataLoader(full_ds, batch_size=1, shuffle=False)

vah_preds = []
embeddings = []
vae.eval()
with torch.inference_mode():
    for batch in infer_loader:
        vah, desc = batch["vah"].to(device), batch["features"].to(device)
        vah_hat, mu, logvar = vae(vah, desc)
        embed = vae.reparameterize(mu, logvar)
        
        embeddings.append(embed.squeeze().cpu().numpy())
        vah_preds.append(vah_hat.squeeze().cpu().numpy())

# объединяем и превращаем в DataFrame
vah_preds = pd.DataFrame(np.vstack(vah_preds))     # shape [N, 968]
embeddings = pd.DataFrame(np.vstack(embeddings))

In [55]:
embeddings["ppm"] = df["ppm"]
embeddings["Inhibitor"] = df["Inhibitor"]

In [56]:
gen_vah = np.array(vah_preds, dtype=np.float64)
np_val = np.array(val.reset_index(drop=True).drop(columns=prep.desc.columns).drop(columns="ppm"), dtype=np.float64)

In [57]:
from dtaidistance import dtw

def compute_dtw(original, reconstructed):
    """
    Dynamic Time Warping расстояние между временными рядами.
    """
    return dtw.distance_fast(original, reconstructed, use_c = True)

In [58]:
from tqdm import tqdm

dtw_m = []

for i in tqdm(range(len(gen_vah))):
    dtw_m.append(compute_dtw(gen_vah[i], np_val[i]))

100%|██████████| 940/940 [00:03<00:00, 285.44it/s]


In [59]:
metr = pd.DataFrame({
    "Inhibitor": val.reset_index(drop=True)["Inhibitor"],
    "metrics":dtw_m})

In [60]:
metr['metrics'].mean()

0.8731441625740993

In [61]:
metr.to_csv("vae_conv_benzimidazole.csv", index=False)

In [None]:
n = 7

plt.figure(figsize=(10, 5))
plt.plot(np_val[n], label="Original")
plt.plot(gen_vah[n], label="Reconstructed", linestyle='--', color='red')
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Time, s", fontsize=16, fontweight='bold', labelpad=20)
plt.ylabel("Current, A", fontsize=16, fontweight='bold', labelpad=20)
plt.legend(fontsize=14, )
plt.show()