In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# --------------------------------------
# 1. Sub-modules: Gseg, Gatt, Ginit
# --------------------------------------
class WeatherCueSegmentationModule(nn.Module):
    def __init__(self, in_channels=3, num_classes=4):  # e.g., fog, rain, snow, clear
        super().__init__()
        # Simple UNet-like architecture
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True),
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, num_classes, 4, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        e = self.encoder(x)
        m = self.middle(e)
        seg_map = self.decoder(m)
        return seg_map  # shape: (B, num_classes, H, W)

class AttentionModule(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        # Simple conv-based attention
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, 1), nn.Sigmoid()
        )

    def forward(self, x):
        att = self.net(x)
        return att  # shape: (B, 1, H, W)

class InitialTranslationModule(nn.Module):
    def __init__(self, in_channels=3, feature_dim=64):
        super().__init__()
        # Down-sample
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, feature_dim, 4, 2, 1),
            nn.ReLU(inplace=True), nn.InstanceNorm2d(feature_dim),
            nn.Conv2d(feature_dim, feature_dim*2, 4, 2, 1),
            nn.ReLU(inplace=True), nn.InstanceNorm2d(feature_dim*2),
            nn.Conv2d(feature_dim*2, feature_dim*4, 4, 2, 1),
            nn.ReLU(inplace=True), nn.InstanceNorm2d(feature_dim*4),
        )
        # Residual blocks
        res_blocks = []
        for _ in range(6):
            res_blocks += [
                nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1),
                nn.InstanceNorm2d(feature_dim*4), nn.ReLU(inplace=True),
                nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1),
                nn.InstanceNorm2d(feature_dim*4),
            ]
        self.res_blocks = nn.Sequential(*res_blocks)
        # Up-sample
        self.up = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*4, feature_dim*2, 4, 2, 1),
            nn.ReLU(inplace=True), nn.InstanceNorm2d(feature_dim*2),
            nn.ConvTranspose2d(feature_dim*2, feature_dim, 4, 2, 1),
            nn.ReLU(inplace=True), nn.InstanceNorm2d(feature_dim),
            nn.ConvTranspose2d(feature_dim, in_channels, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, x):
        d = self.down(x)
        r = self.res_blocks(d) + d
        out = self.up(r)
        return out  # shape: (B, 3, H, W)

# --------------------------------------
# 2. Generator G combining modules
# --------------------------------------
class Generator(nn.Module):
    def __init__(self, in_channels=3, seg_classes=4):
        super().__init__()
        self.Gseg = WeatherCueSegmentationModule(in_channels, seg_classes)
        self.Gatt = AttentionModule(in_channels)
        self.Ginit = InitialTranslationModule(in_channels)

    def forward(self, x):
        seg_map = self.Gseg(x)  # (B, C, H, W)
        # collapse one-hot to mask: max over classes
        seg_mask, _ = torch.max(seg_map, dim=1, keepdim=True)  # (B,1,H,W)
        att_map = self.Gatt(x)  # (B,1,H,W)
        T = att_map * seg_mask  # (B,1,H,W)

        init_out = self.Ginit(x)  # (B,3,H,W)
        # Broadcast T to 3 channels
        T3 = T.repeat(1,3,1,1)
        out = T3 * init_out + (1 - T3) * x
        return out, seg_map, att_map

# --------------------------------------
# 3. Discriminator D (PatchGAN)
# --------------------------------------
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        layers = []
        dims = [in_channels, 64, 128, 256, 512]
        for i in range(len(dims)-1):
            layers += [
                nn.Conv2d(dims[i], dims[i+1], 4, 2, 1),
                nn.LeakyReLU(0.2, inplace=True)
            ]
        layers += [nn.Conv2d(dims[-1], 1, 4, 1, 1)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)  # (B,1,H/16,W/16)

# --------------------------------------
# 4. Losses: adversarial, cycle, perceptual, classification
# --------------------------------------
# Use BCEWithLogitsLoss for adversarial
adv_criterion = nn.BCEWithLogitsLoss()
# L1 for cycle
l1_criterion = nn.L1Loss()
# CrossEntropy for segmentation/classification
seg_criterion = nn.CrossEntropyLoss()

# Perceptual loss: VGG19 features
class PerceptualLoss(nn.Module):
    def __init__(self, layers=[0,5,10,19], weights=None):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg_layers = vgg
        self.layers = layers
        self.weights = weights or [1.0]*len(layers)

    def forward(self, x, y):
        loss = 0.0
        xi = x; yi = y
        for idx, layer in enumerate(self.vgg_layers):
            xi = layer(xi)
            yi = layer(yi)
            if idx in self.layers:
                w = self.weights[self.layers.index(idx)]
                loss += w * F.mse_loss(xi, yi)
        return loss

# --------------------------------------
# 5. Training loop skeleton
# --------------------------------------
def train(
    G, F, D_X, D_Y,
    loader_X, loader_Y,
    optim_G, optim_F, optim_D_X, optim_D_Y,
    device,
    epochs=100,
    lambda_cycle=10.0
):
    perceptual = PerceptualLoss().to(device)
    for epoch in range(epochs):
        for i, (x, y) in enumerate(zip(loader_X, loader_Y)):
            x = x.to(device); y = y.to(device)
            # ---------------------
            # 1) Update D_X: real vs F(x)
            # ---------------------
            optim_D_X.zero_grad()
            real_pred = D_X(x)
            fake_y, _, _ = G(x)
            fake_pred = D_X(fake_y.detach())
            loss_DX = adv_criterion(real_pred, torch.ones_like(real_pred)) + adv_criterion(fake_pred, torch.zeros_like(fake_pred))
            loss_DX.backward(); optim_D_X.step()

            # ---------------------
            # 2) Update D_Y: real vs F(y)
            # ---------------------
            optim_D_Y.zero_grad()
            real_pred_Y = D_Y(y)
            fake_x, _, _ = F(y)
            fake_pred_Y = D_Y(fake_x.detach())
            loss_DY = adv_criterion(real_pred_Y, torch.ones_like(real_pred_Y)) + adv_criterion(fake_pred_Y, torch.zeros_like(fake_pred_Y))
            loss_DY.backward(); optim_D_Y.step()

            # ---------------------
            # 3) Update Generators G and F
            # ---------------------
            optim_G.zero_grad(); optim_F.zero_grad()
            # Adversarial
            pred_fakeY = D_Y(fake_y)
            loss_G_adv = adv_criterion(pred_fakeY, torch.ones_like(pred_fakeY))
            pred_fakeX = D_X(fake_x)
            loss_F_adv = adv_criterion(pred_fakeX, torch.ones_like(pred_fakeX))
            # Cycle
            rec_x, _, _ = F(fake_y)
            rec_y, _, _ = G(fake_x)
            loss_cycle = l1_criterion(rec_x, x) + l1_criterion(rec_y, y)
            # Perceptual
            loss_perc = perceptual(x, rec_x) + perceptual(y, rec_y)
            # Total cycle
            loss_cycle_total = lambda_cycle * loss_cycle + (1 - lambda_cycle) * loss_perc
            # Classification loss (if classifier is integrated in D)
            # Assume D outputs logits for class as well (not shown)

            # Total G+F loss
            loss_GFTotal = loss_G_adv + loss_F_adv + loss_cycle_total
            loss_GFTotal.backward()
            optim_G.step(); optim_F.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{epochs}] Iter [{i}] "
                      f"LossD_X: {loss_DX.item():.4f}, LossD_Y: {loss_DY.item():.4f}, "
                      f"LossG: {loss_GFTotal.item():.4f}")

    print("Training Complete!")
