In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, global_mean_pool

In [3]:

import os
import torch
from torch_geometric.data import Dataset

class HeteroGraphDataset(Dataset):
    def __init__(self, root):
        super().__init__(root)
        self.root = root
        self.file_list = [f for f in os.listdir(root) if f.endswith('.pt')]
        self.file_list.sort()

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root, self.file_list[idx])
        return torch.load(file_path)

The following is the domain generalization dataset partitioning code.

In [None]:
from torch.utils.data import random_split
from torch_geometric.data import DataLoader

# Modify to your path
dataset1 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX")
dataset2 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX2")
dataset3 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX3")

train_len1 = int(0.7 * len(dataset1))
val_len1 = int(0.15 * len(dataset1))
test_len1 = len(dataset1) - train_len1 - val_len1

train_len2 = int(0.7 * len(dataset2))
val_len2 = int(0.15 * len(dataset2))
test_len2 = len(dataset2) - train_len2 - val_len2

train_len3 = int(0.7 * len(dataset3))
val_len3 = int(0.15 * len(dataset3))
test_len3 = len(dataset3) - train_len3 - val_len3

train_set1, val_set1, test_set1 = random_split(dataset1, [train_len1, val_len1, test_len1], generator=torch.Generator().manual_seed(42))
train_loader1 = DataLoader(train_set1, batch_size=8, shuffle=True)
val_loader1 = DataLoader(val_set1, batch_size=8)
test_loader1 = DataLoader(test_set1, batch_size=8)

train_set2, val_set2, test_set2 = random_split(dataset2, [train_len2, val_len2, test_len2], generator=torch.Generator().manual_seed(42))
train_loader2 = DataLoader(train_set2, batch_size=8, shuffle=True)
val_loader2 = DataLoader(val_set2, batch_size=8)
test_loader2 = DataLoader(test_set2, batch_size=8)

train_set3, val_set3, test_set3 = random_split(dataset3, [train_len3, val_len3, test_len3], generator=torch.Generator().manual_seed(42))
train_loader3 = DataLoader(train_set3, batch_size=8, shuffle=True)
val_loader3 = DataLoader(val_set3, batch_size=8)
test_loader3 = DataLoader(test_set3, batch_size=8)


The following is the domain adaptation data partitioning code.

In [None]:
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader 
# Load the dataset
dataset1 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX")  # Source domain
dataset2 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX2")  # Target domain
dataset3 = HeteroGraphDataset("G:\XXXX\XXXXX\XXXXXX3")  # Target domain

