In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphSAGE
from minisom import MiniSom  
import numpy as np

In [2]:
import sys
sys.path.append('/home/mei/nas/docker/thesis')
from dataloader.ts_reader import LSTMTSDataset,collate_fn
from dataloader.pyg_reader import GraphDataset,get_graph_dataloader

In [3]:
class MultiModalICUModel(nn.Module):
    def __init__(self, flat_input_dim, ts_input_dim, hidden_dim, 
                 gnn_input_dim, gnn_hidden_dim, som_dim, num_classes):
        super(MultiModalICUModel, self).__init__()

        # **1️⃣ Flat Data → MLP**
        self.fc_flat = nn.Sequential(
            nn.Linear(flat_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # **2️⃣ Time Series → LSTM**
        self.lstm = nn.LSTM(input_size=ts_input_dim, hidden_size=hidden_dim, 
                            num_layers=2, batch_first=True, bidirectional=True)

        # **3️⃣ Graph Data → GNN**
        self.gnn = GraphSAGE(gnn_input_dim, gnn_hidden_dim, num_layers=2)  # ✅ 添加 num_layers=2
        self.gnn_fc = nn.Linear(gnn_hidden_dim, hidden_dim)  # Graph embedding 维度调整

        # **4️⃣ Self-Organizing Map (SOM)**
        self.som_dim = som_dim  
        self.som = MiniSom(som_dim, som_dim, input_len=hidden_dim * 3, sigma=0.3, learning_rate=0.5)
        
        # **5️⃣ Classification Head**
        self.clf = nn.Sequential(
            nn.Linear(som_dim * som_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)  
        )

    def forward(self, flat, ts, graph_data):
        """
        flat: (batch_size, flat_input_dim)
        ts: (batch_size, seq_len, ts_input_dim)
        graph_data: PyG Data 对象
        """

        # **1️⃣ Flat Data → Embedding**
        flat_embed = self.fc_flat(flat)  

        # **2️⃣ Time Series → LSTM**
        lstm_out, _ = self.lstm(ts)  
        lstm_embed = lstm_out[:, -1, :]  

        # **3️⃣ Graph → GNN**
        gnn_embed = self.gnn(graph_data.x, graph_data.edge_index)  
        gnn_embed = self.gnn_fc(gnn_embed)  

        # **4️⃣ Multi-modal Fusion**
        fusion = torch.cat([flat_embed, lstm_embed, gnn_embed], dim=-1)  

        # **5️⃣ SOM**
        som_input = fusion.cpu().detach().numpy()  
        som_output = np.array([self.som.winner(x) for x in som_input])  
        som_output = som_output.reshape(fusion.shape[0], -1)  
        som_output = torch.tensor(som_output, dtype=torch.float).to(flat.device)  

        # **6️⃣ Classification**
        output = self.clf(som_output)  

        return output, fusion, som_output  


In [4]:
data_dir = "/home/mei/nas/docker/thesis/data/hdf/train"


In [5]:
config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf/val",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3
          
}

In [6]:
from torch.utils.data import Dataset, DataLoader

# === LSTM + Flat Dataset ===
lstm_dataset = LSTMTSDataset(data_dir)
lstm_loader = DataLoader(lstm_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# === Graph Dataset ===
graph_dataset = GraphDataset(config)
graph_loader = get_graph_dataloader(graph_dataset, batch_size=32, shuffle=True)


==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt


IndexError: Only slices (':'), list, tuples, torch.tensor and np.ndarray of dtype long or bool are valid indices (got 'str')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# **模型**
model = MultiModalICUModel(
    flat_input_dim=104, ts_input_dim=163, hidden_dim=128, 
    gnn_input_dim=128, gnn_hidden_dim=128, som_dim=10, num_classes=2
).to(device)

# **优化器 & 损失**
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# **获取完整图数据（因为 GNN 处理整张图）**
graph_data = next(iter(graph_loader)).to(device)  # 只取一次完整的图

# **训练循环**
for epoch in range(10):
    for batch in lstm_loader:
        (seqs_padded, seq_lengths, flats), labels, ids = batch

        # **移动数据到 GPU**
        seqs_padded, seq_lengths = seqs_padded.to(device), seq_lengths.to(device)
        flats, labels = flats.to(device), labels.to(device)

        # **前向传播**
        outputs, fusion, som_output = model(flats, seqs_padded, seq_lengths, graph_data)

        # **计算损失**
        loss = loss_fn(outputs, labels)

        # **梯度下降**
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


In [None]:
import matplotlib.pyplot as plt

som_output = som_out.cpu().detach().numpy()
plt.scatter(som_output[:, 0], som_output[:, 1], c=labels.cpu().numpy(), cmap="coolwarm")
plt.title("SOM Feature Space")
plt.show()