In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader,random_split,TensorDataset
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
import math
import json
import joblib
import ast
import pickle
import os
import random
import string

In [None]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")

In [None]:
class StaticOneHotManager:
    def __init__(self):
        self.encoders = {}

    def fit(self, df, categorical_cols):
        for col in categorical_cols:
            categories = df[col].dropna().astype(str).unique().tolist()

            enc = OneHotEncoder(
                sparse_output=False,
                handle_unknown="ignore",
                categories=[categories]
            )
            values = df[col].dropna().astype(str).values.reshape(-1, 1)
            enc.fit(values)

            self.encoders[col] = enc

    def transform(self, df, categorical_cols, return_numpy=True):
        out = pd.DataFrame()
        if "PATIENT" in df.columns:
            out["PATIENT"] = df["PATIENT"]

        for col in categorical_cols:
            enc = self.encoders[col]

            values = df[col].astype(str).values.reshape(-1, 1)
            is_na = df[col].isna().values

            encoded = enc.transform(values)

            encoded[is_na] = 0
            if return_numpy:
                out[col] = [encoded[i, :] for i in range(encoded.shape[0])]
            else:
                out[col] = encoded.tolist()

        return out

    def inverse_transform(self, series, col):
        enc = self.encoders[col]
        categories = enc.categories_[0]  
        results = []
    
        for arr in series:
            arr = np.array(arr, dtype=float)  
            if arr.sum() == 0: 
                results.append(np.nan)
            else:
                idx = arr.argmax()  
                results.append(categories[idx])
    
        return results

    def save(self, path="static_encoded.pkl"):
        joblib.dump(self.encoders, path)

    def load(self, path="static_encoded.pkl"):
        self.encoders = joblib.load(path)

In [None]:
class TemporalOneHotManager:
    def __init__(self):
        self.encoders = {}

    def fit(self, df, categorical_cols):
        for col in categorical_cols:
            flat_values = []
            for seq in df[col]:
                if pd.isna(seq):
                    continue
                if isinstance(seq, str):
                    try:
                        seq = ast.literal_eval(seq.replace('nan', 'None'))
                    except (ValueError, SyntaxError):
                        seq = [seq]
                if not isinstance(seq, list):
                    seq = [seq]
                for v in seq:
                    if pd.notna(v):
                        flat_values.append(str(v))

            categories = pd.Series(flat_values).unique().tolist()
            enc = OneHotEncoder(
                sparse_output=False,
                handle_unknown="ignore",
                categories=[categories]
            )
            enc.fit(np.array(categories).reshape(-1, 1))
            self.encoders[col] = enc

    def transform(self, df, categorical_cols, return_numpy=True):
        out = pd.DataFrame()
        if "PATIENT" in df.columns:
            out["PATIENT"] = df["PATIENT"]

        for col in categorical_cols:
            enc = self.encoders[col]
            cats = enc.categories_[0].tolist()
            n_cats = len(cats)

            rows = []
            for seq in df[col]:
                if pd.isna(seq):
                    seq = []
                elif isinstance(seq, str):
                    try:
                        seq = ast.literal_eval(seq.replace('nan', 'None'))
                    except (ValueError, SyntaxError):
                        seq = [seq]

                if not isinstance(seq, list):
                    seq = [seq]

                encoded_seq = []
                for v in seq:
                    if pd.isna(v):
                        if n_cats > 0:
                            encoded_seq.append([0] * n_cats)
                    else:
                        if n_cats > 0:
                            arr = enc.transform([[str(v)]]).flatten()
                            encoded_seq.append(arr.astype(int).tolist())
                rows.append(encoded_seq)
            out[col] = rows

        return out

    def inverse_transform(self, series, col):
        enc = self.encoders[col]
        cats = enc.categories_[0].tolist()
        results = []
    
        for row in series:
            arr = np.array(row, dtype=float)
            if arr.sum() == 0:
                results.append(np.nan)
            else:
                results.append(cats[arr.argmax()])
        
        return results

    def inverse_transform2(self, series, col):
        enc = self.encoders[col]
        categories = enc.categories_[0]
        results = []
        
        for idx in series:
            if pd.isna(idx) or idx == 0:
                results.append(np.nan)
            else:
                results.append(categories[int(idx)])
        
        return results





    def save(self, path="temporal_encoded.pkl"):
        joblib.dump(self.encoders, path)

    def load(self, path="temporal_encoded.pkl"):
        self.encoders = joblib.load(path)

