In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.mfnet_dataset import MFNetDataset
from models.semantic_encoder import SemanticEncoder
from models.cross_modal_discriminator import CrossModalDiscriminator
import torch.optim as optim
import torch.nn as nn

def main():
    # 配置
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    latent_dim = 512
    #batch_size = 8
    #num_epochs = 2
    batch_size = 2  # 更小的 batch
    num_epochs = 1  # 更少的 epoch

    # 数据
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    #train_dataset = MFNetDataset(root='path/to/MFNet', transform=transform)
    train_dataset = MFNetDataset(root='path/to/MFNet', transform=transform, max_samples=20)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # 模型
    encoder = SemanticEncoder(latent_dim=latent_dim).to(device)
    discriminator = CrossModalDiscriminator(latent_dim=latent_dim).to(device)

    # 优化器
    optimizer_enc = optim.Adam(encoder.parameters(), lr=1e-4)
    optimizer_disc = optim.Adam(discriminator.parameters(), lr=1e-4)
    criterion = nn.BCELoss()

    # 训练循环
    for epoch in range(num_epochs):
        for batch in train_loader:
            rgb = batch['rgb'].to(device)
            # 假设 label 也可编码为特征
            label = batch['label'].to(device)
            z_rgb = encoder(rgb)
            z_label = encoder(label.float().repeat(1,3,1,1))  # 简单处理，实际应有更合适的label编码

            # 判别器训练
            real = torch.ones(rgb.size(0), 1).to(device)
            fake = torch.zeros(rgb.size(0), 1).to(device)
            d_real = discriminator(z_rgb, z_label)
            d_fake = discriminator(z_rgb, z_rgb[torch.randperm(z_rgb.size(0))])
            loss_disc = criterion(d_real, real) + criterion(d_fake, fake)
            optimizer_disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            optimizer_disc.step()

            # 编码器训练（对抗损失）
            d_real = discriminator(z_rgb, z_label)
            loss_enc = criterion(d_real, fake)
            optimizer_enc.zero_grad()
            loss_enc.backward()
            optimizer_enc.step()

        print(f"Epoch {epoch+1}: D_loss={loss_disc.item():.4f}, E_loss={loss_enc.item():.4f}")

if __name__ == "__main__":
    main()