In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader,random_split,TensorDataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import math
import json
import joblib
import ast
import pickle
import os
import random
import string

In [None]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")

In [None]:
static_df = pd.read_csv('static_categorical.csv')

In [None]:
class StaticOneHotManager:
    def __init__(self):
        self.encoders = {}

    def fit(self, df, categorical_cols):
        for col in categorical_cols:
            categories = df[col].dropna().astype(str).unique().tolist()

            enc = OneHotEncoder(
                sparse_output=False,
                handle_unknown="ignore",
                categories=[categories]
            )
            values = df[col].dropna().astype(str).values.reshape(-1, 1)
            enc.fit(values)

            self.encoders[col] = enc

    def transform(self, df, categorical_cols, return_numpy=True):
        out = pd.DataFrame()
        if "PATIENT" in df.columns:
            out["PATIENT"] = df["PATIENT"]

        for col in categorical_cols:
            enc = self.encoders[col]

            values = df[col].astype(str).values.reshape(-1, 1)
            is_na = df[col].isna().values

            encoded = enc.transform(values)

            encoded[is_na] = 0
            if return_numpy:
                out[col] = [encoded[i, :] for i in range(encoded.shape[0])]
            else:
                out[col] = encoded.tolist()

        return out

    def inverse_transform(self, series, col):
        enc = self.encoders[col]
        categories = enc.categories_[0]  
        results = []
    
        for arr in series:
            arr = np.array(arr, dtype=float)  
            if arr.sum() == 0: 
                results.append(np.nan)
            else:
                idx = arr.argmax()  
                results.append(categories[idx])
    
        return results

    def save(self, path="static_encoded.pkl"):
        joblib.dump(self.encoders, path)

    def load(self, path="static_encoded.pkl"):
        self.encoders = joblib.load(path)

In [None]:
static_manager = StaticOneHotManager()
static_manager.fit(static_df,static_df.columns)
static_manager.save()
static_one_hot = static_manager.transform(static_df,static_df.columns)
static_one_hot.to_pickle("static_one_hot.pkl")

In [None]:
static_one_hot.head()

In [None]:
loaded = pd.read_pickle("static_one_hot.pkl")
manager2 = StaticOneHotManager()
manager2.load()
decoded_race = manager2.inverse_transform(loaded["RACE"], "RACE")
decoded_race

In [None]:
temporal_df = pd.read_csv('temporal_categorical.csv')

In [None]:
class TemporalOneHotManager:
    def __init__(self):
        self.encoders = {}

    def fit(self, df, categorical_cols):
        for col in categorical_cols:
            flat_values = []
            for seq in df[col]:
                if pd.isna(seq):
                    continue
                if isinstance(seq, str):
                    try:
                        seq = ast.literal_eval(seq.replace('nan', 'None'))
                    except (ValueError, SyntaxError):
                        seq = [seq]
                if not isinstance(seq, list):
                    seq = [seq]
                for v in seq:
                    if pd.notna(v):
                        flat_values.append(str(v))

            categories = pd.Series(flat_values).unique().tolist()
            enc = OneHotEncoder(
                sparse_output=False,
                handle_unknown="ignore",
                categories=[categories]
            )
            enc.fit(np.array(categories).reshape(-1, 1))
            self.encoders[col] = enc

    def transform(self, df, categorical_cols, return_numpy=True):
        out = pd.DataFrame()
        if "PATIENT" in df.columns:
            out["PATIENT"] = df["PATIENT"]

        for col in categorical_cols:
            enc = self.encoders[col]
            cats = enc.categories_[0].tolist()
            n_cats = len(cats)

            rows = []
            for seq in df[col]:
                if pd.isna(seq):
                    seq = []
                elif isinstance(seq, str):
                    try:
                        seq = ast.literal_eval(seq.replace('nan', 'None'))
                    except (ValueError, SyntaxError):
                        seq = [seq]

                if not isinstance(seq, list):
                    seq = [seq]

                encoded_seq = []
                for v in seq:
                    if pd.isna(v):
                        if n_cats > 0:
                            encoded_seq.append([0] * n_cats)
                    else:
                        if n_cats > 0:
                            arr = enc.transform([[str(v)]]).flatten()
                            encoded_seq.append(arr.astype(int).tolist())
                rows.append(encoded_seq)
            out[col] = rows

        return out


    def inverse_transform(self, series, col):
        enc = self.encoders[col]
        cats = enc.categories_[0].tolist()
        results = []
    
        for row in series:
            arr = np.array(row, dtype=float)
            if arr.sum() == 0:
                results.append(np.nan)
            else:
                results.append(cats[arr.argmax()])
        
        return results



    def save(self, path="temporal_encoded.pkl"):
        joblib.dump(self.encoders, path)

    def load(self, path="temporal_encoded.pkl"):
        self.encoders = joblib.load(path)

In [None]:
cols = temporal_df.columns[1:]
temporal_manager = TemporalOneHotManager()
temporal_manager.fit(temporal_df,cols)
temporal_manager.save()
temporal_one_hot = temporal_manager.transform(temporal_df,cols)
temporal_one_hot.to_pickle("temporal_one_hot.pkl")

In [None]:
temporal_one_hot.head()

In [None]:
def flatten_static_data(input_path: str, output_path: str):
    df_original = pd.read_pickle(input_path)
    all_flat_columns = []
    for col in df_original.columns:
        flat_cols = df_original[col].apply(pd.Series).add_prefix(f'{col}_')
        all_flat_columns.append(flat_cols)
    df_flat = pd.concat(all_flat_columns, axis=1)
    numpy_array = df_flat.to_numpy(dtype=np.float32)
    with open(output_path, 'wb') as f:
        pickle.dump(numpy_array, f)

    print(f"Original shape: {df_original.shape}")
    print(f"New flattened NumPy array shape: {numpy_array.shape}")
    print(f"Saved flattened NumPy array to {output_path}\n")

In [None]:
STATIC_INPUT_PATH = "static_one_hot.pkl"
STATIC_OUTPUT_PATH = "static_one_hot_flat.pkl"
flatten_static_data(STATIC_INPUT_PATH, STATIC_OUTPUT_PATH)

In [None]:
def flatten_temporal_data(input_path: str, output_path: str) -> np.ndarray:
    df_original = pd.read_pickle(input_path)
    all_timestamps_flat: List[list] = []
    for index, patient_row in df_original.iterrows():
        sequence_length = len(patient_row.iloc[0])
        for t in range(sequence_length):
            single_timestamp_parts = []
            for feature in df_original.columns:
                one_hot_vector = patient_row[feature][t]
                single_timestamp_parts.extend(one_hot_vector)
            all_timestamps_flat.append(single_timestamp_parts)
    training_data = np.array(all_timestamps_flat, dtype=np.float32)

    with open(output_path, 'wb') as f:
        pickle.dump(training_data, f)

    print(f"Successfully created flattened training data.")
    print(f"  - Total patients processed: {len(df_original)}")
    print(f"  - Final training data shape: {training_data.shape}")
    print(f"  - Saved to {output_path}")

In [None]:
TEMPORAL_INPUT_PATH = "temporal_one_hot.pkl"
TEMPORAL_OUTPUT_PATH = "temporal_one_hot_flat.pkl"
flatten_temporal_data(TEMPORAL_INPUT_PATH, TEMPORAL_OUTPUT_PATH)

