<a href="https://colab.research.google.com/github/dyna478/p1/blob/main/Diff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ks_2samp, chi2_contingency, entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mutual_info_score, mean_squared_error, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.metrics import classification_report, confusion_matrix
import datetime
import warnings
import re  # Ajoutez cette ligne

warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ks_2samp, chi2_contingency, entropy, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mutual_info_score, mean_squared_error, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

# Pour torch et les fonctionnalités de diffusion
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Pour le traitement des données
from sklearn.preprocessing import StandardScaler, OneHotEncoder

# Pour les métriques d'évaluation supplémentaires
from sklearn.metrics import classification_report, confusion_matrix

# Pour la visualisation des distributions
import matplotlib.pyplot as plt
import seaborn as sns

# Pour les calculs statistiques
from scipy import stats

# Pour le parallélisme (optionnel, mais peut accélérer certains calculs)
from joblib import Parallel, delayed

# Pour la manipulation de dates (si vous avez des données temporelles)
import datetime

# Pour la gestion des avertissements
import warnings
warnings.filterwarnings('ignore')

In [None]:
class EnhancedDiffusionModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128, time_dim=128, seq_length=10,
                 num_layers=4, num_heads=8, num_numeric_features=0,
                 alphas_cumprod=None, alphas=None, betas=None, num_classes=1000,
                 feature_types=None, feature_dims=None, feature_specific_params=None,
                 dynamic_thresholding_ratio=0.995, interval_min=-1, interval_max=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.time_dim = time_dim
        self.seq_length = seq_length
        self.num_numeric_features = num_numeric_features
        self.num_classes = num_classes
        self.feature_types = feature_types or ['continuous'] * input_dim
        self.feature_dims = feature_dims or [1] * input_dim
        assert len(self.feature_types) == len(self.feature_dims), "feature_types and feature_dims must have the same length"
        assert sum(self.feature_dims) == input_dim, f"Sum of feature_dims ({sum(self.feature_dims)}) must equal input_dim ({input_dim})"

        self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
        self.interval_min = interval_min
        self.interval_max = interval_max
        self.dynamic_thresholding_ratio = 0.995

         # Paramètres par défaut
        self.default_interval_min = interval_min
        self.default_interval_max = interval_max
        self.default_dynamic_thresholding_ratio = dynamic_thresholding_ratio

        # Paramètres spécifiques aux caractéristiques
        self.feature_specific_params = feature_specific_params or {}

        # Register buffers
        if alphas_cumprod is not None:
            self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
        if alphas is not None:
            self.register_buffer('alphas', torch.tensor(alphas, dtype=torch.float32))
        if betas is not None:
            self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.Mish(),
            nn.LayerNorm(time_dim),
            nn.Linear(time_dim, time_dim),
        )

        self.input_upscale = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish()
        )

        self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True)

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim + time_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.Mish(),
                ResidualBlock(hidden_dim)
            ) for _ in range(num_layers)
        ])

        self.attention_layers = nn.ModuleList([
            AttentionLayer(hidden_dim, num_heads=num_heads)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, output_dim)
        )

        self.class_embedding = nn.Embedding(num_classes + 1, hidden_dim)


        print(f"Model initialized with input_dim: {input_dim}, output_dim: {output_dim}")
        print(f"Feature types: {feature_types}")
        print(f"Feature dimensions: {feature_dims}")




    def forward(self, x, t, c=None):
        if isinstance(x, tuple):
            x_cont, x_disc = x
            print(f"Forward input - cont shape: {x_cont.shape}, disc shape: {x_disc.shape}")
            x = torch.cat([x_cont, x_disc], dim=-1)

        #print(f"Forward method input shapes - x: {x.shape}, t: {t.shape}")
        batch_size, seq_len, _ = x.shape

        # Time embedding
        t_emb = self.time_mlp(t)
        t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)

        # Process input
        x = self.input_upscale(x)
        x, _ = self.rnn(x)

        # Apply layers with attention
        for layer, attention in zip(self.layers, self.attention_layers):
            x = layer(torch.cat([x, t_emb], dim=-1))
            x = attention(x)

        if c is None:
            c = torch.full((x.shape[0],), self.num_classes, device=x.device)
        else:
            c = torch.where(c == -1, self.num_classes, c)

        if c.dim() == 1:
            c = c.unsqueeze(1).expand(-1, x.size(1))

        class_emb = self.class_embedding(c)
        x = x + class_emb

        output = self.output_layer(x)
        output_cont = output[:, :, :self.num_numeric_features]
        output_disc = output[:, :, self.num_numeric_features:]

        # Appliquer le limited interval et le dynamic thresholding spécifique à chaque caractéristique
        for i in range(self.num_numeric_features):
            feature = f"numeric_{i}"
            params = self.feature_specific_params.get(feature, {})
            interval_min = params.get('interval_min', self.default_interval_min)
            interval_max = params.get('interval_max', self.default_interval_max)
            dt_ratio = params.get('dynamic_thresholding_ratio', self.default_dynamic_thresholding_ratio)

            # Utiliser des opérations qui créent de nouveaux tenseurs
            output_cont_i = torch.clamp(output_cont[:, :, i:i+1], interval_min, interval_max)
            output_cont_i = self.dynamic_thresholding(output_cont_i, dt_ratio)
            output_cont = torch.cat([output_cont[:, :, :i], output_cont_i, output_cont[:, :, i+1:]], dim=2)


        # Appliquer le limited interval aux sorties continues
        output_cont = torch.clamp(output_cont, self.interval_min, self.interval_max)

        # Appliquer le dynamic thresholding aux sorties continues
        output_cont = self.dynamic_thresholding(output_cont, self.dynamic_thresholding_ratio)

        # Traitement des sorties discrètes
        output_disc_list = []
        start_idx = 0
        for feat_dim in self.feature_dims[self.num_numeric_features:]:
            end_idx = start_idx + feat_dim
            feat_output = output_disc[:, :, start_idx:end_idx]
            if feat_output.shape[-1] != feat_dim:
                print(f"Adjusting output for discrete feature with dim {feat_dim}")
                feat_output = F.pad(feat_output, (0, feat_dim - feat_output.shape[-1]))
            output_disc_list.append(feat_output)
            start_idx = end_idx
        output_disc = torch.cat(output_disc_list, dim=-1)

        print(f"Forward method output shapes - cont: {output_cont.shape}, disc: {output_disc.shape}")
        return output_cont, output_disc

    def dynamic_thresholding(self, x, p):
        # Reshape pour appliquer le seuillage sur chaque séquence indépendamment
        batch_size, seq_len, feature_dim = x.shape
        x_flat = x.reshape(batch_size * seq_len, feature_dim)

        s = torch.quantile(torch.abs(x_flat), p, dim=1)
        s = torch.clamp(s, min=1.0).unsqueeze(1)

        x_thresholded = torch.clamp(x_flat, -s, s) / s

        # Reshape pour revenir à la forme originale
        return x_thresholded.reshape(batch_size, seq_len, feature_dim)


    # ... (keep other methods the same)
    def to(self, device):
        super().to(device)
        if hasattr(self, 'alphas_cumprod'):
            self.alphas_cumprod = self.alphas_cumprod.to(device)
        if hasattr(self, 'alphas'):
            self.alphas = self.alphas.to(device)
        if hasattr(self, 'betas'):
            self.betas = self.betas.to(device)
        return self

    def compute_score(self, x, t, c=None):
        x_cont, x_disc = x[:, :, :self.num_numeric_features], x[:, :, self.num_numeric_features:]

        output_cont, output_disc = self(x, t, c)

        lambda_t = self.interpolate_alphas_cumprod(t).unsqueeze(1).unsqueeze(1)

        score_cont = (1 / torch.sqrt(lambda_t)) * (output_cont - x_cont / torch.sqrt(lambda_t))
        score_disc = (1 / torch.sqrt(lambda_t)) * (output_disc - x_disc / torch.sqrt(lambda_t))

        return torch.cat([score_cont, score_disc], dim=-1).float()




    def compute_loss(self, x_0, x_t, t):
        print(f"Compute loss input shapes - x_0: {x_0.shape}, x_t: {type(x_t)}, t: {t.shape}")
        predicted_noise_cont, predicted_noise_disc = self(x_t, t)
        print(f"Predicted noise shapes - cont: {predicted_noise_cont.shape}, disc: {predicted_noise_disc.shape}")

        # Unpack x_t if it's a tuple
        if isinstance(x_t, tuple):
            x_t_cont, x_t_disc = x_t
        else:
            x_t_cont = x_t[:, :, :self.num_numeric_features]
            x_t_disc = x_t[:, :, self.num_numeric_features:]

        total_loss = 0  # Utilisation d'une nouvelle variable pour accumuler la perte
        start_idx = 0
        for i, (feat_type, feat_dim) in enumerate(zip(self.feature_types, self.feature_dims)):
            end_idx = start_idx + feat_dim
            if feat_type == 'continuous':
                pred = predicted_noise_cont[:, :, start_idx:end_idx]
                target = x_0[:, :, start_idx:end_idx] - x_t_cont[:, :, start_idx:end_idx]
                total_loss = total_loss + F.mse_loss(pred, target)  # Addition au lieu de +=
            else:  # discrete
                pred = predicted_noise_disc[:, :, start_idx:end_idx]
                target = x_0[:, :, start_idx:end_idx].long()
                print(f"Before reshape - Discrete feature {i}:")
                print(f"  pred shape: {pred.shape}, pred size: {pred.numel()}")
                print(f"  target shape: {target.shape}, target size: {target.numel()}")
                print(f"  feat_dim: {feat_dim}")

                # Reshape pred and target
                batch_size, seq_len, pred_classes = pred.shape
                pred_reshaped = pred.reshape(-1, pred_classes)
                target_reshaped = target.reshape(-1)
                print(f"After reshape - Discrete feature {i}:")
                print(f"  pred_reshaped shape: {pred_reshaped.shape}, pred_reshaped size: {pred_reshaped.numel()}")
                print(f"  target_reshaped shape: {target_reshaped.shape}, target_reshaped size: {target_reshaped.numel()}")

                # Ensure pred and target have the same first dimension
                min_size = min(pred_reshaped.shape[0], target_reshaped.shape[0])
                pred_reshaped = pred_reshaped[:min_size]
                target_reshaped = target_reshaped[:min_size]

                # Handle mismatch in number of classes
                if pred_classes != feat_dim:
                    print(f"Warning: Mismatch in number of classes for feature {i}. Pred: {pred_classes}, Expected: {feat_dim}")
                    # Option 1: Truncate or pad the predictions
                    if pred_classes < feat_dim:
                        pred_reshaped = F.pad(pred_reshaped, (0, feat_dim - pred_classes))
                    else:
                        pred_reshaped = pred_reshaped[:, :feat_dim]
                    # Option 2: Adjust the target (use this if the model's prediction is correct)
                    # target_reshaped = torch.clamp(target_reshaped, 0, pred_classes - 1)

                print(f"Final - Discrete feature {i}:")
                print(f"  pred_reshaped shape: {pred_reshaped.shape}, pred_reshaped size: {pred_reshaped.numel()}")
                print(f"  target_reshaped shape: {target_reshaped.shape}, target_reshaped size: {target_reshaped.numel()}")

                total_loss = total_loss + F.cross_entropy(pred_reshaped, target_reshaped)  # Addition au lieu de +=
            start_idx = end_idx

        print(f"Computed loss: {total_loss.item()}")
        return total_loss

    # ... (keep other methods the same)


    def interpolate_alphas_cumprod(self, t):
        return self._interpolate_tensor(self.alphas_cumprod, t)

    def interpolate_alphas(self, t):
        return self._interpolate_tensor(self.alphas, t)

    def interpolate_betas(self, t):
        return self._interpolate_tensor(self.betas, t)

    def _interpolate_tensor(self, tensor, t):
        low = torch.floor(t).long()
        high = torch.ceil(t).long()
        w = t - low.float()
        return (1 - w) * tensor[low] + w * tensor[high]

    @staticmethod
    def get_timestep_embedding(timesteps, embedding_dim):
        assert len(timesteps.shape) == 1, "Timesteps should be a 1-D tensor"
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
        emb = emb.to(device=timesteps.device)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
        return emb