# Define the Source domain and Target domain labels for each task.
label_mapping = {
    'T0B': {
        'source': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 1
    },
    'T1B': {
        'source': [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.875
    },
    'T2B': {
        'source': [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.8125
    },
    'T3B': {
        'source': [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,15],
        'intersection_ratio': 0.8125
    },
    'T4B': {
        'source': [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.6875
    },
    'T5B': {
        'source': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.6875
    },
    'T6B': {
        'source': [0, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.5
    },
    'T7B': {
        'source': [0, 9, 10, 11, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.3125
    },
    'T8B': {
        'source': [0, 12, 13, 14, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.125
    },
    'T9B': {
        'source': [0, 5],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.125
    },
    'T10B': {
        'source': [0, 11],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.125
    },
    'T11B': {
        'source': [0, 15],
        'target': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
        'intersection_ratio': 0.125
    }
}
# Extract source domain and target domain data
source_domain_data_train = []
source_domain_data_val = []
target_domain_data_test = []

# ====== Define Tag Filtering Task (Using T1B as an Example) ======
task_name = 'T11B'
label_filter = label_mapping[task_name]
source_labels = label_filter['source']
target_labels = label_filter['target']

# ====== Data Set Partitioning Ratio Settings ======
def get_split_lengths(dataset):
    n = len(dataset)
    return int(0.7 * n), int(0.15 * n), n - int(0.7 * n) - int(0.15 * n)

train_len1, val_len1, test_len1 = get_split_lengths(dataset1)
train_len2, val_len2, test_len2 = get_split_lengths(dataset2)
train_len3, val_len3, test_len3 = get_split_lengths(dataset3)

# ====== Dataset Partitioning ======
train_set1, val_set1, test_set1 = random_split(dataset1, [train_len1, val_len1, test_len1], generator=torch.Generator().manual_seed(42))
train_set2, val_set2, test_set2 = random_split(dataset2, [train_len2, val_len2, test_len2], generator=torch.Generator().manual_seed(42))
train_set3, val_set3, test_set3 = random_split(dataset3, [train_len3, val_len3, test_len3], generator=torch.Generator().manual_seed(42))

# ====== Label Filtering Logic: Filter image samples based solely on y (category label) ======
for dataset, train_indices in zip([dataset1,dataset2, dataset3], [train_set1.indices, train_set2.indices, train_set3.indices]):
    source_domain_data_train.extend([
        dataset[idx] for idx in train_indices if dataset[idx].y.item() in source_labels
    ])

for dataset, val_indices in zip([dataset1,dataset2, dataset3], [val_set1.indices, val_set2.indices, val_set3.indices]):
    source_domain_data_val.extend([
        dataset[idx] for idx in val_indices if dataset[idx].y.item() in source_labels
    ])

# Target domain: Select only images whose labels are present in target_labels (from dataset2 and dataset3).

for dataset, test_indices in zip([dataset1,dataset2, dataset3], [test_set1.indices, test_set2.indices, test_set3.indices]):
    target_domain_data_test.extend([
        dataset[idx] for idx in test_indices if dataset[idx].y.item() in target_labels
    ])

# ====== Build DataLoader (batch processing diagram) ======
train_loader_source = DataLoader(source_domain_data_train, batch_size=8, shuffle=True, drop_last=True)
val_loader_source   = DataLoader(source_domain_data_val, batch_size=8, drop_last=True)
test_loader_target  = DataLoader(target_domain_data_test, batch_size=8, drop_last=True)

# ====== Print Information Confirmation ======
print(f"Current Task: {task_name}")
print(f"Tag filtering is enabled: Source={source_labels} | Target={target_labels}")
print(f"Source Domain Train Size: {len(source_domain_data_train)}")
print(f"Source Domain Validation Size: {len(source_domain_data_val)}")
print(f"Target Domain Test Size: {len(target_domain_data_test)}")

Visualization of Graph Construction Results

In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from torch_geometric.data import HeteroData

sample = dataset1[0]

G = nx.Graph()
node_colors = {}
node_labels = {}
edge_colors = []

# Color and Legend Name
color_palette = {
    'Displacement': '#1f77b4',
    'Acceleration': '#ff7f0e',
    'Acoustic': '#2ca02c',
    'Telemetric stress': '#d62728',
    'Patch stress': '#9467bd',
    'Supernode': '#8c564b',
    'bottom_layer': '#e377c2',
    'global_supernode': '#8c564b'  
}

edge_color_palette = {
    'Frequency edges': '#17becf',
    'Frequency-Hop edges': '#bcbd22',
    'Modality edges': '#e377c2',
    'Modality supernode edges': '#d8b79a',
    'Space edges': '#7f7f7f'
}

# Establish a mapping
node_type_mapping = {
    'displacement': 'Displacement',
    'acceleration': 'Acceleration',
    'sound': 'Acoustic',
    'blade_stress': 'Telemetric stress',
    'casing_stress': 'Patch stress',
}
edge_type_mapping = {
    'freq_edge': 'Frequency edges',
    'hop_edge': 'Frequency-Hop edges',
    'modality_edge': 'Modality edges',
    'supernode_edge': 'Modality supernode edges',
    'inter_supernode_edge': 'Space edges'
}

super_nodes_set = set()
modality_supernodes = []

# Add node
for node_type in sample.node_types:
    num_nodes = sample[node_type].x.size(0)
    for i in range(num_nodes):
        node_id = f"{node_type}_{i}"
        is_super = (i == num_nodes - 1)
        mapped_type = node_type_mapping.get(node_type, node_type)
        color = color_palette['Supernode'] if is_super else color_palette[mapped_type]
        label = f"{mapped_type}_super" if is_super else mapped_type
        G.add_node(node_id)
        node_colors[node_id] = color
        node_labels[node_id] = label
        if is_super:
            super_nodes_set.add(node_id)
            modality_supernodes.append(node_id)


# Add edge
for edge_type in sample.edge_types:
    src_type, rel_type, dst_type = edge_type
    edge_index = sample[edge_type].edge_index
    rel_label = edge_type_mapping.get(rel_type, rel_type)
    color = edge_color_palette.get(rel_label, 'black')
    for s, d in zip(edge_index[0].tolist(), edge_index[1].tolist()):
        src_id = f"{src_type}_{s}"
        dst_id = f"{dst_type}_{d}"
        G.add_edge(src_id, dst_id, rel_type=rel_type)
        edge_colors.append(color)


for i in range(len(modality_supernodes)):
    for j in range(i + 1, len(modality_supernodes)):
        G.add_edge(modality_supernodes[i], modality_supernodes[j], rel_type="inter_supernode_edge")
        edge_colors.append(edge_color_palette['Space edges'])


plt.figure(figsize=(6, 4))
pos = nx.spring_layout(G, seed=42)

node_color_list = [node_colors[n] for n in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_color_list, node_size=40)
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=0.6)


supernode_edges = [(u, v) for u, v, attr in G.edges(data=True) if attr.get('rel_type') == 'supernode_edge']
inter_supernode_edges = [(u, v) for u, v, attr in G.edges(data=True) if attr.get('rel_type') == 'inter_supernode_edge']

nx.draw_networkx_edges(G, pos, edgelist=supernode_edges, edge_color=edge_color_palette['Modality supernode edges'], width=0.6)
nx.draw_networkx_edges(G, pos, edgelist=inter_supernode_edges, edge_color=edge_color_palette['Space edges'], width=0.6)


filtered_node_legend = [
    mlines.Line2D([], [], color='white', marker='o', markerfacecolor=c, markersize=6, label=nt)
    for nt, c in color_palette.items()
    if nt not in ['bottom_layer', 'global_supernode']
]
edge_legend = [
    mlines.Line2D([], [], color=c, label=et, linewidth=1.2)
    for et, c in edge_color_palette.items()
]

plt.legend(
    handles=filtered_node_legend + edge_legend,
    loc='upper center',
    bbox_to_anchor=(0.5, 0.07),
    ncol=3,
    fontsize=14,
    frameon=False,
    handlelength=1.0,
    handletextpad=0.4,
    columnspacing=0.8
)

plt.axis('off')
plt.tight_layout(pad=0)
plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.savefig("graph_5supernodes_interconnected.png", dpi=1800, bbox_inches='tight', pad_inches=0)
plt.savefig("graph_5supernodes_interconnected.jpg", dpi=1800, bbox_inches='tight', pad_inches=0)
plt.savefig("graph_5supernodes_interconnected.tiff", dpi=1800, bbox_inches='tight', pad_inches=0)
plt.show()


Main Content of the Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, global_mean_pool

# DropEdge
def drop_edge(edge_index, drop_prob=0.2):
    num_edges = edge_index.size(1)
    if num_edges == 0:
        return edge_index

    keep_ratio = max(1.0 - drop_prob, 0.7)
    num_keep = max(int(num_edges * keep_ratio), int(num_edges * 0.7))

    perm = torch.randperm(num_edges, device=edge_index.device)[:num_keep]
    return edge_index[:, perm]

# DropFeature
def drop_feature(x, drop_prob=0.2, attention_weight=None):
    if attention_weight is not None:
        att_prob = torch.sigmoid(attention_weight).unsqueeze(1).expand_as(x)
        mask = torch.rand_like(x) > (drop_prob * (1 - att_prob))  # ‰øùÁïôÈ´òÊùÉÈáç
    else:
        mask = torch.rand_like(x) > drop_prob
    return x * mask.float()

# Modal Attention
class ModalityAttention(nn.Module):
    def __init__(self, hidden_dim, num_modalities):
        super().__init__()
        self.att_weight = nn.Parameter(torch.randn(num_modalities))
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.gate_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, modality_ids):
        alpha = F.softmax(self.att_weight, dim=0)
        weight = alpha[modality_ids]
        gate = torch.sigmoid(self.gate_proj(x))
        return self.proj(x) * weight.unsqueeze(1) * gate, alpha

# Frequency Attention
class FrequencyAttention(nn.Module):
    def __init__(self, hidden_dim, num_freqs):
        super().__init__()
        self.att_weight = nn.Parameter(torch.randn(num_freqs))
        self.proj = nn.Linear(hidden_dim, hidden_dim)
        self.gate_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, freq_ids):
        gamma = F.softmax(self.att_weight, dim=0)
        weight = gamma[freq_ids]
        gate = torch.sigmoid(self.gate_proj(x))
        return self.proj(x) * weight.unsqueeze(1) * gate, gamma

# Model Structure
class HGTWithContrastive(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, metadata,
                 num_modalities, num_freqs, heads=4, dropout=0.3,
                 dropedge_prob=0.2, dropfeat_prob=0.2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.dropedge_prob = dropedge_prob
        self.dropfeat_prob = dropfeat_prob
        self.metadata = metadata

        self.embeddings = nn.ModuleDict({
            ntype: nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim)
            ) for ntype in metadata[0]
        })

        self.hgt1 = HGTConv(hidden_dim, hidden_dim, metadata=metadata, heads=heads)
        self.hgt2 = HGTConv(hidden_dim, hidden_dim, metadata=metadata, heads=heads)

        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.mod_att = ModalityAttention(hidden_dim, num_modalities)
        self.freq_att = FrequencyAttention(hidden_dim, num_freqs)
        self.pool = global_mean_pool

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x_dict, edge_index_dict, batch_dict, modality_ids, freq_ids, training=True):
        alpha_dict = {}
        gamma_dict = {}
        for ntype in x_dict:
            x = self.embeddings[ntype](x_dict[ntype])   # ÊòæÂºèÊâßË°å Linear(28->64)
            _, alpha = self.mod_att(torch.zeros_like(x), modality_ids[ntype])
            _, gamma = self.freq_att(torch.zeros_like(x), freq_ids[ntype])
            alpha_dict[ntype] = alpha.detach()
            gamma_dict[ntype] = gamma.detach()

        if training:
            x_embed = {
                ntype: drop_feature(
                    self.embeddings[ntype](x),
                    self.dropfeat_prob,
                    attention_weight=(alpha_dict[ntype][modality_ids[ntype]] + gamma_dict[ntype][freq_ids[ntype]]) / 2
                )
                for ntype, x in x_dict.items()
            }
        else:
            x_embed = {
                ntype: self.embeddings[ntype](x) for ntype, x in x_dict.items()
            }

        x1 = self.hgt1(x_embed, edge_index_dict)  

    
        if training:
            edge_index_dict_dropped = {
                k: drop_edge(v, self.dropedge_prob) if k[1] in ['modality', 'supmod'] else v
                for k, v in edge_index_dict.items()
            }
        else:
            edge_index_dict_dropped = edge_index_dict

        x2 = self.hgt2(x1, edge_index_dict_dropped)

        x_stack = {
            ntype: torch.stack([x1[ntype], x2[ntype]], dim=1)
            for ntype in x1
        }
        x_fused = {
            ntype: self.gru(x_stack[ntype])[0][:, -1, :]
            for ntype in x_stack
        }

        features = [x_fused[ntype] for ntype in x_fused]
        x_all = torch.cat(features, dim=0)

        all_mod = torch.cat([modality_ids[ntype] for ntype in x_fused], dim=0)
        all_freq = torch.cat([freq_ids[ntype] for ntype in x_fused], dim=0)
        all_batch = torch.cat([batch_dict[ntype] for ntype in x_fused], dim=0)

        x_all = x_all + self.mod_att(x_all, all_mod)[0] + self.freq_att(x_all, all_freq)[0]
        x_pool = self.pool(x_all, all_batch)

        return self.classifier(x_pool)

    def extract_feature(self, x_dict, edge_index_dict, batch_dict, modality_ids, freq_ids, training=True):
        return self.forward(x_dict, edge_index_dict, batch_dict, modality_ids, freq_ids, training=training)
    def extract_multi_stage_features(self, x_dict, edge_index_dict, batch_dict, modality_ids, freq_ids, training=False):
        stage_outputs = {}

        embed_dict = {ntype: self.embeddings[ntype](x_dict[ntype]) for ntype in x_dict}

        x1 = self.hgt1(embed_dict, edge_index_dict)
        stage_outputs["HGT1"] = torch.cat([x1[ntype] for ntype in x1], dim=0)

        x2 = self.hgt2(x1, edge_index_dict)
        stage_outputs["HGT2"] = torch.cat([x2[ntype] for ntype in x2], dim=0)

        x_stack = {ntype: torch.stack([x1[ntype], x2[ntype]], dim=1) for ntype in x1}
        x_gru = {ntype: self.gru(x_stack[ntype])[0][:, -1, :] for ntype in x_stack}
        stage_outputs["GRU"] = torch.cat([x_gru[ntype] for ntype in x_gru], dim=0)

        all_mod = torch.cat([modality_ids[ntype] for ntype in x_gru], dim=0)
        all_freq = torch.cat([freq_ids[ntype] for ntype in x_gru], dim=0)
        all_batch = torch.cat([batch_dict[ntype] for ntype in x_gru], dim=0)

        x_all = torch.cat([x_gru[ntype] for ntype in x_gru], dim=0)
        x_all_att = x_all + self.mod_att(x_all, all_mod)[0] + self.freq_att(x_all, all_freq)[0]
        stage_outputs["Attention"] = x_all_att

        x_pool = self.pool(x_all_att, all_batch)
        stage_outputs["Pool"] = x_pool

        x_hidden = self.classifier[0](x_pool)
        stage_outputs["Classifier"] = x_hidden

        return stage_outputs


In [None]:
import copy
import torch_geometric

def graph_augment(data, drop_edge_prob=0.2, drop_feat_prob=0.2):
    # Create a data copy
    data = copy.deepcopy(data)

    # DropEdge
    for edge_type in data.edge_index_dict:
        edge_index = data.edge_index_dict[edge_type]
        num_edges = edge_index.size(1)
        mask = torch.rand(num_edges, device=edge_index.device) > drop_edge_prob
        data.edge_index_dict[edge_type] = edge_index[:, mask]

    # DropFeature
    for ntype in data.x_dict:
        x = data.x_dict[ntype]
        mask = torch.rand_like(x) > drop_feat_prob
        data.x_dict[ntype] = x * mask.float()

    return data


In [None]:
import torch.nn.functional as F

def contrastive_loss(z1, z2, temperature=0.5):
    """
    Compute NT-Xent contrast loss
    z1, z2: Graph embeddings under two augmented views (batch_size, dim)
    """
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)

    batch_size = z1.size(0)
    representations = torch.cat([z1, z2], dim=0)  # (2B, D)
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)

    # Construct labels (positive samples on the diagonal)
    labels = torch.arange(batch_size, device=z1.device)
    labels = torch.cat([labels, labels], dim=0)

    # Block diagonal (self vs. self)
    mask = torch.eye(2 * batch_size, device=z1.device).bool()
    similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)

    # scale by temperature and compute loss
    similarity_matrix = similarity_matrix / temperature
    loss = F.cross_entropy(similarity_matrix, labels)
    return loss


