In [1]:
import torch
print(torch.cuda.is_available())  
print(torch.cuda.current_device())  
print(torch.cuda.get_device_name(torch.cuda.current_device())) 

True
0
NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

CUDA is available. Using GPU.


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from pytorch_tabnet.pretraining import TabNetPretrainer
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import numpy as np


In [10]:
df_train = pd.DataFrame(np.random.rand(700, 42), columns=[f"feature_{i}" for i in range(42)])
df_test = pd.DataFrame(np.random.rand(300, 42), columns=[f"feature_{i}" for i in range(42)])

y_train = (np.random.rand(700) > 0.5).astype(np.float32)
y_test = (np.random.rand(300) > 0.5).astype(np.float32)


scaler = MinMaxScaler()
X_train = scaler.fit_transform(df_train.values)
X_test = scaler.transform(df_test.values)


# y_train = y_train.values.astype(np.float32)
# y_test = y_test.values.astype(np.float32)
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.long)
num_classes = 2  # Assuming binary classification
y_train = F.one_hot(y_train, num_classes=num_classes)
y_test = F.one_hot(y_test, num_classes=num_classes)
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)


torch.Size([700, 42]) torch.Size([700, 2]) torch.Size([300, 42]) torch.Size([300, 2])


RuntimeError: one_hot is only applicable to index tensor.

In [None]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(X_train, y_train_onehot)
test_dataset = TensorDataset(X_test, y_test_onehot)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
tabnet_params = {
    "n_d": 512,
    "n_a": 512,
    "n_steps": 3,
    "n_shared": 2,
    "n_independent": 2,
    "gamma": 1.3,
    "epsilon": 1e-15,
    "momentum": 0.98,
    "mask_type": "sparsemax",
    "lambda_sparse": 1e-3,
    "device_name": "cuda" if torch.cuda.is_available() else "cpu"
}


unsupervised_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    **tabnet_params
)
 

unsupervised_model.fit(
    X_train,
    eval_set=[X_test],  
    pretraining_ratio=0.8,
    max_epochs=101,
    patience=10,
    batch_size=1024,
    virtual_batch_size=128,
    num_workers=0,
    drop_last=False
)



epoch 0  | loss: 270.00562| val_0_unsup_loss_numpy: 94.52517700195312|  0:00:00s
epoch 1  | loss: 638.96796| val_0_unsup_loss_numpy: 86.1489028930664|  0:00:00s
epoch 2  | loss: 244.51616| val_0_unsup_loss_numpy: 40.36521911621094|  0:00:00s
epoch 3  | loss: 193.12486| val_0_unsup_loss_numpy: 31.518159866333008|  0:00:00s
epoch 4  | loss: 102.24525| val_0_unsup_loss_numpy: 17.251129150390625|  0:00:00s
epoch 5  | loss: 57.67551| val_0_unsup_loss_numpy: 8.71133041381836|  0:00:00s
epoch 6  | loss: 46.38808| val_0_unsup_loss_numpy: 9.282190322875977|  0:00:00s
epoch 7  | loss: 34.91736| val_0_unsup_loss_numpy: 6.242280006408691|  0:00:00s
epoch 8  | loss: 22.04987| val_0_unsup_loss_numpy: 5.376999855041504|  0:00:00s
epoch 9  | loss: 18.68647| val_0_unsup_loss_numpy: 5.352340221405029|  0:00:00s
epoch 10 | loss: 14.43801| val_0_unsup_loss_numpy: 3.8629400730133057|  0:00:00s
epoch 11 | loss: 11.87551| val_0_unsup_loss_numpy: 3.0694000720977783|  0:00:01s
epoch 12 | loss: 10.00301| val_0_



In [None]:
# Truy cập vào mô hình TabNet bên trong
from torchinfo import summary

tabnet_model = unsupervised_model.network.to(device)

summary(tabnet_model, input_size=X_train.shape) 

