In [1]:
import os

import matplotlib.pyplot
import matplotlib.pyplot as plt
import numpy
import numpy as np
import pandas
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import umap
from matplotlib.colors import LogNorm
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import (adjusted_rand_score, completeness_score,
                             homogeneity_score, mean_absolute_error,
                             normalized_mutual_info_score, r2_score,
                             silhouette_score)
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)

In [2]:
%load_ext autoreload
%autoreload 2

from utils import *
from VAE_GDSC import GDSCDataset, GeneDrugVAE

In [3]:
cellline = pd.read_csv(
    "../dataset/gdsc/cellline_info.csv.gz",
    usecols=["COSMIC_ID", "SMILES", "NAME", "Z_score"],
)
exp = pd.read_csv("../dataset/gdsc/exp.csv.gz", index_col=0)

In [4]:
# 学習済みモデルの指定（ChemBERTa）
MODEL_NAME = "seyonec/ChemBERTa-zinc-base-v1"

# トークナイザーとモデルをロード
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

# GPUが使える場合はGPUに
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [5]:
smiles_list = cellline["SMILES"].unique().tolist()

inputs = tokenizer(smiles_list, padding=True, truncation=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model(**inputs)
    embeddings = outputs.last_hidden_state[:, 0, :]  # [batch_size, hidden_size]

In [6]:
drugs = pd.DataFrame(embeddings.cpu().numpy())
drugs.index = smiles_list

In [7]:
exp.index = exp.index.astype(str)
cellline["COSMIC_ID"] = cellline["COSMIC_ID"].astype(str)
drugs.index = drugs.index.astype(str)

In [8]:
exp_values = exp.to_numpy(dtype=np.float32)
cid_to_row = {cid: i for i, cid in enumerate(exp.index)}
drug_vec = {s: drugs.loc[s].to_numpy(dtype=np.float32) for s in drugs.index}
mask = cellline["COSMIC_ID"].isin(exp.index) & cellline["SMILES"].isin(drugs.index)
cellline_small = cellline.loc[mask].reset_index(drop=True)
print(f"usable pairs: {len(cellline_small):,} / {len(cellline):,}")

usable pairs: 284,826 / 284,826


In [9]:
print("n_sample:", cellline_small.shape[0])

n_sample: 284826


In [10]:
# そのまま/変更案
batch_size = 16384
num_epochs = 100

# まずはこのセットで様子見（20–30 epoch）
rec_w = 0.2  # 再構成の重みをさらに下げる
lam = 3.0  # 10.0 → 3.0（目安: 2〜5）
beta_max = 0.7  # KLを抑える項は一旦オフ
gamma = 4  # C に寄せる圧を強化（いまの4倍）

C_max = 0.015  # 0.10 → 0.05
C_steps = 200  # 100 → 200

lr = 1e-3
weight_decay = 1e-3

In [11]:
ds = GDSCDataset(
    cellline_small,
    exp_values,
    cid_to_row,
    drug_vec,
    dtype=torch.float32,
    pin_memory=True,
    copy_arrays=False,
    #     materialize="none",    # ← これが重要！前展開を完全禁止
    #     ram_limit_gb=0.5       # （任意）auto を使う時の安全枠も小さくできる
)

In [12]:
val_ratio = 0.2
n_total = len(ds)
n_val = int(n_total * val_ratio)
n_train = n_total - n_val

# 再現性
g = torch.Generator().manual_seed(42)
train_ds, val_ds = random_split(ds, [n_train, n_val], generator=g)

In [13]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,  # GPU メモリに合わせて調整（8k〜32k 目安）
    shuffle=True,
    num_workers=1,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4,
    drop_last=True,
)

In [14]:
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,  # 余裕があれば少し大きめ
    shuffle=False,  # 検証はシャッフル不要
    num_workers=1,  # ここは少なめでもOK
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2,
    drop_last=False,
)

In [15]:
gene_dim = exp_values.shape[1]
smiles_dim = 768

device = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else (
        torch.device("mps")
        if torch.backends.mps.is_available()
        else torch.device("cpu")
    )
)
device_type = (
    "cuda" if device.type == "cuda" else ("mps" if device.type == "mps" else "cpu")
)

use_amp = device_type in ("cuda", "mps")

# GradScaler は CUDA のみ
scaler = torch.amp.GradScaler(enabled=(device_type == "cuda"))

In [None]:
os.makedirs("models/gdsc", exist_ok=True)

best_val = float("inf")
best_path = f"models/gdsc/gdsc_best.pt"

# --- logs / histories ---
train_hist, val_hist, elbo_val_hist = [], [], []
train_rec_hist, train_kl_hist, train_z_hist = [], [], []
val_rec_hist, val_kl_hist, val_z_hist = [], [], []
train_capgap_hist, val_capgap_hist, C_values = [], [], []

model = GeneDrugVAE(
    gene_dim, drug_dim=smiles_dim, proj_dim=256, hidden=512, latent=128
).to(device)

# 1) Optimizer を param group で作り直し
head_params = list(model.head.parameters())
base_params = [p for n, p in model.named_parameters() if not n.startswith("head.")]

opt = torch.optim.Adam(
    [
        {"params": base_params, "lr": lr, "weight_decay": 1e-3},
        {"params": head_params, "lr": lr * 0.5, "weight_decay": 5e-3},  # ← LR↓, WD↑
    ]
)


def _avg(sum_val, cnt):
    return sum_val / max(cnt, 1)


