In [11]:
import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import logging
import tempfile, os
    
from src.datasets import CSVDataset
from src.methods.DCAE import DCAEImputer
from src.methods.DMF import DMFImputer

warnings.filterwarnings('ignore')
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

pl.seed_everything(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['font.size'] = 10

Seed set to 0


In [12]:
def prepare_dataset(csv_path):

    df = pd.read_csv(csv_path)
    vals = df.iloc[:, 1:].to_numpy(dtype=float)
    pos = vals > 0
    vals[pos] = np.log2(vals[pos])
    df.iloc[:, 1:] = vals

    tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False)
    try:
        df.to_csv(tmp.name, index=False)
        tmp.close()
        dataset = CSVDataset(tmp.name)
    finally:
        try:
            os.unlink(tmp.name)
        except Exception:
            pass
    
    return dataset

def train_dcae_with_lambda(data_normalized, mask, lambda_mask, max_epochs=500):
    model = DCAEImputer(
        full_data_tensor=data_normalized,
        full_mask_tensor=mask,
        batch_size=256,
        ae_dim=256,
        mask_predictor_hidden_dim=128,
        lambda_mask=lambda_mask,
        num_encoder_blocks=3,
        num_decoder_blocks=3,
        dilation=3,
        learning_rate=1e-3,
    )
    
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='cuda',
        devices=1,
        enable_checkpointing=False,
        enable_model_summary=False,
        logger=False,
        enable_progress_bar=False,
        gradient_clip_val=1.0,
    )
    
    trainer.fit(model)
    return model

def train_dmf_with_lambda(data_normalized, mask, mask_weight, max_epochs=500):
    model = DMFImputer(
        full_data_tensor=data_normalized,
        full_mask_tensor=mask,
        batch_size=256,
        embedding_dim=64,
        hidden_dims=[256, 128],
        reconstruction_weight=1.0,
        mask_weight=mask_weight,
        lr=1e-3,
    )
    
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='cuda',
        devices=1,
        enable_checkpointing=False,
        enable_model_summary=False,
        logger=False,
        enable_progress_bar=False,
        gradient_clip_val=1.0,
    )
    
    trainer.fit(model)
    return model

def evaluate_model(model, data_normalized, mask):
    """评估模型损失"""
    import torch.nn.functional as F
    
    model.eval()
    with torch.no_grad():
        input_data = data_normalized * mask
        
        if isinstance(model, DCAEImputer):
            x_recon, mask_logits = model.model(input_data)
            lambda_mask = model.lambda_mask
        else: 
            x_recon, mask_logits = model.model(input_data)
            lambda_mask = model.mask_weight
        
        observed = mask > 0.5
        if observed.any():
            recon_loss = F.mse_loss(x_recon[observed], data_normalized[observed]).item()
        else:
            recon_loss = 0.0
        
        if lambda_mask > 0:
            mask_loss = F.binary_cross_entropy_with_logits(
                mask_logits, mask.float()
            ).item()
        else:
            mask_loss = 0.0
        
        total_loss = recon_loss + lambda_mask * mask_loss
    
    return recon_loss, mask_loss, total_loss

In [13]:
data_dir = Path("./data")
csv_files = list(data_dir.glob("*.csv"))
dataset_path = csv_files[0]

print(f"Loading dataset: {dataset_path.name}")
dataset = prepare_dataset(dataset_path)

data_normalized = dataset.data_normalized
mask = dataset.mask
print(f"Data shape: {data_normalized.shape}")
print(f"Mask coverage: {mask.mean():.2%}")

Loading dataset: Alzheimer.csv
Data shape: torch.Size([210, 1541])
Mask coverage: 7.86%


In [14]:
lambda_values = [0.0, 0.1, 0.2,0.3, 0.4, 0.5, 0.6, 0.7, 0.8,   0.9,1.0, 2.0]
max_epochs = 500

dcae_results = []
mask_tensor = torch.tensor(mask.numpy(), dtype=torch.float32)