Layer (type:depth-idx)                                       Output Shape              Param #
TabNetPretraining                                            [700, 42]                 --
├─EmbeddingGenerator: 1-1                                    [700, 42]                 --
├─TabNetEncoder: 1-2                                         [700, 512]                --
│    └─BatchNorm1d: 2-1                                      [700, 42]                 84
│    └─FeatTransformer: 2-2                                  [700, 1024]               4,202,496
│    │    └─GLU_Block: 3-1                                   [700, 1024]               2,191,360
│    └─ModuleList: 2-12                                      --                        (recursive)
│    │    └─FeatTransformer: 3-17                            --                        (recursive)
│    └─FeatTransformer: 2-6                                  --                        (recursive)
│    │    └─GLU_Block: 3-5                            

In [None]:
encoder = tabnet_model.encoder

print("\nEncoder Summary:")
print(encoder)




Encoder Summary:
TabNetEncoder(
  (initial_bn): BatchNorm1d(42, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (initial_splitter): FeatTransformer(
    (shared): GLU_Block(
      (shared_layers): ModuleList(
        (0): Linear(in_features=42, out_features=2048, bias=False)
        (1): Linear(in_features=1024, out_features=2048, bias=False)
      )
      (glu_layers): ModuleList(
        (0): GLU_Layer(
          (fc): Linear(in_features=42, out_features=2048, bias=False)
          (bn): GBN(
            (bn): BatchNorm1d(2048, eps=1e-05, momentum=0.98, affine=True, track_running_stats=True)
          )
        )
        (1): GLU_Layer(
          (fc): Linear(in_features=1024, out_features=2048, bias=False)
          (bn): GBN(
            (bn): BatchNorm1d(2048, eps=1e-05, momentum=0.98, affine=True, track_running_stats=True)
          )
        )
      )
    )
    (specifics): GLU_Block(
      (glu_layers): ModuleList(
        (0-1): 2 x GLU_Layer(
          (fc

In [None]:
decoder = tabnet_model.decoder

print("\nDecoder Summary:")
print(decoder)


Decoder Summary:
TabNetDecoder(
  (feat_transformers): ModuleList(
    (0-2): 3 x FeatTransformer(
      (shared): GLU_Block(
        (shared_layers): ModuleList(
          (0): Linear(in_features=512, out_features=1024, bias=False)
        )
        (glu_layers): ModuleList(
          (0): GLU_Layer(
            (fc): Linear(in_features=512, out_features=1024, bias=False)
            (bn): GBN(
              (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.98, affine=True, track_running_stats=True)
            )
          )
        )
      )
      (specifics): GLU_Block(
        (glu_layers): ModuleList(
          (0): GLU_Layer(
            (fc): Linear(in_features=512, out_features=1024, bias=False)
            (bn): GBN(
              (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.98, affine=True, track_running_stats=True)
            )
          )
        )
      )
    )
  )
  (reconstruction_layer): Linear(in_features=512, out_features=42, bias=False)
)


In [None]:
sample_input = torch.tensor(X_train[:5]).to(device)  

try:
    result = tabnet_model.encoder(sample_input)
    if isinstance(result, tuple):
        print(f'TabNet encoder trả về {len(result)} giá trị.')
        for i, res in enumerate(result):
            print(f'Giá trị {i + 1} shape: {res.shape}')
    else:
        print('TabNet encoder chỉ trả về một giá trị.')
        print(f'Giá trị shape: {result.shape}')
except Exception as e:
    print(f'Đã xảy ra lỗi: {e}')

TabNet encoder trả về 2 giá trị.
Đã xảy ra lỗi: 'list' object has no attribute 'shape'


  sample_input = torch.tensor(X_train[:5]).to(device)


In [None]:
class Sampling(nn.Module):
    def __init__(self, seed=1337):
        super(Sampling, self).__init__()
        self.seed = seed

    def forward(self, inputs):
        z_mean, z_log_var = inputs
        batch = z_mean.size(0)
        dim = z_mean.size(1)
        # print(batch, dim)
        epsilon = torch.randn(batch, dim, generator=torch.Generator().manual_seed(self.seed)).to(device)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

In [None]:
class VAE_Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(VAE_Encoder, self).__init__()
        self.tabnet_encoder = tabnet_model.encoder
        self.mlp = nn.Sequential(
            nn.Linear(512, 256),  
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 96),
            nn.ReLU(),
            nn.Linear(96, 96),
            nn.ReLU(),
            nn.Linear(96, 32),
            nn.ReLU(),
            nn.Linear(32, latent_dim)
        ).to(device)
        self.fc_mean = nn.Linear(latent_dim, latent_dim).to(device)
        self.fc_log_var = nn.Linear(latent_dim, latent_dim).to(device)
        self.sampling = Sampling().to(device)

    def forward(self, x):
        x = x.to(device)
        steps_output, _ = self.tabnet_encoder(x)
        encoded = steps_output[-1]
        # print("Shape of encoded tensor:", encoded.shape)
        encoded = self.mlp(encoded)
        z_mean = self.fc_mean(encoded)
        z_log_var = self.fc_log_var(encoded)
        z = self.sampling((z_mean, z_log_var))
        # print(f'Shape of z: {z.shape} - {z_log_var.shape} -{z_log_var.shape}')
        return z_mean, z_log_var, z


In [None]:
class VAE_Decoder(nn.Module):
    def __init__(self, latent_dim,encoded_dim, output_dim):
        super(VAE_Decoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(latent_dim, 32),   
            nn.ReLU(),
            nn.Linear(32, 96),
            nn.ReLU(),
            nn.Linear(96, 96),
            nn.ReLU(),
            nn.Linear(96, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, encoded_dim),  
        )
        self.tabnet_decoder = tabnet_model.decoder
        self.reshape = nn.Unflatten(1, (encoded_dim,))
        self.output_dim=output_dim


    def forward(self, z):
        x = F.relu(self.mlp(z))

        print("Shape before reshape:", x.shape)
        x = self.reshape(x)
        x = x[None, ...]

        print("Shape after reshape:", x.shape)
        # x = x.view(x.size(0), output_dim)
        
        output = self.tabnet_decoder(x)
        # print(output.shape)
        # print("Shape of output from tabnet_decoder:", output.shape)
        output = torch.softmax(output, dim=-1)  # Assuming the output is a probability distribution
        output = output.view(-1, self.output_dim)
        return output

In [None]:
def check_data_range(tensor, name):
    if not torch.all((tensor >= 0) & (tensor <= 1)):
        print(f"{name} contains values outside the range [0, 1]")
        print(f"{name} min: {tensor.min()}, max: {tensor.max()}")

In [None]:
class VAE_Tabnet_MLPS(nn.Module):
    def __init__(self, encoder, decoder, classifier):
        super(VAE_Tabnet_MLPS, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.classifier = classifier
        self.total_loss_tracker = []
        self.reconstruction_loss_tracker = []
        self.kl_loss_tracker = []
        self.classification_loss_tracker = []
        self.accuracy_tracker = []

    def forward(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        classification_output = self.classifier(z)
        return reconstruction, z_mean, z_log_var, classification_output

    def train_step(self, data, labels, optimizer):
        labels = torch.argmax(labels, dim=1) 
        optimizer.zero_grad()
        # z_mean, z_log_var, z = self.encoder(data)
        # reconstruction = self.decoder(z)
        reconstruction, z_mean, z_log_var, classification_output = self.forward(data)
        # print('classifi',classification_output.shape)
        # print(check_data_range(data, 'data'))
        # print(check_data_range(reconstruction, 'reconstruction'))
        # reconstruction_loss = torch.mean(
        #     torch.sum(
        #         F.binary_cross_entropy(reconstruction, data, reduction='none'),
        #         dim=1
        #     )
        # )
        reconstruction_loss = torch.mean(
            torch.sum(
                F.binary_cross_entropy_with_logits(reconstruction, data, reduction='none'),
                dim=1
                # dim=(1, 2)
                )  
        )
        classification_loss = torch.mean(
            torch.sum(
                F.cross_entropy(classification_output, labels, reduction='none'),
                # dim=1
                # dim=(1, 2)
                )  
        )
        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp(), dim=1)
        kl_loss = torch.mean(torch.sum(kl_loss))
        total_loss = reconstruction_loss + kl_loss + classification_loss
        total_loss.backward()
        optimizer.step()

        self.total_loss_tracker.append(total_loss.item())
        self.reconstruction_loss_tracker.append(reconstruction_loss.item())
        self.kl_loss_tracker.append(kl_loss.item())
        self.classification_loss_tracker.append(classification_loss.item())

        preds = torch.softmax(classification_output)
        correct = (preds == labels).float().sum()
        accuracy = correct / labels.size(0)
        self.accuracy_tracker.append(accuracy.item())

        return {
            "loss": total_loss.item(),
            "reconstruction_loss": reconstruction_loss.item(),
            "kl_loss": kl_loss.item(),
            "classification_loss": classification_loss.item(),
            "accuracy": accuracy.item()
        }

In [None]:
latent_dim = 64
encoded_dim = 512
output_dim = X_train.shape[1]
input_dim = X_train.shape[1]
print(input_dim)


42


In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim),
            nn.Softmax()  
        )

    def forward(self, x):
        # print('input: ',x.shape)
        output = self.fc(x)
        output = output.view(-1)
        # print('output',output.shape)
        return output

In [None]:
classifier = SimpleClassifier(latent_dim, output_dim=1).to(device)


In [None]:
def check_output(model, input_tensor):
    with torch.no_grad():  
        output = model(input_tensor)
        print(f"Input size: {input_tensor.size()}")
        print(f"Output size: {output.size()}")
        print(f"Output: {output}")

model = SimpleClassifier(latent_dim, output_dim=1)

input_tensor = torch.randn(32,latent_dim)  

check_output(model, input_tensor)

Input size: torch.Size([32, 64])
Output size: torch.Size([32])
Output: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


  return self._call_impl(*args, **kwargs)


In [None]:
vae_encoder = VAE_Encoder(latent_dim=latent_dim)
print("Encoder Summary:")
# vae_encoder.to(device)

summary(vae_encoder, input_size=(32, input_dim), device=device)

Encoder Summary:


Layer (type:depth-idx)                                       Output Shape              Param #
VAE_Encoder                                                  [32, 64]                  --
├─TabNetEncoder: 1-1                                         [32, 512]                 --
│    └─BatchNorm1d: 2-1                                      [32, 42]                  84
│    └─FeatTransformer: 2-2                                  [32, 1024]                4,202,496
│    │    └─GLU_Block: 3-1                                   [32, 1024]                2,191,360
│    └─ModuleList: 2-12                                      --                        (recursive)
│    │    └─FeatTransformer: 3-17                            --                        (recursive)
│    └─FeatTransformer: 2-6                                  --                        (recursive)
│    │    └─GLU_Block: 3-5                                   --                        (recursive)
│    └─ModuleList: 2-12                      

In [None]:
x = torch.randn(800, 42).to(device)
steps_output, _ = tabnet_model.encoder(x)
encoded = steps_output[-1]
print(f"Encoded shape: {encoded.shape}")

Encoded shape: torch.Size([800, 512])


In [None]:
import torch

x = torch.randn(800, 42).to(device)  # Đầu vào có kích thước (batch_size, features)

steps_output, _ = tabnet_model.encoder(x)
print("Shape of encoder output:", [output.shape for output in steps_output])

decoder_input = steps_output[-1]  
decoder_input = decoder_input[None, ...]
try:
    decoder_output = tabnet_model.decoder(decoder_input)
    print(f"Decoder shape: {decoder_output.shape}")
except ValueError as e:
    print(f"Error: {e}")


Shape of encoder output: [torch.Size([800, 512]), torch.Size([800, 512]), torch.Size([800, 512])]
Decoder shape: torch.Size([800, 42])


In [None]:
encoded_dim, output_dim

(512, 42)

In [None]:
vae_decoder = VAE_Decoder(latent_dim=latent_dim, encoded_dim=encoded_dim, output_dim=output_dim).to(device)
print("Decoder Summary:")
summary(vae_decoder, input_size=(32, latent_dim), device=device)

Decoder Summary:
Shape before reshape: torch.Size([32, 512])
Shape after reshape: torch.Size([1, 32, 512])


Layer (type:depth-idx)                                       Output Shape              Param #
VAE_Decoder                                                  [32, 42]                  --
├─Sequential: 1-1                                            [32, 512]                 --
│    └─Linear: 2-1                                           [32, 32]                  2,080
│    └─ReLU: 2-2                                             [32, 32]                  --
│    └─Linear: 2-3                                           [32, 96]                  3,168
│    └─ReLU: 2-4                                             [32, 96]                  --
│    └─Linear: 2-5                                           [32, 96]                  9,312
│    └─ReLU: 2-6                                             [32, 96]                  --
│    └─Linear: 2-7                                           [32, 128]                 12,416
│    └─ReLU: 2-8                                             [32, 128]            

In [None]:
vae = VAE_Tabnet_MLPS(encoder=vae_encoder, decoder=vae_decoder,classifier=classifier).to(device)
summary(vae, input_size=(32, input_dim), device=device)

Shape before reshape: torch.Size([32, 512])
Shape after reshape: torch.Size([1, 32, 512])


Layer (type:depth-idx)                                            Output Shape              Param #
VAE_Tabnet_MLPS                                                   [32, 42]                  --
├─VAE_Encoder: 1-1                                                [32, 64]                  --
│    └─TabNetEncoder: 2-1                                         [32, 512]                 --
│    │    └─BatchNorm1d: 3-1                                      [32, 42]                  84
│    │    └─FeatTransformer: 3-2                                  [32, 1024]                6,393,856
│    │    └─ModuleList: 3-12                                      --                        (recursive)
│    │    └─FeatTransformer: 3-6                                  --                        (recursive)
│    │    └─ModuleList: 3-12                                      --                        (recursive)
│    │    └─FeatTransformer: 3-6                                  --                        (recursive)
│ 

In [None]:
learning_rate = 0.0001
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
num_epochs = 10

for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    rec_loss = 0
    kl_loss = 0
    classification_loss = 0
    accuracy = 0

    for batch_data, batch_labels in train_loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        print(f"Batch data shape: {batch_data.shape}, Batch labels shape: {batch_labels.shape}")
        results = vae.train_step(batch_data, batch_labels, optimizer)
        
        train_loss += results["loss"]
        rec_loss += results["reconstruction_loss"]
        kl_loss += results["kl_loss"]
        classification_loss += results["classification_loss"]
        accuracy += results["accuracy"]

    train_loss /= len(train_loader)
    rec_loss /= len(train_loader)
    kl_loss /= len(train_loader)
    classification_loss /= len(train_loader)
    accuracy /= len(train_loader)

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Reconstruction Loss: {rec_loss:.4f}, KL Loss: {kl_loss:.4f}, Classification Loss: {classification_loss:.4f}, Accuracy: {accuracy:.4f}")


Batch data shape: torch.Size([32, 42]), Batch labels shape: torch.Size([32])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [None]:
for param in vae.encoder.parameters():
    param.requires_grad = False

In [None]:
vae_new = VAE_Tabnet_MLPS(vae.encoder, vae.decoder, vae.classifier).to(device)
for param in vae_new.encoder.parameters():
    param.requires_grad = False

optimizer = optim.Adam(filter(lambda p: p.requires_grad, vae_new.parameters()), lr=learning_rate)
for epoch in range(num_epochs):
    vae_new.train()
    train_loss = 0
    rec_loss = 0
    kl_loss = 0
    classification_loss = 0
    accuracy = 0

    for batch_data, batch_labels in train_loader:
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        results = vae.train_step(batch_data, batch_labels, optimizer)
        
        train_loss += results["loss"]
        rec_loss += results["reconstruction_loss"]
        kl_loss += results["kl_loss"]
        classification_loss += results["classification_loss"]
        accuracy += results["accuracy"]

    train_loss /= len(train_loader)
    rec_loss /= len(train_loader)
    kl_loss /= len(train_loader)
    classification_loss /= len(train_loader)
    accuracy /= len(train_loader)

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Reconstruction Loss: {rec_loss:.4f}, KL Loss: {kl_loss:.4f}, Classification Loss: {classification_loss:.4f}, Accuracy: {accuracy:.4f}")

Epoch 1/10, Loss: 53.1120, Reconstruction Loss: 30.4644, KL Loss: 0.0400, Classification Loss: 22.6076, Accuracy: 0.4992
Epoch 2/10, Loss: 53.0209, Reconstruction Loss: 30.4537, KL Loss: 0.0387, Classification Loss: 22.5284, Accuracy: 0.5002
Epoch 3/10, Loss: 52.9978, Reconstruction Loss: 30.4485, KL Loss: 0.0430, Classification Loss: 22.5062, Accuracy: 0.4990
Epoch 4/10, Loss: 52.9031, Reconstruction Loss: 30.4347, KL Loss: 0.0404, Classification Loss: 22.4280, Accuracy: 0.4998
Epoch 5/10, Loss: 52.9155, Reconstruction Loss: 30.4229, KL Loss: 0.0395, Classification Loss: 22.4531, Accuracy: 0.4998
Epoch 6/10, Loss: 52.8807, Reconstruction Loss: 30.4122, KL Loss: 0.0409, Classification Loss: 22.4276, Accuracy: 0.5000
Epoch 7/10, Loss: 52.8405, Reconstruction Loss: 30.4025, KL Loss: 0.0428, Classification Loss: 22.3952, Accuracy: 0.5002
Epoch 8/10, Loss: 52.8140, Reconstruction Loss: 30.3933, KL Loss: 0.0420, Classification Loss: 22.3787, Accuracy: 0.5000
Epoch 9/10, Loss: 52.7752, Recon

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleDiffusionModel(nn.Module):
    def __init__(self, latent_dim, time_steps=1000):
        super().__init__()
        self.time_steps = time_steps
        self.latent_dim = latent_dim
        
        # Tạo các beta_schedule tuyến tính
        beta = torch.linspace(0.0001, 0.02, time_steps)
        alpha = 1. - beta
        alpha_bar = torch.cumprod(alpha, dim=0)

        self.register_buffer('beta', beta)
        self.register_buffer('alpha', alpha)
        self.register_buffer('alpha_bar', alpha_bar)

        # Mạng neural đơn giản để dự đoán nhiễu
        self.model = nn.Sequential(
            nn.Linear(latent_dim + 1, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )

    def forward(self, z, t):
        noise = torch.randn_like(z)
        
        # Đảm bảo t là long type và shape phù hợp
        if isinstance(t, torch.Tensor):
            t = t.to(dtype=torch.long)
        else:
            t = torch.tensor([t], device=z.device, dtype=torch.long).expand(z.shape[0])

        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None]
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t])[:, None]
        noisy_z = sqrt_alpha_bar * z + sqrt_one_minus_alpha_bar * noise

        predicted_noise = self.model(torch.cat([noisy_z, t.unsqueeze(1)], dim=1))
        loss = F.mse_loss(predicted_noise, noise)
        return loss

    def sample(self, num_samples):
        z = torch.randn(num_samples, self.latent_dim).to(next(self.parameters()).device)
        for i in reversed(range(self.time_steps)):
            t = torch.full((num_samples,), i, device=z.device, dtype=torch.long)
            z = self.denoise_step(z, t)
        return z
    
    def denoise_step(self, z, t):
        timestep = t.item() if isinstance(t, torch.Tensor) else t
        t_batch = torch.full((z.shape[0],), timestep, device=z.device, dtype=torch.long)

        predicted_noise = self.model(torch.cat([z, t_batch.unsqueeze(1)], dim=1))

        alpha = self.alpha[timestep]
        alpha_bar = self.alpha_bar[timestep]
        beta = self.beta[timestep]

        z = (1 / torch.sqrt(alpha)) * (z - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * predicted_noise)
        if timestep > 0:
            noise = torch.randn_like(z)
            z += torch.sqrt(beta) * noise
        return z

In [None]:
# Thay đoạn lỗi này:
# latent_dim = vae.encoder[-1].out_features

# Bằng đoạn này:
with torch.no_grad():
    vae_new.encoder.eval()
    dummy_input = torch.randn(1, input_dim).to(device)  # Thay input_dim theo đúng dữ liệu của bạn
    z_mean, z_log_var, _ = vae_new.encoder(dummy_input)
    latent_dim = z_mean.shape[1]

In [None]:
# latent_dim = vae.encoder[-1].out_features  # Kích thước latent z
diffusion_model = SimpleDiffusionModel(latent_dim=latent_dim).to(device)
diffusion_optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm
def train_diffusion(vae, diffusion_model, dataloader, optimizer, device, time_steps=1000, epochs=20):
    diffusion_model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_data, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            batch_data = batch_data.to(device)
            with torch.no_grad():
                z_mean, z_log_var, z = vae.encoder(batch_data)

            t = torch.randint(0, time_steps, (z.shape[0],), device=device).long()
            loss = diffusion_model(z, t)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Diffusion Loss: {total_loss/len(dataloader):.4f}")

In [None]:
def evaluate_diffusion_with_classifier(vae, diffusion_model, classifier, test_loader, device, time_steps=1000):
    vae.eval()
    diffusion_model.eval()
    classifier.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)

            # Lấy z từ encoder
            z_mean, z_log_var, z = vae.encoder(batch_data)

            # Forward diffusion
            t_forward = time_steps - 1
            sqrt_alpha_bar = torch.sqrt(diffusion_model.alpha_bar[t_forward])
            sqrt_one_minus_alpha_bar = torch.sqrt(1 - diffusion_model.alpha_bar[t_forward])
            noisy_z = sqrt_alpha_bar * z + sqrt_one_minus_alpha_bar * torch.randn_like(z)

            # Reverse diffusion (hoàn nhiễu)
            z_recovered = noisy_z
            for t in reversed(range(time_steps)):
                z_recovered = diffusion_model.denoise_step(z_recovered, t)

            # Phân loại trên z đã hoàn nhiễu
            logits = classifier(z_recovered)

            # Kiểm tra shape của logits
            print("Logits shape:", logits.shape)  # Debug

            if len(logits.shape) == 1:
                # Trường hợp: binary classification với output shape [batch_size]
                preds = (torch.sigmoid(logits) > 0.5).float()
            elif len(logits.shape) == 2:
                # Trường hợp: multi-class classification
                preds = torch.argmax(logits, dim=1)
            else:
                raise ValueError(f"Unexpected logits shape: {logits.shape}")

            # Cập nhật accuracy
            if len(batch_labels.shape) == 2 and batch_labels.shape[1] == 1:
                batch_labels = batch_labels.squeeze(1)  # về shape [batch_size]

            correct += (preds == batch_labels).sum().item()
            total += batch_labels.size(0)

    accuracy = correct / total
    print(f"Accuracy on recovered z: {accuracy:.4f}")

In [None]:
device

device(type='cuda')

In [None]:
# latent_dim = vae.encoder[-1].out_features
diffusion_model = SimpleDiffusionModel(latent_dim=latent_dim).to(device)
diffusion_optimizer = optim.Adam(diffusion_model.parameters(), lr=1e-3)

train_diffusion(vae_new, diffusion_model, train_loader, diffusion_optimizer, device)

evaluate_diffusion_with_classifier(vae, diffusion_model, vae.classifier, test_loader, device)

Epoch 1/20: 100%|██████████| 22/22 [00:00<00:00, 92.62it/s]


Epoch 1, Diffusion Loss: 42.7793


Epoch 2/20: 100%|██████████| 22/22 [00:00<00:00, 111.39it/s]


Epoch 2, Diffusion Loss: 4.0952


Epoch 3/20: 100%|██████████| 22/22 [00:00<00:00, 108.44it/s]


Epoch 3, Diffusion Loss: 1.3671


Epoch 4/20: 100%|██████████| 22/22 [00:00<00:00, 109.31it/s]


Epoch 4, Diffusion Loss: 1.0407


Epoch 5/20: 100%|██████████| 22/22 [00:00<00:00, 108.88it/s]


Epoch 5, Diffusion Loss: 0.9676


Epoch 6/20: 100%|██████████| 22/22 [00:00<00:00, 107.51it/s]


Epoch 6, Diffusion Loss: 0.9269


Epoch 7/20: 100%|██████████| 22/22 [00:00<00:00, 105.55it/s]


Epoch 7, Diffusion Loss: 0.8888


Epoch 8/20: 100%|██████████| 22/22 [00:00<00:00, 105.90it/s]


Epoch 8, Diffusion Loss: 0.8463


Epoch 9/20: 100%|██████████| 22/22 [00:00<00:00, 108.87it/s]


Epoch 9, Diffusion Loss: 0.8085


Epoch 10/20: 100%|██████████| 22/22 [00:00<00:00, 101.10it/s]


Epoch 10, Diffusion Loss: 0.7707


Epoch 11/20: 100%|██████████| 22/22 [00:00<00:00, 100.04it/s]


Epoch 11, Diffusion Loss: 0.7214


Epoch 12/20: 100%|██████████| 22/22 [00:00<00:00, 118.32it/s]


Epoch 12, Diffusion Loss: 0.6964


Epoch 13/20: 100%|██████████| 22/22 [00:00<00:00, 116.84it/s]


Epoch 13, Diffusion Loss: 0.6584


Epoch 14/20: 100%|██████████| 22/22 [00:00<00:00, 119.06it/s]


Epoch 14, Diffusion Loss: 0.6505


Epoch 15/20: 100%|██████████| 22/22 [00:00<00:00, 117.38it/s]


Epoch 15, Diffusion Loss: 0.6098


Epoch 16/20: 100%|██████████| 22/22 [00:00<00:00, 112.84it/s]


Epoch 16, Diffusion Loss: 0.5598


Epoch 17/20: 100%|██████████| 22/22 [00:00<00:00, 112.00it/s]


Epoch 17, Diffusion Loss: 0.5508


Epoch 18/20: 100%|██████████| 22/22 [00:00<00:00, 122.61it/s]


Epoch 18, Diffusion Loss: 0.5294


Epoch 19/20: 100%|██████████| 22/22 [00:00<00:00, 112.93it/s]


Epoch 19, Diffusion Loss: 0.5032


Epoch 20/20: 100%|██████████| 22/22 [00:00<00:00, 110.51it/s]


Epoch 20, Diffusion Loss: 0.4683
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([32])
Logits shape: torch.Size([12])
Accuracy on recovered z: 0.5433