beta_warmup_epochs = int(0.5 * num_epochs)
for ep in range(num_epochs):
    # === 追加：このepochのβ（0→1に線形） ===
    beta = beta_max * min(1.0, (ep + 1) / beta_warmup_epochs)

    model.train()
    C = C_max * (ep / C_steps) if ep < C_steps else C_max
    C_values.append(C)

    # ---- training ----
    m_total = n_batches = 0
    rec_sum = kl_sum = z_sum = capgap_sum = 0.0
    n_train_samples = 0

    for bidx, (xg, xd, y) in enumerate(
        tqdm(train_dl, desc=f"Train E{ep+1}", leave=False)
    ):
        if ep == 0 and bidx == 0:
            print(
                f"[LR] base={opt.param_groups[0]['lr']:.2e} | head={opt.param_groups[1]['lr']:.2e}"
            )

        xg = xg.to(device, non_blocking=True)
        xd = xd.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device_type, enabled=use_amp):
            recon, mu, lv, yhat = model(xg, xd)
            rec = F.mse_loss(recon, xg, reduction="mean")
            kl = 0.5 * (-1 - lv + mu.pow(2) + lv.exp()).sum(dim=1).mean()
            zloss = F.smooth_l1_loss(yhat, y, beta=0.5, reduction="mean")  # Huber
            gap_abs = torch.abs(kl - C)
            gap_sq = (kl - C) ** 2
            loss = rec_w * rec + beta * kl + gamma * gap_sq + lam * zloss

        if device_type == "cuda":
            prev_scale = scaler.get_scale()
            scaler.scale(loss).backward()
            scaler.unscale_(opt)  # ← これを追加してからクリップ
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()

        bs = xg.size(0)
        m_total += loss.item()
        n_batches += 1
        rec_sum += rec.item() * bs
        kl_sum += kl.item() * bs
        z_sum += zloss.item() * bs
        capgap_sum += gap_abs.item() * bs
        n_train_samples += bs

    train_loss = m_total / n_batches
    train_hist.append(train_loss)
    train_rec_hist.append(_avg(rec_sum, n_train_samples))
    train_kl_hist.append(_avg(kl_sum, n_train_samples))
    train_z_hist.append(_avg(z_sum, n_train_samples))
    train_capgap_hist.append(_avg(capgap_sum, n_train_samples))

    # ---- validation ----
    model.eval()
    val_loss_sum = 0.0
    n_val_batches = 0
    v_rec_sum = v_kl_sum = v_z_sum = v_capgap_sum = 0.0
    n_val_samples = 0
    elbo_val_sum = 0.0

    with torch.no_grad():
        for xg, xd, y in val_dl:
            xg, xd, y = xg.to(device), xd.to(device), y.to(device)
            recon, mu, lv, yhat = model(xg, xd)
            rec = F.mse_loss(recon, xg, reduction="mean")
            kl = 0.5 * (-1 - lv + mu.pow(2) + lv.exp()).sum(dim=1).mean()
            zloss = F.smooth_l1_loss(yhat, y, beta=0.5, reduction="mean")  # Huber
            gap_abs = torch.abs(kl - C)  # ← 追加：ログ用
            gap_sq = (kl - C) ** 2  # ← 追加：損失用（二乗）

            loss = rec_w * rec + beta * kl + gamma * gap_sq + lam * zloss
            elbo_loss = rec + beta * kl + lam * zloss

            bs = xg.size(0)
            val_loss_sum += loss.item()
            n_val_batches += 1
            v_rec_sum += rec.item() * bs
            v_kl_sum += kl.item() * bs
            v_z_sum += zloss.item() * bs
            v_capgap_sum += gap_abs.item() * bs  # ← 置換：capgap → gap_abs
            n_val_samples += bs
            elbo_val_sum += elbo_loss.item() * bs

    val_loss = val_loss_sum / n_val_batches
    val_hist.append(val_loss)
    val_rec_hist.append(_avg(v_rec_sum, n_val_samples))
    val_kl_hist.append(_avg(v_kl_sum, n_val_samples))
    val_z_hist.append(_avg(v_z_sum, n_val_samples))
    val_capgap_hist.append(_avg(v_capgap_sum, n_val_samples))
    elbo_val_hist.append(_avg(elbo_val_sum, n_val_samples))

    print(
        f"Epoch {ep+1}/{num_epochs} | Train {train_loss:.4f} | Val {val_loss:.4f} | Val(ELBO) {elbo_val_hist[-1]:.4f}"
    )

    print(
        f"  C={C:.3f} | "
        f"KL tr/val={train_kl_hist[-1]:.4f}/{val_kl_hist[-1]:.4f} | "
        f"gap(|KL-C|) tr/val={train_capgap_hist[-1]:.4f}/{val_capgap_hist[-1]:.4f} | "
        f"zloss tr/val={train_z_hist[-1]:.4f}/{val_z_hist[-1]:.4f} | "
        f"rec tr/val={train_rec_hist[-1]:.4f}/{val_rec_hist[-1]:.4f} | "
        f"beta={beta:.4f}, gamma={gamma:.4f}"
    )

    if (ep + 1) % 10 == 0:
        model.eval()
        eval_dl = torch.utils.data.DataLoader(
            val_ds,
            batch_size=8192,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            persistent_workers=False,
            drop_last=False,
        )
        ys, yhats = [], []
        with torch.inference_mode():
            for xg, xd, y in eval_dl:
                xg, xd = xg.to(device), xd.to(device)
                _, _, _, yhat = model(xg, xd)
                ys.append(y.cpu().numpy())
                yhats.append(yhat.cpu().numpy())
        import numpy as np
        from sklearn.metrics import mean_absolute_error, r2_score

        ys = np.concatenate(ys)
        yhats = np.concatenate(yhats)
        mae = mean_absolute_error(ys, yhats)
        r2 = r2_score(ys, yhats)
        print(f"  [Eval@E{ep+1}] MAE={mae:.4f}, R^2={r2:.4f}")
        print(f"  [Eval@E{ep+1}] yhat mean/std = {yhats.mean():.4f}/{yhats.std():.4f}")

    if val_loss < best_val - 1e-6:
        best_val = val_loss
        torch.save(model.state_dict(), best_path)
        print(f"[Save] Best model updated at epoch {ep+1}: {best_path}")

Train E1:   0%|          | 0/13 [00:00<?, ?it/s]

[LR] base=1.00e-03 | head=5.00e-04


                                                         

Epoch 1/100 | Train 1.9798 | Val 1.8846 | Val(ELBO) 2.6908
  C=0.000 | KL tr/val=0.1007/0.0425 | gap(|KL-C|) tr/val=0.1007/0.0425 | zloss tr/val=0.5609/0.5581 | rec tr/val=1.0340/1.0159 | beta=0.0140, gamma=4.0000
[Save] Best model updated at epoch 1: models/gdsc/gdsc_best.pt


                                                         

Epoch 2/100 | Train 1.8445 | Val 1.8108 | Val(ELBO) 2.5943
  C=0.000 | KL tr/val=0.0531/0.0711 | gap(|KL-C|) tr/val=0.0531/0.0711 | zloss tr/val=0.5432/0.5294 | rec tr/val=1.0088/1.0040 | beta=0.0280, gamma=4.0000
[Save] Best model updated at epoch 2: models/gdsc/gdsc_best.pt


                                                         

Epoch 3/100 | Train 1.7691 | Val 1.7397 | Val(ELBO) 2.5129
  C=0.000 | KL tr/val=0.0767/0.0840 | gap(|KL-C|) tr/val=0.0766/0.0839 | zloss tr/val=0.5138/0.5029 | rec tr/val=1.0019/1.0008 | beta=0.0420, gamma=4.0000
[Save] Best model updated at epoch 3: models/gdsc/gdsc_best.pt


                                                         

Epoch 4/100 | Train 1.7001 | Val 1.6780 | Val(ELBO) 2.4543
  C=0.000 | KL tr/val=0.0789/0.0778 | gap(|KL-C|) tr/val=0.0787/0.0776 | zloss tr/val=0.4902/0.4833 | rec tr/val=1.0001/0.9999 | beta=0.0560, gamma=4.0000
