In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

# --------------------------------------
# 1. Sub-modules: Gseg, Gatt, Ginit (conditional on target label)
# --------------------------------------
class WeatherCueSegmentationModule(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, num_classes, 4, 2, 1), nn.Softmax(dim=1)
        )

    def forward(self, x):
        e = self.encoder(x)
        m = self.middle(e)
        return self.decoder(m)

class AttentionModule(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(64, 1, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

class InitialTranslationModule(nn.Module):
    def __init__(self, cond_channels, feature_dim=64):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(cond_channels, feature_dim, 4, 2, 1), nn.ReLU(True), nn.InstanceNorm2d(feature_dim),
            nn.Conv2d(feature_dim, feature_dim*2, 4, 2, 1), nn.ReLU(True), nn.InstanceNorm2d(feature_dim*2),
            nn.Conv2d(feature_dim*2, feature_dim*4, 4, 2, 1), nn.ReLU(True), nn.InstanceNorm2d(feature_dim*4)
        )
        res = []
        for _ in range(6):
            res += [nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1), nn.InstanceNorm2d(feature_dim*4), nn.ReLU(True),
                    nn.Conv2d(feature_dim*4, feature_dim*4, 3, 1, 1), nn.InstanceNorm2d(feature_dim*4)]
        self.res_blocks = nn.Sequential(*res)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(feature_dim*4, feature_dim*2, 4, 2, 1), nn.ReLU(True), nn.InstanceNorm2d(feature_dim*2),
            nn.ConvTranspose2d(feature_dim*2, feature_dim, 4, 2, 1), nn.ReLU(True), nn.InstanceNorm2d(feature_dim),
            nn.ConvTranspose2d(feature_dim, 3, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        d = self.down(x)
        r = self.res_blocks(d) + d
        return self.up(r)

# --------------------------------------
# 2. Generator G combining modules, conditional on target label
# --------------------------------------
class Generator(nn.Module):
    def __init__(self, in_channels=3, seg_classes=5):
        super().__init__()
        # label embedding to spatial map
        self.label_emb = nn.Sequential(
            nn.Linear(seg_classes, in_channels), nn.ReLU(True)
        )
        # modules take concatenated [image, label_map]
        cond_channels = in_channels * 2
        self.Gseg = WeatherCueSegmentationModule(in_channels, seg_classes)
        self.Gatt = AttentionModule(in_channels)
        self.Ginit = InitialTranslationModule(cond_channels)

    def forward(self, x, target_label):
        # target_label: (B,) long or (B,C) one-hot
        if target_label.dim() == 1:
            onehot = F.one_hot(target_label, num_classes=self.Gseg.decoder[-2].out_channels).float()
        else:
            onehot = target_label.float()
        emb = self.label_emb(onehot)           # (B, in_channels)
        B, C, H, W = x.shape
        label_map = emb.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, H, W)
        cond_input = torch.cat([x, label_map], dim=1)  # (B, 2C, H, W)

        seg_map = self.Gseg(x)                 # (B, Cseg, H, W)
        seg_mask,_ = torch.max(seg_map,1,keepdim=True)  # (B,1,H,W)
        att_map = self.Gatt(x)                # (B,1,H,W)
        T = att_map * seg_mask                # (B,1,H,W)
        init = self.Ginit(cond_input)         # (B,3,H,W)
        T3 = T.repeat(1,3,1,1)
        return T3*init + (1-T3)*x, seg_map, att_map

# --------------------------------------
# 3. Discriminator D with class head
# --------------------------------------
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, num_classes=5):
        super().__init__()
        adv = []
        dims = [in_channels,64,128,256,512]
        for i in range(len(dims)-1): 
            adv += [nn.Conv2d(dims[i],dims[i+1],4,2,1), nn.LeakyReLU(0.2,True)]
        adv += [nn.Conv2d(512,1,4,1,1)]
        self.adv = nn.Sequential(*adv)
        
        cls = []
        dims2 = [in_channels,64,128,256]
        for i in range(len(dims2)-1): 
            cls += [nn.Conv2d(dims2[i],dims2[i+1],4,2,1), nn.LeakyReLU(0.2,True)]
        cls += [nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256,num_classes)]
        self.cls = nn.Sequential(*cls)

    def forward(self, x):
        return self.adv(x), self.cls(x)

# --------------------------------------
# 4. Losses remain unchanged
# --------------------------------------
adv_criterion = nn.BCEWithLogitsLoss()
l1_criterion = nn.L1Loss()
ce_criterion  = nn.CrossEntropyLoss()

class PerceptualLoss(nn.Module):
    def __init__(self, layers=[0,5,10,19], weights=None):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features.eval()
        for p in vgg.parameters(): p.requires_grad=False
        self.vgg,self.layers = vgg,layers; self.ws=weights or [1.0]*len(layers)
    def forward(self,x,y):
        loss=0; xi,yi=x,y
        for i, l in enumerate(self.vgg): 
            xi, yi = l(xi), l(yi)
            if i in self.layers:
                loss += self.ws[self.layers.index(i)] * F.mse_loss(xi, yi)
        return loss

