In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import TimeSeriesTransformerModel
from torch_geometric.nn import GATConv
from lightly.models import SimCLR
from torch.utils.data import DataLoader
from torch_geometric.loader import NeighborLoader

import torch.nn.functional as F
import torch.distributions as dist


2025-02-02 16:33:52.358468: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738510432.384812 3769486 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738510432.393046 3769486 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-02 16:33:52.420019: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
class TDPSOM(nn.Module):
    """T-DPSOM model integrating VAE, LSTM-based prediction, and SOM clustering."""
    
    def __init__(self, input_size, latent_dim=10, som_dim=[8, 8], learning_rate=1e-4, decay_factor=0.99,
                 decay_steps=2000, input_channels=98, alpha=10., beta=100., gamma=100., kappa=0.,
                 theta=1., eta=1., dropout=0.5, prior=0.001, lstm_dim=100):
        super(TDPSOM, self).__init__()
        
        self.input_size = input_size
        self.latent_dim = latent_dim
        self.som_dim = som_dim
        self.learning_rate = learning_rate
        self.decay_factor = decay_factor
        self.decay_steps = decay_steps
        self.input_channels = input_channels
        self.alpha = alpha
        self.beta = beta
        self.eta = eta
        self.gamma = gamma
        self.theta = theta
        self.kappa = kappa
        self.dropout = dropout
        self.prior = prior
        self.lstm_dim = lstm_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_channels, 500),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(500),
            nn.Linear(500, 500),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(500),
            nn.Linear(500, 2000),
            nn.LeakyReLU(),
            nn.Dropout(dropout),
            nn.BatchNorm1d(2000),
            nn.Linear(2000, latent_dim * 2)  # mu and logvar
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 2000),
            nn.LeakyReLU(),
            nn.BatchNorm1d(2000),
            nn.Linear(2000, 500),
            nn.LeakyReLU(),
            nn.BatchNorm1d(500),
            nn.Linear(500, 500),
            nn.LeakyReLU(),
            nn.BatchNorm1d(500),
            nn.Linear(500, input_channels)
        )

        # LSTM for prediction
        self.lstm = nn.LSTM(latent_dim, lstm_dim, batch_first=True)
        self.prediction_head = nn.Sequential(
            nn.Linear(lstm_dim, lstm_dim),
            nn.LeakyReLU(),
            nn.Linear(lstm_dim, latent_dim * 2)  # mu and logvar
        )

        # SOM embeddings
        self.embeddings = nn.Parameter(torch.randn(som_dim[0] * som_dim[1], latent_dim))

        # Optimizer
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=decay_steps, gamma=decay_factor)

    def encode(self, x):
        """Encode input into latent distribution."""
        h = self.encoder(x)
        mu, logvar = torch.chunk(h, 2, dim=-1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """Reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Decode latent samples into reconstructed input."""
        return self.decoder(z)

    def predict(self, z):
        """Predict the next state using LSTM."""
        batch_size, step_size, _ = z.size()
        h, _ = self.lstm(z)
        h_flat = h.reshape(batch_size * step_size, self.lstm_dim)
        next_z = self.prediction_head(h_flat)
        next_z_mu, next_z_logvar = torch.chunk(next_z, 2, dim=-1)
        return next_z_mu, next_z_logvar

    def som_loss(self, z):
        """Compute SOM loss."""
        batch_size, step_size, _ = z.size()
        z_flat = z.reshape(-1, self.latent_dim)
        distances = torch.cdist(z_flat, self.embeddings)
        k = torch.argmin(distances, dim=-1)
        k_1 = k // self.som_dim[1]
        k_2 = k % self.som_dim[1]
        som_coords = torch.stack([k_1.float(), k_2.float()], dim=1)
        som_coords = som_coords.reshape(batch_size, step_size, 2)
        return som_coords

    def forward(self, x):
        """Forward pass of the model."""
        batch_size, step_size, _ = x.size()
        x_flat = x.reshape(-1, self.input_channels) # Flatten input[batch_size, step_size, input_channels]

        # Encode
        mu, logvar = self.encode(x_flat)
        z = self.reparameterize(mu, logvar)

        # Decode
        x_recon = self.decode(z)

        # Reshape z for LSTM
        z_seq = z.reshape(batch_size, step_size, self.latent_dim)

        # Predict next state
        next_z_mu, next_z_logvar = self.predict(z_seq)

        # SOM coordinates
        som_coords = self.som_loss(z_seq)

        return x_recon, z_seq, (next_z_mu, next_z_logvar), som_coords

    def loss_function(self, x, x_recon, mu, logvar, next_z_mu, next_z_logvar):
        """Compute the loss function."""
        # Reconstruction loss
        recon_loss = F.mse_loss(x_recon, x.view(-1, self.input_channels), reduction='sum')

        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        # Prediction loss
        pred_loss = F.mse_loss(next_z_mu, mu.view(-1, self.latent_dim), reduction='sum')

        # Total loss
        loss = self.theta * (recon_loss + self.prior * kl_loss) + pred_loss

        return loss

    def train_step(self, x):
        """Perform a single training step."""
        self.optimizer.zero_grad()
        x_recon, z_seq, (next_z_mu, next_z_logvar), som_coords = self.forward(x)
        mu, logvar = self.encode(x.view(-1, self.input_channels))
        loss = self.loss_function(x, x_recon, mu, logvar, next_z_mu, next_z_logvar)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        return loss.item()

In [None]:
# 5️⃣ **完整 ICU 预测模型**
class MultiModalICUModel(nn.Module):
    def __init__(self, ts_dim, gnn_dim, flat_dim, hidden_dim, num_classes):
        super().__init__()

        self.ts_model = TimeSeriesModel(ts_dim, hidden_dim)
        self.gnn = GATModel(gnn_dim, hidden_dim, hidden_dim)
        self.contrastive = ContrastiveModel(hidden_dim)
        self.fusion = MultimodalTransformer(num_modalities=3, embedding_size=hidden_dim)

        self.clf = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, flat, ts, graph_data):
        ts_embed = self.ts_model(ts)
        gnn_embed = self.gnn(graph_data.x, graph_data.edge_index)
        flat_embed = torch.relu(flat)

        # 对比学习表征
        ts_embed = self.contrastive(ts_embed)
        gnn_embed = self.contrastive(gnn_embed)

        # 融合
        fusion_embed = self.fusion([flat_embed, ts_embed, gnn_embed])

        # 分类
        output = self.clf(fusion_embed)

        return output

In [None]:
def get_dataloaders(ts_data, flat_data, graph_data, batch_size=32):
    dataset = list(zip(flat_data, ts_data, graph_data))
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
def train(model, dataloader, epochs=10, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for flat, ts, graph in dataloader:
            flat, ts, graph = flat.to(device), ts.to(device), graph.to(device)
            
            optimizer.zero_grad()
            outputs = model(flat, ts, graph)
            loss = criterion(outputs, torch.randint(0, 2, (flat.shape[0],), dtype=torch.long, device=device))  # 假设二分类
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")