[Save] Best model updated at epoch 4: models/gdsc/gdsc_best.pt


                                                         

Epoch 5/100 | Train 1.6507 | Val 1.6440 | Val(ELBO) 2.4274
  C=0.000 | KL tr/val=0.0739/0.0660 | gap(|KL-C|) tr/val=0.0736/0.0657 | zloss tr/val=0.4746/0.4743 | rec tr/val=0.9998/0.9997 | beta=0.0700, gamma=4.0000
[Save] Best model updated at epoch 5: models/gdsc/gdsc_best.pt


                                                         

Epoch 6/100 | Train 1.6231 | Val 1.6255 | Val(ELBO) 2.4085
  C=0.000 | KL tr/val=0.0625/0.0653 | gap(|KL-C|) tr/val=0.0621/0.0650 | zloss tr/val=0.4674/0.4678 | rec tr/val=0.9996/0.9997 | beta=0.0840, gamma=4.0000
[Save] Best model updated at epoch 6: models/gdsc/gdsc_best.pt


                                                         

Epoch 7/100 | Train 1.5961 | Val 1.5995 | Val(ELBO) 2.3876
  C=0.000 | KL tr/val=0.0549/0.0536 | gap(|KL-C|) tr/val=0.0544/0.0532 | zloss tr/val=0.4596/0.4609 | rec tr/val=0.9994/0.9997 | beta=0.0980, gamma=4.0000
[Save] Best model updated at epoch 7: models/gdsc/gdsc_best.pt


                                                         

Epoch 8/100 | Train 1.5809 | Val 1.5898 | Val(ELBO) 2.3838
  C=0.001 | KL tr/val=0.0510/0.0388 | gap(|KL-C|) tr/val=0.0505/0.0383 | zloss tr/val=0.4550/0.4599 | rec tr/val=0.9997/0.9997 | beta=0.1120, gamma=4.0000
[Save] Best model updated at epoch 8: models/gdsc/gdsc_best.pt


                                                         

Epoch 9/100 | Train 1.5687 | Val 1.5687 | Val(ELBO) 2.3626
  C=0.001 | KL tr/val=0.0470/0.0400 | gap(|KL-C|) tr/val=0.0464/0.0394 | zloss tr/val=0.4513/0.4526 | rec tr/val=0.9998/0.9997 | beta=0.1260, gamma=4.0000
[Save] Best model updated at epoch 9: models/gdsc/gdsc_best.pt


                                                          

Epoch 10/100 | Train 1.5456 | Val 1.5698 | Val(ELBO) 2.3611
  C=0.001 | KL tr/val=0.0459/0.0465 | gap(|KL-C|) tr/val=0.0453/0.0458 | zloss tr/val=0.4437/0.4516 | rec tr/val=0.9996/0.9997 | beta=0.1400, gamma=4.0000
  [Eval@E10] MAE=0.6602, R^2=0.2518
  [Eval@E10] yhat mean/std = 0.0840/0.4932


                                                          

Epoch 11/100 | Train 1.5365 | Val 1.5436 | Val(ELBO) 2.3354
  C=0.001 | KL tr/val=0.0441/0.0446 | gap(|KL-C|) tr/val=0.0433/0.0438 | zloss tr/val=0.4407/0.4429 | rec tr/val=0.9995/0.9997 | beta=0.1540, gamma=4.0000
[Save] Best model updated at epoch 11: models/gdsc/gdsc_best.pt


                                                          

Epoch 12/100 | Train 1.5213 | Val 1.5420 | Val(ELBO) 2.3339
  C=0.001 | KL tr/val=0.0422/0.0447 | gap(|KL-C|) tr/val=0.0414/0.0439 | zloss tr/val=0.4358/0.4422 | rec tr/val=0.9995/0.9997 | beta=0.1680, gamma=4.0000
[Save] Best model updated at epoch 12: models/gdsc/gdsc_best.pt


                                                          

Epoch 13/100 | Train 1.5119 | Val 1.5414 | Val(ELBO) 2.3385
  C=0.001 | KL tr/val=0.0414/0.0294 | gap(|KL-C|) tr/val=0.0405/0.0285 | zloss tr/val=0.4326/0.4445 | rec tr/val=0.9994/0.9997 | beta=0.1820, gamma=4.0000
[Save] Best model updated at epoch 13: models/gdsc/gdsc_best.pt


                                                          

Epoch 14/100 | Train 1.5031 | Val 1.5092 | Val(ELBO) 2.3030
  C=0.001 | KL tr/val=0.0373/0.0394 | gap(|KL-C|) tr/val=0.0364/0.0385 | zloss tr/val=0.4302/0.4319 | rec tr/val=0.9997/0.9997 | beta=0.1960, gamma=4.0000
[Save] Best model updated at epoch 14: models/gdsc/gdsc_best.pt


                                                          

Epoch 15/100 | Train 1.4893 | Val 1.5027 | Val(ELBO) 2.2985
  C=0.001 | KL tr/val=0.0374/0.0315 | gap(|KL-C|) tr/val=0.0364/0.0305 | zloss tr/val=0.4254/0.4307 | rec tr/val=0.9994/0.9997 | beta=0.2100, gamma=4.0000
[Save] Best model updated at epoch 15: models/gdsc/gdsc_best.pt


                                                          

Epoch 16/100 | Train 1.4780 | Val 1.4963 | Val(ELBO) 2.2924
  C=0.001 | KL tr/val=0.0356/0.0293 | gap(|KL-C|) tr/val=0.0345/0.0281 | zloss tr/val=0.4218/0.4287 | rec tr/val=0.9996/0.9997 | beta=0.2240, gamma=4.0000
[Save] Best model updated at epoch 16: models/gdsc/gdsc_best.pt


                                                          

Epoch 17/100 | Train 1.4750 | Val 1.4931 | Val(ELBO) 2.2892
  C=0.001 | KL tr/val=0.0335/0.0302 | gap(|KL-C|) tr/val=0.0323/0.0290 | zloss tr/val=0.4210/0.4274 | rec tr/val=0.9993/0.9997 | beta=0.2380, gamma=4.0000
[Save] Best model updated at epoch 17: models/gdsc/gdsc_best.pt


                                                          

Epoch 18/100 | Train 1.4681 | Val 1.4861 | Val(ELBO) 2.2817
  C=0.001 | KL tr/val=0.0326/0.0285 | gap(|KL-C|) tr/val=0.0314/0.0272 | zloss tr/val=0.4187/0.4249 | rec tr/val=0.9997/0.9997 | beta=0.2520, gamma=4.0000
[Save] Best model updated at epoch 18: models/gdsc/gdsc_best.pt


                                                          

Epoch 19/100 | Train 1.4531 | Val 1.4705 | Val(ELBO) 2.2673
  C=0.001 | KL tr/val=0.0317/0.0267 | gap(|KL-C|) tr/val=0.0303/0.0254 | zloss tr/val=0.4137/0.4202 | rec tr/val=0.9997/0.9997 | beta=0.2660, gamma=4.0000