print("\n" + "="*60)
print("DCAE Lambda Benchmark")
print("="*60)
for lambda_mask in tqdm(lambda_values, desc="DCAE"):
    print(f"\nTraining DCAE with lambda_mask={lambda_mask}")
    
    model = train_dcae_with_lambda(
        data_normalized, 
        mask_tensor,
        lambda_mask=lambda_mask,
        max_epochs=max_epochs
    )
    
    recon_loss, mask_loss, total_loss = evaluate_model(model, data_normalized, mask_tensor)
    
    dcae_results.append({
        'lambda': lambda_mask,
        'recon_loss': recon_loss,
        'mask_loss': mask_loss,
        'total_loss': total_loss
    })
    
    print(f"  Recon Loss: {recon_loss:.4f}")
    print(f"  Mask Loss: {mask_loss:.4f}")
    print(f"  Total Loss: {total_loss:.4f}")

dcae_df = pd.DataFrame(dcae_results)


print("\n" + "="*60)
print("DCAE RESULTS")
print("="*60)
display(dcae_df)



DCAE Lambda Benchmark


DCAE:   0%|          | 0/12 [00:00<?, ?it/s]


Training DCAE with lambda_mask=0.0


DCAE:   8%|▊         | 1/12 [00:37<06:53, 37.61s/it]

  Recon Loss: 0.0492
  Mask Loss: 0.0000
  Total Loss: 0.0492

Training DCAE with lambda_mask=0.1


DCAE:  17%|█▋        | 2/12 [01:16<06:24, 38.45s/it]

  Recon Loss: 0.0448
  Mask Loss: 0.2466
  Total Loss: 0.0695

Training DCAE with lambda_mask=0.2


DCAE:  25%|██▌       | 3/12 [01:55<05:48, 38.71s/it]

  Recon Loss: 0.0426
  Mask Loss: 0.1890
  Total Loss: 0.0804

Training DCAE with lambda_mask=0.3


DCAE:  33%|███▎      | 4/12 [02:35<05:13, 39.14s/it]

  Recon Loss: 0.0267
  Mask Loss: 0.1347
  Total Loss: 0.0671

Training DCAE with lambda_mask=0.4


DCAE:  42%|████▏     | 5/12 [03:13<04:32, 38.86s/it]

  Recon Loss: 0.0654
  Mask Loss: 0.2215
  Total Loss: 0.1540

Training DCAE with lambda_mask=0.5


DCAE:  50%|█████     | 6/12 [03:51<03:50, 38.46s/it]

  Recon Loss: 0.0401
  Mask Loss: 0.1743
  Total Loss: 0.1273

Training DCAE with lambda_mask=0.6


DCAE:  58%|█████▊    | 7/12 [04:31<03:14, 38.80s/it]

  Recon Loss: 0.0574
  Mask Loss: 0.1933
  Total Loss: 0.1735

Training DCAE with lambda_mask=0.7


DCAE:  67%|██████▋   | 8/12 [05:09<02:35, 38.77s/it]

  Recon Loss: 0.0506
  Mask Loss: 0.1731
  Total Loss: 0.1718

Training DCAE with lambda_mask=0.8


DCAE:  75%|███████▌  | 9/12 [05:47<01:55, 38.60s/it]

  Recon Loss: 0.0478
  Mask Loss: 0.2529
  Total Loss: 0.2501

Training DCAE with lambda_mask=0.9


DCAE:  83%|████████▎ | 10/12 [06:26<01:17, 38.64s/it]

  Recon Loss: 0.0504
  Mask Loss: 0.1034
  Total Loss: 0.1434

Training DCAE with lambda_mask=1.0


DCAE:  92%|█████████▏| 11/12 [07:05<00:38, 38.56s/it]

  Recon Loss: 0.0252
  Mask Loss: 0.1239
  Total Loss: 0.1491

Training DCAE with lambda_mask=2.0


DCAE: 100%|██████████| 12/12 [07:42<00:00, 38.57s/it]

  Recon Loss: 0.1418
  Mask Loss: 0.2367
  Total Loss: 0.6152

DCAE RESULTS





Unnamed: 0,lambda,recon_loss,mask_loss,total_loss
0,0.0,0.049237,0.0,0.049237
1,0.1,0.04485,0.246608,0.069511
2,0.2,0.042561,0.188996,0.08036
3,0.3,0.026699,0.134705,0.06711
4,0.4,0.065377,0.221547,0.153996
5,0.5,0.040122,0.17433,0.127287
6,0.6,0.057447,0.193344,0.173453
7,0.7,0.050637,0.173075,0.17179
8,0.8,0.047832,0.252894,0.250147
9,0.9,0.050356,0.103426,0.143439