# Training loop unchanged, but G and F calls updated to pass labels
# --------------------------------------
# In training: 
#   fake_y,_,_ = G(x, target_label=y_lbl)
#   fake_x,_,_ = F(y, target_label=x_lbl)
# --------------------------------------
from torch.optim.lr_scheduler import LambdaLR

def train(
    G, F, D_X, D_Y,
    loader_X, loader_Y,
    optim_G, optim_F, optim_D_X, optim_D_Y,
    device, epochs=100,
    lambda_cycle=0.8,
    initial_lr=2e-4,
    decay_start_step=1000
):
    perceptual = PerceptualLoss().to(device)

    total_steps = epochs * len(loader_X)
    # linear decay after decay_start_step
    lr_lambda = lambda step: 1.0 if step < decay_start_step else max(0, float(total_steps-step)/(total_steps-decay_start_step))
    sched_G = LambdaLR(optim_G, lr_lambda); sched_F = LambdaLR(optim_F, lr_lambda)
    sched_DX = LambdaLR(optim_D_X, lr_lambda); sched_DY = LambdaLR(optim_D_Y, lr_lambda)

    step = 0
    for epoch in range(epochs):
        for (x, x_lbl), (y, y_lbl) in zip(loader_X, loader_Y):
            x, y = x.to(device), y.to(device)
            x_lbl, y_lbl = x_lbl.to(device), y_lbl.to(device)
            step +=1

            # D_X: real vs fake X=F(Y)
            optim_D_X.zero_grad()
            real_adv_X, real_cls_X = D_X(x)
            fake_x, _, _ = F(y, target_label=x_lbl)
            fake_adv_X, fake_cls_X = D_X(fake_x.detach())
            loss_DX_adv = adv_criterion(real_adv_X, torch.ones_like(real_adv_X)) + adv_criterion(fake_adv_X, torch.zeros_like(fake_adv_X))
            loss_DX_cls = ce_criterion(real_cls_X, x_lbl)
            (loss_DX_adv + loss_DX_cls).backward(); optim_D_X.step(); sched_DX.step()

            # D_Y: real vs fake Y=G(X)
            optim_D_Y.zero_grad()
            real_adv_Y, real_cls_Y = D_Y(y)
            fake_y, _, _ = G(x, target_label=y_lbl)
            fake_adv_Y, fake_cls_Y = D_Y(fake_y.detach())
            loss_DY_adv = adv_criterion(real_adv_Y, torch.ones_like(real_adv_Y)) + adv_criterion(fake_adv_Y, torch.zeros_like(fake_adv_Y))
            loss_DY_cls = ce_criterion(real_cls_Y, y_lbl)
            (loss_DY_adv + loss_DY_cls).backward(); optim_D_Y.step(); sched_DY.step()

            # Generators G and F
            optim_G.zero_grad(); optim_F.zero_grad()
            # adversarial
            fake_y, _, _ = G(x, target_label=y_lbl)
            fake_adv_Y, fake_cls_Y_gen = D_Y(fake_y)
            loss_G_adv = adv_criterion(fake_adv_Y, torch.ones_like(fake_adv_Y))
            fake_x, _, _ = F(y, target_label=x_lbl)
            fake_adv_X, fake_cls_X_gen = D_X(fake_x)
            loss_F_adv = adv_criterion(fake_adv_X, torch.ones_like(fake_adv_X))
            # cycle
            rec_x, _, _ = F(fake_y, target_label=x_lbl)
            rec_y, _, _ = G(fake_x, target_label=y_lbl)
            loss_L1 = l1_criterion(rec_x, x) + l1_criterion(rec_y, y)
            loss_perc = perceptual(x, rec_x) + perceptual(y, rec_y)
            loss_cycle = lambda_cycle*loss_L1 + (1-lambda_cycle)*loss_perc
            # classification on generated
            loss_cls_G = ce_criterion(fake_cls_Y_gen, y_lbl)
            loss_cls_F = ce_criterion(fake_cls_X_gen, x_lbl)
            # total
            loss_G = loss_G_adv + loss_F_adv + loss_cycle + loss_cls_G + loss_cls_F
            loss_G.backward(); optim_G.step(); optim_F.step(); sched_G.step(); sched_F.step()

            if step % 100 == 0:
                print(f"Step {step}/{total_steps} | D_X_adv: {loss_DX_adv.item():.3f}, D_Y_adv: {loss_DY_adv.item():.3f}, G_adv: {loss_G_adv.item():.3f}, cycle: {loss_cycle.item():.3f}")

    print("Training Completed.")