[Save] Best model updated at epoch 19: models/gdsc/gdsc_best.pt


                                                          

Epoch 20/100 | Train 1.4468 | Val 1.4709 | Val(ELBO) 2.2664
  C=0.001 | KL tr/val=0.0301/0.0311 | gap(|KL-C|) tr/val=0.0286/0.0296 | zloss tr/val=0.4117/0.4193 | rec tr/val=0.9996/0.9997 | beta=0.2800, gamma=4.0000
  [Eval@E20] MAE=0.6246, R^2=0.3154
  [Eval@E20] yhat mean/std = -0.0226/0.5610


                                                          

Epoch 21/100 | Train 1.4453 | Val 1.4596 | Val(ELBO) 2.2561
  C=0.002 | KL tr/val=0.0296/0.0299 | gap(|KL-C|) tr/val=0.0281/0.0284 | zloss tr/val=0.4111/0.4159 | rec tr/val=0.9996/0.9997 | beta=0.2940, gamma=4.0000
[Save] Best model updated at epoch 21: models/gdsc/gdsc_best.pt


                                                          

Epoch 22/100 | Train 1.4225 | Val 1.4515 | Val(ELBO) 2.2480
  C=0.002 | KL tr/val=0.0287/0.0273 | gap(|KL-C|) tr/val=0.0271/0.0257 | zloss tr/val=0.4036/0.4133 | rec tr/val=0.9995/0.9997 | beta=0.3080, gamma=4.0000
[Save] Best model updated at epoch 22: models/gdsc/gdsc_best.pt


                                                          

Epoch 23/100 | Train 1.4133 | Val 1.4850 | Val(ELBO) 2.2803
  C=0.002 | KL tr/val=0.0277/0.0335 | gap(|KL-C|) tr/val=0.0261/0.0319 | zloss tr/val=0.4006/0.4233 | rec tr/val=0.9996/0.9997 | beta=0.3220, gamma=4.0000


                                                          

Epoch 24/100 | Train 1.4177 | Val 1.4461 | Val(ELBO) 2.2430
  C=0.002 | KL tr/val=0.0271/0.0249 | gap(|KL-C|) tr/val=0.0254/0.0231 | zloss tr/val=0.4020/0.4117 | rec tr/val=0.9996/0.9997 | beta=0.3360, gamma=4.0000
[Save] Best model updated at epoch 24: models/gdsc/gdsc_best.pt


                                                          

Epoch 25/100 | Train 1.3976 | Val 1.4325 | Val(ELBO) 2.2298
  C=0.002 | KL tr/val=0.0265/0.0230 | gap(|KL-C|) tr/val=0.0247/0.0212 | zloss tr/val=0.3953/0.4073 | rec tr/val=0.9996/0.9997 | beta=0.3500, gamma=4.0000
[Save] Best model updated at epoch 25: models/gdsc/gdsc_best.pt


                                                          

Epoch 26/100 | Train 1.3896 | Val 1.4324 | Val(ELBO) 2.2306
  C=0.002 | KL tr/val=0.0260/0.0240 | gap(|KL-C|) tr/val=0.0242/0.0221 | zloss tr/val=0.3926/0.4074 | rec tr/val=0.9996/0.9997 | beta=0.3640, gamma=4.0000
[Save] Best model updated at epoch 26: models/gdsc/gdsc_best.pt


                                                          

Epoch 27/100 | Train 1.3858 | Val 1.4168 | Val(ELBO) 2.2142
  C=0.002 | KL tr/val=0.0249/0.0268 | gap(|KL-C|) tr/val=0.0230/0.0248 | zloss tr/val=0.3914/0.4015 | rec tr/val=0.9997/0.9997 | beta=0.3780, gamma=4.0000
[Save] Best model updated at epoch 27: models/gdsc/gdsc_best.pt


                                                          

Epoch 28/100 | Train 1.3696 | Val 1.4089 | Val(ELBO) 2.2057
  C=0.002 | KL tr/val=0.0247/0.0259 | gap(|KL-C|) tr/val=0.0227/0.0239 | zloss tr/val=0.3860/0.3986 | rec tr/val=0.9994/0.9997 | beta=0.3920, gamma=4.0000
[Save] Best model updated at epoch 28: models/gdsc/gdsc_best.pt


                                                          

Epoch 29/100 | Train 1.3621 | Val 1.4042 | Val(ELBO) 2.2017
  C=0.002 | KL tr/val=0.0241/0.0235 | gap(|KL-C|) tr/val=0.0220/0.0214 | zloss tr/val=0.3835/0.3975 | rec tr/val=0.9997/0.9997 | beta=0.4060, gamma=4.0000
[Save] Best model updated at epoch 29: models/gdsc/gdsc_best.pt


                                                          

Epoch 30/100 | Train 1.3551 | Val 1.4043 | Val(ELBO) 2.2018
  C=0.002 | KL tr/val=0.0233/0.0235 | gap(|KL-C|) tr/val=0.0211/0.0214 | zloss tr/val=0.3812/0.3974 | rec tr/val=0.9994/0.9997 | beta=0.4200, gamma=4.0000
  [Eval@E30] MAE=0.5995, R^2=0.3530
  [Eval@E30] yhat mean/std = 0.0952/0.6117


                                                          

Epoch 31/100 | Train 1.3495 | Val 1.4027 | Val(ELBO) 2.2009
  C=0.002 | KL tr/val=0.0230/0.0190 | gap(|KL-C|) tr/val=0.0208/0.0168 | zloss tr/val=0.3793/0.3976 | rec tr/val=0.9994/0.9997 | beta=0.4340, gamma=4.0000
[Save] Best model updated at epoch 31: models/gdsc/gdsc_best.pt


                                                          

Epoch 32/100 | Train 1.3470 | Val 1.4027 | Val(ELBO) 2.2016
  C=0.002 | KL tr/val=0.0224/0.0212 | gap(|KL-C|) tr/val=0.0201/0.0189 | zloss tr/val=0.3785/0.3975 | rec tr/val=0.9995/0.9997 | beta=0.4480, gamma=4.0000
[Save] Best model updated at epoch 32: models/gdsc/gdsc_best.pt


                                                          

Epoch 33/100 | Train 1.3336 | Val 1.3823 | Val(ELBO) 2.1804
  C=0.002 | KL tr/val=0.0220/0.0192 | gap(|KL-C|) tr/val=0.0196/0.0168 | zloss tr/val=0.3740/0.3906 | rec tr/val=0.9996/0.9997 | beta=0.4620, gamma=4.0000
[Save] Best model updated at epoch 33: models/gdsc/gdsc_best.pt


                                                          