In [None]:
# Definition of the MMD Loss Function
def compute_mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    """Maximum Mean Discrepancy (MMD) loss"""
    batch_size = source.size(0)
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(total.size(0), -1, -1)
    total1 = total.unsqueeze(1).expand(-1, total.size(0), -1)
    L2_distance = ((total0 - total1)**2).sum(2)

    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (batch_size ** 2 - batch_size)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    kernels = sum(kernel_val)

    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY - YX)
    return loss


In [None]:
# Definition of the SupCon Contrast Loss Function
def supcon_loss(features, labels, temperature=0.07):
    """Supervised Contrastive Loss"""
    device = features.device
    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float().to(device)

    features = F.normalize(features, dim=1)
    similarity_matrix = torch.div(torch.matmul(features, features.T), temperature)

    logits_max, _ = torch.max(similarity_matrix, dim=1, keepdim=True)
    logits = similarity_matrix - logits_max.detach()
    exp_logits = torch.exp(logits) * (1 - torch.eye(labels.size(0), device=device))
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

    mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12)
    loss = -mean_log_prob_pos.mean()
    return loss


In [None]:
# Training main loop (main task + self-supervised contrastive learning, no domain_id required)
def train_one_epoch(model, loader, optimizer, criterion_cls,
                    contrastive_weight=0.5, mmd_weight=0.0,
                    mode='DG', use_supcon=False):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in loader:
        batch = batch.to(next(model.parameters()).device)
        optimizer.zero_grad()

        # Primary Task Output
        logits = model(
            x_dict=batch.x_dict,
            edge_index_dict=batch.edge_index_dict,
            batch_dict=batch.batch_dict,
            modality_ids={k: batch[k].modality_ids for k in batch.node_types},
            freq_ids={k: batch[k].freq_ids for k in batch.node_types},
            training=True
        )

        # Comparison View
        aug1 = graph_augment(batch, drop_edge_prob=0.2, drop_feat_prob=0.2)
        aug2 = graph_augment(batch, drop_edge_prob=0.2, drop_feat_prob=0.2)

        z1 = model.extract_feature(
            x_dict=aug1.x_dict,
            edge_index_dict=aug1.edge_index_dict,
            batch_dict=aug1.batch_dict,
            modality_ids={k: aug1[k].modality_ids for k in aug1.node_types},
            freq_ids={k: aug1[k].freq_ids for k in aug1.node_types},
            training=True
        )
        z2 = model.extract_feature(
            x_dict=aug2.x_dict,
            edge_index_dict=aug2.edge_index_dict,
            batch_dict=aug2.batch_dict,
            modality_ids={k: aug2[k].modality_ids for k in aug2.node_types},
            freq_ids={k: aug2[k].freq_ids for k in aug2.node_types},
            training=True
        )

        # Select Contrastive or SupCon
        if use_supcon:
            labels = batch.y.repeat(2)  # Êâ©Â±ïÂåπÈÖç z1/z2 ÂêàÂπ∂ÂêéÊ†áÁ≠æ
            z_all = torch.cat([z1, z2], dim=0)
            loss_cl = supcon_loss(z_all, labels)
        else:
            loss_cl = contrastive_loss(z1, z2)

        loss_cls = criterion_cls(logits, batch.y)
        loss = loss_cls + contrastive_weight * loss_cl

        # Adding MMD in DA mode (requires domain_id)
        if mode == "DA" and hasattr(batch, "domain_ids"):
            z_feat = model.extract_feature(
                x_dict=batch.x_dict,
                edge_index_dict=batch.edge_index_dict,
                batch_dict=batch.batch_dict,
                modality_ids={k: batch[k].modality_ids for k in batch.node_types},
                freq_ids={k: batch[k].freq_ids for k in batch.node_types},
                training=False
            )
            domain_ids = torch.cat([batch[k].domain_ids for k in batch.node_types], dim=0)
            source = z_feat[domain_ids == 0]
            target = z_feat[domain_ids != 0]
            if len(source) > 0 and len(target) > 0:
                loss_mmd = compute_mmd(source, target)
                loss += mmd_weight * loss_mmd

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == batch.y).sum().item()
        total += batch.y.size(0)

    acc = correct / total
    return total_loss / len(loader), acc