In [None]:
class OneHotDataset(Dataset):
    def __init__(self, numpy_array):
        self.data = torch.from_numpy(numpy_array)
        print(f"Created dataset with shape: {self.data.shape}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return (self.data[idx],)

In [None]:
STATIC_FILE_PATH = "static_one_hot_flat.pkl"
with open(STATIC_FILE_PATH, 'rb') as f:
    full_dataset = pickle.load(f)

train_data, val_data = train_test_split(
    full_dataset,
    test_size=0.1,
    random_state=123
)

print(f"\nData split into:")
print(f" - Training set shape: {train_data.shape}")
print(f" - Validation set shape: {val_data.shape}")
print("\n--- Creating DataLoaders ---")
static_train_dataset = OneHotDataset(train_data)
static_val_dataset = OneHotDataset(val_data)

static_train_dataloader = DataLoader(dataset=static_train_dataset, batch_size=64, shuffle=True)
static_val_dataloader = DataLoader(dataset=static_val_dataset, batch_size=64, shuffle=False)

In [None]:
STATIC_FILE_PATH = "/kaggle/input/static-data/static_one_hot_flat.pkl"
with open(STATIC_FILE_PATH, 'rb') as f:
    full_dataset = pickle.load(f)

train_data, val_data = train_test_split(
    full_dataset,
    test_size=0.05,
    random_state=123
)

print(f"\nData split into:")
print(f" - Training set shape: {train_data.shape}")
print(f" - Validation set shape: {val_data.shape}")
print("\n--- Creating DataLoaders ---")
static_train_dataset = OneHotDataset(train_data)
static_val_dataset = OneHotDataset(val_data)

static_train_dataloader = DataLoader(
    static_train_dataset,
    batch_size=512,
    shuffle=True,
    num_workers=8,     
    pin_memory=True    
)

static_val_dataloader = DataLoader(
    static_val_dataset,
    batch_size=512,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

In [None]:
static_batch = next(iter(static_train_dataloader))
print(f"Shape of one static batch: {static_batch[0].shape}")

In [None]:
print(torch.cuda.is_available())  
print(torch.cuda.get_device_name(0))  

In [None]:
TEMPORAL_FILE_PATH = "/kaggle/input/temporal-one-hot/temporal_one_hot_flat.pkl"
with open(TEMPORAL_FILE_PATH, 'rb') as f:
    full_dataset = pickle.load(f)

train_data, val_data = train_test_split(
    full_dataset,
    test_size=0.05,
    random_state=123
)

print(f"\nData split into:")
print(f" - Training set shape: {train_data.shape}")
print(f" - Validation set shape: {val_data.shape}")
print("\n--- Creating DataLoaders ---")
temporal_train_dataset = OneHotDataset(train_data)
temporal_val_dataset = OneHotDataset(val_data)

temporal_train_dataloader = DataLoader(
    temporal_train_dataset,
    batch_size=1024,
    shuffle=True,
    num_workers=8,     
    pin_memory=True    
)

temporal_val_dataloader = DataLoader(
    temporal_val_dataset,
    batch_size=1024,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

In [None]:
temporal_batch = next(iter(temporal_train_dataloader))
print(f"Shape of one temporal batch: {temporal_batch[0].shape}")

In [None]:
class CategoricalAutoEncoder(nn.Module):

    def __init__(self, input_dims, lr=1e-3, optimizer_type="adam", use_scheduler=False, filename=None):
        super().__init__()

        self.input_dims = input_dims
        self.filename = filename
        self.total_input_dim = sum(self.input_dims)
        self.num_features = len(self.input_dims)
        hidden_dim = max(256, self.total_input_dim)
        latent_dim = max(32, min(self.total_input_dim // 8, self.total_input_dim - 1))

        # ------ Encoder -----------
        self.encoder = nn.Sequential(
            nn.Linear(self.total_input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim // 2, latent_dim),
            nn.Tanh()
        )

        # -------- Decoders: one head per feature ----------
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2),
                nn.LeakyReLU(0.2),
                
                nn.Linear(hidden_dim // 2, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                
                nn.Linear(hidden_dim, k)
            )
            for k in self.input_dims
        ])
        
        #--------- Loss & Optimizer ---------
        self.criterion = nn.CrossEntropyLoss()
        if optimizer_type == "adam":
            self.optimizer = optim.Adam(self.parameters(), lr=lr)
        elif optimizer_type.lower() == "rmsprop":
            self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        else:
            raise ValueError("optimizer_type must be Adam or RMSprop")

        self.scheduler = (optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min',
                                                              factor=0.5, patience=3)
                         if use_scheduler else None)

    # ---------- Forward / Encode / Decode ----------
    def forward(self, x):
        z = self.encoder(x)
        logits_list = [head(z) for head in self.decoders]
        return logits_list

    def encode(self, x):
        self.eval()
        with torch.no_grad():
            z = self.encoder(x)
        return z

    def decode(self, z: torch.Tensor, mask=None, return_onehot=True):
        self.eval()
        with torch.no_grad():
            logits_list = [head(z) for head in self.decoders]
            preds = [logits.argmax(dim=-1) for logits in logits_list]
    
            if not return_onehot:
                return preds
    
            bsz = z.shape[0]
            parts = []
    
            for i, (idxs, K) in enumerate(zip(preds, self.input_dims)):
                onehot = torch.zeros(bsz, K, device=z.device)
                onehot.scatter_(1, idxs.view(-1, 1), 1.0)
                
                if mask is not None:
                    # mask[:, i] has shape (batch_size,)
                    onehot = onehot * mask[:, i].unsqueeze(1)  # broadcast to (batch_size, K)
                
                parts.append(onehot)
    
            return preds, torch.cat(parts, dim=1)


    def predict_prob(self, x):
        self.eval()
        with torch.no_grad():
            logits_list = self.forward(x)
            probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
        return probs_list

    # ---------- Reconstruct with mask for missingness ----------
    def reconstruct(self, x, return_onehot=False):
        probs_list = self.predict_prob(x)
        preds = []
        bsz = x.shape[0]
        parts = []

        start = 0
        for i, K in enumerate(self.input_dims):
            end = start + K
            segment = x[:, start:end]

            # All-zero mask for missing features
            mask_all_zero = (segment.sum(dim=1) == 0)
            segment_logits = probs_list[i]
            segment_pred = segment_logits.argmax(dim=-1)
            preds.append(segment_pred)

            if return_onehot:
                onehot = torch.zeros(bsz, K, device=x.device)
                idxs_to_scatter = (~mask_all_zero).nonzero(as_tuple=True)[0]
                if len(idxs_to_scatter) > 0:
                    onehot[idxs_to_scatter].scatter_(1, segment_pred[idxs_to_scatter].view(-1,1), 1.0)
                parts.append(onehot)

            start = end

        if return_onehot:
            return preds, torch.cat(parts, dim=1)
        return preds

    # ---------- Helper functions ----------
    def _targets_from_onehot(self, x):
        parts = torch.split(x, self.input_dims, dim=1)
        return torch.stack([p.argmax(dim=1) for p in parts], dim=1).long()

    def _compute_loss(self, logits_list, targets, mask=None):
        loss = 0.0
        for i, logits in enumerate(logits_list):
            if mask is not None:
                present_idx = (mask[:, i] == 1).nonzero(as_tuple=True)[0]
                if len(present_idx) == 0:
                    continue  # skip missing features
                loss += self.criterion(logits[present_idx], targets[present_idx, i])
            else:
                loss += self.criterion(logits, targets[:, i])
        return loss

    # ---------- Training ----------
    def fit(self, dataloader, epochs=10, val_dataloader=None, wrapper_model=None, mask_val=None):
        forward_model = wrapper_model if wrapper_model is not None else self
        best_val_loss = float('inf')
        best_weights = None

        for epoch in range(1, epochs + 1):
            self.train()
            epoch_losses = []

            for (x,) in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False):
                x = x.to(next(self.parameters()).device, non_blocking=True)
                y = self._targets_from_onehot(x)
                self.optimizer.zero_grad()
                logits_list = forward_model(x)
                loss = self._compute_loss(logits_list, y)
                if wrapper_model is not None:
                    loss = loss.mean()
                loss.backward()
                self.optimizer.step()
                epoch_losses.append(loss.item())

            train_loss = sum(epoch_losses) / max(1, len(epoch_losses))

            val_loss = None
            if val_dataloader is not None:
                self.eval()
                val_losses = []
                with torch.no_grad():
                    for i, (vx,) in enumerate(val_dataloader):
                        vx = vx.to(next(self.parameters()).device, non_blocking=True)
                        vy = self._targets_from_onehot(vx)
                        vlogits_list = forward_model(vx)
                        if mask_val is not None:
                            vloss = self._compute_loss(vlogits_list, vy, mask=mask_val[i*vx.size(0):(i+1)*vx.size(0)])
                        else:
                            vloss = self._compute_loss(vlogits_list, vy)
                        val_losses.append(vloss.item())
                val_loss = sum(val_losses) / max(1, len(val_losses))

                if self.scheduler is not None:
                    self.scheduler.step(val_loss)

                # Save best weights
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_weights = self.state_dict()

            print(f"Epoch {epoch:03d} | Train {train_loss:.7f} | "
                  f"Val {val_loss:.7f}" if val_loss is not None else f"Epoch {epoch:03d} | Train {train_loss:.7f}")

        # Save best weights
        if best_weights is not None:
            torch.save(best_weights, self.filename)
            print(f"Best model saved with val_loss={best_val_loss:.7f} as {self.filename}")
        else:
            self.save_model()

    # ---------- Save / Load ----------
    def save_model(self):
        torch.save(self.state_dict(), self.filename)
        print(f"Model saved as {self.filename}")

    def load_model(self, map_location=None):
        state = torch.load(self.filename, map_location=map_location)
        self.load_state_dict(state)
        self.eval()
        print(f"Model loaded from {self.filename}")


Training Static One Hot into Latent Embeddings

In [None]:
static_one_hot = pd.read_pickle('/kaggle/input/static-data/static_one_hot.pkl')
static_one_hot.head()

In [None]:
input_dims = [len(static_one_hot.iloc[0][col]) for col in static_one_hot.columns]
input_dims

In [None]:
model_instance = CategoricalAutoEncoder(input_dims,use_scheduler=True,filename="static_categorical_encoder_decoder.pt")
model_wrapper = nn.DataParallel(model_instance)
model_wrapper = model_wrapper.to(device)
model_wrapper.module.fit(static_train_dataloader, epochs=30, val_dataloader=static_val_dataloader,wrapper_model=model_wrapper)

In [None]:
model_instance = CategoricalAutoEncoder(input_dims, filename="static_categorical_encoder_decoder.pt")
model_instance = model_instance.to(device)
model_instance.load_model(map_location=device)
x_sample = next(iter(static_val_dataloader))[0].to(device)
mask_list = []
start = 0
for dim in input_dims:
    end = start + dim
    feature_segment = x_sample[:, start:end]
    feature_mask = (feature_segment.sum(dim=1) > 0).float()
    mask_list.append(feature_mask)
    start = end

mask = torch.stack(mask_list, dim=1).to(device)  

z = model_instance.encode(x_sample)
preds, recon = model_instance.decode(z, mask=mask, return_onehot=True)

print("Original (first row):\n", x_sample[64])
print("Reconstructed (first row):\n", recon[64])
num_features = len(input_dims)
start = 0
accuracies = []

for i, dim in enumerate(input_dims):
    end = start + dim
    original_segment = x_sample[:, start:end]
    recon_segment = recon[:, start:end]
    present_mask = mask[:, i].unsqueeze(1) 
    if present_mask.sum() > 0:
        acc = ((original_segment * present_mask) == (recon_segment * present_mask)).float().mean().item()
    else:
        acc = float('nan')  
    
    accuracies.append(acc)
    start = end
for i, acc in enumerate(accuracies):
    if not acc != acc:  # check for nan
        print(f"Feature {i} (length {input_dims[i]}): Bitwise accuracy {acc*100:.2f}%")
    else:
        print(f"Feature {i} (length {input_dims[i]}): No present values in batch")

In [None]:
new_model = CategoricalAutoEncoder(input_dims,filename="static_categorical_encoder_decoder.pt")
new_model.load_model()

sample_batch, = next(iter(static_train_dataloader))
encodings = new_model.encode(sample_batch)
print(encodings.shape)  # (batch_size, latent_dim)
encodings

In [None]:
new_model = CategoricalAutoEncoder(input_dims,filename="temporal_categorical_encoder_decoder.pt")
new_model.load_model()

sample_batch, = next(iter(static_train_dataloader))
encodings = new_model.encode(sample_batch)
print(encodings.shape)  # (batch_size, latent_dim)
encodings

Training Temporal One Hot into Latent Embeddings



In [None]:
temporal_one_hot = pd.read_pickle('temporal_one_hot.pkl')

In [None]:
input_dims = [len(temporal_one_hot.iloc[0][col][0]) for col in temporal_one_hot.columns]
input_dims

In [None]:
input_dims = [76, 68, 126, 28, 99, 80]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_instance = CategoricalAutoEncoder(input_dims,use_scheduler=True,filename="temporal_categorical_encoder_decoder.pt")
model_wrapper = nn.DataParallel(model_instance)
model_wrapper = model_wrapper.to(device)
model_wrapper.module.fit(temporal_train_dataloader, epochs=50, val_dataloader=temporal_val_dataloader,wrapper_model=model_wrapper)

In [None]:
model_instance = CategoricalAutoEncoder(input_dims, filename="temporal_categorical_encoder_decoder.pt")
model_instance = model_instance.to(device)
model_instance.load_model(map_location=device)
x_sample = next(iter(temporal_val_dataloader))[0].to(device)
preds, recon = model_instance.reconstruct(x_sample,return_onehot=True)
print("Original (first 5 rows):\n", x_sample[:1])
print("Reconstructed (first 5 rows):\n", recon[:1])

In [None]:
model_instance = CategoricalAutoEncoder(input_dims, filename="temporal_categorical_encoder_decoder.pt")
model_instance = model_instance.to(device)
model_instance.load_model(map_location=device)
x_sample = next(iter(temporal_train_dataloader))[0].to(device)
mask_list = []
start = 0
for dim in input_dims:
    end = start + dim
    feature_segment = x_sample[:, start:end]
    feature_mask = (feature_segment.sum(dim=1) > 0).float()
    mask_list.append(feature_mask)
    start = end

mask = torch.stack(mask_list, dim=1).to(device)  

z = model_instance.encode(x_sample)
preds, recon = model_instance.decode(z, mask=mask, return_onehot=True)

print("Original (first row):\n", x_sample[30])
print("Reconstructed (first row):\n", recon[30])
num_features = len(input_dims)
start = 0
accuracies = []

for i, dim in enumerate(input_dims):
    end = start + dim
    original_segment = x_sample[:, start:end]
    recon_segment = recon[:, start:end]
    present_mask = mask[:, i].unsqueeze(1) 
    if present_mask.sum() > 0:
        acc = ((original_segment * present_mask) == (recon_segment * present_mask)).float().mean().item()
    else:
        acc = float('nan')  
    
    accuracies.append(acc)
    start = end
for i, acc in enumerate(accuracies):
    if not acc != acc:  # check for nan
        print(f"Feature {i} (length {input_dims[i]}): Bitwise accuracy {acc*100:.2f}%")
    else:
        print(f"Feature {i} (length {input_dims[i]}): No present values in batch")

In [None]:
class StochasticNormalizer:
    def __init__(self):
        self.params = {}

    def stochastic_normalize(self, X: torch.Tensor, key=None):
        X = X.float()
        unique_vals, counts = torch.unique(X, return_counts=True)
        N = X.numel()
        X_hat = torch.empty_like(X, dtype=torch.float32, device=X.device)
        lower_bound = 0.0
        params = {}

        for val, count in zip(unique_vals, counts):
            ratio = count.item() / N
            upper_bound = lower_bound + ratio
            mask = X == val
            X_hat[mask] = torch.rand(mask.sum(), device=X.device) * (upper_bound - lower_bound) + lower_bound
            params[val.item()] = [lower_bound, upper_bound]
            lower_bound = upper_bound

        if key is not None:
            self.params[key] = params
        return X_hat

    def stochastic_renormalize(self, X_hat: torch.Tensor, key=None):
        X_hat = X_hat.float()
        X = torch.zeros_like(X_hat, dtype=torch.float32, device=X_hat.device)
        if key is not None:
            params = self.params[key]

        for val, (low, high) in params.items():
            mask = (X_hat >= low) & (X_hat < high)
            X[mask] = val

        # Handle edge case where X_hat == 1.0
        mask = X_hat == 1.0
        for val, (low, high) in params.items():
            if abs(high - 1.0) < 1e-8:
                X[mask] = val
        return X

    def normalize_sample(self, X: torch.Tensor, key):
        if key not in self.params:
            raise ValueError(f"No parameters found for key '{key}'. Load or train first.")

        params = self.params[key]
        X_hat = torch.zeros_like(X, dtype=torch.float32, device=X.device)
        for val, (low, high) in params.items():
            mask = X == val
            if mask.any():
                X_hat[mask] = torch.rand(mask.sum(), device=X.device) * (high - low) + low
        return X_hat

    def save_params(self, filepath):
        torch.save(self.params, filepath)

    def load_params(self, filepath):
        self.params = torch.load(filepath)


In [None]:
static_numerical = pd.read_csv('/kaggle/input/numerical-data/static_numerical.csv')

In [None]:
X_age = torch.tensor(static_numerical["AGE"].values, dtype=torch.float32)
normalizer  = StochasticNormalizer()
X_hat_age = normalizer.stochastic_normalize(X_age,key = "AGE")
normalizer.save_params("static_params.pt")
static_numerical["AGE_normalized"] = X_hat_age.numpy()

In [None]:
temporal_numerical = pd.read_csv('/kaggle/input/numerical-data/temporal_numerical.csv')
for col in temporal_numerical.columns:
    temporal_numerical[col] = temporal_numerical[col].apply(lambda x: json.loads(x.replace("nan", "null")) if isinstance(x, str) else x)

In [None]:
normalizer = StochasticNormalizer()
for col_idx, col in enumerate(temporal_numerical.columns):
    all_values = []
    lengths_per_cell = []

    for row in temporal_numerical[col]:
        if isinstance(row, list):
            clean_vals = [v for v in row if v is not None and not pd.isna(v)]
            
            if col_idx == 0 and len(clean_vals) > 1:  
                deltas = [clean_vals[0]] + [clean_vals[i] - clean_vals[i-1] for i in range(1, len(clean_vals))]
                all_values.extend(deltas)
            else:
                all_values.extend(clean_vals)
            
            lengths_per_cell.append(len(row))
        else:
            lengths_per_cell.append(0)

    X = torch.tensor(all_values, dtype=torch.float32)
    X_hat = normalizer.stochastic_normalize(X, key=col)

    normalized_col = []
    counter = 0
    for row in temporal_numerical[col]:
        if isinstance(row, list) and len(row) > 0:
            n_valid = len([v for v in row if v is not None and not pd.isna(v)])
            norm_cell = X_hat[counter:counter+n_valid].tolist()
            counter += n_valid

            norm_cell_with_nan = []
            idx_norm = 0
            for v in row:
                if v is None or pd.isna(v):
                    norm_cell_with_nan.append(float('nan'))
                else:
                    norm_cell_with_nan.append(norm_cell[idx_norm])
                    idx_norm += 1
            normalized_col.append(norm_cell_with_nan)
        else:
            normalized_col.append([])

    temporal_numerical[col] = normalized_col

normalizer.save_params("temporal_params.pt")

# Prepare Numerical Data

In [None]:
def process_static_num(patient_row, feature_cols):
    values = []
    masks = []

    for col in feature_cols:
        val = patient_row[col]
        if pd.isna(val):
            values.append(0.0)
            masks.append(0.0)
        else:
            values.append(val)
            masks.append(1.0)

    sn = torch.tensor(values, dtype=torch.float32)     
    sn_mask = torch.tensor(masks, dtype=torch.float32) 

    return sn, sn_mask

In [None]:
def process_temporal_num(patient_row, feature_cols, date_col="DATE"):
    seqs = []
    masks = []

    for col in feature_cols:
        values = patient_row[col]
        if isinstance(values, str):
            values = ast.literal_eval(values)

        arr = np.array(values, dtype=np.float32)
        mask = ~np.isnan(arr)
        arr[np.isnan(arr)] = 0

        seqs.append(arr)
        masks.append(mask.astype(np.float32))
    seqs = np.stack(seqs, axis=-1)
    masks = np.stack(masks, axis=-1)
    timestamps = patient_row[date_col]
    if isinstance(timestamps, str):
        timestamps = ast.literal_eval(timestamps)
    timestamps = np.array(timestamps, dtype=np.float32)
    tn = torch.tensor(seqs, dtype=torch.float32)          # [seq_len, num_features]
    tn_mask = torch.tensor(masks, dtype=torch.float32)    # [seq_len, num_features]
    un = torch.tensor(timestamps, dtype=torch.float32)       # [seq_len]

    return tn, tn_mask, un, tn.shape[0]

In [None]:
static_num_tensor = []
static_num_mask_tensor = []

for idx, row in static_numerical.iterrows():
    sn, sn_mask = process_static_num(row, feature_cols=['AGE_normalized'])
    static_num_tensor.append(sn)
    static_num_mask_tensor.append(sn_mask)

static_num_tensor = torch.stack(static_num_tensor, dim=0)       
static_num_mask_tensor = torch.stack(static_num_mask_tensor, dim=0)

In [None]:
tn_list = []
tn_mask_list = []
un_list = []
seq_len_num = []
feature_cols = temporal_numerical.columns[1:]
for idx, row in temporal_numerical.iterrows():
    tn, tn_mask, un, seq_len = process_temporal_num(row, feature_cols, date_col="DATE")
    tn_list.append(tn)
    tn_mask_list.append(tn_mask)
    un_list.append(un)
    seq_len_num.append(seq_len)

# Prepare Categorical Data

In [None]:
static_categorical = pd.read_pickle('/kaggle/input/static-data/static_one_hot_flat.pkl')

In [None]:
temporal_categorical = pd.read_pickle('/kaggle/input/temporal-one-hot/temporal_one_hot_flat.pkl')

In [None]:
df = pd.read_csv('/kaggle/input/temporal-csv/temporal_categorical.csv')
temporal_times = df[df.columns[0]]

In [None]:
parsed_rows = [ast.literal_eval(row) if isinstance(row, str) else row for row in temporal_times]

delta_rows = []
for row in parsed_rows:
    row = torch.tensor(row, dtype=torch.float32)
    delta = torch.empty_like(row)
    delta[0] = row[0] 
    if len(row) > 1:
        delta[1:] = row[1:] - row[:-1]  
    delta_rows.append(delta)

all_deltas = torch.cat(delta_rows)

normalizer = StochasticNormalizer()
X_hat = normalizer.stochastic_normalize(all_deltas, key="DATE")

normalized_times = []
counter = 0
for delta in delta_rows:
    length = len(delta)
    norm_row = X_hat[counter:counter+length].tolist()
    normalized_times.append(norm_row)
    counter += length

normalizer.save_params("categorical_times_params.pt")

In [None]:
uc_list = []
for row in normalized_times:
    if isinstance(row, str):
        row = ast.literal_eval(row)
    tensor_row = torch.tensor(row, dtype=torch.float32)
    uc_list.append(tensor_row)

In [None]:
def process_static_cat(one_hot_row, model, input_dims, device="cpu"):
    if isinstance(one_hot_row, np.ndarray):
        one_hot_row = torch.tensor(one_hot_row, dtype=torch.float32, device=device)
    else:
        one_hot_row = one_hot_row.to(device)
    masks = []
    start = 0
    for dim in input_dims:
        segment = one_hot_row[start:start+dim]
        masks.append(0.0 if torch.all(segment == 0) else 1.0)
        start += dim
    sc_mask = torch.tensor(masks, dtype=torch.float32, device=device)

    with torch.no_grad():
        sc = model.encode(one_hot_row.unsqueeze(0))  
    sc = sc.squeeze(0)

    return sc, sc_mask

In [None]:
seq_lens = []
for timestamps in temporal_times:
    if isinstance(timestamps, str):
        timestamps = ast.literal_eval(timestamps)
    seq_lens.append(len(timestamps))

In [None]:
def split_temporal_one_hot(flat_array, seq_lens):
    sequences = []
    start = 0
    for l in seq_lens:
        seq = flat_array[start:start+l]
        sequences.append(seq)
        start += l
    return sequences

In [None]:
def process_temporal_cat(flat_seq_list, input_dims, model, device="cuda"):
    tc_list = []
    tc_mask_list = []
    seq_len_list = []

    model.to(device)
    model.eval()

    for seq in tqdm(flat_seq_list):
        seq_tensor = torch.tensor(seq, dtype=torch.float32, device=device)
        seq_len_list.append(seq_tensor.shape[0])

        masks = []
        start = 0
        for dim in input_dims:
            segment = seq_tensor[:, start:start+dim]
            mask_segment = (segment.abs().sum(dim=1) != 0).float().unsqueeze(1)
            masks.append(mask_segment)
            start += dim
        masks = torch.cat(masks, dim=1)
        tc_mask_list.append(masks)

        with torch.no_grad():
            latent_seq = model.encode(seq_tensor)
        tc_list.append(latent_seq)

    return tc_list, tc_mask_list, seq_len_list

In [None]:
static_one_hot = pd.read_pickle('/kaggle/input/static-data/static_one_hot.pkl')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
static_cat_list = []
static_cat_mask_list = []
static_input_dims = [len(static_one_hot.iloc[0][col]) for col in static_one_hot.columns]
static_cat_model = CategoricalAutoEncoder(static_input_dims, filename='/kaggle/input/weights/static_categorical_encoder_decoder.pt')
static_cat_model.load_model()
static_cat_model.to(device)
static_cat_model.eval()

for vec in tqdm(static_categorical):
    sc, sc_mask = process_static_cat(vec, static_cat_model, input_dims=static_input_dims, device=device)
    static_cat_list.append(sc)
    static_cat_mask_list.append(sc_mask)

In [None]:
temporal_input_dims = [76, 68, 126, 28, 99, 80]
flat_array = temporal_categorical
flat_seq_list = split_temporal_one_hot(flat_array, seq_lens)
temporal_cat_model = CategoricalAutoEncoder(temporal_input_dims, filename="/kaggle/input/weights/temporal_categorical_encoder_decoder.pt")
temporal_cat_model.load_model()
temporal_cat_model.to(device)
temporal_cat_model.eval()
tc_list, tc_mask_list, seq_len_cat = process_temporal_cat(flat_seq_list, temporal_input_dims, temporal_cat_model)

In [None]:
class EncoderDecoderDataset(Dataset):
    def __init__(self, static_num_tensor, static_num_mask_tensor,
                 static_cat_list, static_cat_mask_list,
                 tn_list, tn_mask_list, un_list,
                 tc_list, tc_mask_list, uc_list):
        self.static_num_tensor = static_num_tensor.cpu()
        self.static_num_mask_tensor = static_num_mask_tensor.cpu()
        self.static_cat_list = [t.cpu() for t in static_cat_list]
        self.static_cat_mask_list = [t.cpu() for t in static_cat_mask_list]
        self.tn_list = [t.cpu() for t in tn_list]
        self.tn_mask_list = [t.cpu() for t in tn_mask_list]
        self.un_list = [torch.as_tensor(u, dtype=torch.float32, device="cpu") for u in un_list]
        self.tc_list = [t.cpu() for t in tc_list]
        self.tc_mask_list = [t.cpu() for t in tc_mask_list]
        self.uc_list = [torch.as_tensor(u, dtype=torch.float32, device="cpu") for u in uc_list]

        self.num_patients = len(self.static_num_tensor)

    def __len__(self):
        return self.num_patients

    def __getitem__(self, idx):
        sn = self.static_num_tensor[idx]
        sn_mask = self.static_num_mask_tensor[idx]
        sc = self.static_cat_list[idx]
        sc_mask = self.static_cat_mask_list[idx]
        tn = self.tn_list[idx]
        tn_mask = self.tn_mask_list[idx]
        un = self.un_list[idx]
        tc = self.tc_list[idx]
        tc_mask = self.tc_mask_list[idx]
        uc = self.uc_list[idx]

        seq_len_num = tn.size(0)
        seq_len_cat = tc.size(0)

        return (
            sn, sc, tn, tc, un, uc,
            sn_mask, sc_mask, tn_mask, tc_mask,
            seq_len_num, seq_len_cat
        )

In [None]:
def collate_fn(batch):
    (sn_list, sc_list, 
     tn_list, tc_list, 
     un_list, uc_list, 
     sn_mask_list, sc_mask_list, 
     tn_mask_list, tc_mask_list, 
     seq_len_num_list, seq_len_cat_list) = zip(*batch)

    sn = torch.stack(sn_list, dim=0)
    sc = torch.stack(sc_list, dim=0)
    sn_mask = torch.stack(sn_mask_list, dim=0)
    sc_mask = torch.stack(sc_mask_list, dim=0)

    tn = pad_sequence(tn_list, batch_first=True, padding_value=0.0)
    tn_mask = pad_sequence(tn_mask_list, batch_first=True, padding_value=0.0)
    un = pad_sequence(un_list, batch_first=True, padding_value=0.0)

    tc = pad_sequence(tc_list, batch_first=True, padding_value=0.0)
    tc_mask = pad_sequence(tc_mask_list, batch_first=True, padding_value=0.0)
    uc = pad_sequence(uc_list, batch_first=True, padding_value=0.0)

    seq_len_num = torch.tensor(seq_len_num_list, dtype=torch.long)
    seq_len_cat = torch.tensor(seq_len_cat_list, dtype=torch.long)

    return sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, seq_len_num, seq_len_cat

In [None]:
GLOBAL_MAX_SEQ_NUM = max(seq_len_num)
GLOBAL_MAX_SEQ_CAT = max(seq_len_cat)
print(GLOBAL_MAX_SEQ_NUM)
print(GLOBAL_MAX_SEQ_CAT)

In [None]:
dataset = EncoderDecoderDataset(static_num_tensor, static_num_mask_tensor,
                                static_cat_list, static_cat_mask_list,
                                tn_list, tn_mask_list, un_list,
                                tc_list, tc_mask_list, uc_list)

num_patients = len(dataset)
val_ratio = 0.05
val_size = int(num_patients * val_ratio)
train_size = num_patients - val_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print(f"\nData split into:")
print(f" - Training set: {train_size} patients")
print(f" - Validation set: {val_size} patients")
print("\n--- Creating DataLoaders ---")

# --- DataLoaders ---
train_loader = DataLoader(
    train_dataset,
    batch_size=128,          
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")

In [None]:
first_batch = next(iter(train_loader))

(sn, sc, tn, tc, un, uc, 
 sn_mask, sc_mask, tn_mask, tc_mask, 
 seq_len_num, seq_len_cat) = first_batch

sample = (
    sn[0], sc[0], tn[0], tc[0], un[0], uc[0],
    sn_mask[0], sc_mask[0], tn_mask[0], tc_mask[0],
    seq_len_num[0], seq_len_cat[0] 
)

names = ["sn", "sc", "tn", "tc", "un", "uc",
         "sn_mask", "sc_mask", "tn_mask", "tc_mask",
         "seq_len_num", "seq_len_cat"]

for name, value in zip(names, sample):
    if isinstance(value, torch.Tensor):
        if value.dim() == 0:
            print(f"{name}: {value.item()}")
        else:
            print(f"{name}: {tuple(value.shape)}")
    else:
        print(f"{name}: {value}")

In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dims=(512,256,256), out_dim=16, dropout=0.1,
                 use_batch_norm=True, activation='lrelu', final_activation=None):
        super().__init__()
        layers = []

        activation_fn = {
            'lrelu': nn.LeakyReLU(0.01),
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'swish': nn.SiLU(),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid()
        }[activation]

        d = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(d,h))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(h))
            layers.append(activation_fn)
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            d = h

        layers.append(nn.Linear(d, out_dim))
        if final_activation is not None:
            layers.append(final_activation)   

        self.net = nn.Sequential(*layers)

    def forward(self,x):
        return self.net(x)