Epoch 34/100 | Train 1.3268 | Val 1.3877 | Val(ELBO) 2.1864
  C=0.002 | KL tr/val=0.0219/0.0213 | gap(|KL-C|) tr/val=0.0195/0.0188 | zloss tr/val=0.3717/0.3922 | rec tr/val=0.9995/0.9997 | beta=0.4760, gamma=4.0000


                                                          

Epoch 35/100 | Train 1.3198 | Val 1.3671 | Val(ELBO) 2.1658
  C=0.003 | KL tr/val=0.0212/0.0213 | gap(|KL-C|) tr/val=0.0186/0.0188 | zloss tr/val=0.3694/0.3852 | rec tr/val=0.9996/0.9997 | beta=0.4900, gamma=4.0000
[Save] Best model updated at epoch 35: models/gdsc/gdsc_best.pt


                                                          

Epoch 36/100 | Train 1.3076 | Val 1.3742 | Val(ELBO) 2.1727
  C=0.003 | KL tr/val=0.0212/0.0181 | gap(|KL-C|) tr/val=0.0186/0.0154 | zloss tr/val=0.3652/0.3879 | rec tr/val=0.9995/0.9997 | beta=0.5040, gamma=4.0000


                                                          

Epoch 37/100 | Train 1.2980 | Val 1.3651 | Val(ELBO) 2.1636
  C=0.003 | KL tr/val=0.0205/0.0210 | gap(|KL-C|) tr/val=0.0178/0.0183 | zloss tr/val=0.3620/0.3843 | rec tr/val=0.9997/0.9997 | beta=0.5180, gamma=4.0000
[Save] Best model updated at epoch 37: models/gdsc/gdsc_best.pt


                                                          

Epoch 38/100 | Train 1.2950 | Val 1.3615 | Val(ELBO) 2.1594
  C=0.003 | KL tr/val=0.0204/0.0188 | gap(|KL-C|) tr/val=0.0176/0.0160 | zloss tr/val=0.3610/0.3832 | rec tr/val=0.9996/0.9997 | beta=0.5320, gamma=4.0000
[Save] Best model updated at epoch 38: models/gdsc/gdsc_best.pt


                                                          

Epoch 39/100 | Train 1.2910 | Val 1.3521 | Val(ELBO) 2.1504
  C=0.003 | KL tr/val=0.0203/0.0178 | gap(|KL-C|) tr/val=0.0175/0.0150 | zloss tr/val=0.3596/0.3803 | rec tr/val=0.9996/0.9997 | beta=0.5460, gamma=4.0000
[Save] Best model updated at epoch 39: models/gdsc/gdsc_best.pt


                                                          

Epoch 40/100 | Train 1.2936 | Val 1.3549 | Val(ELBO) 2.1535
  C=0.003 | KL tr/val=0.0203/0.0191 | gap(|KL-C|) tr/val=0.0174/0.0161 | zloss tr/val=0.3604/0.3811 | rec tr/val=0.9996/0.9997 | beta=0.5600, gamma=4.0000
  [Eval@E40] MAE=0.5820, R^2=0.3908
  [Eval@E40] yhat mean/std = 0.0174/0.6415


                                                          

Epoch 41/100 | Train 1.2771 | Val 1.3532 | Val(ELBO) 2.1503
  C=0.003 | KL tr/val=0.0198/0.0185 | gap(|KL-C|) tr/val=0.0168/0.0155 | zloss tr/val=0.3549/0.3800 | rec tr/val=0.9997/0.9997 | beta=0.5740, gamma=4.0000


                                                          

Epoch 42/100 | Train 1.2709 | Val 1.3479 | Val(ELBO) 2.1467
  C=0.003 | KL tr/val=0.0192/0.0183 | gap(|KL-C|) tr/val=0.0161/0.0152 | zloss tr/val=0.3529/0.3787 | rec tr/val=0.9997/0.9997 | beta=0.5880, gamma=4.0000
[Save] Best model updated at epoch 42: models/gdsc/gdsc_best.pt


                                                          

Epoch 43/100 | Train 1.2636 | Val 1.3351 | Val(ELBO) 2.1324
  C=0.003 | KL tr/val=0.0192/0.0175 | gap(|KL-C|) tr/val=0.0160/0.0144 | zloss tr/val=0.3504/0.3741 | rec tr/val=0.9997/0.9997 | beta=0.6020, gamma=4.0000
[Save] Best model updated at epoch 43: models/gdsc/gdsc_best.pt


                                                          

Epoch 44/100 | Train 1.2574 | Val 1.3257 | Val(ELBO) 2.1239
  C=0.003 | KL tr/val=0.0190/0.0176 | gap(|KL-C|) tr/val=0.0158/0.0144 | zloss tr/val=0.3483/0.3711 | rec tr/val=0.9995/0.9997 | beta=0.6160, gamma=4.0000
[Save] Best model updated at epoch 44: models/gdsc/gdsc_best.pt


                                                          

Epoch 45/100 | Train 1.2563 | Val 1.3418 | Val(ELBO) 2.1399
  C=0.003 | KL tr/val=0.0188/0.0169 | gap(|KL-C|) tr/val=0.0155/0.0136 | zloss tr/val=0.3479/0.3765 | rec tr/val=0.9992/0.9997 | beta=0.6300, gamma=4.0000


                                                          

Epoch 46/100 | Train 1.2500 | Val 1.3262 | Val(ELBO) 2.1245
  C=0.003 | KL tr/val=0.0185/0.0178 | gap(|KL-C|) tr/val=0.0151/0.0144 | zloss tr/val=0.3457/0.3711 | rec tr/val=1.0000/0.9997 | beta=0.6440, gamma=4.0000


                                                          

Epoch 47/100 | Train 1.2410 | Val 1.3196 | Val(ELBO) 2.1178
  C=0.003 | KL tr/val=0.0186/0.0163 | gap(|KL-C|) tr/val=0.0151/0.0129 | zloss tr/val=0.3427/0.3691 | rec tr/val=0.9995/0.9997 | beta=0.6580, gamma=4.0000
[Save] Best model updated at epoch 47: models/gdsc/gdsc_best.pt


                                                          

Epoch 48/100 | Train 1.2350 | Val 1.3285 | Val(ELBO) 2.1261
  C=0.004 | KL tr/val=0.0182/0.0155 | gap(|KL-C|) tr/val=0.0147/0.0119 | zloss tr/val=0.3407/0.3720 | rec tr/val=0.9994/0.9997 | beta=0.6720, gamma=4.0000


                                                          

Epoch 49/100 | Train 1.2298 | Val 1.3221 | Val(ELBO) 2.1202
  C=0.004 | KL tr/val=0.0178/0.0172 | gap(|KL-C|) tr/val=0.0142/0.0136 | zloss tr/val=0.3390/0.3696 | rec tr/val=0.9998/0.9997 | beta=0.6860, gamma=4.0000


                                                          

