项目 LD-S2Com 是一个用于红外图像语义分割任务的深度学习研究原型，具体目标是跨模态对比学习，即让模型从 RGB 图像与其语义标签之间学到共通特征表达，用于提升红外图像的理解能力。

项目主要实现两个模块：

编码器（Encoder）：负责将图像编码成特征向量（即模型能理解的形式）

判别器（Discriminator）：判断两张图像（RGB 和 Label）是否是成对的，通过这个机制来强化编码器学习语义内容。

    核心训练脚本，训练编码器和判别器，用于验证最小训练流程。

In [1]:
import os
os.chdir(r'C:\Users\jessi\Desktop\5th\experiment\LD-S2Com')

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.mfnet_dataset import MFNetDataset
from models.semantic_encoder import SemanticEncoder
from models.cross_modal_discriminator import CrossModalDiscriminator
import torch.optim as optim
import torch.nn as nn

from config import root, max_samples, img_size, latent_dim, batch_size, num_epochs, lr


In [3]:
def main():
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #latent_dim = 512
    #batch_size = batch_size
    #num_epochs = num_epochs
    #batch_size = 2  # 更小的 batch
    #num_epochs = 1  # 更少的 epoch

    # 数据
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x)  # 1通道变3通道
    ])
    #train_dataset = MFNetDataset(root='path/to/MFNet', transform=transform)
    train_dataset = MFNetDataset(root=root, transform=transform, max_samples=max_samples)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # 模型
    encoder = SemanticEncoder(latent_dim=latent_dim).to(device)
    discriminator = CrossModalDiscriminator(latent_dim=latent_dim).to(device)

    # 优化器
    optimizer_enc = optim.Adam(encoder.parameters(), lr=lr)
    optimizer_disc = optim.Adam(discriminator.parameters(), lr= lr)
    criterion = nn.BCELoss()

    # 训练循环
    for epoch in range(num_epochs):
        for batch in train_loader:
            rgb = batch['rgb'].to(device)
            # 假设 label 也可编码为特征
            label = batch['label'].to(device)
            z_rgb = encoder(rgb)
            #z_label = encoder(label.float().repeat(1,3,1,1))  # 简单处理，实际应有更合适的label编码
            
            #label = batch['label'].to(device)
            if label.shape[1] == 1:
                label_input = label.repeat(1, 3, 1, 1)
            else:
                label_input = label
            z_label = encoder(label_input.float())

            # 判别器训练
            real = torch.ones(rgb.size(0), 1).to(device)
            fake = torch.zeros(rgb.size(0), 1).to(device)
            d_real = discriminator(z_rgb, z_label)
            d_fake = discriminator(z_rgb, z_rgb[torch.randperm(z_rgb.size(0))])
            loss_disc = criterion(d_real, real) + criterion(d_fake, fake)
            optimizer_disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            optimizer_disc.step()

            # 编码器训练（对抗损失）
            d_real = discriminator(z_rgb, z_label)
            loss_enc = criterion(d_real, fake)
            optimizer_enc.zero_grad()
            loss_enc.backward()
            optimizer_enc.step()

        print(f"Epoch {epoch+1}: D_loss={loss_disc.item():.4f}, E_loss={loss_enc.item():.4f}")

if __name__ == "__main__":
    main()

Epoch 1: D_loss=1.4162, E_loss=0.6583
