In [None]:
import torch
from itertools import cycle
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def train_gan_segmentation(
    model, generator, discriminator,
    labeled_loader, unlabeled_loader, test_loader,
    disc_optimizer, gen_optimizer,
    discriminator_loss, generator_loss,
    num_epochs=100, gamma=2.0, clip_norm=1.0, device='cuda'
):
    """
    Training loop for semi-supervised segmentation GAN
    """
    labeled_iter = cycle(labeled_loader)

    for epoch in range(1, num_epochs + 1):
        torch.autograd.set_detect_anomaly(True)
        model.train()
        running_d, running_g = 0.0, 0.0

        for imgs in unlabeled_loader:
            imgs = imgs.to(device)
            imgs_lab, masks_lab = next(labeled_iter)
            imgs_lab, masks_lab = imgs_lab.to(device), masks_lab.squeeze(1).long().to(device)

            noise = torch.randn(imgs.size(0), generator.latent_dim, device=device)
            fake_imgs = generator(noise)

            seg_out_real, disc_out_real = model(imgs)
            seg_lab, disc_lab = model(imgs_lab)
            _, disc_out_fake = model(fake_imgs.detach())

            d_loss = discriminator_loss(
                disc_out_real=disc_out_real,
                disc_out_fake=disc_out_fake,
                seg_out_labeled=seg_lab,
                labels_labeled=masks_lab,
                gamma=gamma
            )

            disc_optimizer.zero_grad()
            d_loss.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_norm)
            disc_optimizer.step()

            _, disc_out_fake2 = model(fake_imgs)
            g_loss = generator_loss(disc_out_fake2)

            gen_optimizer.zero_grad()
            g_loss.backward()
            torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_norm)
            gen_optimizer.step()

            running_d += d_loss.item()
            running_g += g_loss.item()

        avg_d = running_d / len(unlabeled_loader)
        avg_g = running_g / len(unlabeled_loader)
        print(f"Epoch {epoch}/{num_epochs} | D Loss: {avg_d:.4f} | G Loss: {avg_g:.4f}")

        model.eval()
        with torch.no_grad():
            imgs, masks = next(iter(test_loader))
            imgs = imgs[:4].to(device)
            noise = torch.randn(imgs.size(0), generator.latent_dim, device=device)
            fake_imgs = generator(noise)
            seg_out, _ = model(imgs)
            preds = torch.argmax(seg_out, 1, keepdim=True).float()

            imgs_vis = (imgs + 1) / 2.0
            fake_imgs_vis = (fake_imgs + 1) / 2.0
            grid_in = make_grid(imgs_vis, nrow=4)
            grid_fake = make_grid(fake_imgs_vis, nrow=4)
            grid_out = make_grid(preds.expand(-1, 3, -1, -1), nrow=4)

            fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 9))
            ax1.imshow(grid_in.permute(1, 2, 0).cpu())
            ax1.set_title('Original Image (Input)')
            ax1.axis('off')

            ax2.imshow(grid_fake.permute(1, 2, 0).cpu())
            ax2.set_title('Generated Image (Fake)')
            ax2.axis('off')

            ax3.imshow(grid_out.permute(1, 2, 0).cpu())
            ax3.set_title('Segmentation Output')
            ax3.axis('off')

            plt.show()

    print("Training completed.")