Epoch 50/100 | Train 1.2241 | Val 1.3454 | Val(ELBO) 2.1444
  C=0.004 | KL tr/val=0.0178/0.0148 | gap(|KL-C|) tr/val=0.0142/0.0111 | zloss tr/val=0.3370/0.3781 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
  [Eval@E50] MAE=0.5796, R^2=0.4034
  [Eval@E50] yhat mean/std = -0.0817/0.6393


                                                          

Epoch 51/100 | Train 1.2346 | Val 1.3134 | Val(ELBO) 2.1118
  C=0.004 | KL tr/val=0.0180/0.0161 | gap(|KL-C|) tr/val=0.0143/0.0123 | zloss tr/val=0.3404/0.3670 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 51: models/gdsc/gdsc_best.pt


                                                          

Epoch 52/100 | Train 1.2103 | Val 1.3088 | Val(ELBO) 2.1071
  C=0.004 | KL tr/val=0.0175/0.0151 | gap(|KL-C|) tr/val=0.0137/0.0113 | zloss tr/val=0.3324/0.3656 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 52: models/gdsc/gdsc_best.pt


                                                          

Epoch 53/100 | Train 1.2133 | Val 1.3206 | Val(ELBO) 2.1190
  C=0.004 | KL tr/val=0.0174/0.0172 | gap(|KL-C|) tr/val=0.0135/0.0133 | zloss tr/val=0.3335/0.3691 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 54/100 | Train 1.2065 | Val 1.3137 | Val(ELBO) 2.1110
  C=0.004 | KL tr/val=0.0178/0.0138 | gap(|KL-C|) tr/val=0.0139/0.0098 | zloss tr/val=0.3311/0.3672 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 55/100 | Train 1.1997 | Val 1.3081 | Val(ELBO) 2.1061
  C=0.004 | KL tr/val=0.0169/0.0144 | gap(|KL-C|) tr/val=0.0129/0.0103 | zloss tr/val=0.3291/0.3654 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 55: models/gdsc/gdsc_best.pt


                                                          

Epoch 56/100 | Train 1.1922 | Val 1.2988 | Val(ELBO) 2.0974
  C=0.004 | KL tr/val=0.0170/0.0150 | gap(|KL-C|) tr/val=0.0128/0.0108 | zloss tr/val=0.3266/0.3624 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 56: models/gdsc/gdsc_best.pt


                                                          

Epoch 57/100 | Train 1.1906 | Val 1.3096 | Val(ELBO) 2.1076
  C=0.004 | KL tr/val=0.0167/0.0149 | gap(|KL-C|) tr/val=0.0125/0.0107 | zloss tr/val=0.3261/0.3658 | rec tr/val=0.9994/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 58/100 | Train 1.1927 | Val 1.3033 | Val(ELBO) 2.1005
  C=0.004 | KL tr/val=0.0167/0.0154 | gap(|KL-C|) tr/val=0.0124/0.0112 | zloss tr/val=0.3268/0.3633 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 59/100 | Train 1.1798 | Val 1.3017 | Val(ELBO) 2.1003
  C=0.004 | KL tr/val=0.0166/0.0150 | gap(|KL-C|) tr/val=0.0123/0.0107 | zloss tr/val=0.3225/0.3634 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 60/100 | Train 1.1764 | Val 1.3007 | Val(ELBO) 2.0983
  C=0.004 | KL tr/val=0.0171/0.0148 | gap(|KL-C|) tr/val=0.0127/0.0104 | zloss tr/val=0.3213/0.3627 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
  [Eval@E60] MAE=0.5611, R^2=0.4256
  [Eval@E60] yhat mean/std = 0.0998/0.6924


                                                          

Epoch 61/100 | Train 1.1848 | Val 1.3006 | Val(ELBO) 2.0994
  C=0.004 | KL tr/val=0.0167/0.0146 | gap(|KL-C|) tr/val=0.0122/0.0101 | zloss tr/val=0.3242/0.3631 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 62/100 | Train 1.1649 | Val 1.2985 | Val(ELBO) 2.0963
  C=0.005 | KL tr/val=0.0163/0.0139 | gap(|KL-C|) tr/val=0.0117/0.0094 | zloss tr/val=0.3177/0.3623 | rec tr/val=0.9993/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 62: models/gdsc/gdsc_best.pt


                                                          

Epoch 63/100 | Train 1.1702 | Val 1.2984 | Val(ELBO) 2.0974
  C=0.005 | KL tr/val=0.0160/0.0145 | gap(|KL-C|) tr/val=0.0113/0.0098 | zloss tr/val=0.3195/0.3625 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 63: models/gdsc/gdsc_best.pt


                                                          

Epoch 64/100 | Train 1.1607 | Val 1.2895 | Val(ELBO) 2.0882
  C=0.005 | KL tr/val=0.0164/0.0135 | gap(|KL-C|) tr/val=0.0117/0.0088 | zloss tr/val=0.3162/0.3597 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 64: models/gdsc/gdsc_best.pt


                                                          

Epoch 65/100 | Train 1.1620 | Val 1.2903 | Val(ELBO) 2.0890
  C=0.005 | KL tr/val=0.0163/0.0130 | gap(|KL-C|) tr/val=0.0115/0.0082 | zloss tr/val=0.3167/0.3601 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 66/100 | Train 1.1562 | Val 1.2909 | Val(ELBO) 2.0894
  C=0.005 | KL tr/val=0.0164/0.0137 | gap(|KL-C|) tr/val=0.0116/0.0088 | zloss tr/val=0.3147/0.3600 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 67/100 | Train 1.1430 | Val 1.2836 | Val(ELBO) 2.0812
  C=0.005 | KL tr/val=0.0159/0.0134 | gap(|KL-C|) tr/val=0.0109/0.0084 | zloss tr/val=0.3105/0.3574 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 67: models/gdsc/gdsc_best.pt


                                                          

Epoch 68/100 | Train 1.1616 | Val 1.3019 | Val(ELBO) 2.1005
  C=0.005 | KL tr/val=0.0171/0.0137 | gap(|KL-C|) tr/val=0.0121/0.0087 | zloss tr/val=0.3164/0.3637 | rec tr/val=0.9994/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 69/100 | Train 1.1437 | Val 1.2798 | Val(ELBO) 2.0781
  C=0.005 | KL tr/val=0.0162/0.0132 | gap(|KL-C|) tr/val=0.0111/0.0081 | zloss tr/val=0.3106/0.3564 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 69: models/gdsc/gdsc_best.pt


                                                          

Epoch 70/100 | Train 1.1335 | Val 1.2875 | Val(ELBO) 2.0849
  C=0.005 | KL tr/val=0.0169/0.0130 | gap(|KL-C|) tr/val=0.0117/0.0078 | zloss tr/val=0.3070/0.3587 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000
  [Eval@E70] MAE=0.5568, R^2=0.4351
  [Eval@E70] yhat mean/std = 0.0580/0.6885


                                                          