# Helper classes (AttentionLayer, ResidualBlock, SinusoidalPositionEmbeddings) remain the same as in the previous response

In [None]:
class EnhancedDiffusionModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128, time_dim=128, seq_length=10,
                 num_layers=4, num_heads=8, num_numeric_features=0,
                 alphas_cumprod=None, alphas=None, betas=None, num_classes=1000,
                 feature_types=None, feature_dims=None, feature_specific_params=None,
                 dynamic_thresholding_ratio=0.995, interval_min=-1, interval_max=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.time_dim = time_dim
        self.seq_length = seq_length
        self.num_numeric_features = num_numeric_features
        self.num_classes = num_classes
        self.feature_types = feature_types or ['continuous'] * input_dim
        self.feature_dims = feature_dims or [1] * input_dim
        assert len(self.feature_types) == len(self.feature_dims), "feature_types and feature_dims must have the same length"
        assert sum(self.feature_dims) == input_dim, f"Sum of feature_dims ({sum(self.feature_dims)}) must equal input_dim ({input_dim})"

        self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
        self.interval_min = interval_min
        self.interval_max = interval_max
        self.dynamic_thresholding_ratio = 0.995

         # Paramètres par défaut
        self.default_interval_min = interval_min
        self.default_interval_max = interval_max
        self.default_dynamic_thresholding_ratio = dynamic_thresholding_ratio

        # Paramètres spécifiques aux caractéristiques
        self.feature_specific_params = feature_specific_params or {}

        # Register buffers
        if alphas_cumprod is not None:
            self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
        if alphas is not None:
            self.register_buffer('alphas', torch.tensor(alphas, dtype=torch.float32))
        if betas is not None:
            self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.Mish(),
            nn.LayerNorm(time_dim),
            nn.Linear(time_dim, time_dim),
        )

        self.input_upscale = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish()
        )

        self.rnn = nn.LSTM(hidden_dim, hidden_dim, num_layers=2, batch_first=True)

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim + time_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.Mish(),
                ResidualBlock(hidden_dim)
            ) for _ in range(num_layers)
        ])

        self.attention_layers = nn.ModuleList([
            AttentionLayer(hidden_dim, num_heads=num_heads)
            for _ in range(num_layers)
        ])

        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.Mish(),
            nn.Linear(hidden_dim, output_dim)
        )

        self.class_embedding = nn.Embedding(num_classes + 1, hidden_dim)


        print(f"Model initialized with input_dim: {input_dim}, output_dim: {output_dim}")
        print(f"Feature types: {feature_types}")
        print(f"Feature dimensions: {feature_dims}")




    def forward(self, x, t, c=None):
        if isinstance(x, tuple):
            x_cont, x_disc = x
            print(f"Forward input - cont shape: {x_cont.shape}, disc shape: {x_disc.shape}")
            x = torch.cat([x_cont, x_disc], dim=-1)

        #print(f"Forward method input shapes - x: {x.shape}, t: {t.shape}")
        batch_size, seq_len, _ = x.shape

        # Time embedding
        t_emb = self.time_mlp(t)
        t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)

        # Process input
        x = self.input_upscale(x)
        x, _ = self.rnn(x)

        # Apply layers with attention
        for layer, attention in zip(self.layers, self.attention_layers):
            x = layer(torch.cat([x, t_emb], dim=-1))
            x = attention(x)

        if c is None:
            c = torch.full((x.shape[0],), self.num_classes, device=x.device)
        else:
            c = torch.where(c == -1, self.num_classes, c)

        if c.dim() == 1:
            c = c.unsqueeze(1).expand(-1, x.size(1))

        class_emb = self.class_embedding(c)
        x = x + class_emb

        output = self.output_layer(x)
        output_cont = output[:, :, :self.num_numeric_features]
        output_disc = output[:, :, self.num_numeric_features:]

        # Appliquer le limited interval et le dynamic thresholding spécifique à chaque caractéristique
        for i in range(self.num_numeric_features):
            feature = f"numeric_{i}"
            params = self.feature_specific_params.get(feature, {})
            interval_min = params.get('interval_min', self.default_interval_min)
            interval_max = params.get('interval_max', self.default_interval_max)
            dt_ratio = params.get('dynamic_thresholding_ratio', self.default_dynamic_thresholding_ratio)

            # Utiliser des opérations qui créent de nouveaux tenseurs
            output_cont_i = torch.clamp(output_cont[:, :, i:i+1], interval_min, interval_max)
            output_cont_i = self.dynamic_thresholding(output_cont_i, dt_ratio)
            output_cont = torch.cat([output_cont[:, :, :i], output_cont_i, output_cont[:, :, i+1:]], dim=2)


        # Appliquer le limited interval aux sorties continues
        output_cont = torch.clamp(output_cont, self.interval_min, self.interval_max)

        # Appliquer le dynamic thresholding aux sorties continues
        output_cont = self.dynamic_thresholding(output_cont, self.dynamic_thresholding_ratio)

        # Traitement des sorties discrètes
        output_disc_list = []
        start_idx = 0
        for feat_dim in self.feature_dims[self.num_numeric_features:]:
            end_idx = start_idx + feat_dim
            feat_output = output_disc[:, :, start_idx:end_idx]
            if feat_output.shape[-1] != feat_dim:
                print(f"Adjusting output for discrete feature with dim {feat_dim}")
                feat_output = F.pad(feat_output, (0, feat_dim - feat_output.shape[-1]))
            output_disc_list.append(feat_output)
            start_idx = end_idx
        output_disc = torch.cat(output_disc_list, dim=-1)

        print(f"Forward method output shapes - cont: {output_cont.shape}, disc: {output_disc.shape}")
        return output_cont, output_disc

    def dynamic_thresholding(self, x, p):
        # Reshape pour appliquer le seuillage sur chaque séquence indépendamment
        batch_size, seq_len, feature_dim = x.shape
        x_flat = x.reshape(batch_size * seq_len, feature_dim)

        s = torch.quantile(torch.abs(x_flat), p, dim=1)
        s = torch.clamp(s, min=1.0).unsqueeze(1)

        x_thresholded = torch.clamp(x_flat, -s, s) / s

        # Reshape pour revenir à la forme originale
        return x_thresholded.reshape(batch_size, seq_len, feature_dim)


    # ... (keep other methods the same)
    def to(self, device):
        super().to(device)
        if hasattr(self, 'alphas_cumprod'):
            self.alphas_cumprod = self.alphas_cumprod.to(device)
        if hasattr(self, 'alphas'):
            self.alphas = self.alphas.to(device)
        if hasattr(self, 'betas'):
            self.betas = self.betas.to(device)
        return self

    def compute_score(self, x, t, c=None):
        x_cont, x_disc = x[:, :, :self.num_numeric_features], x[:, :, self.num_numeric_features:]

        output_cont, output_disc = self(x, t, c)

        lambda_t = self.interpolate_alphas_cumprod(t).unsqueeze(1).unsqueeze(1)

        score_cont = (1 / torch.sqrt(lambda_t)) * (output_cont - x_cont / torch.sqrt(lambda_t))
        score_disc = (1 / torch.sqrt(lambda_t)) * (output_disc - x_disc / torch.sqrt(lambda_t))

        return torch.cat([score_cont, score_disc], dim=-1).float()




    def compute_loss(self, x_0, x_t, t):
        print(f"Compute loss input shapes - x_0: {x_0.shape}, x_t: {type(x_t)}, t: {t.shape}")
        predicted_noise_cont, predicted_noise_disc = self(x_t, t)
        print(f"Predicted noise shapes - cont: {predicted_noise_cont.shape}, disc: {predicted_noise_disc.shape}")

        # Unpack x_t if it's a tuple
        if isinstance(x_t, tuple):
            x_t_cont, x_t_disc = x_t
        else:
            x_t_cont = x_t[:, :, :self.num_numeric_features]
            x_t_disc = x_t[:, :, self.num_numeric_features:]

        total_loss = 0  # Utilisation d'une nouvelle variable pour accumuler la perte
        start_idx = 0
        for i, (feat_type, feat_dim) in enumerate(zip(self.feature_types, self.feature_dims)):
        if feat_type == 'discrete':
            pred = predicted_noise_disc[:, :, start_idx:end_idx]
            target = x_0[:, :, start_idx:end_idx].long()

            # Ajuster la dimension des prédictions si nécessaire
            if pred.shape[-1] != feat_dim:
                if pred.shape[-1] < feat_dim:
                    pred = F.pad(pred, (0, feat_dim - pred.shape[-1]))
                else:
                    pred = pred[:, :, :feat_dim]

            pred_reshaped = pred.reshape(-1, feat_dim)
            target_reshaped = target.reshape(-1)

            loss = loss + F.cross_entropy(pred_reshaped, target_reshaped)# Addition au lieu de +=
            else:  # discrete
                pred = predicted_noise_disc[:, :, start_idx:end_idx]
                target = x_0[:, :, start_idx:end_idx].long()
                print(f"Before reshape - Discrete feature {i}:")
                print(f"  pred shape: {pred.shape}, pred size: {pred.numel()}")
                print(f"  target shape: {target.shape}, target size: {target.numel()}")
                print(f"  feat_dim: {feat_dim}")

                # Reshape pred and target
                batch_size, seq_len, pred_classes = pred.shape
                pred_reshaped = pred.reshape(-1, pred_classes)
                target_reshaped = target.reshape(-1)
                print(f"After reshape - Discrete feature {i}:")
                print(f"  pred_reshaped shape: {pred_reshaped.shape}, pred_reshaped size: {pred_reshaped.numel()}")
                print(f"  target_reshaped shape: {target_reshaped.shape}, target_reshaped size: {target_reshaped.numel()}")

                # Ensure pred and target have the same first dimension
                min_size = min(pred_reshaped.shape[0], target_reshaped.shape[0])
                pred_reshaped = pred_reshaped[:min_size]
                target_reshaped = target_reshaped[:min_size]

                # Handle mismatch in number of classes
                if pred_classes != feat_dim:
                    print(f"Warning: Mismatch in number of classes for feature {i}. Pred: {pred_classes}, Expected: {feat_dim}")
                    # Option 1: Truncate or pad the predictions
                    if pred_classes < feat_dim:
                        pred_reshaped = F.pad(pred_reshaped, (0, feat_dim - pred_classes))
                    else:
                        pred_reshaped = pred_reshaped[:, :feat_dim]
                    # Option 2: Adjust the target (use this if the model's prediction is correct)
                    # target_reshaped = torch.clamp(target_reshaped, 0, pred_classes - 1)

                print(f"Final - Discrete feature {i}:")
                print(f"  pred_reshaped shape: {pred_reshaped.shape}, pred_reshaped size: {pred_reshaped.numel()}")
                print(f"  target_reshaped shape: {target_reshaped.shape}, target_reshaped size: {target_reshaped.numel()}")

                total_loss = total_loss + F.cross_entropy(pred_reshaped, target_reshaped)  # Addition au lieu de +=
            start_idx = end_idx

        print(f"Computed loss: {total_loss.item()}")
        return total_loss

    # ... (keep other methods the same)


    def interpolate_alphas_cumprod(self, t):
        return self._interpolate_tensor(self.alphas_cumprod, t)

    def interpolate_alphas(self, t):
        return self._interpolate_tensor(self.alphas, t)

    def interpolate_betas(self, t):
        return self._interpolate_tensor(self.betas, t)

    def _interpolate_tensor(self, tensor, t):
        low = torch.floor(t).long()
        high = torch.ceil(t).long()
        w = t - low.float()
        return (1 - w) * tensor[low] + w * tensor[high]

    @staticmethod
    def get_timestep_embedding(timesteps, embedding_dim):
        assert len(timesteps.shape) == 1, "Timesteps should be a 1-D tensor"
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
        emb = emb.to(device=timesteps.device)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
        return emb