# Verification Function (Main Task Only)
@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    for batch in loader:
        batch = batch.to(model.classifier[0].weight.device)
        logits = model(
            x_dict=batch.x_dict,
            edge_index_dict=batch.edge_index_dict,
            batch_dict=batch.batch_dict,
            modality_ids={k: batch[k].modality_ids for k in batch.node_types},
            freq_ids={k: batch[k].freq_ids for k in batch.node_types},
            training=False
        )
        loss = criterion(logits, batch.y)
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        correct += (pred == batch.y).sum().item()
        total += batch.y.size(0)
    return total_loss / len(loader), correct / total


In [None]:

def train_model(model, train_loader, val_loader, num_epochs=50, lr=1e-3, 
                contrastive_weight=0.5, patience=20,
                mode='DG',              # Mode Switching Support (‚ÄúDG‚Äù / ‚ÄúDA‚Äù)
                use_supcon=False,        # Should SupCon be used?
                mmd_weight=1.0):        # MMD Loss Weight (Effective only in DA mode)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    criterion_cls = nn.CrossEntropyLoss()
    best_val_acc = 0
    patience_counter = 0
    save_path = "your_model_new.pth"

    for epoch in range(1, num_epochs + 1):
        try:
            train_loss, train_acc = train_one_epoch(
                model,
                train_loader,
                optimizer,
                criterion_cls,
                contrastive_weight=contrastive_weight,
                mmd_weight=mmd_weight,
                mode=mode,
                use_supcon=use_supcon
            )
        except Exception as e:
            print(f"[Epoch {epoch:02d}] ‚ö†Ô∏è Training failed: {e}")
            continue

        val_loss, val_acc = evaluate(model, val_loader, criterion_cls)
        scheduler.step(val_acc)

        print(f"[Epoch {epoch:02d}] ‚úÖ Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            print(f"üì¶ Best model saved at epoch {epoch} with Val Acc: {val_acc:.4f}")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"üõë Early stopping at epoch {epoch}.")
                break

    print(f"üéØ Best Validation Accuracy: {best_val_acc:.4f}")
    # model.load_state_dict(torch.load(save_path))
    return model



In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Using device: {device}")

model = HGTWithContrastive(
    input_dim=28,                         
    hidden_dim=64,
    output_dim=16,                        
    metadata=dataset1[0].metadata(),     
    num_modalities=5,
    num_freqs=7,
    # num_layers=2,
    heads=4,
    dropout=0.5,
    dropedge_prob=0.2,
    dropfeat_prob=0.2                    
).to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"üì¶ Model initialized with {total_params:,} trainable parameters.")