In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self,in_dim,hidden_dim=256,out_dim=128,depth=2,use_post_mlp=True):
        super().__init__()

        self.gru = nn.GRU(
            input_size = in_dim,
            hidden_size = hidden_dim,
            num_layers = depth,
            batch_first = True,
            dropout = 0.1
        )

        self.attn = nn.Linear(hidden_dim, 1)

        if use_post_mlp:
            self.fc = nn.Sequential(
                nn.Linear(hidden_dim,hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Linear(hidden_dim,out_dim)
            )
        else:
            self.fc = nn.Linear(hidden_dim,out_dim)

    def forward(self,x):
        # x: (batch, seq_len, in_dim)
        out, _ = self.gru(x)
        attn_scores = self.attn(out).squeeze(-1) # (batch,seq_len)
        attn_weights = torch.softmax(attn_scores,dim=1)
        context = torch.sum(out * attn_weights.unsqueeze(-1),dim=1) # (batch, hidden_dim)
        return self.fc(context)

In [None]:
class FusionMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims=[128, 64], output_dim=64, dropout=0.1):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.BatchNorm1d(dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-1], output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


In [None]:
class TemporalDecoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim = 512, depth = 2,
                 num_features=None, embed_dim=None, is_categorical=False):
        super().__init__()
        self.is_categorical = is_categorical
        self.fc_init_h = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, depth * hidden_dim)
        )

        self.gru = nn.GRU(
            input_size=latent_dim,
            hidden_size=hidden_dim,
            num_layers=depth,
            batch_first=True,
            dropout=0.1 if depth > 1 else 0.0
        )

        self.head_value = nn.Sequential(
            nn.Linear(hidden_dim, embed_dim if is_categorical else num_features),
            nn.Tanh() if is_categorical else nn.Sigmoid()
        )

        self.head_time = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 8),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 8, 1),
            nn.Sigmoid()
        )

        self.head_mask = nn.Sequential(
            nn.Linear(hidden_dim  , num_features)
        )

    def forward(self, e, max_seq_len, mask_threshold = 0.5):
        batch_size = e.size(0)
        device = e.device
        h_0_flat = self.fc_init_h(e)
        h_0 = h_0_flat.view(self.gru.num_layers, batch_size, self.gru.hidden_size)

        tn_hat_list, u_hat_list, mask_hat_list = [], [], []
        h_t = h_0

        seq_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
        active_sequences = torch.ones(batch_size, dtype=torch.bool, device=device)

        for _ in range(max_seq_len):
            if not active_sequences.any():
                break

            gru_input = e.unsqueeze(1)
            gru_out, h_t_new = self.gru(gru_input, h_t)

            h_t = torch.where(active_sequences.view(1, -1, 1), h_t_new, h_t) 

            tn_hat_step = self.head_value(gru_out.squeeze(1))
            u_hat_step = self.head_time(gru_out.squeeze(1)) 
            mask_hat_step = self.head_mask(gru_out.squeeze(1))  

            tn_hat_list.append(tn_hat_step)
            u_hat_list.append(u_hat_step)
            mask_hat_list.append(mask_hat_step)

            seq_lengths += active_sequences.long()

            stop_condition = (torch.sigmoid(mask_hat_step) < mask_threshold).all(dim=-1)
            active_sequences = active_sequences & ~stop_condition

        tn_hat = torch.stack(tn_hat_list, dim=1)
        u_hat = torch.stack(u_hat_list, dim=1)
        mask_hat = torch.stack(mask_hat_list, dim=1)

        return tn_hat, u_hat, mask_hat, seq_lengths

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,sn_dim,sce_latent_dim,tn_dim,tce_latent_dim,sc_dim,tc_dim,latent_dim=512):
        super().__init__()
        self.latent_dim = latent_dim
        static_input_dim = sn_dim + sce_latent_dim
        self.static_encoder = MLP(static_input_dim, out_dim=128)
        self.temporal_num_encoder = TemporalEncoder(tn_dim + 1, out_dim=256)
        self.temporal_cat_encoder = TemporalEncoder(tce_latent_dim + 1, out_dim=256)
        self.static_mask = MLP(in_dim = sn_dim + sc_dim,hidden_dims=(32,16,8),out_dim=4)
        self.temporal_num_mask_encoder = TemporalEncoder(tn_dim,hidden_dim=128,out_dim=64)
        self.temporal_cat_mask_encoder = TemporalEncoder(tc_dim,hidden_dim=128,out_dim=64)
        fusion_dim = 128 + 256 + 256 + 4 + 64 + 64
        self.fusion = FusionMLP(fusion_dim, hidden_dims=[2048, 1024], output_dim=latent_dim)
        self.static_decoder_num = MLP(latent_dim, out_dim=sn_dim,final_activation=nn.Sigmoid())
        self.static_decoder_cat = MLP(latent_dim,out_dim=sce_latent_dim,final_activation=nn.Tanh())
        self.static_decoder_mask = MLP(latent_dim, out_dim=sn_dim + sc_dim)
        self.temporal_decoder_num = TemporalDecoder(latent_dim, num_features=tn_dim, is_categorical=False)
        self.temporal_decoder_cat = TemporalDecoder(latent_dim, num_features=tc_dim, embed_dim=tce_latent_dim,is_categorical=True)

    def encode(self,sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask):
        static_in = torch.cat([sn,sc],dim=-1)
        static_e = self.static_encoder(static_in)
        un = un.unsqueeze(-1)
        temporal_num_in = torch.cat([tn,un],dim=-1)
        temporal_num_e = self.temporal_num_encoder(temporal_num_in)
        uc = uc.unsqueeze(-1)
        temporal_cat_in = torch.cat([tc,uc],dim=-1)
        temporal_cat_e = self.temporal_cat_encoder(temporal_cat_in)
        static_mask_in = torch.cat([sn_mask,sc_mask],dim=-1)
        static_mask_e = self.static_mask(static_mask_in)
        temporal_num_mask_e = self.temporal_num_mask_encoder(tn_mask)
        temporal_cat_mask_e = self.temporal_cat_mask_encoder(tc_mask)
        e = torch.cat([static_e,temporal_num_e,temporal_cat_e,static_mask_e,temporal_num_mask_e,temporal_cat_mask_e],dim=-1)
        e = self.fusion(e)
        return e


    def decode(self, e, max_seq_len_num=100,max_seq_len_cat=300):
        sn_hat = self.static_decoder_num(e)
        sc_hat = self.static_decoder_cat(e)
        static_mask_hat = self.static_decoder_mask(e)
        sn_dim = sn_hat.shape[-1]
        sn_mask_hat = static_mask_hat[..., :sn_dim]
        sc_mask_hat = static_mask_hat[..., sn_dim:]
        tn_hat, un_hat, tn_mask_hat, seq_len_num = self.temporal_decoder_num(e, max_seq_len_num)
        tc_hat, uc_hat, tc_mask_hat, seq_len_cat = self.temporal_decoder_cat(e, max_seq_len_cat)
        return sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,seq_len_num,seq_len_cat

    def forward(self,sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask,max_seq_len_num=None,max_seq_len_cat=None):
        e = self.encode(sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask)
        if max_seq_len_num is None:
            max_seq_len_num = tn.size(1)
        if max_seq_len_cat is None:
            max_seq_len_cat = tc.size(1)
        return self.decode(e, max_seq_len_num,max_seq_len_cat)

    def get_encoding(self, sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, as_numpy: bool = False):
        self.eval()
        device = next(self.parameters()).device
        with torch.no_grad():
            inputs = [sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask]
            inputs = [x.to(device) if torch.is_tensor(x) else x for x in inputs]
            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask = inputs
            e = self.encode(sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask)
            if as_numpy:
                return e.detach().cpu().numpy()
            return e
    
    def generate_decoding(self, e=None, batch_size=1, max_seq_len_num=100, max_seq_len_cat=300, device="cpu", mask_threshold=0.5):
        self.eval()
        device = torch.device(device)
        with torch.no_grad():
            (sn_hat, sc_hat, tn_hat, tc_hat, un_hat, uc_hat, sn_mask_hat, sc_mask_hat, tn_mask_hat, tc_mask_hat, seq_len_num, seq_len_cat) = self.decode(e, max_seq_len_num, max_seq_len_cat)
            
            # Apply sigmoid + threshold for all masks
            sn_mask_hat = torch.sigmoid(sn_mask_hat) > mask_threshold
            sc_mask_hat = torch.sigmoid(sc_mask_hat) > mask_threshold
            tn_mask_hat = torch.sigmoid(tn_mask_hat) > mask_threshold
            tc_mask_hat = torch.sigmoid(tc_mask_hat) > mask_threshold
    
            # Slice sequences according to predicted lengths
            tn_hat = [tn_hat[i, :seq_len_num[i]] for i in range(tn_hat.size(0))]
            tc_hat = [tc_hat[i, :seq_len_cat[i]] for i in range(tc_hat.size(0))]
            un_hat = [un_hat[i, :seq_len_num[i]] for i in range(un_hat.size(0))]
            uc_hat = [uc_hat[i, :seq_len_cat[i]] for i in range(uc_hat.size(0))]
            tn_mask_hat = [tn_mask_hat[i, :seq_len_num[i]] for i in range(tn_mask_hat.size(0))]
            tc_mask_hat = [tc_mask_hat[i, :seq_len_cat[i]] for i in range(tc_mask_hat.size(0))]
    
        return (sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat, tn_mask_hat, tc_mask_hat)

    def compute_loss(
        self, sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
        sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
        tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat, true_seq_len_num, true_seq_len_cat,
        lambda_mse=1.0, lambda_len=0.1
    ):
        losses = {}
        bce = nn.BCEWithLogitsLoss(reduction='none')
        mse = nn.MSELoss(reduction='none')
        epsilon = 1e-8
        device = sn.device
        if not torch.is_tensor(pred_seq_len_num):
            pred_seq_len_num = torch.tensor(pred_seq_len_num, device=device, dtype=torch.float32)
        if pred_seq_len_num.dim() == 0:
            pred_seq_len_num = pred_seq_len_num.unsqueeze(0)
        if not torch.is_tensor(pred_seq_len_cat):
            pred_seq_len_cat = torch.tensor(pred_seq_len_cat, device=device, dtype=torch.float32)
        if pred_seq_len_cat.dim() == 0:
            pred_seq_len_cat = pred_seq_len_cat.unsqueeze(0)
        if not torch.is_tensor(true_seq_len_num):
            true_seq_len_num = torch.tensor(true_seq_len_num, device=device, dtype=torch.float32)
        if true_seq_len_num.dim() == 0:
            true_seq_len_num = true_seq_len_num.unsqueeze(0)
        if not torch.is_tensor(true_seq_len_cat):
            true_seq_len_cat = torch.tensor(true_seq_len_cat, device=device, dtype=torch.float32)
        if true_seq_len_cat.dim() == 0:
            true_seq_len_cat = true_seq_len_cat.unsqueeze(0)
        pred_len_num = tn_hat.size(1)
        pred_len_cat = tc_hat.size(1)
        tn_sliced = tn[:, :pred_len_num, :]
        un_sliced = un[:, :pred_len_num].unsqueeze(-1)
        tn_mask_sliced = tn_mask[:, :pred_len_num, :]
        tc_sliced = tc[:, :pred_len_cat, :]
        uc_sliced = uc[:, :pred_len_cat].unsqueeze(-1)
        tc_mask_sliced = tc_mask[:, :pred_len_cat, :]
        seq_mask_num = torch.arange(pred_len_num, device=device).unsqueeze(0) < pred_seq_len_num.unsqueeze(1)
        seq_mask_num = seq_mask_num.unsqueeze(-1).float()
        seq_mask_cat = torch.arange(pred_len_cat, device=device).unsqueeze(0) < pred_seq_len_cat.unsqueeze(1)
        seq_mask_cat = seq_mask_cat.unsqueeze(-1).float()
        losses["sn_mask"] = bce(sn_mask_hat, sn_mask).mean()
        losses["sc_mask"] = bce(sc_mask_hat, sc_mask).mean()
        tn_mask_loss = bce(tn_mask_hat, tn_mask_sliced).mean(dim=-1)
        losses["tn_mask"] = (tn_mask_loss * seq_mask_num[..., 0]).sum() / (seq_mask_num[..., 0].sum() + epsilon)
        tc_mask_loss = bce(tc_mask_hat, tc_mask_sliced).mean(dim=-1)
        losses["tc_mask"] = (tc_mask_loss * seq_mask_cat[..., 0]).sum() / (seq_mask_cat[..., 0].sum() + epsilon)
        losses["sn"] = (mse(sn_hat, sn) * sn_mask).sum() / (sn_mask.sum() + epsilon)
        losses["sc"] = mse(sc_hat, sc).mean()
        losses["tn"] = (mse(tn_hat, tn_sliced) * tn_mask_sliced * seq_mask_num).sum() / ((tn_mask_sliced * seq_mask_num).sum() + epsilon)
        losses["tc"] = (mse(tc_hat, tc_sliced) * seq_mask_cat).sum() / (seq_mask_cat.sum() + epsilon)
        losses["un"] = (mse(un_hat, un_sliced) * seq_mask_num).sum() / (seq_mask_num.sum() + epsilon)
        losses["uc"] = (mse(uc_hat, uc_sliced) * seq_mask_cat).sum() / (seq_mask_cat.sum() + epsilon)
        losses["len_num"] = F.mse_loss(pred_seq_len_num.float(), true_seq_len_num.float())
        losses["len_cat"] = F.mse_loss(pred_seq_len_cat.float(), true_seq_len_cat.float())
        total_loss = (
            losses["sn_mask"] + losses["sc_mask"] +
            losses["tn_mask"] + losses["tc_mask"] +
            lambda_mse * (losses["sn"] + losses["sc"] + losses["tn"] + losses["tc"] + losses["un"] + losses["uc"]) +
            lambda_len * (losses["len_num"] + losses["len_cat"])
        )
        return total_loss, losses

    def fit(self, train_dataloader, val_dataloader=None, epochs=20, lr=1e-3, optimizer="adam", 
            lambda_mse=1.0, lambda_len=0.1, device="cpu",
            scheduler_patience=5, scheduler_factor=0.1,resume_from=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        print(f"Training on: {device}")
        if optimizer.lower() == "adam":
            opt = optim.Adam(self.parameters(), lr=lr)
        elif optimizer.lower() == "rmsprop":
            opt = optim.RMSprop(self.parameters(), lr=lr)
        else:
            raise ValueError("Optimizer must be 'adam' or 'rmsprop'")
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=scheduler_patience, factor=scheduler_factor)
        start_epoch = 1
        best_val_loss = float('inf')
        
        if resume_from is not None:
            start_epoch = self.load_checkpoint(resume_from, optimizer=opt, scheduler=scheduler)  

        for epoch in range(start_epoch, epochs + 1):
            self.train()
            total_train_loss = 0.0
            train_loss_components = {}
            for batch in tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False):
                sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, true_seq_len_num, true_seq_len_cat = [
                    x.to(device) if torch.is_tensor(x) else x for x in batch
                ]
                max_seq_len_num = tn.size(1)
                max_seq_len_cat = tc.size(1)
                (sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,pred_seq_len_num,pred_seq_len_cat) = self(
                    sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, max_seq_len_num, max_seq_len_cat
                )
                loss, losses_dict = self.compute_loss(
                    sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
                    sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
                    tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat,
                    true_seq_len_num, true_seq_len_cat, lambda_mse=lambda_mse, lambda_len=lambda_len
                )
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_train_loss += loss.item()
                for k, v in losses_dict.items():
                    train_loss_components[k] = train_loss_components.get(k, 0.0) + v.item()
            avg_train_loss = total_train_loss / len(train_dataloader)
            avg_train_components = {k: v / len(train_dataloader) for k, v in train_loss_components.items()}
            avg_val_loss = None
            val_loss_components = {}
            if val_dataloader is not None:
                self.eval()
                total_val_loss = 0.0
                with torch.no_grad():
                    for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]", leave=False):
                        sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, true_seq_len_num, true_seq_len_cat = [
                            x.to(device) if torch.is_tensor(x) else x for x in batch
                        ]
                        max_seq_len_num = tn.size(1)
                        max_seq_len_cat = tc.size(1)
                        (sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,pred_seq_len_num,pred_seq_len_cat) = self(
                            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, max_seq_len_num, max_seq_len_cat)
                        loss, losses_dict = self.compute_loss(
                            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
                            sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
                            tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat,
                            true_seq_len_num, true_seq_len_cat, lambda_mse=lambda_mse, lambda_len=lambda_len
                        )
                        total_val_loss += loss.item()
                        for k, v in losses_dict.items():
                            val_loss_components[k] = val_loss_components.get(k, 0.0) + v.item()
                avg_val_loss = total_val_loss / len(val_dataloader)
                avg_val_components = {k: v / len(val_dataloader) for k, v in val_loss_components.items()}
                scheduler.step(avg_val_loss)
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    self.save_checkpoint(opt, scheduler, epoch, filename='best_encoder_decoder_ckpt.pt')
                    
            if avg_val_loss is not None:
                print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.7f} | Val Loss: {avg_val_loss:.7f} | Train LenNum: {avg_train_components['len_num']:.6f} | Train LenCat: {avg_train_components['len_cat']:.6f} | Val LenNum: {avg_val_components['len_num']:.6f} | Val LenCat: {avg_val_components['len_cat']:.6f}")
            else:
                scheduler.step(avg_train_loss)
                print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.7f} | Train LenNum: {avg_train_components['len_num']:.6f} | Train LenCat: {avg_train_components['len_cat']:.6f}")

        self.save_checkpoint(opt, scheduler, epoch)

    def save_checkpoint(self, optimizer, scheduler, epoch, filename="encoder_decoder_ckpt.pt"):
        checkpoint = {
            "model_state": self.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
            "epoch": epoch
        }
        torch.save(checkpoint, filename)
        print(f"Checkpoint saved at epoch {epoch} -> {filename}")

    def load_checkpoint(self, filename="encoder_decoder_ckpt.pt", optimizer=None, scheduler=None, map_location=None):
        checkpoint = torch.load(filename, map_location=map_location)
        self.load_state_dict(checkpoint["model_state"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
        if scheduler is not None and checkpoint["scheduler_state"] is not None:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resumed from checkpoint {filename}, starting at epoch {start_epoch}")
        return start_epoch


In [None]:
sn_dim = static_num_tensor.shape[1]            
sce_latent_dim = static_cat_list[0].shape[0]   
tn_dim = tn_list[0].shape[1:][-1]                
tce_latent_dim = tc_list[0].shape[1]          
sc_dim = static_cat_mask_list[0].shape[-1]
tc_dim = tc_mask_list[0].shape[-1]                      

model = EncoderDecoder(sn_dim=sn_dim,sce_latent_dim=sce_latent_dim,tn_dim=tn_dim,
        tce_latent_dim=tce_latent_dim,sc_dim=sc_dim,tc_dim=tc_dim,latent_dim=256)
model.to(device)

model.fit(train_loader, val_loader, epochs=50, lr=1e-4, optimizer="adam", lambda_mse=2.0, 
          lambda_len=1.0,device="cuda",resume_from='best_encoder_decoder_ckpt.pt')

In [None]:
def encode_dataloader(encoder, dataloader, device=None, as_numpy=False):
    encoder.eval()
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder.to(device)

    embeddings_list = []
    seq_len_num_list = []
    seq_len_cat_list = []

    with torch.no_grad():
        for batch in dataloader:
            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, seq_len_num, seq_len_cat = [
                x.to(device) if torch.is_tensor(x) else x for x in batch
            ]
            e = encoder.get_encoding(
                sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, as_numpy=False
            )
            embeddings_list.append(e.cpu())

            seq_len_num_list.append(seq_len_num.cpu())
            seq_len_cat_list.append(seq_len_cat.cpu())
            
    embeddings = torch.cat(embeddings_list, dim=0)          
    seq_len_num_all = torch.cat(seq_len_num_list, dim=0)    
    seq_len_cat_all = torch.cat(seq_len_cat_list, dim=0)     

    if as_numpy:
        return embeddings.numpy(), {
            "seq_len_num": seq_len_num_all.numpy(),
            "seq_len_cat": seq_len_cat_all.numpy()
        }
    return embeddings, {"seq_len_num": seq_len_num_all, "seq_len_cat": seq_len_cat_all}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sn_dim = static_num_tensor.shape[1]            
sce_latent_dim = static_cat_list[0].shape[0]   
tn_dim = tn_list[0].shape[1:][-1]                
tce_latent_dim = tc_list[0].shape[1]          
sc_dim = static_cat_mask_list[0].shape[-1]
tc_dim = tc_mask_list[0].shape[-1]                      
encoder = EncoderDecoder(sn_dim=sn_dim,sce_latent_dim=sce_latent_dim,tn_dim=tn_dim,
            tce_latent_dim=tce_latent_dim,sc_dim=sc_dim,tc_dim=tc_dim,latent_dim=256)
encoder.load_checkpoint(filename="best_encoder_decoder_ckpt.pt")
encoder.to(device)
emb_train, info_train = encode_dataloader(encoder, train_loader, device=device)
emb_val, info_val = encode_dataloader(encoder, val_loader, device=device)

In [None]:
torch.save(emb_train, "encoder_embeddings_train.pt")
torch.save(emb_val, "encoder_embeddings_val.pt")
torch.save(info_train, "encoder_embeddings_info_train.pt")
torch.save(info_val, "encoder_embeddings_info_val.pt")
print("Saved embeddings: train:", emb_train.shape, " val:", emb_val.shape)

In [None]:
all_embeddings = torch.cat([emb_train, emb_val], dim=0)  
torch.save(all_embeddings, "encoder_embeddings_all.pt")

In [None]:
all_embeddings = torch.cat([emb_train, emb_val], dim=0)
gan_dataset = TensorDataset(all_embeddings)
dataset_size = len(gan_dataset)
split_size = int(0.05 * dataset_size)
train_dataset, val_dataset = random_split(gan_dataset, [dataset_size - split_size, split_size])
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True,num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, drop_last=False,num_workers=4)
print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

In [None]:
first_batch = next(iter(train_loader))

print("Shape of first batch:", first_batch[0].shape)
print("First 5 entries in the first batch:")
print(first_batch[0][:5])

In [None]:
batch = torch.stack(first_batch) 
overall_mean = batch.mean().item()
overall_std = batch.std().item()
print("Overall mean:", overall_mean)
print("Overall std:", overall_std)

In [None]:
class Generator(nn.Module):
    def __init__(self,encoder_state_dim,latent_dim = 256,hidden_dims=None):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [512,1024,2048,1024,512]

        layers = []
        prev_dim = latent_dim

        for i,hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim,hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.2 if i < len(hidden_dims)//2 else 0.1)
            ])
            prev_dim = hidden_dim

        layers.extend([
            nn.Linear(prev_dim,encoder_state_dim*2),
            nn.BatchNorm1d(encoder_state_dim*2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(encoder_state_dim*2,encoder_state_dim)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self,z):
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,encoder_state_dim,hidden_dims = None):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [256,512,1024,2048,1024,512,256,128]

        layers = []
        prev_dim = encoder_state_dim

        for i,hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim,hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3 if i < len(hidden_dims)//2 else 0.2)
            ])
            prev_dim = hidden_dim

        layers.extend([
            nn.Linear(prev_dim,64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(64,1)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

In [None]:
def compute_mmd(x, y, sigma=None):
    if isinstance(x, list):
        x = torch.stack(x)
    if isinstance(y, list):
        y = torch.stack(y)

    x, y = x.to(torch.float32), y.to(torch.float32)

    combined = torch.cat([x, y], dim=0)
    mean = combined.mean(dim=0, keepdim=True)
    std = combined.std(dim=0, keepdim=True) + 1e-6
    x_norm = (x - mean) / std
    y_norm = (y - mean) / std
    if sigma is None:
        xy = torch.cat([x_norm, y_norm], dim=0)
        dists = torch.cdist(xy, xy, p=2)
        sigma = torch.median(dists).item()
        if sigma == 0:
            sigma = 1.0

    def gaussian_kernel(a, b, sigma):
        dist_sq = torch.cdist(a, b, p=2) ** 2
        return torch.exp(-dist_sq / (2 * sigma ** 2))

    k_xx = gaussian_kernel(x_norm, x_norm, sigma)
    k_yy = gaussian_kernel(y_norm, y_norm, sigma)
    k_xy = gaussian_kernel(x_norm, y_norm, sigma)

    mmd = k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
    return mmd.item()


In [None]:
class WGANGP:
    def __init__(self, encoder_state_dim, latent_dim=128,
                 generator_hidden_dims=None, discriminator_hidden_dims=None,
                 lr_generator=1e-4, lr_discriminator=1e-4,
                 lambda_gp=10.0, n_critic=5, device=None,
                 plateau_factor=0.5, plateau_patience=10):

        self.encoder_state_dim = encoder_state_dim
        self.latent_dim = latent_dim
        self.lambda_gp = lambda_gp
        self.n_critic = n_critic
        self.device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

        self.generator = Generator(encoder_state_dim, latent_dim, generator_hidden_dims).to(self.device)
        self.discriminator = Discriminator(encoder_state_dim, discriminator_hidden_dims).to(self.device)

        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=lr_generator, betas=(0.5, 0.9))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr_discriminator, betas=(0.5, 0.9))

        # ReduceLROnPlateau scheduler
        self.scheduler_G = ReduceLROnPlateau(self.optimizer_G, mode='min', factor=plateau_factor,
                                             patience=plateau_patience)
        self.scheduler_D = ReduceLROnPlateau(self.optimizer_D, mode='min', factor=plateau_factor,
                                             patience=plateau_patience)

        self.start_epoch = 1
        self.best_mmd = float("inf")

    def gradient_penalty(self, real_samples, fake_samples):
        real_samples = real_samples.to(self.device).float()
        fake_samples = fake_samples.to(self.device).float()
        batch_size = real_samples.size(0)
        epsilon = torch.rand(batch_size, 1, device=self.device).expand_as(real_samples)
        interpolated = epsilon * real_samples + (1 - epsilon) * fake_samples.detach()
        interpolated.requires_grad_(True)

        d_interpolated = self.discriminator(interpolated)
        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(d_interpolated, device=self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        gradient_penalty = self.lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def discriminator_loss(self, real_samples, fake_samples):
        d_real = self.discriminator(real_samples)
        d_fake = self.discriminator(fake_samples.detach())
        wasserstein_distance = d_real.mean() - d_fake.mean()
        gp = self.gradient_penalty(real_samples, fake_samples)
        d_loss = -wasserstein_distance + gp
        return d_loss, wasserstein_distance, gp

    def generator_loss(self, fake_samples):
        d_fake = self.discriminator(fake_samples)
        g_loss = -d_fake.mean()
        return g_loss

    def generate_samples(self, batch_size):
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_samples = self.generator(z).float()
        return fake_samples

    def fit(self, train_dataloader, epochs=100, resume_from=None, 
            val_dataloader=None, verbose=True):

        if resume_from:
            self.load_checkpoint(resume_from)
            print(f"Resumed training from checkpoint: {resume_from}")

        self.generator.train()
        self.discriminator.train()

        history = {
            "train_d_loss": [], "train_g_loss": [], "train_wd": [], "train_gp": [],
            "val_g_loss": [], "val_wd": [], "val_mmd": []
        }

        for epoch in range(self.start_epoch, epochs + 1):
            epoch_d_loss = epoch_g_loss = epoch_wd = epoch_gp = 0.0
            batches = 0
            progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch}/{epochs}') if verbose else train_dataloader

            for i, (real_samples,) in enumerate(progress_bar):
                batches += 1
                real_samples = real_samples.to(self.device).float()
                bsz = real_samples.size(0)

                # ---- Train Discriminator ----
                self.optimizer_D.zero_grad()
                fake_samples = self.generate_samples(bsz)
                d_loss, wd, gp = self.discriminator_loss(real_samples, fake_samples)
                d_loss.backward()
                self.optimizer_D.step()

                epoch_d_loss += d_loss.item()
                epoch_wd += wd.item()
                epoch_gp += gp.item()

                # ---- Train Generator every n_critic ----
                if i % self.n_critic == 0:
                    self.optimizer_G.zero_grad()
                    fake_samples = self.generate_samples(bsz)
                    g_loss = self.generator_loss(fake_samples)
                    g_loss.backward()
                    self.optimizer_G.step()
                    epoch_g_loss += g_loss.item()

            avg_d_loss = epoch_d_loss / batches
            avg_g_loss = epoch_g_loss / max(1, (batches // self.n_critic))
            avg_wd = epoch_wd / batches
            avg_gp = epoch_gp / batches

            history["train_d_loss"].append(avg_d_loss)
            history["train_g_loss"].append(avg_g_loss)
            history["train_wd"].append(avg_wd)
            history["train_gp"].append(avg_gp)

            # ---- Validation ----
            if val_dataloader is not None:
                self.generator.eval()
                self.discriminator.eval()
                real_embeddings, fake_embeddings = [], []

                with torch.no_grad():
                    for real_samples, in val_dataloader:
                        real_samples = real_samples.to(self.device).float()
                        fake_samples = self.generate_samples(real_samples.size(0))
                        real_embeddings.append(real_samples)
                        fake_embeddings.append(fake_samples)

                real_embeddings = torch.cat(real_embeddings, dim=0)
                fake_embeddings = torch.cat(fake_embeddings, dim=0)
                current_mmd = compute_mmd(real_embeddings, fake_embeddings)
                history["val_mmd"].append(current_mmd)

                val_g_loss = val_wd = 0.0
                val_batches = 0
                with torch.no_grad():
                    for real_samples, in val_dataloader:
                        val_batches += 1
                        real_samples = real_samples.to(self.device).float()
                        bsz = real_samples.size(0)
                        fake_samples = self.generate_samples(bsz)

                        d_real = self.discriminator(real_samples)
                        d_fake = self.discriminator(fake_samples)

                        wd = d_real.mean() - d_fake.mean()
                        g_loss = self.generator_loss(fake_samples)
                        val_g_loss += g_loss.item()
                        val_wd += wd.item()

                avg_val_g_loss = val_g_loss / val_batches
                avg_val_wd = val_wd / val_batches

                history["val_g_loss"].append(avg_val_g_loss)
                history["val_wd"].append(avg_val_wd)

                if verbose:
                    print(f"[Epoch {epoch}] Train D: {avg_d_loss:.7f}, G: {avg_g_loss:.7f}, WD: {avg_wd:.7f}, GP: {avg_gp:.7f} | "
                          f"Val WD: {avg_val_wd:.7f}, MMD: {current_mmd:.7f}")

                # Save best model
                if current_mmd < self.best_mmd:
                    self.best_mmd = current_mmd
                    self.save_checkpoint("best_gan.pt", epoch, history, is_best=True)

                self.generator.train()
                self.discriminator.train()

                # ---- Step scheduler using validation MMD ----
                self.scheduler_G.step(current_mmd)
                self.scheduler_D.step(current_mmd)

            else:
                if avg_wd > getattr(self, "best_wd", float("-inf")):
                    self.best_wd = avg_wd
                    self.save_checkpoint("best_gan.pt", epoch, history, is_best=True)

        self.save_checkpoint("final_gan.pt", epoch, history)
        print("Training completed!")
        return history

    def save_checkpoint(self, filename, epoch, history, is_best=False):
        state = {
            "epoch": epoch + 1,
            "generator_state": self.generator.state_dict(),
            "discriminator_state": self.discriminator.state_dict(),
            "optimizer_G": self.optimizer_G.state_dict(),
            "optimizer_D": self.optimizer_D.state_dict(),
            "latent_dim": self.latent_dim,
            "encoder_state_dim": self.encoder_state_dim,
            "best_mmd": self.best_mmd,
            "history": history
        }
        torch.save(state, filename)
        if is_best:
            print(f"Best model saved at {filename} (MMD: {self.best_mmd:.7f})")
        else:
            print(f"Checkpoint saved at {filename}")

    def load_checkpoint(self, filename, map_location=None):
        checkpoint = torch.load(filename, map_location=map_location or self.device)
        self.generator.load_state_dict(checkpoint["generator_state"])
        self.discriminator.load_state_dict(checkpoint["discriminator_state"])
        self.optimizer_G.load_state_dict(checkpoint["optimizer_G"])
        self.optimizer_D.load_state_dict(checkpoint["optimizer_D"])
        self.generator.to(self.device)
        self.discriminator.to(self.device)
        self.start_epoch = checkpoint["epoch"]
        self.best_mmd = checkpoint.get("best_mmd", float("inf"))
        print(f"Checkpoint loaded: {filename} (resuming at epoch {self.start_epoch})")
        return checkpoint


In [None]:
wgan = WGANGP(encoder_state_dim=256,latent_dim=256)
history = wgan.fit(
    train_dataloader=train_loader,   
    epochs=200,
    val_dataloader=val_loader,
    verbose=True
)