In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# -------------------- IGNORE THIS BLOCK AT THE MOMENT --------------------


# Assuming G, D, and F are your generator, discriminator, and detector models

# Detector loss components (simplified)
class DetectorLoss(nn.Module):
    def __init__(self):
        super(DetectorLoss, self).__init__()
        # Define loss components, e.g., Focal Loss for classification

    def forward(self, predictions, targets):
        # Compute classification and localization loss
        # Lcls and Lloc from Eq. (5)
        return lcls + lloc

# Generator and Discriminator loss components
class GANDetLoss(nn.Module):
    def __init__(self, alpha_det=1.0, alpha_non_det=1.0, beta=0.8, gamma=10.0):
        super(GANDetLoss, self).__init__()
        self.alpha_det = alpha_det
        self.alpha_non_det = alpha_non_det
        self.beta = beta
        self.gamma = gamma
        # Additional components, e.g., feature matching loss, could be defined here

    def forward(self, hdr_images, ldr_images, ground_truth):
        # Calculate LG, LD, and LDet as described in Eq. (4), (6), and (5)
        # This includes calling the DetectorLoss for LDet
        # Note: This is a simplified outline. Actual implementation will depend on how G, D, and F are defined
        return ltmo_det

# Instantiate the loss
loss_fn = GANDetLoss()

# Example forward pass (simplified)
# hdr_images, ldr_images, ground_truth = your data loading logic here
loss = loss_fn(hdr_images, ldr_images, ground_truth)


In [None]:
# -------------------- IGNORE THIS BLOCK AT THE MOMENT --------------------

def discriminator_loss(D, real_images, fake_images):
  real_loss = F.mse_loss(D(real_images), torch.ones_like(D(real_images)))
  fake_loss = F.mse_loss(D(fake_images), torch.zeros_like(D(fake_images)))
  return real_loss + fake_loss

def generator_loss(G, fake_images):
  return F.mse_loss(G(fake_images), torch.ones_like(G(fake_images)))

# Optional: L1 Content Loss to enforce similarity between generated and real LDR images
def l1_content_loss(fake_images, real_images):
    return torch.mean(torch.abs(fake_images - real_images))


# def detector_loss():
#   # TBD
#   pass

In [None]:
# -------------------- BASIC GAN LOSSES -------------------------

def generator_adversarial_loss(D, fake_images):
    return torch.mean((D(fake_images) - 1) ** 2)

def discriminator_loss(D, real_images, fake_images):
    real_loss = torch.mean((D(real_images) - 1) ** 2)
    fake_loss = torch.mean(D(fake_images) ** 2)
    return (real_loss + fake_loss) / 2

In [None]:
# ------------------------- GENERATOR ARCHITECTURE -------------------------

class ConvBlock(nn.module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    super(ConvBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.ins = nn.InstanceNorm2d(out_channels)
    self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.ins(x)
        x = self.relu(x)
        return x

class AttentionModule(nn.module):
  def __init__(self, in_channels, out_channels):
    super(AttentionModule, self).__init__()
    self.attention_score = nn.Conv2d(in_channels, in_channels, kernel_size=1)
    self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        score = self.attention_score(x)
        score = self.sigmoid(score)
        attention_map = torch.mul(score, x)
        return attention_map

class Generator(nn.module)
    def __init__(self, in_channels=3):
        super(Generator, self).__init__()
        self.conv1 = ConvBlock(in_channels, kernel_size=1, stride=1, padding=None)
        self.conv2 = ConvBlock(32, 64, kernel_size=1, stride=1, padding=None)
        self.conv3 = ConvBlock(64, 128, kernel_size=1, stride=1, padding=None)
        self.conv4 = ConvBlock(512, 128, kernel_size=1, stride=1, padding=None)
        self.conv5 = ConvBlock(128, 64, kernel_size=1, stride=1, padding=None)
        self.conv6 = ConvBlock(64, 3, kernel_size=1, stride=1, padding=None)

        self.attention1 = AttentionModule(32, 32)
        self.attention2 = AttentionModule(64, 64)
        self.attention3 = AttentionModule(512, 512)

        self.k3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.k5 = nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=2)
        self.k7 = nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3)
        self.k9 = nn.Conv2d(128, 128, kernel_size=9, stride=1, padding=4)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest') # TESTING

        def forward(self, x):
            x = self.conv1(x)
            x = self.maxpool(x)
            att_1 = self.attention1(x)
            x = self.conv2(x)
            x = self.maxpool(x)
            att_2 = self.attention2(x)
            x = self.conv3(x)
            x = self.maxpool(x)
            x_1 = self.k3(x)
            x_2 = self.k5(x)
            x_3 = self.k7(x)
            x_4 = self.k9(x)
            x = torch.cat((x_1, x_2, x_3, x_4), dim=1) # TESTING
            att_3 = self.attention3(x)
            x = self.upsample(att_3) # CONTINUE CODING HERE


In [None]:
# ------------------------- DISCRIMINATOR ARCHITECTURE -------------------------

class Discriminator(nn.Module):
  def __init__(self, in_channels=3):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        # Block 1 Input -> Channels x H x W
        nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),

        # Block 2 Input -> 64 * H/2 * W/2
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
    )

    def forward(self, x):
        return self.model(x)