Epoch 71/100 | Train 1.1282 | Val 1.2814 | Val(ELBO) 2.0798
  C=0.005 | KL tr/val=0.0143/0.0130 | gap(|KL-C|) tr/val=0.0090/0.0077 | zloss tr/val=0.3060/0.3570 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 72/100 | Train 1.1198 | Val 1.2825 | Val(ELBO) 2.0813
  C=0.005 | KL tr/val=0.0137/0.0134 | gap(|KL-C|) tr/val=0.0084/0.0081 | zloss tr/val=0.3033/0.3574 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 73/100 | Train 1.1274 | Val 1.2804 | Val(ELBO) 2.0788
  C=0.005 | KL tr/val=0.0166/0.0132 | gap(|KL-C|) tr/val=0.0112/0.0078 | zloss tr/val=0.3051/0.3566 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 74/100 | Train 1.1244 | Val 1.2823 | Val(ELBO) 2.0806
  C=0.005 | KL tr/val=0.0169/0.0123 | gap(|KL-C|) tr/val=0.0114/0.0068 | zloss tr/val=0.3040/0.3574 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 75/100 | Train 1.1174 | Val 1.2829 | Val(ELBO) 2.0826
  C=0.006 | KL tr/val=0.0165/0.0132 | gap(|KL-C|) tr/val=0.0110/0.0077 | zloss tr/val=0.3018/0.3579 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 76/100 | Train 1.1242 | Val 1.2905 | Val(ELBO) 2.0884
  C=0.006 | KL tr/val=0.0171/0.0124 | gap(|KL-C|) tr/val=0.0115/0.0067 | zloss tr/val=0.3039/0.3600 | rec tr/val=0.9994/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 77/100 | Train 1.1174 | Val 1.2776 | Val(ELBO) 2.0761
  C=0.006 | KL tr/val=0.0163/0.0125 | gap(|KL-C|) tr/val=0.0106/0.0068 | zloss tr/val=0.3019/0.3559 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
[Save] Best model updated at epoch 77: models/gdsc/gdsc_best.pt


                                                          

Epoch 78/100 | Train 1.1100 | Val 1.2805 | Val(ELBO) 2.0792
  C=0.006 | KL tr/val=0.0154/0.0136 | gap(|KL-C|) tr/val=0.0096/0.0078 | zloss tr/val=0.2996/0.3567 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 79/100 | Train 1.1046 | Val 1.2884 | Val(ELBO) 2.0873
  C=0.006 | KL tr/val=0.0163/0.0126 | gap(|KL-C|) tr/val=0.0104/0.0068 | zloss tr/val=0.2976/0.3596 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 80/100 | Train 1.1078 | Val 1.2680 | Val(ELBO) 2.0670
  C=0.006 | KL tr/val=0.0156/0.0126 | gap(|KL-C|) tr/val=0.0096/0.0066 | zloss tr/val=0.2989/0.3528 | rec tr/val=0.9996/0.9997 | beta=0.7000, gamma=4.0000
  [Eval@E80] MAE=0.5505, R^2=0.4480
  [Eval@E80] yhat mean/std = 0.0347/0.7047
[Save] Best model updated at epoch 80: models/gdsc/gdsc_best.pt


                                                          

Epoch 81/100 | Train 1.0956 | Val 1.2856 | Val(ELBO) 2.0839
  C=0.006 | KL tr/val=0.0177/0.0123 | gap(|KL-C|) tr/val=0.0117/0.0063 | zloss tr/val=0.2942/0.3585 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 82/100 | Train 1.0924 | Val 1.2701 | Val(ELBO) 2.0693
  C=0.006 | KL tr/val=0.0134/0.0129 | gap(|KL-C|) tr/val=0.0073/0.0068 | zloss tr/val=0.2943/0.3535 | rec tr/val=0.9994/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 83/100 | Train 1.0949 | Val 1.2744 | Val(ELBO) 2.0731
  C=0.006 | KL tr/val=0.0141/0.0129 | gap(|KL-C|) tr/val=0.0079/0.0068 | zloss tr/val=0.2950/0.3548 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 84/100 | Train 1.0925 | Val 1.2800 | Val(ELBO) 2.0781
  C=0.006 | KL tr/val=0.0177/0.0116 | gap(|KL-C|) tr/val=0.0115/0.0054 | zloss tr/val=0.2932/0.3568 | rec tr/val=0.9998/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 85/100 | Train 1.0851 | Val 1.2728 | Val(ELBO) 2.0716
  C=0.006 | KL tr/val=0.0169/0.0119 | gap(|KL-C|) tr/val=0.0106/0.0056 | zloss tr/val=0.2910/0.3545 | rec tr/val=0.9994/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 86/100 | Train 1.0841 | Val 1.2881 | Val(ELBO) 2.0863
  C=0.006 | KL tr/val=0.0182/0.0123 | gap(|KL-C|) tr/val=0.0118/0.0059 | zloss tr/val=0.2903/0.3593 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 87/100 | Train 1.0782 | Val 1.2762 | Val(ELBO) 2.0750
  C=0.006 | KL tr/val=0.0146/0.0123 | gap(|KL-C|) tr/val=0.0082/0.0058 | zloss tr/val=0.2893/0.3556 | rec tr/val=0.9993/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 88/100 | Train 1.0764 | Val 1.2848 | Val(ELBO) 2.0841
  C=0.007 | KL tr/val=0.0129/0.0128 | gap(|KL-C|) tr/val=0.0064/0.0063 | zloss tr/val=0.2891/0.3585 | rec tr/val=0.9997/0.9997 | beta=0.7000, gamma=4.0000


                                                          

Epoch 89/100 | Train 1.0767 | Val 1.2752 | Val(ELBO) 2.0731
  C=0.007 | KL tr/val=0.0150/0.0120 | gap(|KL-C|) tr/val=0.0084/0.0054 | zloss tr/val=0.2886/0.3550 | rec tr/val=0.9995/0.9997 | beta=0.7000, gamma=4.0000


Train E90:  38%|███▊      | 5/13 [00:03<00:04,  1.72it/s]

In [None]:
plot_all_curves(
    train_hist,
    val_hist,
    elbo_val_hist,
    train_rec_hist,
    val_rec_hist,
    train_kl_hist,
    val_kl_hist,
    train_z_hist,
    val_z_hist,
    C_values,
    train_capgap_hist,
    val_capgap_hist,
)

In [None]:
model.eval()

eval_dl = torch.utils.data.DataLoader(
    val_ds,  # 元の val_ds でOK
    batch_size=8192,  # 好きなサイズ
    shuffle=False,
    num_workers=0,  # ★シングルプロセス
    pin_memory=False,  # ★不要
    persistent_workers=False,  # ★絶対 False
    drop_last=False,
)