# Helper classes (AttentionLayer, ResidualBlock, SinusoidalPositionEmbeddings) remain the same as in the previous response

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon
import math
import scipy

class AttentionLayer(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.Mish(),
            nn.Linear(4 * dim, dim)
        )

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.Mish(),
            nn.Linear(dim, dim)
        )

    def forward(self, x):
        return x + self.block(x)

class DiffusionModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256, time_dim=256, seq_length=30, num_numeric_features=0, alphas_cumprod=None, alphas=None, betas=None, num_classes=1000):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.time_dim = time_dim
        self.seq_length = seq_length
        self.num_numeric_features = num_numeric_features
        self.num_classes = num_classes

        # Only register buffers if they are not None
        if alphas_cumprod is not None:
            self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
        if alphas is not None:
            self.register_buffer('alphas', torch.tensor(alphas, dtype=torch.float32))
        if betas is not None:
            self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))

        # Rest of the initialization code...
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim),
        )

        self.input_upscale = nn.Linear(input_dim, hidden_dim)
        self.rnn = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.attention = AttentionLayer(hidden_dim)

        self.output_layer = nn.Sequential(
            nn.Linear(hidden_dim + time_dim, hidden_dim),
            nn.SiLU(),
            ResidualBlock(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            ResidualBlock(hidden_dim),
            nn.Linear(hidden_dim, output_dim)
        )

        self.class_embedding = nn.Embedding(num_classes + 1, hidden_dim)


    def to(self, device):
        super().to(device)
        self.alphas_cumprod = self.alphas_cumprod.to(device)
        self.alphas = self.alphas.to(device)
        self.betas = self.betas.to(device)
        return self

    def forward(self, x, t, c= None):
        batch_size, seq_len, _ = x.shape

        # Time embedding
        t_emb = self.time_mlp(t)

        t_emb = t_emb.unsqueeze(1).expand(-1, seq_len, -1)

        # Process input
        x = self.input_upscale(x)
        x, _ = self.rnn(x)
        x = self.attention(x)


        if c is None:
            # For unconditional generation, use the last embedding (num_classes)
            c = torch.full((x.shape[0],), self.num_classes, device=x.device)
        else:
            c = torch.where(c == -1, self.num_classes, c)

        if c.dim() == 1:
            c = c.unsqueeze(1).expand(-1, x.size(1))

        class_emb = self.class_embedding(c)

        x = x + class_emb


        # Combine with time embedding
        x = torch.cat([x, t_emb], dim=-1)

        output = self.output_layer(x)
        output_cont = output[:, :, :self.num_numeric_features]
        output_disc = output[:, :, self.num_numeric_features:]

        # Apply Gumbel-Softmax to discrete output
        output_disc = F.gumbel_softmax(output_disc, tau=1.0, hard=False, dim=-1)

        return output_cont, output_disc

    def compute_score(self, x, t, c=None):
        x_cont, x_disc = x[:, :, :self.num_numeric_features], x[:, :, self.num_numeric_features:]

        output_cont, output_disc = self.model(x, t, c)  # c can be None here

        lambda_t = self.interpolate_alphas_cumprod(t).unsqueeze(1).unsqueeze(1)

        score_cont = (1 / torch.sqrt(lambda_t)) * (output_cont - x_cont / torch.sqrt(lambda_t))
        score_disc = (1 / torch.sqrt(lambda_t)) * (output_disc - x_disc / torch.sqrt(lambda_t))

        return torch.cat([score_cont, score_disc], dim=-1).float()

    def compute_loss(self, x_0, x_t, t):
        predicted_noise = self(x_t, t)

        loss = 0
        for i, feat_type in enumerate(self.feature_types):
            if feat_type == 'continuous':
                loss += F.mse_loss(predicted_noise[:, :, i], x_0[:, :, i] - x_t[:, :, i])
            else:  # discrete
                loss += F.cross_entropy(predicted_noise[:, :, i].view(-1, predicted_noise.shape[-1]), x_0[:, :, i].long().view(-1))

        return loss
    def interpolate_alphas_cumprod(self, t):
        # Interpolate alphas_cumprod for float timesteps
        low = torch.floor(t).long()
        high = torch.ceil(t).long()
        w = t - low.float()
        return (1 - w) * self.alphas_cumprod[low] + w * self.alphas_cumprod[high]
    def interpolate_alphas(self, t):
        return self._interpolate_tensor(self.alphas, t)

    def interpolate_betas(self, t):
        return self._interpolate_tensor(self.betas, t)

    def _interpolate_tensor(self, tensor, t):
        low = torch.floor(t).long()
        high = torch.ceil(t).long()
        w = t - low.float()
        return (1 - w) * tensor[low] + w * tensor[high]

    @staticmethod
    def get_timestep_embedding(timesteps, embedding_dim):
        assert len(timesteps.shape) == 1, "Timesteps should be a 1-D tensor"
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
        emb = emb.to(device=timesteps.device)
        emb = timesteps.float()[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
        return emb
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
        def sample_gumbel(shape, eps=1e-10, device=None):
            U = torch.rand(shape, device=device)
            return -torch.log(-torch.log(U + eps) + eps)

        device = logits.device
        gumbels = sample_gumbel(logits.shape, eps=eps, device=device)
        gumbels = (logits + gumbels) / tau
        y_soft = gumbels.softmax(dim)

        if hard:
            index = y_soft.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits, device=device).scatter_(dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            ret = y_soft
        return ret

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class FinDiffusionPlus:

    def __init__(self, n_timesteps=200, n_noise=200, learning_rate=0.001, seq_length=30, epsilon=1e-8, p_uncond=0.1):
        self.n_timesteps = n_timesteps
        self.n_noise = n_noise
        self.learning_rate = learning_rate
        self.seq_length = seq_length
        self.epsilon = epsilon
        self.p_uncond = p_uncond
        self.feature_types = None
        self.target_column = 'target'

        # Set up device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Use cosine beta schedule
        self.betas = self.cosine_beta_schedule(n_timesteps).to(self.device)
        self.alphas = (1. - self.betas).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(self.device)


        # Move to GPU if available
        try:
            self.betas = self.betas.to(self.device)
            self.alphas = self.alphas.to(self.device)
            self.alphas_cumprod = self.alphas_cumprod.to(self.device)
            print("Successfully moved tensors to GPU")
        except RuntimeError as e:
            print(f"Error moving tensors to GPU: {e}")
            print("Falling back to CPU")
            self.device = torch.device('cpu')

        # Define noise schedule
        self.noise_schedule = torch.linspace(1, 0.01, n_timesteps).to(self.device)

        print(f"alphas_cumprod shape: {self.alphas_cumprod.shape}")

        self.model = None
        self.numeric_features = None
        self.categorical_features = None
        self.input_dim = None

    def cosine_beta_schedule(self,timesteps, s=0.008):
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    def preprocess_data(self, data_path):
        df = pd.read_csv(data_path)
        X = df.drop([self.target_column], axis=1)
        y = df[self.target_column]

        self.numeric_features = X.select_dtypes(include=['int64', 'float64']).columns
        self.categorical_features = X.select_dtypes(include=['object']).columns

        # Preprocess numeric features
        self.numeric_scaler = StandardScaler()
        X[self.numeric_features] = self.numeric_scaler.fit_transform(X[self.numeric_features])

        # Preprocess categorical features
        self.categorical_encoders = {}
        for cat_feature in self.categorical_features:
            encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
            encoded = encoder.fit_transform(X[[cat_feature]])
            encoded_df = pd.DataFrame(encoded, columns=[f"{cat_feature}_{cat}" for cat in encoder.categories_[0]])
            X = pd.concat([X.drop(columns=[cat_feature]), encoded_df], axis=1)
            self.categorical_encoders[cat_feature] = encoder


        self.feature_types = ['continuous'] * len(self.numeric_features)
        self.feature_dims = [1] * len(self.numeric_features)
        for encoder in self.categorical_encoders.values():
            self.feature_types.append('discrete')
            self.feature_dims.append(len(encoder.categories_[0]))

        # Set input dimension
        self.input_dim = sum(self.feature_dims)



        # Create sequences
        X_seq = self.create_sequences(X.values, self.seq_length)
        y_encoded = pd.factorize(y)[0]

        #self.input_dim = X.shape[1]
        self.num_numeric_features = len(self.numeric_features)

        if X_seq.shape[1] < self.seq_length:
            pad_width = ((0, 0), (0, self.seq_length - X_seq.shape[1]), (0, 0))
            X_seq = np.pad(X_seq, pad_width, mode='constant')
        elif X_seq.shape[1] > self.seq_length:
            X_seq = X_seq[:, :self.seq_length, :]

        print(f"Preprocessed data shape: {X_seq.shape}")
        print(f"Number of numeric features: {len(self.numeric_features)}")
        print(f"Number of categorical features: {len(self.categorical_features)}")
        print(f"Feature types: {self.feature_types}")
        print(f"Feature dimensions: {self.feature_dims}")

        return X_seq.astype(np.float32), y_encoded


    def compute_score(self, x, t, c=None):
        cond_score_cont, cond_score_disc = self.model(x, t, c)

        if c is None:
            # Si c est None, cela signifie que nous calculons le score inconditionnel
            uncond_score_cont, uncond_score_disc = cond_score_cont, cond_score_disc
        else:
            # Créer un tenseur de -1 avec la même forme que c pour la génération inconditionnelle
            uncond_c = torch.full_like(c, -1)
            uncond_score_cont, uncond_score_disc = self.model(x, t, uncond_c)

        return (cond_score_cont, cond_score_disc), (uncond_score_cont, uncond_score_disc)

    def create_sequences(self, data, seq_length):
        n_samples, n_features = data.shape
        if n_samples < seq_length:
            # If we have fewer samples than the sequence length, pad with zeros
            pad_size = seq_length - n_samples
            padded_data = np.pad(data, ((0, pad_size), (0, 0)), mode='constant')
            return padded_data.reshape(1, seq_length, n_features)
        else:
            n_seq = n_samples - seq_length + 1
            sequences = np.zeros((n_seq, seq_length, n_features))
            for i in range(n_seq):
                sequences[i] = data[i:i+seq_length]
            return sequences


    def forward_process(self, X, t):
        X_cont = X[:, :, :self.num_numeric_features]
        X_disc = X[:, :, self.num_numeric_features:]

        noise_cont = torch.randn_like(X_cont, device=self.device)
        noise_disc = torch.randn_like(X_disc, device=self.device)

        alpha_cumprod = self.alphas_cumprod[t].unsqueeze(1).unsqueeze(1)

        X_cont_noisy = torch.sqrt(alpha_cumprod) * X_cont + torch.sqrt(1 - alpha_cumprod) * noise_cont

        X_disc_noisy_list = []
        for i, feat_dim in enumerate(self.feature_dims[self.num_numeric_features:]):
            X_disc_feat = X_disc[:, :, i:i+feat_dim]
            noise_disc_feat = noise_disc[:, :, i:i+feat_dim]
            X_disc_noisy_feat = self.gumbel_softmax(X_disc_feat + noise_disc_feat, tau=1, hard=False)
            X_disc_noisy_list.append(X_disc_noisy_feat)

        X_disc_noisy = torch.cat(X_disc_noisy_list, dim=-1)

        return X_cont_noisy, X_disc_noisy

    def interpolate_alphas_cumprod(self, t):
        # Interpolate alphas_cumprod for float timesteps
        low = torch.floor(t).long()
        high = torch.ceil(t).long()
        w = t - low.float()
        return (1 - w) * self.alphas_cumprod[low] + w * self.alphas_cumprod[high]

    def reverse_process(self, X, t_tensor, score):
        with torch.no_grad():
            noise_level = self.noise_schedule[t_tensor].unsqueeze(1).unsqueeze(2)

            alpha = self.model.interpolate_alphas(t_tensor).unsqueeze(1).unsqueeze(1)
            alpha_cumprod = self.model.interpolate_alphas_cumprod(t_tensor).unsqueeze(1).unsqueeze(1)
            beta = self.model.interpolate_betas(t_tensor).unsqueeze(1).unsqueeze(1)

            # Compute adaptive step size
            step_size = 1.0 / (torch.norm(score, dim=-1, keepdim=True).clamp(min=self.epsilon))

            # Annealed Langevin Dynamics update
            noise = torch.randn_like(X)
            X_new = X + step_size * ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * score + \
                torch.sqrt(2 * step_size * noise_level) * noise

        return X_new

    def improved_loss_function(self, X, X_noisy, predicted_noise, t):
        # Split into continuous and discrete parts
        X_cont, X_disc = X[:, :, :self.num_numeric_features], X[:, :, self.num_numeric_features:]
        X_noisy_cont, X_noisy_disc = X_noisy[:, :, :self.num_numeric_features], X_noisy[:, :, self.num_numeric_features:]
        pred_noise_cont, pred_noise_disc = predicted_noise[:, :, :self.num_numeric_features], predicted_noise[:, :, self.num_numeric_features:]

        # Continuous loss
        cont_loss = F.mse_loss(pred_noise_cont, X_cont - X_noisy_cont)

        # Discrete loss using KL divergence
        disc_loss = F.kl_div(
            F.log_softmax(pred_noise_disc, dim=-1),
            F.softmax(X_disc, dim=-1),
            reduction='batchmean'
        )

        # Combine losses with adaptive weighting
        alpha = self.alphas_cumprod[t]
        loss = alpha * cont_loss + (1 - alpha) * disc_loss

        return loss

    """def train(self, X, y):
        print(f"X shape: {X.shape}")
        print(f"y shape: {y.shape}")
        X_tensor = torch.from_numpy(X).float().to(self.device).requires_grad_(True)
        print(f"X_tensor shape: {X_tensor.shape}")
        y_tensor = torch.from_numpy(y).long().to(self.device)""""""

        if self.model is None:
            num_classes = len(np.unique(y))
            self.model = DiffusionModel(
                input_dim=X.shape[2],
                output_dim=X.shape[2],
                hidden_dim=512,
                time_dim=512,
                seq_length=self.seq_length,
                num_numeric_features=len(self.numeric_features),
                alphas_cumprod=self.alphas_cumprod.cpu().numpy(),
                alphas=self.alphas.cpu().numpy(),
                betas=self.betas.cpu().numpy(),
                num_classes=num_classes).to(self.device)""""""
        self.model = EnhancedDiffusionModel(
            input_dim=self.input_dim,
            output_dim=self.input_dim,
            hidden_dim=256,
            time_dim=256,
            seq_length=self.seq_length,
            num_layers=4,
            num_heads=8,
            num_numeric_features=len(self.numeric_features),
            alphas_cumprod=self.alphas_cumprod.cpu().numpy(),
            alphas=self.alphas.cpu().numpy(),
            betas=self.betas.cpu().numpy(),
            num_classes=len(np.unique(y))
        ).to(self.device)


        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.n_timesteps)
        try:


            for epoch in range(self.n_timesteps):
                self.optimizer.zero_grad()

                # Sample a batch
                batch_size = min(128, X.shape[0])  # Adjust batch size as needed
                batch_indices = torch.randint(0, X.shape[0], (batch_size,))
                X_batch = X_tensor[batch_indices]
                y_batch = y_tensor[batch_indices]

                t = torch.randint(0, self.n_timesteps, (batch_size,), device=self.device)

                # Randomly decide whether to use conditional or unconditional training
                use_uncond = torch.rand(batch_size, device=self.device) < self.p_uncond
                c = torch.where(use_uncond, torch.full_like(y_batch, -1), y_batch)  # -1 represents unconditional

                # Add noise
                noise = torch.randn_like(X_batch)
                X_noisy = self.add_noise(X_batch, t)

                # Forward process
                X_dilated = self.forward_process(X_noisy, t)

                # Split into continuous and discrete parts
                X_cont = X_batch[:, :, :self.num_numeric_features]
                X_disc = X_batch[:, :, self.num_numeric_features:]
                X_dilated_cont = X_dilated[:, :, :self.num_numeric_features]
                X_dilated_disc = X_dilated[:, :, self.num_numeric_features:]

                # Compute scores
                score_cont, score_disc = self.model(X_dilated, t, c)
                print(f"score_cont shape: {score_cont.shape}, score_disc shape: {score_disc.shape}")
                # Compute losses
                loss_cont = self.score_matching_loss_continuous(
                    X_cont, X_dilated_cont, score_cont, t
                )
                loss_disc = self.score_matching_loss_discrete(
                    X_disc, X_dilated_disc, score_disc, t
                )

                loss = loss_cont + loss_disc

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                if (epoch + 1) % 100 == 0:
                    print(f"Epoch {epoch+1}/{self.n_timesteps}, Loss: {loss.item():.4f}, Loss Cont: {loss_cont.item():.4f}, Loss Disc: {loss_disc.item():.4f}")
        except RuntimeError as e:
            print(f"Error occurred: {e}")
            print(f"Error type: {type(e).__name__}")
            print(f"Error args: {e.args}")
            raise

        return self.model"""

    def train(self, X, y):
        print(f"X shape: {X.shape}")
        print(f"y shape: {y.shape}")
        X_tensor = torch.from_numpy(X).float().to(self.device).requires_grad_(True)
        print(f"X_tensor shape: {X_tensor.shape}")
        y_tensor = torch.from_numpy(y).long().to(self.device)
        feature_specific_params = {
            'numeric_0': {'interval_min': 0, 'interval_max': 1, 'dynamic_thresholding_ratio': 0.99},  # liveness
            'numeric_1': {'interval_min': 0, 'interval_max': 1, 'dynamic_thresholding_ratio': 0.99},  # speechiness
            'numeric_2': {'interval_min': 0, 'interval_max': 1, 'dynamic_thresholding_ratio': 0.98},  # instrumentalness
            'numeric_3': {'dynamic_thresholding_ratio': 0.97}  # mode
        } # mode

        if self.model is None:
            num_classes = len(np.unique(y))
            self.model = EnhancedDiffusionModel(
                input_dim=X.shape[2],
                output_dim=X.shape[2],
                hidden_dim=128,
                time_dim=128,
                seq_length=self.seq_length,
                num_layers=4,
                num_heads=8,
                num_numeric_features=len(self.numeric_features),
                alphas_cumprod=self.alphas_cumprod.cpu().numpy(),
                alphas=self.alphas.cpu().numpy(),
                betas=self.betas.cpu().numpy(),
                num_classes=num_classes,
                feature_types=self.feature_types,
                feature_dims=self.feature_dims,
                feature_specific_params=feature_specific_params,
                dynamic_thresholding_ratio=0.995,  # valeur par défaut
                interval_min=-1,  # valeur par défaut
                interval_max=1  # valeur par défaut
            ).to(self.device)
        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.n_timesteps)

        try:
            for epoch in range(self.n_timesteps):
                self.optimizer.zero_grad()

                # Sample a batch
                batch_size = min(128, X.shape[0])
                batch_indices = torch.randint(0, X.shape[0], (batch_size,))
                X_batch = X_tensor[batch_indices].clone()  # Utiliser .clone() pour créer une copie
                y_batch = y_tensor[batch_indices].clone()

                t = torch.randint(0, self.n_timesteps, (batch_size,), device=self.device)

                # Randomly decide whether to use conditional or unconditional training
                use_uncond = torch.rand(batch_size, device=self.device) < self.p_uncond
                c = torch.where(use_uncond, torch.full_like(y_batch, -1), y_batch)  # -1 represents unconditional

                # Add noise
                X_noisy = self.add_noise(X_batch, t)

                # Forward process
                X_dilated = self.forward_process(X_noisy, t)

                # Compute scores
                (cond_score_cont, cond_score_disc), (uncond_score_cont, uncond_score_disc) = self.compute_score(X_dilated, t, c)

                # Compute loss
                loss = self.model.compute_loss(X_batch, X_dilated, t)

                loss.backward()
                # Vérification des gradients
                for name, param in self.model.named_parameters():
                    if param.grad is not None:
                        if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                            print(f"Warning: NaN or Inf detected in gradients of {name}")

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                self.optimizer.step()
                self.scheduler.step()

                if (epoch + 1) % 100 == 0:
                    print(f"Epoch {epoch+1}/{self.n_timesteps}, Loss: {loss.item():.4f}")

        except RuntimeError as e:
            print(f"Error occurred: {e}")
            print(f"Error type: {type(e).__name__}")
            print(f"Error args: {e.args}")
            raise

        return self.model
    def initialize_noise_schedule(self):
        self.noise_schedule = torch.linspace(1, 0.01, self.n_timesteps).to(self.device)

    def score_matching_loss_continuous(self, X_cont, X_cont_dilated, score_pred_cont, t):

        lambda_t = self.model.interpolate_alphas_cumprod(t).unsqueeze(1).unsqueeze(1)

        # Reshape tensors to combine batch and sequence dimensions
        batch_size, seq_len, num_features = X_cont.shape
        X_cont_flat = X_cont.reshape(-1, num_features)
        X_cont_dilated_flat = X_cont_dilated.reshape(-1, num_features)
        score_pred_cont_flat = score_pred_cont.reshape(-1, num_features)

        # Expand lambda_t to match the flattened shape
        lambda_t_expanded = lambda_t.expand(batch_size, seq_len, 1).reshape(-1, 1)

        # Compute the true score
        true_score = (X_cont_flat - torch.sqrt(lambda_t_expanded) * X_cont_dilated_flat) / (1 - lambda_t_expanded)

        # Compute the loss
        loss = F.mse_loss(score_pred_cont_flat, true_score)
        print(f"X_cont shape: {X_cont.shape}, X_cont_dilated shape: {X_cont_dilated.shape}")
        print(f"score_pred_cont shape: {score_pred_cont.shape}, t shape: {t.shape}")
        return loss

    def score_matching_loss_discrete(self, X_disc, X_disc_dilated, score_pred_disc, t):
        lambda_t = self.model.interpolate_alphas_cumprod(t).unsqueeze(1).unsqueeze(1)

        # Reshape tensors to combine batch and sequence dimensions
        batch_size, seq_len, num_features = X_disc.shape
        X_disc_flat = X_disc.reshape(-1, num_features)
        X_disc_dilated_flat = X_disc_dilated.reshape(-1, num_features)
        score_pred_disc_flat = score_pred_disc.reshape(-1, num_features)

        # Expand lambda_t to match the flattened shape
        lambda_t_expanded = lambda_t.expand(batch_size, seq_len, 1).reshape(-1, 1)

        # Compute the true score
        true_score = (X_disc_flat - X_disc_dilated_flat) / (1 - lambda_t_expanded)

        # Compute the loss using binary cross-entropy for each feature independently
        loss = F.binary_cross_entropy_with_logits(score_pred_disc_flat, true_score, reduction='none')

        # Sum over features and average over batch and sequence
        loss = loss.sum(dim=-1).mean()

        return loss

    def generate(self, n_samples, class_labels=None, guidance_strength=1):
        X = torch.randn(n_samples, self.seq_length, self.input_dim, device=self.device)

        if class_labels is None:
            class_labels = torch.randint(0, self.model.num_classes, (n_samples,), device=self.device)
        else:
            class_labels = torch.tensor(class_labels, device=self.device)

        for t in reversed(range(self.n_timesteps)):
            t_tensor = torch.full((n_samples,), t, device=self.device)

            (cond_score_cont, cond_score_disc), (uncond_score_cont, uncond_score_disc) = self.compute_score(X, t_tensor, class_labels)

            # Apply classifier-free guidance
            guided_score_cont = uncond_score_cont + guidance_strength * (cond_score_cont - uncond_score_cont)
            guided_score_disc = uncond_score_disc + guidance_strength * (cond_score_disc - uncond_score_disc)

            # Combine continuous and discrete scores
            guided_score = torch.cat([guided_score_cont, guided_score_disc], dim=-1)

            X = self.reverse_process(X, t_tensor, guided_score)

        X_np = X[:, -1, :].detach().cpu().numpy()

        # Split back into continuous and discrete parts
        X_cont_np = X_np[:, :self.num_numeric_features]
        X_disc_np = X_np[:, self.num_numeric_features:]

        # Inverse transform numerical features
        X_cont_inv = self.numeric_scaler.inverse_transform(X_cont_np)

        X_disc_inv = []
        start_idx = 0
        for cat_feature, encoder in self.categorical_encoders.items():
            end_idx = start_idx + len(encoder.categories_[0])

            # Créer un vecteur one-hot à partir des indices
            X_disc_cat = np.zeros((n_samples, len(encoder.categories_[0])))
            cat_indices = np.argmax(X_disc_np[:, start_idx:end_idx], axis=1)
            X_disc_cat[np.arange(n_samples), cat_indices] = 1

            # Inverse transform
            X_disc_inv.append(encoder.inverse_transform(X_disc_cat))
            start_idx = end_idx

        # Combine continuous and discrete data
        X_generated = np.column_stack([X_cont_inv] + X_disc_inv)

        # Convert to DataFrame
        columns = list(self.numeric_features) + list(self.categorical_features)
        X_generated_df = pd.DataFrame(X_generated, columns=columns)

        return X_generated_df
    """
    def generate(self, n_samples, class_labels=None, guidance_strength=0.0):
        X = torch.randn(n_samples, self.seq_length, self.input_dim, device=self.device)

        if class_labels is None:
            class_labels = torch.randint(0, self.model.num_classes, (n_samples,), device=self.device)
        else:
            class_labels = torch.tensor(class_labels, device=self.device)

        for t in reversed(range(self.n_timesteps)):
            t_tensor = torch.full((n_samples,), t, device=self.device)

            (cond_score_cont, cond_score_disc), (uncond_score_cont, uncond_score_disc) = self.compute_score(X, t_tensor, class_labels)

            # Apply classifier-free guidance
            guided_score_cont = uncond_score_cont + guidance_strength * (cond_score_cont - uncond_score_cont)
            guided_score_disc = uncond_score_disc + guidance_strength * (cond_score_disc - uncond_score_disc)

            # Combine continuous and discrete scores
            guided_score = torch.cat([guided_score_cont, guided_score_disc], dim=-1)

            X = self.reverse_process(X, t_tensor, guided_score)

        X_np = X[:, -1, :].detach().cpu().numpy()

        # Split back into continuous and discrete parts
        X_cont_np = X_np[:, :self.num_numeric_features]
        X_disc_np = X_np[:, self.num_numeric_features:]

        # Inverse transform numerical features
        X_cont_inv = self.numeric_scaler.inverse_transform(X_cont_np)

        # Convert softmax probabilities to categories for discrete features
        X_disc_inv = []
        start_idx = 0
        for cat_feature, encoder in self.categorical_encoders.items():
            end_idx = start_idx + len(encoder.categories_[0])
            X_disc_inv.append(encoder.inverse_transform(X_disc_np[:, start_idx:end_idx]))
            start_idx = end_idx

        # Combine continuous and discrete data
        X_generated = np.column_stack([X_cont_inv] + X_disc_inv)

        # Convert to DataFrame
        columns = list(self.numeric_features) + list(self.categorical_features)
        """

        #X_generated_df = pd.DataFrame(X_generated, columns=columns)

        #return X_generated_df"""

        #X_generated_df = pd.DataFrame(X_generated, columns=columns)

        #return X_generated_df"""

    def impute(self, X):
        X_imputed = X.copy()

        print(f"Original X shape: {X.shape}")

        # Preprocess the data
        X_num = X_imputed[self.numeric_features].fillna(0)  # Fill NaN with 0 temporarily
        X_cat = X_imputed[self.categorical_features].fillna('missing')  # Fill NaN with 'missing' temporarily

        # Transform numeric features
        X_num_processed = self.numeric_scaler.transform(X_num)

        # Transform categorical features
        X_cat_processed = np.column_stack([
            self.categorical_encoders[feature].transform(X_cat[[feature]])
            for feature in self.categorical_features
        ])

        X_processed = np.column_stack([X_num_processed, X_cat_processed])

        print(f"X_processed shape: {X_processed.shape}")

        # Create sequences
        X_seq = self.create_sequences(X_processed, min(self.seq_length, X_processed.shape[0]))

        print(f"X_seq shape: {X_seq.shape}")

        X_cont = X_seq[:, :, :len(self.numeric_features)]
        X_disc = X_seq[:, :, len(self.numeric_features):]

        X_cont_tensor = torch.from_numpy(X_cont).float()
        X_disc_tensor = torch.from_numpy(X_disc).float()

        # Create a mask for missing values
        mask = torch.tensor(X.isna().values, dtype=torch.bool)
        print(f"Original mask shape: {mask.shape}")

        # Adjust mask to match X_seq shape
        adjusted_mask = mask.unsqueeze(1).repeat(1, X_seq.shape[1], 1)
        adjusted_mask = adjusted_mask[:X_seq.shape[0], :, :]
        print(f"Adjusted mask shape: {adjusted_mask.shape}")

        # Expand mask to match X_imputed_tensor shape
        expanded_mask = torch.zeros(X_seq.shape, dtype=torch.bool)
        expanded_mask[:, :, :adjusted_mask.shape[2]] = adjusted_mask
        print(f"Expanded mask shape: {expanded_mask.shape}")

        # Initialize with random noise
        X_imputed_tensor = torch.cat([X_cont_tensor, X_disc_tensor], dim=-1)
        print(f"X_imputed_tensor shape: {X_imputed_tensor.shape}")
        X_noise = torch.randn_like(X_imputed_tensor)

        # Apply mask
        X_imputed_tensor = torch.where(expanded_mask, X_noise, X_imputed_tensor)

        with torch.no_grad():
            for t in reversed(range(self.n_timesteps)):
                X_imputed_tensor = self.reverse_process(X_imputed_tensor, torch.full((X_imputed_tensor.shape[0],), t))

                # Keep original non-missing values
                X_imputed_tensor = torch.where(expanded_mask, X_imputed_tensor, torch.cat([X_cont_tensor, X_disc_tensor], dim=-1))

        # Convert back to numpy and original scale
        X_imputed_np = X_imputed_tensor.squeeze(0).numpy()  # Remove the batch dimension
        X_imputed_original = self.custom_inverse_transform(X_imputed_np)

        # Create a DataFrame with imputed values
        X_imputed_df = pd.DataFrame(X_imputed_original, columns=X.columns, index=X.index)

        # Update only the missing values in the original DataFrame
        for col in X.columns:
            mask = X[col].isnull()
            X_imputed.loc[mask, col] = X_imputed_df.loc[mask, col]

        return X_imputed

    def calculate_advanced_metrics(self, real_data, generated_data):
        metrics = {}

        # KS test for continuous features
        for col in real_data.select_dtypes(include=['float64', 'int64']).columns:
            ks_statistic, p_value = ks_2samp(real_data[col], generated_data[col])
            metrics[f'KS_statistic_{col}'] = ks_statistic
            metrics[f'KS_p_value_{col}'] = p_value

        # Chi-square test for categorical features
        for col in real_data.select_dtypes(include=['object', 'category']).columns:
            real_counts = real_data[col].value_counts()
            gen_counts = generated_data[col].value_counts()
            all_categories = set(real_counts.index) | set(gen_counts.index)
            real_counts = real_counts.reindex(all_categories, fill_value=0)
            gen_counts = gen_counts.reindex(all_categories, fill_value=0)
            chi2, p_value, _, _ = chi2_contingency([real_counts, gen_counts])
            metrics[f'Chi2_statistic_{col}'] = chi2
            metrics[f'Chi2_p_value_{col}'] = p_value

        # Categorical imbalance metric
        for col in real_data.select_dtypes(include=['object', 'category']).columns:
            real_prob = real_data[col].value_counts(normalize=True)
            gen_prob = generated_data[col].value_counts(normalize=True)
            imbalance = np.sum(np.abs(real_prob - gen_prob)) / 2  # Total Variation Distance
            metrics[f'Categorical_imbalance_{col}'] = imbalance

        # Jensen-Shannon Divergence for all features
        for col in real_data.columns:
            real_prob = real_data[col].value_counts(normalize=True)
            gen_prob = generated_data[col].value_counts(normalize=True)
            all_categories = set(real_prob.index) | set(gen_prob.index)
            real_prob = real_prob.reindex(all_categories, fill_value=0)
            gen_prob = gen_prob.reindex(all_categories, fill_value=0)
            js_div = jensenshannon(real_prob, gen_prob)
            metrics[f'JS_divergence_{col}'] = js_div

        return metrics

    def safe_column_name(self, column_name):
          return re.sub(r'[^\w\s-]', '_', column_name)

    def generate_comparison_plots(self, real_data, generated_data):
        for col in real_data.columns:
            safe_col = self.safe_column_name(col)
            plt.figure(figsize=(10, 6))
            if real_data[col].dtype in ['float64', 'int64']:
                sns.kdeplot(real_data[col], label='Real', shade=True)
                sns.kdeplot(generated_data[col], label='Synthetic', shade=True)
                plt.title(f'Distribution Comparison: {safe_col}')
            else:
                real_counts = real_data[col].value_counts(normalize=True)
                gen_counts = generated_data[col].value_counts(normalize=True)
                all_categories = set(real_counts.index) | set(gen_counts.index)
                real_counts = real_counts.reindex(all_categories, fill_value=0)
                gen_counts = gen_counts.reindex(all_categories, fill_value=0)
                width = 0.35
                x = range(len(all_categories))
                plt.bar([i - width/2 for i in x], real_counts, width, label='Real')
                plt.bar([i + width/2 for i in x], gen_counts, width, label='Synthetic')
                plt.title(f'Categorical Distribution Comparison: {safe_col}')
                plt.xticks(x, [self.safe_column_name(str(cat)) for cat in all_categories], rotation=45, ha='right')
            plt.legend()
            plt.tight_layout()
            plt.savefig(f'distribution_comparison_{safe_col}.png')
            plt.close()

        # Correlation heatmap comparison
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        sns.heatmap(real_data.corr(), ax=ax1, cmap='coolwarm', vmin=-1, vmax=1)
        ax1.set_title('Real Data Correlation')
        sns.heatmap(generated_data.corr(), ax=ax2, cmap='coolwarm', vmin=-1, vmax=1)
        ax2.set_title('Synthetic Data Correlation')
        plt.tight_layout()
        plt.savefig('correlation_heatmap_comparison.png')
        plt.close()

        # Pairplot comparison (for a subset of features if there are many)
        num_features = min(5, len(real_data.columns))
        selected_features = real_data.columns[:num_features]
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
        sns.pairplot(real_data[selected_features], ax=ax1)
        ax1.fig.suptitle('Real Data Pairplot', y=1.02)
        sns.pairplot(generated_data[selected_features], ax=ax2)
        ax2.fig.suptitle('Synthetic Data Pairplot', y=1.02)
        plt.tight_layout()
        plt.savefig('pairplot_comparison.png')
        plt.close()



    def evaluate(self, X_real, X_generated):
        # Transform real data
        X_real_num = self.numeric_scaler.transform(X_real[self.numeric_features])
        X_real_cat = np.column_stack([
            self.categorical_encoders[feature].transform(X_real[[feature]])
            for feature in self.categorical_features
        ])
        X_real_transformed = np.column_stack([X_real_num, X_real_cat])

        # Transform generated data
        X_generated_num = self.numeric_scaler.transform(X_generated[self.numeric_features])
        X_generated_cat = np.column_stack([
            self.categorical_encoders[feature].transform(X_generated[[feature]])
            for feature in self.categorical_features
        ])
        X_generated_transformed = np.column_stack([X_generated_num, X_generated_cat])

        X_real_flat = X_real_transformed.flatten()
        X_generated_flat = X_generated_transformed.flatten()

        w_distance = wasserstein_distance(X_real_flat, X_generated_flat)
        print(f"Wasserstein distance: {w_distance}")

        # Additional evaluation metrics
        for col in self.numeric_features:
            real_hist, _ = np.histogram(X_real[col], bins=30, density=True)
            gen_hist, _ = np.histogram(X_generated[col], bins=30, density=True)
            kl_div = scipy.stats.entropy(real_hist + 1e-10, gen_hist + 1e-10)  # Add small constant to avoid division by zero
            print(f"KL divergence for {col}: {kl_div}")

        for col in self.categorical_features:
            real_dist = X_real[col].value_counts(normalize=True)
            gen_dist = X_generated[col].value_counts(normalize=True)

            # Ensure both distributions have the same categories
            all_categories = set(real_dist.index) | set(gen_dist.index)
            real_dist = real_dist.reindex(all_categories, fill_value=0)
            gen_dist = gen_dist.reindex(all_categories, fill_value=0)

            # Ensure the distributions sum to 1
            real_dist = real_dist / real_dist.sum()
            gen_dist = gen_dist / gen_dist.sum()

            js_div = jensenshannon(real_dist, gen_dist)
            print(f"Jensen-Shannon divergence for {col}: {js_div}")
            # Plots pour les variables numériques
            fig, axs = plt.subplots(len(self.numeric_features), 1, figsize=(10, 5*len(self.numeric_features)))
            if len(self.numeric_features) == 1:
                axs = [axs]  # Convertir en liste si un seul subplot
            for i, col in enumerate(self.numeric_features):
                clean_col = clean_string(col)
                sns.histplot(X_real[col], kde=True, color='blue', alpha=0.5, ax=axs[i], label='Réel')
                sns.histplot(X_generated[col], kde=True, color='red', alpha=0.5, ax=axs[i], label='Synthétique')
                axs[i].set_title(f'Distribution de {clean_col}')
                axs[i].legend()
            plt.tight_layout()
            plt.savefig('numeric_distributions.png')
            plt.close()

            # Plots pour les variables catégoriques
            fig, axs = plt.subplots(len(self.categorical_features), 1, figsize=(10, 5*len(self.categorical_features)))
            if len(self.categorical_features) == 1:
                axs = [axs]  # Convertir en liste si un seul subplot
            for i, col in enumerate(self.categorical_features):
                clean_col = clean_string(col)
                real_counts = X_real[col].value_counts(normalize=True)
                gen_counts = X_generated[col].value_counts(normalize=True)

                all_categories = list(set(real_counts.index) | set(gen_counts.index))
                real_counts = real_counts.reindex(all_categories, fill_value=0)
                gen_counts = gen_counts.reindex(all_categories, fill_value=0)

                x = range(len(all_categories))
                width = 0.35

                axs[i].bar([i - width/2 for i in x], real_counts, width, label='Réel', alpha=0.5)
                axs[i].bar([i + width/2 for i in x], gen_counts, width, label='Synthétique', alpha=0.5)
                axs[i].set_title(f'Distribution de {clean_col}')
                axs[i].set_xticks(x)
                axs[i].set_xticklabels([clean_string(cat) for cat in all_categories], rotation=45, ha='right')
                axs[i].legend()
            plt.tight_layout()
            plt.savefig('categorical_distributions.png')
            plt.close()

            # Heatmap des corrélations
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
            sns.heatmap(X_real[self.numeric_features].corr(), ax=ax1, cmap='coolwarm', annot=True)
            ax1.set_title('Corrélations - Données Réelles')
            sns.heatmap(X_generated[self.numeric_features].corr(), ax=ax2, cmap='coolwarm', annot=True)
            ax2.set_title('Corrélations - Données Synthétiques')
            plt.tight_layout()
            plt.savefig('correlation_heatmaps.png')
            plt.close()

    """def calculate_advanced_metrics(self, X_real, X_generated):
        metrics = {}

        # KS test for continuous features
        for col in X_real.select_dtypes(include=['float64', 'int64']).columns:
            try:
                ks_statistic, p_value = ks_2samp(X_real[col], X_generated[col])
                metrics[f'KS_statistic_{col}'] = ks_statistic
                metrics[f'KS_p_value_{col}'] = p_value
            except Exception as e:
                print(f"Error calculating KS test for column {col}: {str(e)}")

        # Chi-square test for categorical features
        for col in X_real.select_dtypes(include=['object', 'category']).columns:
            try:
                real_counts = X_real[col].value_counts()
                gen_counts = X_generated[col].value_counts()
                all_categories = set(real_counts.index) | set(gen_counts.index)
                real_counts = real_counts.reindex(all_categories, fill_value=0)
                gen_counts = gen_counts.reindex(all_categories, fill_value=0)
                chi2, p_value, _, _ = chi2_contingency([real_counts, gen_counts])
                metrics[f'Chi2_statistic_{col}'] = chi2
                metrics[f'Chi2_p_value_{col}'] = p_value
            except Exception as e:
                print(f"Error calculating Chi-square test for column {col}: {str(e)}")

        # Categorical imbalance metric and Jensen-Shannon Divergence
        for col in X_real.columns:
            try:
                real_prob = X_real[col].value_counts(normalize=True)
                gen_prob = X_generated[col].value_counts(normalize=True)
                all_categories = set(real_prob.index) | set(gen_prob.index)
                real_prob = real_prob.reindex(all_categories, fill_value=0)
                gen_prob = gen_prob.reindex(all_categories, fill_value=0)

                imbalance = np.sum(np.abs(real_prob - gen_prob)) / 2
                metrics[f'Categorical_imbalance_{col}'] = imbalance

                js_div = jensenshannon(real_prob, gen_prob)
                metrics[f'JS_divergence_{col}'] = js_div
            except Exception as e:
                print(f"Error calculating imbalance and JS divergence for column {col}: {str(e)}")

        return metrics

    def utility_evaluation(self, X_real, y_real, X_generated):
        from sklearn.model_selection import train_test_split
        from sklearn.linear_model import LogisticRegression
        from sklearn.metrics import accuracy_score

        # Split real data
        X_real_train, X_real_test, y_real_train, y_real_test = train_test_split(X_real, y_real, test_size=0.2, random_state=42)

        # Generate synthetic target for X_generated
        y_generated = self.generate_synthetic_target(X_generated)

        # Split generated data
        X_gen_train, X_gen_test, y_gen_train, y_gen_test = train_test_split(X_generated, y_generated, test_size=0.2, random_state=42)

        # Train and evaluate on real data
        model_real = LogisticRegression()
        model_real.fit(X_real_train, y_real_train)
        real_accuracy = accuracy_score(y_real_test, model_real.predict(X_real_test))

        # Train and evaluate on generated data
        model_gen = LogisticRegression()
        model_gen.fit(X_gen_train, y_gen_train)
        gen_accuracy = accuracy_score(y_gen_test, model_gen.predict(X_gen_test))

        return real_accuracy, gen_accuracy

    def generate_synthetic_target(self, X_generated):
        # Cette méthode doit être implémentée pour générer des cibles synthétiques
        # correspondant à X_generated. Vous pouvez utiliser votre modèle de diffusion
        # ou une autre méthode pour générer ces cibles.
        # Pour l'instant, nous allons simplement générer des valeurs aléatoires
        return np.random.randint(0, 2, size=len(X_generated))"""
    """
    def add_noise(self, x, t):
        noise = torch.randn_like(x)
        sqrt_alphas_cumprod_t = torch.sqrt(self.alphas_cumprod[t]).unsqueeze(-1).unsqueeze(-1)
        sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - self.alphas_cumprod[t]).unsqueeze(-1).unsqueeze(-1)

        return sqrt_alphas_cumprod_t * x + sqrt_one_minus_alphas_cumprod_t * noise"""
    def add_noise(self, x, t):
        noise = torch.randn_like(x)
        sqrt_alphas_cumprod_t = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - self.alphas_cumprod[t]).view(-1, 1, 1)
        return sqrt_alphas_cumprod_t * x + sqrt_one_minus_alphas_cumprod_t * noise

    def custom_inverse_transform(self, X):
        num_features = len(self.numeric_features)
        X_num = X[:, :num_features]
        X_cat = X[:, num_features:]

        # Inverse transform numerical features
        X_num_inv = self.numeric_scaler.inverse_transform(X_num)

        # Inverse transform categorical features
        X_cat_inv = []
        start_idx = 0
        for cat_feature, encoder in self.categorical_encoders.items():
            end_idx = start_idx + len(encoder.categories_[0])
            X_cat_inv.append(encoder.inverse_transform(X_cat[:, start_idx:end_idx]))
            start_idx = end_idx

        # Combine numerical and categorical features
        X_combined = np.column_stack([X_num_inv] + X_cat_inv)

        return X_combined


    @staticmethod
    def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
        def sample_gumbel(shape, eps=1e-10, device=None):
            U = torch.rand(shape, device=device)
            return -torch.log(-torch.log(U + eps) + eps)

        device = logits.device
        gumbels = sample_gumbel(logits.shape, eps=eps, device=device)
        gumbels = (logits + gumbels) / tau
        y_soft = gumbels.softmax(dim)

        if hard:
            index = y_soft.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits, device=device).scatter_(dim, index, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            ret = y_soft
        return ret



In [None]:
# Main execution
# Main execution
if __name__ == "__main__":

    torch.autograd.set_detect_anomaly(True)
    torch.autograd.profiler.profile(enabled=True)
    torch.autograd.profiler.emit_nvtx(enabled=True)

    import matplotlib.pyplot as plt
    import seaborn as sns

    data_path = "/notebooks/data.csv"
    seq_length = 30  # Define the sequence length
    dt = pd.read_csv(data_path)
    model = FinDiffusionPlus(n_timesteps=200, n_noise=200, learning_rate=0.001, seq_length=seq_length)
    X, y = model.preprocess_data(data_path)
    # Split the data
    train_size = int(0.8 * len(X))
    X_train, X_test = X[:train_size], X[train_size:]
    y_train, y_test = y[:train_size], y[train_size:]

    # Train the model
    model.train(X_train, y_train)

    # Generate synthetic data
    num_synthetic_samples = 1000
    synthetic_data = model.generate(num_synthetic_samples)

    print(f"Generated {num_synthetic_samples} synthetic samples.")
    print(synthetic_data.head())

    # Evaluate the generated data
    X_test_df = pd.DataFrame(model.custom_inverse_transform(X_test[:, -1, :]),
                             columns=list(model.numeric_features) + list(model.categorical_features))
    model.evaluate(X_test_df, synthetic_data)

    plot_numeric_distributions(model, synthetic_data)
    plot_categorical_distributions(model, synthetic_data)


    # Export synthetic data to CSV
    synthetic_data.to_csv("res.csv", index=False)
    print("\nSynthetic data exported to res.csv")
    """
    # Generate plots
    # 1. Distribution of numerical features
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    for i, feature in enumerate(model.numeric_features):
        sns.histplot(synthetic_data[feature], ax=axes[i//4, i%4], kde=True)
        axes[i//4, i%4].set_title(feature)
    plt.tight_layout()
    plt.savefig("numerical_distributions.png")
    plt.close()

    # 2. Countplot of categorical features
    fig, axes = plt.subplots(4, 4, figsize=(20, 20))
    for i, feature in enumerate(model.categorical_features):
        sns.countplot(data=synthetic_data, x=feature, ax=axes[i//4, i%4])
        axes[i//4, i%4].set_title(feature)
        axes[i//4, i%4].tick_params(axis='x', rotation=90)
    plt.tight_layout()
    plt.savefig("categorical_counts.png")
    plt.close()"""

    # 3. Correlation heatmap of numerical features
    plt.figure(figsize=(12, 10))
    sns.heatmap(synthetic_data[model.numeric_features].corr(), annot=True, cmap='coolwarm')
    plt.title("Correlation Heatmap of Numerical Features")
    plt.savefig("correlation_heatmap_beta.png")
    plt.close()

    # 4. Pairplot of numerical features
    sns.pairplot(synthetic_data[model.numeric_features])
    plt.savefig("pairplot.png")
    plt.close()

    #Scatter plot

    continuous_feature1 = 'age'
    continuous_feature2 = 'Mutual_Funds'
    categorical_feature = 'gender'

    """
    # Generate synthetic data with different guidance strengths
    guidance_strengths = [0.0, 1.0, 2.0, 3.0]
    for strength in guidance_strengths:
        synthetic_data = model.generate(num_synthetic_samples, guidance_strength=strength)

        # Create scatter plot
        model.plot_scatter(X_test_df, synthetic_data, continuous_feature1, continuous_feature2, categorical_feature, strength)

        print(f"Generated scatter plot for guidance strength {strength}")

    print("\nPlots generated and saved as PNG files.")"""
"""
    # Imputation example
    X_with_missing = X_test_df.copy()
    X_with_missing.iloc[np.random.rand(*X_with_missing.shape) < 0.2] == np.nan
    #X_imputed = model.impute(X_with_missing)
    print("\nImputation example:")
    print(X_imputed.head())"""

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import re

real_data = pd.read_csv("/notebooks/data.csv")
def clean_column_name(name):
    return re.sub(r'[^\w\s-]', '_', str(name))

def plot_numeric_comparisons(model, real_data, synthetic_data):
    num_features = len(model.numeric_features)
    if num_features == 0:
        print("Aucune caractéristique numérique à afficher.")
        return

    num_cols = min(3, num_features)
    num_rows = (num_features - 1) // num_cols + 1

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(6*num_cols, 5*num_rows))
    if num_features == 1:
        axes = np.array([[axes]])
    elif num_rows == 1:
        axes = axes.reshape(1, -1)

    for i, feature in enumerate(model.numeric_features):
        row = i // num_cols
        col = i % num_cols
        clean_feature = clean_column_name(feature)

        sns.histplot(real_data[feature], color="blue", label="Réel", kde=True,
                     ax=axes[row, col], alpha=0.5)
        sns.histplot(synthetic_data[feature], color="red", label="Synthétique", kde=True,
                     ax=axes[row, col], alpha=0.5)

        axes[row, col].set_title(f"Distribution de {clean_feature}")
        axes[row, col].legend()

    for i in range(num_features, num_rows * num_cols):
        fig.delaxes(axes[i // num_cols, i % num_cols])

    plt.tight_layout()
    plt.savefig("numeric_comparisons_beta.png")
    plt.close()
    print("Le graphique des comparaisons numériques a été sauvegardé sous 'numeric_comparisons.png'")

def plot_categorical_comparisons(model, real_data, synthetic_data):
    cat_features = model.categorical_features
    if len(cat_features) == 0:
        print("Aucune caractéristique catégorique à afficher.")
        return

    num_cols = min(3, len(cat_features))
    num_rows = (len(cat_features) - 1) // num_cols + 1

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(6*num_cols, 5*num_rows))
    if len(cat_features) == 1:
        axes = np.array([[axes]])
    elif num_rows == 1:
        axes = axes.reshape(1, -1)

    for i, feature in enumerate(cat_features):
        row = i // num_cols
        col = i % num_cols
        clean_feature = clean_column_name(feature)

        real_counts = real_data[feature].value_counts(normalize=True)
        synth_counts = synthetic_data[feature].value_counts(normalize=True)

        # Assurer que les deux séries ont les mêmes index
        all_categories = sorted(set(real_counts.index) | set(synth_counts.index))
        real_counts = real_counts.reindex(all_categories, fill_value=0)
        synth_counts = synth_counts.reindex(all_categories, fill_value=0)

        x = np.arange(len(all_categories))
        width = 0.35

        axes[row, col].bar(x - width/2, real_counts, width, label='Réel', alpha=0.7)
        axes[row, col].bar(x + width/2, synth_counts, width, label='Synthétique', alpha=0.7)

        axes[row, col].set_title(f"Distribution de {clean_feature}")
        axes[row, col].set_xticks(x)
        axes[row, col].set_xticklabels([clean_column_name(cat) for cat in all_categories], rotation=45, ha='right')
        axes[row, col].legend()

    for i in range(len(cat_features), num_rows * num_cols):
        fig.delaxes(axes[i // num_cols, i % num_cols])

    plt.tight_layout()
    plt.savefig("categorical_comparisons_beta.png")
    plt.close()
    print("Le graphique des comparaisons catégoriques a été sauvegardé sous 'categorical_comparisons.png'")

# Utilisation
plot_numeric_comparisons(model, real_data, synthetic_data)
plot_categorical_comparisons(model, real_data, synthetic_data)