In [None]:
trained_model = train_model(
    model,
    train_loader=train_loader3,
    val_loader=val_loader3,
    num_epochs=50,
    lr=1e-3,
    contrastive_weight=0.5,
    patience=15,
    mode='DG',                 
    use_supcon=False,           
    mmd_weight=1.0             
)



In [None]:
import torch
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch_geometric.loader import DataLoader  


def plot_confusion_matrix(cm, title, filename):
    plt.figure(figsize=(3.5, 3))
    plt.rcParams['font.family'] = 'Times New Roman'
    sns.heatmap(cm, annot=True, fmt='d', cmap='YlGnBu', cbar=False,
                xticklabels=np.arange(cm.shape[1]),
                yticklabels=np.arange(cm.shape[0]),
                annot_kws={"fontsize": 9})  # ‚úÖ Â≠óÂè∑ËÆæÁΩÆ‰∏∫11
    plt.xlabel("Predicted label", fontsize=10, fontname='Times New Roman')
    plt.ylabel("True label", fontsize=10, fontname='Times New Roman')
    # plt.title(title, fontsize=13, fontname='Times New Roman')
    plt.tight_layout()
    plt.savefig(f"{filename}.png", dpi=1200, bbox_inches='tight')
    plt.savefig(f"{filename}.jpg", dpi=1200, bbox_inches='tight')
    plt.savefig(f"{filename}.tiff", dpi=1200, bbox_inches='tight')
    plt.show()