ys, yhats = [], []
with torch.inference_mode():
    for xg, xd, y in eval_dl:
        xg, xd = xg.to(device), xd.to(device)
        _, _, _, yhat = model(xg, xd)
        ys.append(y.cpu().numpy())
        yhats.append(yhat.cpu().numpy())
ys = np.concatenate(ys)
yhats = np.concatenate(yhats)
mae = mean_absolute_error(ys, yhats)
r2 = r2_score(ys, yhats)

plt.figure()
plt.scatter(ys, yhats, s=8, alpha=0.4)
lims = [min(ys.min(), yhats.min()), max(ys.max(), yhats.max())]
plt.plot(lims, lims)
plt.xlabel("True y")
plt.ylabel("Predicted ŷ")
plt.title(f"Parity Plot (Val)  MAE={mae:.3f}, R²={r2:.3f}")
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
from IPython.display import Audio

# 440Hz の正弦波を 1 秒再生
sr = 22050
t = np.linspace(0, 1, sr)
x = np.sin(2 * np.pi * 440 * t)
Audio(x, rate=sr, autoplay=True)

# How to eval from loading the model

In [None]:
# # 1. Build the model and load the trained weights
# model = GeneDrugVAE(gene_dim, drug_dim=768, proj_dim=256, hidden=512, latent=128).to(device)
# model.load_state_dict(torch.load("models/gdsc/gdsc_last.pt", map_location=device))
# model.eval()  # set to evaluation mode

# # 2. Extract latent representations (μ) for all samples
# mus = []
# with torch.no_grad():  # disable gradient computation
#     for xg, xd, _ in dl:   # ignore y (IC50), only need xg and xd
#         xg = xg.to(device, non_blocking=True)
#         xd = xd.to(device, non_blocking=True)
#         mu, lv = model.encode(xg, xd)   # encode gene + drug into latent space
#         mus.append(mu.cpu().numpy())    # collect μ on CPU

# # Concatenate all latent vectors into one array
# Z_mu = np.concatenate(mus, axis=0)
# print("Latent shape:", Z_mu.shape)

In [None]:
# # ---------- 1) 全サンプルの μ をバッチで抽出 ----------
# model.eval()
# mus = []
# with torch.no_grad():
#     # 推論用 DataLoader（shuffle=False）
#     infer_dl = DataLoader(ds, batch_size=8192, shuffle=False, num_workers=0)
#     for xg, xd, y in tqdm(infer_dl, desc="Encode (μ)", leave=False):
#         xg = xg.to(device, non_blocking=True)
#         xd = xd.to(device, non_blocking=True)
#         mu, logvar = model.encode(xg, xd)  # GeneDrugVAE で encode を実装している前提
#         mus.append(mu.cpu().numpy())

# Z_mu = np.concatenate(mus, axis=0)  # (N, latent)

In [None]:
# # 可視化用メタ：薬ラベル & Z_score（元の ds に対応する順序を確保）
# # cellline_small の順序が ds と一致している前提（Dataset 内でそのまま index を使っている実装ならOK）
# meta_df = cellline_small.reset_index(drop=True).copy()
# # NAME が無い場合は SMILES を代用
# drug_label = meta_df["NAME"] if "NAME" in meta_df.columns else meta_df["SMILES"]
# zscore = meta_df["Z_score"].values
# cellline = meta_df["COSMIC_ID"]

In [None]:
# X0 = StandardScaler(with_mean=True, with_std=True).fit_transform(Z_mu)
# X = (
#     PCA(n_components=min(100, X0.shape[1]), random_state=42)
#     .fit_transform(X0)
#     .astype("float32")
# )

# umap2d = umap.UMAP(
#     n_neighbors=40,
#     min_dist=0.05,
#     metric="cosine",
#     init="spectral",
#     densmap=True,
#     random_state=42,
#     low_memory=True,
#     verbose=True,
# ).fit_transform(X)

In [None]:
# plt.figure()
# sc = plt.scatter(umap2d[:, 0], umap2d[:, 1], s=6, c=zscore, alpha=0.7)
# cb = plt.colorbar(sc)
# cb.set_label("Z_score")
# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# plt.title("UMAP colored by Z_score")
# plt.tight_layout()
# plt.show()

In [None]:
# # 入力: umap2d (N,2), drug_label (N,), zscore (N,)
# labs = pd.Series(drug_label).astype(str).str.strip().values
# z = np.asarray(zscore, dtype=float)

# # 座標で集約（丸めて同一点判定）
# xy = np.round(umap2d, 6)
# df = pd.DataFrame({"x": xy[:, 0], "y": xy[:, 1], "drug": labs, "z": z})

# # 各座標で「z が最小の薬」を代表に採用（効きが強い薬）
# idx = df.groupby(["x", "y"])["z"].idxmin()
# rep = df.loc[idx].reset_index(drop=True)  # x,y,drug,z が 1行/座標

# # 可視化（Top-20 + Other）
# vc = rep["drug"].value_counts()
# top_k = 20
# top = list(vc.head(top_k).index)

# cmap = plt.colormaps.get_cmap("tab20")
# colors = {d: cmap(i / (top_k - 1)) for i, d in enumerate(top)}
# other_color = (0.5, 0.5, 0.5, 0.25)

# plt.figure(figsize=(9, 7))
# # Otherを下地に
# m_other = ~rep["drug"].isin(top)
# plt.scatter(
#     rep.loc[m_other, "x"],
#     rep.loc[m_other, "y"],
#     s=6,
#     alpha=0.30,
#     color=other_color,
#     rasterized=True,
#     zorder=1,
#     label="Other",
#     edgecolors="none",
# )
# # Top-20 を重ねる
# for d in top:
#     m = rep["drug"].values == d
#     if m.any():
#         plt.scatter(
#             rep.loc[m, "x"],
#             rep.loc[m, "y"],
#             s=12,
#             alpha=0.95,
#             color=colors[d],
#             label=d,
#             rasterized=True,
#             zorder=3,
#             edgecolors="none",
#         )

# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# plt.title("UMAP – winner-takes-all by drug (min Z_score)")
# plt.legend(
#     bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=8, frameon=False, ncol=2
# )
# plt.tight_layout()
# plt.show()

In [None]:
# labs = pd.Series(drug_label).astype(str).str.strip().values
# drug = "Cisplatin"  # 例

# m = labs == drug
# plt.figure(figsize=(7, 6))
# plt.hexbin(umap2d[m, 0], umap2d[m, 1], gridsize=80, mincnt=1, norm=LogNorm())
# cb = plt.colorbar()
# cb.set_label("count (log)")
# plt.title(f"UMAP density – {drug}")
# plt.xlabel("UMAP 1")
# plt.ylabel("UMAP 2")
# plt.tight_layout()
# plt.show()