In [None]:
class CategoricalAutoEncoder(nn.Module):

    def __init__(self, input_dims, lr=1e-3, optimizer_type="adam", use_scheduler=False, filename=None):
        super().__init__()

        self.input_dims = input_dims
        self.filename = filename
        self.total_input_dim = sum(self.input_dims)
        self.num_features = len(self.input_dims)
        hidden_dim = max(256, self.total_input_dim)
        latent_dim = max(32, min(self.total_input_dim // 8, self.total_input_dim - 1))

        # ------ Encoder -----------
        self.encoder = nn.Sequential(
            nn.Linear(self.total_input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim // 2, latent_dim),
            nn.Tanh()
        )

        # -------- Decoders: one head per feature ----------
        self.decoders = nn.ModuleList([
            nn.Sequential(
                nn.Linear(latent_dim, hidden_dim // 2),
                nn.BatchNorm1d(hidden_dim // 2),
                nn.LeakyReLU(0.2),
                
                nn.Linear(hidden_dim // 2, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3),
                
                nn.Linear(hidden_dim, k)
            )
            for k in self.input_dims
        ])
        
        #--------- Loss & Optimizer ---------
        self.criterion = nn.CrossEntropyLoss()
        if optimizer_type == "adam":
            self.optimizer = optim.Adam(self.parameters(), lr=lr)
        elif optimizer_type.lower() == "rmsprop":
            self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        else:
            raise ValueError("optimizer_type must be Adam or RMSprop")

        self.scheduler = (optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min',
                                                              factor=0.5, patience=3)
                         if use_scheduler else None)

    # ---------- Forward / Encode / Decode ----------
    def forward(self, x):
        z = self.encoder(x)
        logits_list = [head(z) for head in self.decoders]
        return logits_list

    def encode(self, x):
        self.eval()
        with torch.no_grad():
            z = self.encoder(x)
        return z

    def decode(self, z: torch.Tensor, mask=None, return_onehot=True):
        self.eval()
        with torch.no_grad():
            logits_list = [head(z) for head in self.decoders]
            preds = [logits.argmax(dim=-1) for logits in logits_list]
    
            if not return_onehot:
                return preds
    
            bsz = z.shape[0]
            parts = []
    
            for i, (idxs, K) in enumerate(zip(preds, self.input_dims)):
                onehot = torch.zeros(bsz, K, device=z.device)
                onehot.scatter_(1, idxs.view(-1, 1), 1.0)
                
                if mask is not None:
                    # mask[:, i] has shape (batch_size,)
                    onehot = onehot * mask[:, i].unsqueeze(1)  # broadcast to (batch_size, K)
                
                parts.append(onehot)
    
            return preds, torch.cat(parts, dim=1)


    def predict_prob(self, x):
        self.eval()
        with torch.no_grad():
            logits_list = self.forward(x)
            probs_list = [F.softmax(logits, dim=-1) for logits in logits_list]
        return probs_list

    # ---------- Reconstruct with mask for missingness ----------
    def reconstruct(self, x, return_onehot=False):
        probs_list = self.predict_prob(x)
        preds = []
        bsz = x.shape[0]
        parts = []

        start = 0
        for i, K in enumerate(self.input_dims):
            end = start + K
            segment = x[:, start:end]

            # All-zero mask for missing features
            mask_all_zero = (segment.sum(dim=1) == 0)
            segment_logits = probs_list[i]
            segment_pred = segment_logits.argmax(dim=-1)
            preds.append(segment_pred)

            if return_onehot:
                onehot = torch.zeros(bsz, K, device=x.device)
                idxs_to_scatter = (~mask_all_zero).nonzero(as_tuple=True)[0]
                if len(idxs_to_scatter) > 0:
                    onehot[idxs_to_scatter].scatter_(1, segment_pred[idxs_to_scatter].view(-1,1), 1.0)
                parts.append(onehot)

            start = end

        if return_onehot:
            return preds, torch.cat(parts, dim=1)
        return preds

    # ---------- Helper functions ----------
    def _targets_from_onehot(self, x):
        parts = torch.split(x, self.input_dims, dim=1)
        return torch.stack([p.argmax(dim=1) for p in parts], dim=1).long()

    def _compute_loss(self, logits_list, targets, mask=None):
        loss = 0.0
        for i, logits in enumerate(logits_list):
            if mask is not None:
                present_idx = (mask[:, i] == 1).nonzero(as_tuple=True)[0]
                if len(present_idx) == 0:
                    continue  # skip missing features
                loss += self.criterion(logits[present_idx], targets[present_idx, i])
            else:
                loss += self.criterion(logits, targets[:, i])
        return loss

    # ---------- Training ----------
    def fit(self, dataloader, epochs=10, val_dataloader=None, wrapper_model=None, mask_val=None):
        forward_model = wrapper_model if wrapper_model is not None else self
        best_val_loss = float('inf')
        best_weights = None

        for epoch in range(1, epochs + 1):
            self.train()
            epoch_losses = []

            for (x,) in tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False):
                x = x.to(next(self.parameters()).device, non_blocking=True)
                y = self._targets_from_onehot(x)
                self.optimizer.zero_grad()
                logits_list = forward_model(x)
                loss = self._compute_loss(logits_list, y)
                if wrapper_model is not None:
                    loss = loss.mean()
                loss.backward()
                self.optimizer.step()
                epoch_losses.append(loss.item())

            train_loss = sum(epoch_losses) / max(1, len(epoch_losses))

            val_loss = None
            if val_dataloader is not None:
                self.eval()
                val_losses = []
                with torch.no_grad():
                    for i, (vx,) in enumerate(val_dataloader):
                        vx = vx.to(next(self.parameters()).device, non_blocking=True)
                        vy = self._targets_from_onehot(vx)
                        vlogits_list = forward_model(vx)
                        if mask_val is not None:
                            vloss = self._compute_loss(vlogits_list, vy, mask=mask_val[i*vx.size(0):(i+1)*vx.size(0)])
                        else:
                            vloss = self._compute_loss(vlogits_list, vy)
                        val_losses.append(vloss.item())
                val_loss = sum(val_losses) / max(1, len(val_losses))

                if self.scheduler is not None:
                    self.scheduler.step(val_loss)

                # Save best weights
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_weights = self.state_dict()

            print(f"Epoch {epoch:03d} | Train {train_loss:.7f} | "
                  f"Val {val_loss:.7f}" if val_loss is not None else f"Epoch {epoch:03d} | Train {train_loss:.7f}")

        # Save best weights
        if best_weights is not None:
            torch.save(best_weights, self.filename)
            print(f"Best model saved with val_loss={best_val_loss:.7f} as {self.filename}")
        else:
            self.save_model()

    # ---------- Save / Load ----------
    def save_model(self):
        torch.save(self.state_dict(), self.filename)
        print(f"Model saved as {self.filename}")

    def load_model(self, map_location=None):
        state = torch.load(self.filename, map_location=map_location)
        self.load_state_dict(state)
        self.eval()
        print(f"Model loaded from {self.filename}")


In [None]:
class StochasticNormalizer:
    def __init__(self):
        self.params = {}

    def stochastic_normalize(self, X: torch.Tensor, key=None):
        X = X.float()
        unique_vals, counts = torch.unique(X, return_counts=True)
        N = X.numel()
        X_hat = torch.empty_like(X, dtype=torch.float32, device=X.device)
        lower_bound = 0.0
        params = {}

        for val, count in zip(unique_vals, counts):
            ratio = count.item() / N
            upper_bound = lower_bound + ratio
            mask = X == val
            X_hat[mask] = torch.rand(mask.sum(), device=X.device) * (upper_bound - lower_bound) + lower_bound
            params[val.item()] = [lower_bound, upper_bound]
            lower_bound = upper_bound

        if key is not None:
            self.params[key] = params
        return X_hat

    def stochastic_renormalize(self, X_hat: torch.Tensor, key=None):
        X_hat = X_hat.float()
        X = torch.zeros_like(X_hat, dtype=torch.float32, device=X_hat.device)
        if key is not None:
            params = self.params[key]

        for val, (low, high) in params.items():
            mask = (X_hat >= low) & (X_hat < high)
            X[mask] = val

        # Handle edge case where X_hat == 1.0
        mask = X_hat == 1.0
        for val, (low, high) in params.items():
            if abs(high - 1.0) < 1e-8:
                X[mask] = val
        return X

    def normalize_sample(self, X: torch.Tensor, key):
        if key not in self.params:
            raise ValueError(f"No parameters found for key '{key}'. Load or train first.")

        params = self.params[key]
        X_hat = torch.zeros_like(X, dtype=torch.float32, device=X.device)
        for val, (low, high) in params.items():
            mask = X == val
            if mask.any():
                X_hat[mask] = torch.rand(mask.sum(), device=X.device) * (high - low) + low
        return X_hat

    def save_params(self, filepath):
        torch.save(self.params, filepath)

    def load_params(self, filepath):
        self.params = torch.load(filepath)


In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dims=(512,256,256), out_dim=16, dropout=0.1,
                 use_batch_norm=True, activation='lrelu', final_activation=None):
        super().__init__()
        layers = []

        activation_fn = {
            'lrelu': nn.LeakyReLU(0.01),
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'swish': nn.SiLU(),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid()
        }[activation]

        d = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(d,h))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(h))
            layers.append(activation_fn)
            if dropout > 0.0:
                layers.append(nn.Dropout(dropout))
            d = h

        layers.append(nn.Linear(d, out_dim))
        if final_activation is not None:
            layers.append(final_activation)   

        self.net = nn.Sequential(*layers)

    def forward(self,x):
        return self.net(x)


In [None]:
class TemporalEncoder(nn.Module):
    def __init__(self,in_dim,hidden_dim=256,out_dim=128,depth=2,use_post_mlp=True):
        super().__init__()

        self.gru = nn.GRU(
            input_size = in_dim,
            hidden_size = hidden_dim,
            num_layers = depth,
            batch_first = True,
            dropout = 0.1
        )

        self.attn = nn.Linear(hidden_dim, 1)

        if use_post_mlp:
            self.fc = nn.Sequential(
                nn.Linear(hidden_dim,hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Linear(hidden_dim,out_dim)
            )
        else:
            self.fc = nn.Linear(hidden_dim,out_dim)

    def forward(self,x):
        # x: (batch, seq_len, in_dim)
        out, _ = self.gru(x)
        attn_scores = self.attn(out).squeeze(-1) # (batch,seq_len)
        attn_weights = torch.softmax(attn_scores,dim=1)
        context = torch.sum(out * attn_weights.unsqueeze(-1),dim=1) # (batch, hidden_dim)
        return self.fc(context)

In [None]:
class FusionMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims=[128, 64], output_dim=64, dropout=0.1):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.BatchNorm1d(dims[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-1], output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


In [None]:
class TemporalDecoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim = 512, depth = 2,
                 num_features=None, embed_dim=None, is_categorical=False):
        super().__init__()
        self.is_categorical = is_categorical
        self.fc_init_h = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, depth * hidden_dim)
        )

        self.gru = nn.GRU(
            input_size=latent_dim,
            hidden_size=hidden_dim,
            num_layers=depth,
            batch_first=True,
            dropout=0.1 if depth > 1 else 0.0
        )

        self.head_value = nn.Sequential(
            nn.Linear(hidden_dim, embed_dim if is_categorical else num_features),
            nn.Tanh() if is_categorical else nn.Sigmoid()
        )

        self.head_time = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 8),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim // 8, 1),
            nn.Sigmoid()
        )

        self.head_mask = nn.Sequential(
            nn.Linear(hidden_dim  , num_features)
        )

    def forward(self, e, max_seq_len, mask_threshold = 0.5):
        batch_size = e.size(0)
        device = e.device
        h_0_flat = self.fc_init_h(e)
        h_0 = h_0_flat.view(self.gru.num_layers, batch_size, self.gru.hidden_size)

        tn_hat_list, u_hat_list, mask_hat_list = [], [], []
        h_t = h_0

        seq_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
        active_sequences = torch.ones(batch_size, dtype=torch.bool, device=device)

        for _ in range(max_seq_len):
            if not active_sequences.any():
                break

            gru_input = e.unsqueeze(1)
            gru_out, h_t_new = self.gru(gru_input, h_t)

            h_t = torch.where(active_sequences.view(1, -1, 1), h_t_new, h_t) 

            tn_hat_step = self.head_value(gru_out.squeeze(1))
            u_hat_step = self.head_time(gru_out.squeeze(1)) 
            mask_hat_step = self.head_mask(gru_out.squeeze(1))  

            tn_hat_list.append(tn_hat_step)
            u_hat_list.append(u_hat_step)
            mask_hat_list.append(mask_hat_step)

            seq_lengths += active_sequences.long()

            stop_condition = (torch.sigmoid(mask_hat_step) < mask_threshold).all(dim=-1)
            active_sequences = active_sequences & ~stop_condition

        tn_hat = torch.stack(tn_hat_list, dim=1)
        u_hat = torch.stack(u_hat_list, dim=1)
        mask_hat = torch.stack(mask_hat_list, dim=1)

        return tn_hat, u_hat, mask_hat, seq_lengths

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,sn_dim,sce_latent_dim,tn_dim,tce_latent_dim,sc_dim,tc_dim,latent_dim=512):
        super().__init__()
        self.latent_dim = latent_dim
        static_input_dim = sn_dim + sce_latent_dim
        self.static_encoder = MLP(static_input_dim, out_dim=128)
        self.temporal_num_encoder = TemporalEncoder(tn_dim + 1, out_dim=256)
        self.temporal_cat_encoder = TemporalEncoder(tce_latent_dim + 1, out_dim=256)
        self.static_mask = MLP(in_dim = sn_dim + sc_dim,hidden_dims=(32,16,8),out_dim=4)
        self.temporal_num_mask_encoder = TemporalEncoder(tn_dim,hidden_dim=128,out_dim=64)
        self.temporal_cat_mask_encoder = TemporalEncoder(tc_dim,hidden_dim=128,out_dim=64)
        fusion_dim = 128 + 256 + 256 + 4 + 64 + 64
        self.fusion = FusionMLP(fusion_dim, hidden_dims=[2048, 1024], output_dim=latent_dim)
        self.static_decoder_num = MLP(latent_dim, out_dim=sn_dim,final_activation=nn.Sigmoid())
        self.static_decoder_cat = MLP(latent_dim,out_dim=sce_latent_dim,final_activation=nn.Tanh())
        self.static_decoder_mask = MLP(latent_dim, out_dim=sn_dim + sc_dim)
        self.temporal_decoder_num = TemporalDecoder(latent_dim, num_features=tn_dim, is_categorical=False)
        self.temporal_decoder_cat = TemporalDecoder(latent_dim, num_features=tc_dim, embed_dim=tce_latent_dim,is_categorical=True)

    def encode(self,sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask):
        static_in = torch.cat([sn,sc],dim=-1)
        static_e = self.static_encoder(static_in)
        un = un.unsqueeze(-1)
        temporal_num_in = torch.cat([tn,un],dim=-1)
        temporal_num_e = self.temporal_num_encoder(temporal_num_in)
        uc = uc.unsqueeze(-1)
        temporal_cat_in = torch.cat([tc,uc],dim=-1)
        temporal_cat_e = self.temporal_cat_encoder(temporal_cat_in)
        static_mask_in = torch.cat([sn_mask,sc_mask],dim=-1)
        static_mask_e = self.static_mask(static_mask_in)
        temporal_num_mask_e = self.temporal_num_mask_encoder(tn_mask)
        temporal_cat_mask_e = self.temporal_cat_mask_encoder(tc_mask)
        e = torch.cat([static_e,temporal_num_e,temporal_cat_e,static_mask_e,temporal_num_mask_e,temporal_cat_mask_e],dim=-1)
        e = self.fusion(e)
        return e


    def decode(self, e, max_seq_len_num=100,max_seq_len_cat=300):
        sn_hat = self.static_decoder_num(e)
        sc_hat = self.static_decoder_cat(e)
        static_mask_hat = self.static_decoder_mask(e)
        sn_dim = sn_hat.shape[-1]
        sn_mask_hat = static_mask_hat[..., :sn_dim]
        sc_mask_hat = static_mask_hat[..., sn_dim:]
        tn_hat, un_hat, tn_mask_hat, seq_len_num = self.temporal_decoder_num(e, max_seq_len_num)
        tc_hat, uc_hat, tc_mask_hat, seq_len_cat = self.temporal_decoder_cat(e, max_seq_len_cat)
        return sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,seq_len_num,seq_len_cat

    def forward(self,sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask,max_seq_len_num=None,max_seq_len_cat=None):
        e = self.encode(sn,sc,tn,tc,un,uc,sn_mask,sc_mask,tn_mask,tc_mask)
        if max_seq_len_num is None:
            max_seq_len_num = tn.size(1)
        if max_seq_len_cat is None:
            max_seq_len_cat = tc.size(1)
        return self.decode(e, max_seq_len_num,max_seq_len_cat)

    def get_encoding(self, sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, as_numpy: bool = False):
        self.eval()
        device = next(self.parameters()).device
        with torch.no_grad():
            inputs = [sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask]
            inputs = [x.to(device) if torch.is_tensor(x) else x for x in inputs]
            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask = inputs
            e = self.encode(sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask)
            if as_numpy:
                return e.detach().cpu().numpy()
            return e
    
    def generate_decoding(self, e=None, batch_size=1, max_seq_len_num=100, max_seq_len_cat=300, device="cpu", mask_threshold=0.5):
        self.eval()
        device = torch.device(device)
        with torch.no_grad():
            (sn_hat, sc_hat, tn_hat, tc_hat, un_hat, uc_hat, sn_mask_hat, sc_mask_hat, tn_mask_hat, tc_mask_hat, seq_len_num, seq_len_cat) = self.decode(e, max_seq_len_num, max_seq_len_cat)
            
            # Apply sigmoid + threshold for all masks
            sn_mask_hat = torch.sigmoid(sn_mask_hat) > mask_threshold
            sc_mask_hat = torch.sigmoid(sc_mask_hat) > mask_threshold
            tn_mask_hat = torch.sigmoid(tn_mask_hat) > mask_threshold
            tc_mask_hat = torch.sigmoid(tc_mask_hat) > mask_threshold
    
            # Slice sequences according to predicted lengths
            tn_hat = [tn_hat[i, :seq_len_num[i]] for i in range(tn_hat.size(0))]
            tc_hat = [tc_hat[i, :seq_len_cat[i]] for i in range(tc_hat.size(0))]
            un_hat = [un_hat[i, :seq_len_num[i]] for i in range(un_hat.size(0))]
            uc_hat = [uc_hat[i, :seq_len_cat[i]] for i in range(uc_hat.size(0))]
            tn_mask_hat = [tn_mask_hat[i, :seq_len_num[i]] for i in range(tn_mask_hat.size(0))]
            tc_mask_hat = [tc_mask_hat[i, :seq_len_cat[i]] for i in range(tc_mask_hat.size(0))]
    
        return (sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat, tn_mask_hat, tc_mask_hat)

    def compute_loss(
        self, sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
        sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
        tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat, true_seq_len_num, true_seq_len_cat,
        lambda_mse=1.0, lambda_len=0.1
    ):
        losses = {}
        bce = nn.BCEWithLogitsLoss(reduction='none')
        mse = nn.MSELoss(reduction='none')
        epsilon = 1e-8
        device = sn.device
        if not torch.is_tensor(pred_seq_len_num):
            pred_seq_len_num = torch.tensor(pred_seq_len_num, device=device, dtype=torch.float32)
        if pred_seq_len_num.dim() == 0:
            pred_seq_len_num = pred_seq_len_num.unsqueeze(0)
        if not torch.is_tensor(pred_seq_len_cat):
            pred_seq_len_cat = torch.tensor(pred_seq_len_cat, device=device, dtype=torch.float32)
        if pred_seq_len_cat.dim() == 0:
            pred_seq_len_cat = pred_seq_len_cat.unsqueeze(0)
        if not torch.is_tensor(true_seq_len_num):
            true_seq_len_num = torch.tensor(true_seq_len_num, device=device, dtype=torch.float32)
        if true_seq_len_num.dim() == 0:
            true_seq_len_num = true_seq_len_num.unsqueeze(0)
        if not torch.is_tensor(true_seq_len_cat):
            true_seq_len_cat = torch.tensor(true_seq_len_cat, device=device, dtype=torch.float32)
        if true_seq_len_cat.dim() == 0:
            true_seq_len_cat = true_seq_len_cat.unsqueeze(0)
        pred_len_num = tn_hat.size(1)
        pred_len_cat = tc_hat.size(1)
        tn_sliced = tn[:, :pred_len_num, :]
        un_sliced = un[:, :pred_len_num].unsqueeze(-1)
        tn_mask_sliced = tn_mask[:, :pred_len_num, :]
        tc_sliced = tc[:, :pred_len_cat, :]
        uc_sliced = uc[:, :pred_len_cat].unsqueeze(-1)
        tc_mask_sliced = tc_mask[:, :pred_len_cat, :]
        seq_mask_num = torch.arange(pred_len_num, device=device).unsqueeze(0) < pred_seq_len_num.unsqueeze(1)
        seq_mask_num = seq_mask_num.unsqueeze(-1).float()
        seq_mask_cat = torch.arange(pred_len_cat, device=device).unsqueeze(0) < pred_seq_len_cat.unsqueeze(1)
        seq_mask_cat = seq_mask_cat.unsqueeze(-1).float()
        losses["sn_mask"] = bce(sn_mask_hat, sn_mask).mean()
        losses["sc_mask"] = bce(sc_mask_hat, sc_mask).mean()
        tn_mask_loss = bce(tn_mask_hat, tn_mask_sliced).mean(dim=-1)
        losses["tn_mask"] = (tn_mask_loss * seq_mask_num[..., 0]).sum() / (seq_mask_num[..., 0].sum() + epsilon)
        tc_mask_loss = bce(tc_mask_hat, tc_mask_sliced).mean(dim=-1)
        losses["tc_mask"] = (tc_mask_loss * seq_mask_cat[..., 0]).sum() / (seq_mask_cat[..., 0].sum() + epsilon)
        losses["sn"] = (mse(sn_hat, sn) * sn_mask).sum() / (sn_mask.sum() + epsilon)
        losses["sc"] = mse(sc_hat, sc).mean()
        losses["tn"] = (mse(tn_hat, tn_sliced) * tn_mask_sliced * seq_mask_num).sum() / ((tn_mask_sliced * seq_mask_num).sum() + epsilon)
        losses["tc"] = (mse(tc_hat, tc_sliced) * seq_mask_cat).sum() / (seq_mask_cat.sum() + epsilon)
        losses["un"] = (mse(un_hat, un_sliced) * seq_mask_num).sum() / (seq_mask_num.sum() + epsilon)
        losses["uc"] = (mse(uc_hat, uc_sliced) * seq_mask_cat).sum() / (seq_mask_cat.sum() + epsilon)
        losses["len_num"] = F.mse_loss(pred_seq_len_num.float(), true_seq_len_num.float())
        losses["len_cat"] = F.mse_loss(pred_seq_len_cat.float(), true_seq_len_cat.float())
        total_loss = (
            losses["sn_mask"] + losses["sc_mask"] +
            losses["tn_mask"] + losses["tc_mask"] +
            lambda_mse * (losses["sn"] + losses["sc"] + losses["tn"] + losses["tc"] + losses["un"] + losses["uc"]) +
            lambda_len * (losses["len_num"] + losses["len_cat"])
        )
        return total_loss, losses

    def fit(self, train_dataloader, val_dataloader=None, epochs=20, lr=1e-3, optimizer="adam", 
            lambda_mse=1.0, lambda_len=0.1, device="cpu",
            scheduler_patience=5, scheduler_factor=0.1,resume_from=None):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        print(f"Training on: {device}")
        if optimizer.lower() == "adam":
            opt = optim.Adam(self.parameters(), lr=lr)
        elif optimizer.lower() == "rmsprop":
            opt = optim.RMSprop(self.parameters(), lr=lr)
        else:
            raise ValueError("Optimizer must be 'adam' or 'rmsprop'")
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=scheduler_patience, factor=scheduler_factor)
        start_epoch = 1
        best_val_loss = float('inf')
        
        if resume_from is not None:
            start_epoch = self.load_checkpoint(resume_from, optimizer=opt, scheduler=scheduler)  

        for epoch in range(start_epoch, epochs + 1):
            self.train()
            total_train_loss = 0.0
            train_loss_components = {}
            for batch in tqdm(train_dataloader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False):
                sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, true_seq_len_num, true_seq_len_cat = [
                    x.to(device) if torch.is_tensor(x) else x for x in batch
                ]
                max_seq_len_num = tn.size(1)
                max_seq_len_cat = tc.size(1)
                (sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,pred_seq_len_num,pred_seq_len_cat) = self(
                    sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, max_seq_len_num, max_seq_len_cat
                )
                loss, losses_dict = self.compute_loss(
                    sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
                    sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
                    tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat,
                    true_seq_len_num, true_seq_len_cat, lambda_mse=lambda_mse, lambda_len=lambda_len
                )
                opt.zero_grad()
                loss.backward()
                opt.step()
                total_train_loss += loss.item()
                for k, v in losses_dict.items():
                    train_loss_components[k] = train_loss_components.get(k, 0.0) + v.item()
            avg_train_loss = total_train_loss / len(train_dataloader)
            avg_train_components = {k: v / len(train_dataloader) for k, v in train_loss_components.items()}
            avg_val_loss = None
            val_loss_components = {}
            if val_dataloader is not None:
                self.eval()
                total_val_loss = 0.0
                with torch.no_grad():
                    for batch in tqdm(val_dataloader, desc=f"Epoch {epoch}/{epochs} [Val]", leave=False):
                        sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, true_seq_len_num, true_seq_len_cat = [
                            x.to(device) if torch.is_tensor(x) else x for x in batch
                        ]
                        max_seq_len_num = tn.size(1)
                        max_seq_len_cat = tc.size(1)
                        (sn_hat,sc_hat,tn_hat,tc_hat,un_hat,uc_hat,sn_mask_hat,sc_mask_hat,tn_mask_hat,tc_mask_hat,pred_seq_len_num,pred_seq_len_cat) = self(
                            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask, max_seq_len_num, max_seq_len_cat)
                        loss, losses_dict = self.compute_loss(
                            sn, sc, tn, tc, un, uc, sn_mask, sc_mask, tn_mask, tc_mask,
                            sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat,
                            tn_mask_hat, tc_mask_hat, pred_seq_len_num, pred_seq_len_cat,
                            true_seq_len_num, true_seq_len_cat, lambda_mse=lambda_mse, lambda_len=lambda_len
                        )
                        total_val_loss += loss.item()
                        for k, v in losses_dict.items():
                            val_loss_components[k] = val_loss_components.get(k, 0.0) + v.item()
                avg_val_loss = total_val_loss / len(val_dataloader)
                avg_val_components = {k: v / len(val_dataloader) for k, v in val_loss_components.items()}
                scheduler.step(avg_val_loss)
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    self.save_checkpoint(opt, scheduler, epoch, filename='best_encoder_decoder_ckpt.pt')
                    
            if avg_val_loss is not None:
                print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.7f} | Val Loss: {avg_val_loss:.7f} | Train LenNum: {avg_train_components['len_num']:.6f} | Train LenCat: {avg_train_components['len_cat']:.6f} | Val LenNum: {avg_val_components['len_num']:.6f} | Val LenCat: {avg_val_components['len_cat']:.6f}")
            else:
                scheduler.step(avg_train_loss)
                print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.7f} | Train LenNum: {avg_train_components['len_num']:.6f} | Train LenCat: {avg_train_components['len_cat']:.6f}")

        self.save_checkpoint(opt, scheduler, epoch)

    def save_checkpoint(self, optimizer, scheduler, epoch, filename="encoder_decoder_ckpt.pt"):
        checkpoint = {
            "model_state": self.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict() if scheduler is not None else None,
            "epoch": epoch
        }
        torch.save(checkpoint, filename)
        print(f"Checkpoint saved at epoch {epoch} -> {filename}")

    def load_checkpoint(self, filename="encoder_decoder_ckpt.pt", optimizer=None, scheduler=None, map_location=None):
        checkpoint = torch.load(filename, map_location=map_location)
        self.load_state_dict(checkpoint["model_state"])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
        if scheduler is not None and checkpoint["scheduler_state"] is not None:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resumed from checkpoint {filename}, starting at epoch {start_epoch}")
        return start_epoch