@torch.no_grad()
def test_model_with_cm(model, loader, title, filename, device='cpu'):
    model.to(device).eval()
    correct, total = 0, 0
    y_true, y_pred = [], []

    for batch in loader:   
        batch = batch.to(device)
        logits = model(
            x_dict=batch.x_dict,
            edge_index_dict=batch.edge_index_dict,
            batch_dict=batch.batch_dict,
            modality_ids={k: batch[k].modality_ids for k in batch.node_types},
            freq_ids={k: batch[k].freq_ids for k in batch.node_types},
            training=False
        )
        pred = logits.argmax(dim=1)
        correct += (pred == batch.y).sum().item()
        total += batch.y.size(0)

        y_true.extend(batch.y.detach().cpu().tolist())
        y_pred.extend(pred.detach().cpu().tolist())

    acc = correct / total
    print(f"{title} ‚úÖ Test Accuracy: {acc:.4f}")

    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, title, filename)

    return acc, cm


In [None]:
import torch
from torch_geometric.loader import DataLoader  

model.load_state_dict(torch.load("your_model_new.pth", map_location='cpu'))

# Test Load 1
acc1, cm1 = test_model_with_cm(model, DataLoader(dataset1), "Test Load 1", "T1_confusion_load1")

# Test Load 2
acc2, cm2 = test_model_with_cm(model, DataLoader(dataset2), "Test Load 2", "T1_confusion_load2")

# Test Load 3
acc3, cm3 = test_model_with_cm(model, test_loader3, "Test Load 3", "T1_confusion_load3")

# # test_model(model, test_loader1)
# test_model(model, DataLoader(dataset1))

# # test_model(model, test_loader2)
# test_model(model, DataLoader(dataset2))

# test_model(model, test_loader3)
# # test_model(model, DataLoader(dataset3))
