In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(x.size(0), *self.shape)

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self._initialize_weights()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 512 * 7 * 7),
            nn.BatchNorm1d(512 * 7 * 7),
            nn.LeakyReLU(0.01, inplace=False),
            Reshape(512, 7, 7),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.01, inplace=False),
            
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.01, inplace=False),
            
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01, inplace=False),
            nn.ConvTranspose2d(64, output_shape[2], kernel_size=4, stride=4, padding=0),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.ConvTranspose2d, nn.Linear)):
                init.normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                init.normal_(m.weight, mean=1.0, std=0.02)
                nn.init.constant_(m.bias, 0)

class Discriminator(nn.Module):
    def __init__(self, num_classes=2, negative_slope=0.01):
        super().__init__()
        vgg = models.vgg16(pretrained=True)
        feats = list(vgg.features.children())
        self.pool3 = nn.Sequential(*feats[:17])
        self.pool4 = nn.Sequential(*feats[17:24])
        self.pool5 = nn.Sequential(*feats[24:])
        
        self.conv6 = nn.Conv2d(512,4096,7,padding=3)
        self.act6 = nn.LeakyReLU(negative_slope, inplace=False)
        self.drop6 = nn.Dropout2d()
        
        self.conv7 = nn.Conv2d(4096,4096,1,padding=0)
        self.act7 = nn.LeakyReLU(negative_slope, inplace=False)
        self.drop7 = nn.Dropout2d()
        
        self.seg_head = nn.Conv2d(4096,num_classes,1)
        self.score_pool4 = nn.Conv2d(512,num_classes,1)
        self.score_pool3 = nn.Conv2d(256,num_classes,1)
        self.up2 = nn.ConvTranspose2d(num_classes,num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.up4 = nn.ConvTranspose2d(num_classes,num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.up8 = nn.ConvTranspose2d(num_classes,num_classes,kernel_size=16,stride=8,padding=4,bias=False)

        self.disc_head = nn.Conv2d(4096, 1, 1)

     def forward(self,x):
        p3 = self.pool3(x)
        p4 = self.pool4(p3)
        p5 = self.pool5(p4)
        
        h = self.drop6(self.act6(self.conv6(p5)))
        h = self.drop7(self.act7(self.conv7(h)))
        
        s = self.seg_head(h)
        
        up2 = self.up2(s)
        s4 = self.score_pool4(p4)
        up2 = up2[:,:,:s4.shape[2],:s4.shape[3]]
        fuse4 = up2 + s4
        
        up4 = self.up4(fuse4)
        s3 = self.score_pool3(p3)
        up4 = up4[:,:,:s3.shape[2],:s3.shape[3]]
        fuse3 = up4 + s3
        
        seg_out = self.up8(fuse3)
        seg_out = seg_out[:, :, :x.shape[2], :x.shape[3]]

        
        disc_out = self.disc_head(h)

        return seg_out, disc_out

class SGAN(nn.Module):
    def __init__(self, generator, discriminator):
        super(SGAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, x):
        return self.discriminator(x)  
    
    def generate_fake(self, x):
        return self.generator(x)

bce_loss = nn.BCEWithLogitsLoss()

def discriminator_loss(disc_out_real=None, disc_out_fake=None, seg_out_labeled=None, labels_labeled=None, gamma=2.0):
    loss_real = bce_loss(disc_out_real, torch.ones_like(disc_out_real)) if disc_out_real is not None else 0.0
    loss_fake = bce_loss(disc_out_fake, torch.zeros_like(disc_out_fake)) if disc_out_fake is not None else 0.0
    ce_loss = F.cross_entropy(seg_out_labeled, labels_labeled) if seg_out_labeled is not None and labels_labeled is not None else 0.0

    return loss_real + loss_fake + gamma * ce_loss

def generator_loss(disc_out_fake):
    return bce_loss(disc_out_fake, torch.ones_like(disc_out_fake))