In [None]:
class Generator(nn.Module):
    def __init__(self,encoder_state_dim,latent_dim = 256,hidden_dims=None):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [512,1024,2048,1024,512]

        layers = []
        prev_dim = latent_dim

        for i,hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim,hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.2 if i < len(hidden_dims)//2 else 0.1)
            ])
            prev_dim = hidden_dim

        layers.extend([
            nn.Linear(prev_dim,encoder_state_dim*2),
            nn.BatchNorm1d(encoder_state_dim*2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(encoder_state_dim*2,encoder_state_dim)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self,z):
        return self.model(z)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,encoder_state_dim,hidden_dims = None):
        super().__init__()

        if hidden_dims is None:
            hidden_dims = [256,512,1024,2048,1024,512,256,128]

        layers = []
        prev_dim = encoder_state_dim

        for i,hidden_dim in enumerate(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim,hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.LeakyReLU(0.2),
                nn.Dropout(0.3 if i < len(hidden_dims)//2 else 0.2)
            ])
            prev_dim = hidden_dim

        layers.extend([
            nn.Linear(prev_dim,64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(64,1)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

In [None]:
def compute_mmd(x, y, sigma=None):
    if isinstance(x, list):
        x = torch.stack(x)
    if isinstance(y, list):
        y = torch.stack(y)

    x, y = x.to(torch.float32), y.to(torch.float32)

    combined = torch.cat([x, y], dim=0)
    mean = combined.mean(dim=0, keepdim=True)
    std = combined.std(dim=0, keepdim=True) + 1e-6
    x_norm = (x - mean) / std
    y_norm = (y - mean) / std
    if sigma is None:
        xy = torch.cat([x_norm, y_norm], dim=0)
        dists = torch.cdist(xy, xy, p=2)
        sigma = torch.median(dists).item()
        if sigma == 0:
            sigma = 1.0

    def gaussian_kernel(a, b, sigma):
        dist_sq = torch.cdist(a, b, p=2) ** 2
        return torch.exp(-dist_sq / (2 * sigma ** 2))

    k_xx = gaussian_kernel(x_norm, x_norm, sigma)
    k_yy = gaussian_kernel(y_norm, y_norm, sigma)
    k_xy = gaussian_kernel(x_norm, y_norm, sigma)

    mmd = k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
    return mmd.item()


In [None]:
class WGANGP:
    def __init__(self, encoder_state_dim, latent_dim=128,
                 generator_hidden_dims=None, discriminator_hidden_dims=None,
                 lr_generator=1e-4, lr_discriminator=1e-4,
                 lambda_gp=10.0, n_critic=5, device=None,
                 plateau_factor=0.5, plateau_patience=10):

        self.encoder_state_dim = encoder_state_dim
        self.latent_dim = latent_dim
        self.lambda_gp = lambda_gp
        self.n_critic = n_critic
        self.device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

        self.generator = Generator(encoder_state_dim, latent_dim, generator_hidden_dims).to(self.device)
        self.discriminator = Discriminator(encoder_state_dim, discriminator_hidden_dims).to(self.device)

        self.optimizer_G = optim.Adam(self.generator.parameters(), lr=lr_generator, betas=(0.5, 0.9))
        self.optimizer_D = optim.Adam(self.discriminator.parameters(), lr=lr_discriminator, betas=(0.5, 0.9))

        # ReduceLROnPlateau scheduler
        self.scheduler_G = ReduceLROnPlateau(self.optimizer_G, mode='min', factor=plateau_factor,
                                             patience=plateau_patience)
        self.scheduler_D = ReduceLROnPlateau(self.optimizer_D, mode='min', factor=plateau_factor,
                                             patience=plateau_patience)

        self.start_epoch = 1
        self.best_mmd = float("inf")

    def gradient_penalty(self, real_samples, fake_samples):
        real_samples = real_samples.to(self.device).float()
        fake_samples = fake_samples.to(self.device).float()
        batch_size = real_samples.size(0)
        epsilon = torch.rand(batch_size, 1, device=self.device).expand_as(real_samples)
        interpolated = epsilon * real_samples + (1 - epsilon) * fake_samples.detach()
        interpolated.requires_grad_(True)

        d_interpolated = self.discriminator(interpolated)
        gradients = torch.autograd.grad(
            outputs=d_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(d_interpolated, device=self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        gradient_penalty = self.lambda_gp * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def discriminator_loss(self, real_samples, fake_samples):
        d_real = self.discriminator(real_samples)
        d_fake = self.discriminator(fake_samples.detach())
        wasserstein_distance = d_real.mean() - d_fake.mean()
        gp = self.gradient_penalty(real_samples, fake_samples)
        d_loss = -wasserstein_distance + gp
        return d_loss, wasserstein_distance, gp

    def generator_loss(self, fake_samples):
        d_fake = self.discriminator(fake_samples)
        g_loss = -d_fake.mean()
        return g_loss

    def generate_samples(self, batch_size):
        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_samples = self.generator(z).float()
        return fake_samples

    def fit(self, train_dataloader, epochs=100, resume_from=None, 
            val_dataloader=None, verbose=True):

        if resume_from:
            self.load_checkpoint(resume_from)
            print(f"Resumed training from checkpoint: {resume_from}")

        self.generator.train()
        self.discriminator.train()

        history = {
            "train_d_loss": [], "train_g_loss": [], "train_wd": [], "train_gp": [],
            "val_g_loss": [], "val_wd": [], "val_mmd": []
        }

        for epoch in range(self.start_epoch, epochs + 1):
            epoch_d_loss = epoch_g_loss = epoch_wd = epoch_gp = 0.0
            batches = 0
            progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch}/{epochs}') if verbose else train_dataloader

            for i, (real_samples,) in enumerate(progress_bar):
                batches += 1
                real_samples = real_samples.to(self.device).float()
                bsz = real_samples.size(0)

                # ---- Train Discriminator ----
                self.optimizer_D.zero_grad()
                fake_samples = self.generate_samples(bsz)
                d_loss, wd, gp = self.discriminator_loss(real_samples, fake_samples)
                d_loss.backward()
                self.optimizer_D.step()

                epoch_d_loss += d_loss.item()
                epoch_wd += wd.item()
                epoch_gp += gp.item()

                # ---- Train Generator every n_critic ----
                if i % self.n_critic == 0:
                    self.optimizer_G.zero_grad()
                    fake_samples = self.generate_samples(bsz)
                    g_loss = self.generator_loss(fake_samples)
                    g_loss.backward()
                    self.optimizer_G.step()
                    epoch_g_loss += g_loss.item()

            avg_d_loss = epoch_d_loss / batches
            avg_g_loss = epoch_g_loss / max(1, (batches // self.n_critic))
            avg_wd = epoch_wd / batches
            avg_gp = epoch_gp / batches

            history["train_d_loss"].append(avg_d_loss)
            history["train_g_loss"].append(avg_g_loss)
            history["train_wd"].append(avg_wd)
            history["train_gp"].append(avg_gp)

            # ---- Validation ----
            if val_dataloader is not None:
                self.generator.eval()
                self.discriminator.eval()
                real_embeddings, fake_embeddings = [], []

                with torch.no_grad():
                    for real_samples, in val_dataloader:
                        real_samples = real_samples.to(self.device).float()
                        fake_samples = self.generate_samples(real_samples.size(0))
                        real_embeddings.append(real_samples)
                        fake_embeddings.append(fake_samples)

                real_embeddings = torch.cat(real_embeddings, dim=0)
                fake_embeddings = torch.cat(fake_embeddings, dim=0)
                current_mmd = compute_mmd(real_embeddings, fake_embeddings)
                history["val_mmd"].append(current_mmd)

                val_g_loss = val_wd = 0.0
                val_batches = 0
                with torch.no_grad():
                    for real_samples, in val_dataloader:
                        val_batches += 1
                        real_samples = real_samples.to(self.device).float()
                        bsz = real_samples.size(0)
                        fake_samples = self.generate_samples(bsz)

                        d_real = self.discriminator(real_samples)
                        d_fake = self.discriminator(fake_samples)

                        wd = d_real.mean() - d_fake.mean()
                        g_loss = self.generator_loss(fake_samples)
                        val_g_loss += g_loss.item()
                        val_wd += wd.item()

                avg_val_g_loss = val_g_loss / val_batches
                avg_val_wd = val_wd / val_batches

                history["val_g_loss"].append(avg_val_g_loss)
                history["val_wd"].append(avg_val_wd)

                if verbose:
                    print(f"[Epoch {epoch}] Train D: {avg_d_loss:.7f}, G: {avg_g_loss:.7f}, WD: {avg_wd:.7f}, GP: {avg_gp:.7f} | "
                          f"Val WD: {avg_val_wd:.7f}, MMD: {current_mmd:.7f}")

                # Save best model
                if current_mmd < self.best_mmd:
                    self.best_mmd = current_mmd
                    self.save_checkpoint("best_gan.pt", epoch, history, is_best=True)

                self.generator.train()
                self.discriminator.train()

                # ---- Step scheduler using validation MMD ----
                self.scheduler_G.step(current_mmd)
                self.scheduler_D.step(current_mmd)

            else:
                if avg_wd > getattr(self, "best_wd", float("-inf")):
                    self.best_wd = avg_wd
                    self.save_checkpoint("best_gan.pt", epoch, history, is_best=True)

        self.save_checkpoint("final_gan.pt", epoch, history)
        print("Training completed!")
        return history

    def save_checkpoint(self, filename, epoch, history, is_best=False):
        state = {
            "epoch": epoch + 1,
            "generator_state": self.generator.state_dict(),
            "discriminator_state": self.discriminator.state_dict(),
            "optimizer_G": self.optimizer_G.state_dict(),
            "optimizer_D": self.optimizer_D.state_dict(),
            "latent_dim": self.latent_dim,
            "encoder_state_dim": self.encoder_state_dim,
            "best_mmd": self.best_mmd,
            "history": history
        }
        torch.save(state, filename)
        if is_best:
            print(f"Best model saved at {filename} (MMD: {self.best_mmd:.7f})")
        else:
            print(f"Checkpoint saved at {filename}")

    def load_checkpoint(self, filename, map_location=None):
        checkpoint = torch.load(filename, map_location=map_location or self.device)
        self.generator.load_state_dict(checkpoint["generator_state"])
        self.discriminator.load_state_dict(checkpoint["discriminator_state"])
        self.optimizer_G.load_state_dict(checkpoint["optimizer_G"])
        self.optimizer_D.load_state_dict(checkpoint["optimizer_D"])
        self.generator.to(self.device)
        self.discriminator.to(self.device)
        self.start_epoch = checkpoint["epoch"]
        self.best_mmd = checkpoint.get("best_mmd", float("inf"))
        print(f"Checkpoint loaded: {filename} (resuming at epoch {self.start_epoch})")
        return checkpoint


In [None]:
def generate_synthetic_dataset(generator, decoder, total_samples, batch_size=128, device=None):
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    decoder.to(device).eval()

    sn_list, sc_list, sn_mask_list, sc_mask_list = [], [], [], []
    tn_list, tc_list, un_list, uc_list = [], [], [], []
    tn_mask_list, tc_mask_list = [], []

    with torch.no_grad():
        for start in tqdm(range(0, total_samples, batch_size)):
            current_batch = min(batch_size, total_samples - start)
            latent_samples = generator.generate_samples(batch_size=current_batch).to(device)
            decoded = decoder.generate_decoding(latent_samples, batch_size=current_batch)
            sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat, tn_mask_hat, tc_mask_hat = decoded

            sn_list.append(sn_hat)
            sc_list.append(sc_hat)
            sn_mask_list.append(sn_mask_hat)
            sc_mask_list.append(sc_mask_hat)

            # Extend temporal lists to preserve batch separation
            if isinstance(tn_hat, list):
                tn_list.extend(tn_hat)
            else:
                tn_list.append(tn_hat)

            if isinstance(tc_hat, list):
                tc_list.extend(tc_hat)
            else:
                tc_list.append(tc_hat)

            if isinstance(un_hat, list):
                un_list.extend(un_hat)
            else:
                un_list.append(un_hat)

            if isinstance(uc_hat, list):
                uc_list.extend(uc_hat)
            else:
                uc_list.append(uc_hat)

            if isinstance(tn_mask_hat, list):
                tn_mask_list.extend(tn_mask_hat)
            else:
                tn_mask_list.append(tn_mask_hat)

            if isinstance(tc_mask_hat, list):
                tc_mask_list.extend(tc_mask_hat)
            else:
                tc_mask_list.append(tc_mask_hat)

    return (
        torch.cat(sn_list, dim=0),     
        torch.cat(sc_list, dim=0),     
        torch.cat(sn_mask_list, dim=0),
        torch.cat(sc_mask_list, dim=0),
        tn_list,                       
        tc_list,                        
        un_list,                       
        uc_list,                       
        tn_mask_list,                    
        tc_mask_list                   
    )


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
generator = WGANGP(encoder_state_dim=256,latent_dim=256)
generator.load_checkpoint(filename="weights/best_gan.pt",map_location=device)
sn_dim = 1           
sce_latent_dim = 32 
tn_dim = 25               
tce_latent_dim = 59         
sc_dim = 3
tc_dim = 6                     
decoder = EncoderDecoder(sn_dim=sn_dim,sce_latent_dim=sce_latent_dim,tn_dim=tn_dim,
            tce_latent_dim=tce_latent_dim,sc_dim=sc_dim,tc_dim=tc_dim,latent_dim=256)
decoder.load_checkpoint(filename="weights/best_encoder_decoder_ckpt.pt",map_location=device)
decoder.to(device)
sn_hat, sc_hat, sn_mask_hat, sc_mask_hat, tn_hat, tc_hat, un_hat, uc_hat, tn_mask_hat, tc_mask_hat = generate_synthetic_dataset(generator, decoder, total_samples=70000, batch_size=64, device=None)

## Generate Numerical Data

In [None]:
def denormalize_generated_data(sn_hat, tn_hat, un_hat, sn_mask_hat, tn_mask_hat,
                               normalizer_sn, normalizer_tn, temporal_columns, use_mask=True):
    if isinstance(sn_hat, list):
        sn_hat = torch.cat(sn_hat, dim=0)

    sn = normalizer_sn.stochastic_renormalize(sn_hat, key="AGE").tolist()

    tn_flat = torch.cat(tn_hat, dim=0)
    un_flat = torch.cat(un_hat, dim=0)
    lengths = [x.shape[0] for x in tn_hat]

    tn_denorm = torch.zeros_like(tn_flat)
    for f, key in enumerate(tqdm(temporal_columns[1:], desc="Denormalizing temporal features", mininterval=0.1, leave=False)):
        tn_denorm[:, f] = normalizer_tn.stochastic_renormalize(tn_flat[:, f], key=key)

    un_denorm = normalizer_tn.stochastic_renormalize(un_flat, key=temporal_columns[0])
    un_time = []
    start = 0
    for L in tqdm(lengths, desc="Building cumulative time", mininterval=0.1, leave=False):
        end = start + L
        un_time.append(torch.cumsum(un_denorm[start:end], dim=0))
        start = end

    if use_mask:
        if sn_mask_hat is not None:
            sn_mask_hat = sn_mask_hat.squeeze() if sn_mask_hat.ndim > 1 else sn_mask_hat
            sn = [float('nan') if m == 0 else (x[0] if isinstance(x, list) else x) for x, m in zip(sn, sn_mask_hat.tolist())]
        if tn_mask_hat is not None:
            tn_mask_flat = torch.cat(tn_mask_hat, dim=0)
            tn_denorm = tn_denorm.masked_fill(tn_mask_flat == 0, float('nan'))

    static_df = pd.DataFrame({
        "AGE": sn
    })

    temporal_df = pd.DataFrame(
        tn_denorm.cpu().numpy(),
        columns=temporal_columns[1:]
    )
    temporal_df.insert(0, "TIME", torch.cat(un_time).cpu().numpy())

    return static_df, temporal_df


In [None]:
normalizer_sn = StochasticNormalizer()
normalizer_sn.load_params("weights/static_params.pt")

normalizer_tn = StochasticNormalizer()
normalizer_tn.load_params("weights/temporal_params.pt")

temporal_columns = ['DATE', 'Body Height', 'Body Mass Index', 'Body Weight', 'Calcium',
                   'Carbon Dioxide', 'Chloride', 'Creatinine',
                   'DXA [T-score] Bone density', 'Diastolic Blood Pressure',
                   'Egg white IgE Ab in Serum', 'Estimated Glomerular Filtration Rate',
                   'FEV1/â€‹FVC', 'Glucose', 'HIV status',
                   'Hemoglobin A1c/Hemoglobin.total in Blood',
                   'High Density Lipoprotein Cholesterol',
                   'Low Density Lipoprotein Cholesterol', 'Microalbumin Creatine Ratio',
                   'Oral temperature', 'Potassium', 'Sodium', 'Systolic Blood Pressure',
                    'Total Cholesterol', 'Triglycerides', 'Urea Nitrogen']

static_df,temporal_df = denormalize_generated_data(sn_hat, tn_hat, un_hat,sn_mask_hat,tn_mask_hat,
                                            normalizer_sn, normalizer_tn,temporal_columns)

In [None]:
static_df.head()

In [None]:
static_df.to_csv('static_numerical.csv',index=False)

In [None]:
temporal_df.head()

In [None]:
temporal_df.to_csv('temporal_numerical.csv',index=False)

## Generate Categorical Data

In [None]:
static_input_dims = [5, 2, 21]
static_categorical_cols = ["RACE", "GENDER", "ETHNICITY"]

In [None]:
static_model_instance = CategoricalAutoEncoder(static_input_dims, filename="weights/static_categorical_encoder_decoder.pt")
static_model_instance = static_model_instance.to(device)
static_model_instance.load_model(map_location=device)

In [None]:
static_manager = StaticOneHotManager()
static_manager.load("weights/static_encoded.pkl")

In [None]:
_, sc_onehot = static_model_instance.decode(sc_hat, mask=sc_mask_hat, return_onehot=True)

In [None]:
start = 0
onehot_lists = {}
for col_name, dim in zip(static_categorical_cols, static_input_dims):
    end = start + dim
    onehot_lists[col_name] = [sc_onehot[i, start:end].cpu().tolist() for i in range(sc_onehot.size(0))]
    start = end

onehot_df = pd.DataFrame(onehot_lists)
onehot_df.head()

In [None]:
decoded_dict = {}
for col_name in static_categorical_cols:
    decoded_dict[col_name] = static_manager.inverse_transform(onehot_df[col_name], col_name)

decoded_df = pd.DataFrame(decoded_dict)
decoded_df.head()

In [None]:
decoded_df.to_csv('static_categorical.csv',index=False)

In [None]:
temporal_input_dims = [76, 68, 126, 28, 99, 80]
temporal_categorical_cols = ['CAREPLAN', 'REASON', 'CONDITIONS', 'ENCOUNTER_TYPE',
       'MEDICINE', 'PROCEDURES']

In [None]:
temporal_model_instance = CategoricalAutoEncoder(temporal_input_dims, filename="weights/temporal_categorical_encoder_decoder.pt")
temporal_model_instance = temporal_model_instance.to(device)
temporal_model_instance.load_model(map_location=device)

In [None]:
temporal_manager = TemporalOneHotManager()
temporal_manager.load("weights/temporal_encoded.pkl")

In [None]:
tc_hat_tensor = torch.cat(tc_hat, dim=0).to(device)
tc_mask_hat_tensor = torch.cat(tc_mask_hat, dim=0).to(device)
preds, recon = temporal_model_instance.decode(tc_hat_tensor, mask=tc_mask_hat_tensor, return_onehot=True)

In [None]:
import torch
import pandas as pd
from tqdm import tqdm
import os, gc

BATCH_SIZE = 1024
SAVE_DIR = "onehot_parts"
os.makedirs(SAVE_DIR, exist_ok=True)

for batch_idx, chunk in enumerate(tqdm(torch.split(recon, BATCH_SIZE), desc="Processing batches")):
    start = 0
    batch_dict = {}
    for col_name, dim in zip(temporal_categorical_cols, temporal_input_dims):
        end = start + dim
        col_array = chunk[:, start:end].numpy()
        col_int = col_array.argmax(axis=1)
        batch_dict[col_name] = col_int
        start = end
    batch_df = pd.DataFrame(batch_dict)
    batch_df.to_parquet(f"{SAVE_DIR}/onehot_batch_{batch_idx:04d}.parquet", index=False)
    del batch_df, batch_dict, chunk, col_array, col_int
    gc.collect()

import glob
files = sorted(glob.glob(f"{SAVE_DIR}/onehot_batch_*.parquet"))
onehot_df = pd.concat((pd.read_parquet(f) for f in files), ignore_index=True)
print("Done. One-hot DataFrame shape:", onehot_df.shape)

In [None]:
onehot_df.head()

In [None]:
decoded_temporal = pd.DataFrame()
for col in temporal_categorical_cols:
    decoded_temporal[col] = temporal_manager.inverse_transform2(onehot_df[col], col)
decoded_temporal.head()

In [None]:
def denormalize_uc(uc_hat, normalizer_uc):
    if isinstance(uc_hat, list):
        lengths = [x.shape[0] for x in uc_hat]
        uc_flat = torch.cat(uc_hat, dim=0)
    else:
        lengths = [uc_hat.shape[0]]
        uc_flat = uc_hat
    uc_denorm = normalizer_uc.stochastic_renormalize(uc_flat, key='DATE')
    uc_time_list = []
    start = 0
    for L in lengths:
        end = start + L
        uc_time_list.append(torch.cumsum(uc_denorm[start:end], dim=0))
        start = end
    uc_time_flat = torch.cat(uc_time_list, dim=0)

    return uc_time_flat

In [None]:
normalizer_uc = StochasticNormalizer()
normalizer_uc.load_params("weights/categorical_times_params.pt")
uc_time = denormalize_uc(uc_hat,normalizer_uc)

In [None]:
decoded_temporal.insert(0, "DATE", uc_time.cpu().numpy())
decoded_temporal.head()

In [None]:
decoded_temporal.to_csv('temporal_categorical.csv',index=False)