In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torchinfo
from torchinfo import summary



In [19]:
# -------------------- IGNORE THIS BLOCK AT THE MOMENT --------------------


# Assuming G, D, and F are your generator, discriminator, and detector models

# Detector loss components (simplified)
class DetectorLoss(nn.Module):
    def __init__(self):
        super(DetectorLoss, self).__init__()
        # Define loss components, e.g., Focal Loss for classification

    def forward(self, predictions, targets):
        # Compute classification and localization loss
        # Lcls and Lloc from Eq. (5)
        return lcls + lloc

# Generator and Discriminator loss components
class GANDetLoss(nn.Module):
    def __init__(self, alpha_det=1.0, alpha_non_det=1.0, beta=0.8, gamma=10.0):
        super(GANDetLoss, self).__init__()
        self.alpha_det = alpha_det
        self.alpha_non_det = alpha_non_det
        self.beta = beta
        self.gamma = gamma
        # Additional components, e.g., feature matching loss, could be defined here

    def forward(self, hdr_images, ldr_images, ground_truth):
        # Calculate LG, LD, and LDet as described in Eq. (4), (6), and (5)
        # This includes calling the DetectorLoss for LDet
        # Note: This is a simplified outline. Actual implementation will depend on how G, D, and F are defined
        return ltmo_det

# Instantiate the loss
loss_fn = GANDetLoss()

# Example forward pass (simplified)
# hdr_images, ldr_images, ground_truth = your data loading logic here

# loss = loss_fn(hdr_images, ldr_images, ground_truth)


In [20]:
# -------------------- IGNORE THIS BLOCK AT THE MOMENT --------------------

def discriminator_loss(D, real_images, fake_images):
  real_loss = F.mse_loss(D(real_images), torch.ones_like(D(real_images)))
  fake_loss = F.mse_loss(D(fake_images), torch.zeros_like(D(fake_images)))
  return real_loss + fake_loss

def generator_loss(G, fake_images):
  return F.mse_loss(G(fake_images), torch.ones_like(G(fake_images)))

# Optional: L1 Content Loss to enforce similarity between generated and real LDR images
def l1_content_loss(fake_images, real_images):
    return torch.mean(torch.abs(fake_images - real_images))


# def detector_loss():
#   # TBD
#   pass

In [21]:
# -------------------- BASIC GAN LOSSES -------------------------

def generator_adversarial_loss(D, fake_images):
    return torch.mean((D(fake_images) - 1) ** 2)

def discriminator_loss(D, real_images, fake_images):
    real_loss = torch.mean((D(real_images) - 1) ** 2)
    fake_loss = torch.mean(D(fake_images) ** 2)
    return (real_loss + fake_loss) / 2

In [31]:
# ------------------------- GENERATOR ARCHITECTURE -------------------------

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    super(ConvBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.ins = nn.InstanceNorm2d(out_channels)
    self.relu = nn.ReLU()

  def forward(self, x):
      x = self.conv(x)
      x = self.ins(x)
      x = self.relu(x)
      return x

class AttentionModule(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(AttentionModule, self).__init__()
    self.attention_score = nn.Conv2d(in_channels, in_channels, kernel_size=1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
      score = self.attention_score(x)
      score = self.sigmoid(score)
      attention_map = torch.mul(score, x)
      return attention_map

class Generator(nn.Module):
    def __init__(self, in_channels=3):
        super(Generator, self).__init__()
        self.conv1 = ConvBlock(in_channels, 32, kernel_size=1, stride=1, padding=0)
        self.conv2 = ConvBlock(32, 64, kernel_size=1, stride=1, padding=0)
        self.conv3 = ConvBlock(64, 128, kernel_size=1, stride=1, padding=0)
        self.conv4 = ConvBlock(512, 128, kernel_size=1, stride=1, padding=0)
        self.conv5 = ConvBlock(128, 64, kernel_size=1, stride=1, padding=0)
        self.conv6 = ConvBlock(64, 3, kernel_size=1, stride=1, padding=0)

        self.attention1 = AttentionModule(32, 32)
        self.attention2 = AttentionModule(64, 64)
        self.attention3 = AttentionModule(512, 512)

        self.k3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.k5 = nn.Conv2d(128, 128, kernel_size=5, stride=1, padding=2)
        self.k7 = nn.Conv2d(128, 128, kernel_size=7, stride=1, padding=3)
        self.k9 = nn.Conv2d(128, 128, kernel_size=9, stride=1, padding=4)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest') # TESTING

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        att_1 = self.attention1(x)
        x = self.conv2(x)
        x = self.maxpool(x)
        att_2 = self.attention2(x)
        x = self.conv3(x)
        x = self.maxpool(x)
        x_1 = self.k3(x)
        x_2 = self.k5(x)
        x_3 = self.k7(x)
        x_4 = self.k9(x)
        x = torch.cat((x_1, x_2, x_3, x_4), dim=1) # TESTING
        att_3 = self.attention3(x)
        x = self.upsample(att_3) # CONTINUE CODING HERE
        x = self.conv4(x)
        x = self.upsample(x)
        x = self.conv5(x)
        x = self.upsample(x)
        x = self.conv6(x)
        return x


In [49]:
# ------------------------- DISCRIMINATOR ARCHITECTURE -------------------------

class Discriminator(nn.Module):
  def __init__(self, in_channels=3):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        # Block 1 Input -> Channels x H x W
        nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.2, inplace=True),

        # Block 2 Input -> 64 * H/2 * W/2
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
    )

  def forward(self, x):
      return self.model(x)

In [50]:
# --------------------- PATCH DISCRIMINATOR -------------------

class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchDiscriminator, self).__init__()

        # Define layers
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1)
        self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)

        # Instance normalization
        self.instance_norm = nn.InstanceNorm2d(512)

    def forward(self, x):
        # Forward pass through convolutional layers
        x = nn.LeakyReLU(0.2)(self.conv1(x))
        x = nn.LeakyReLU(0.2)(self.instance_norm(self.conv2(x)))
        x = nn.LeakyReLU(0.2)(self.instance_norm(self.conv3(x)))
        x = nn.LeakyReLU(0.2)(self.instance_norm(self.conv4(x)))
        x = self.conv5(x)

        return x

In [51]:
H = 256
W = 256
# batch_size = 1
in_channels =3
# Testing

gen_model = Generator(in_channels=3)
print(summary(gen_model, input_size=(batch_size, in_channels, H, W)))

HELLO
torch.Size([1, 32, 256, 256])
Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [1, 3, 256, 256]          --
├─ConvBlock: 1-1                         [1, 32, 256, 256]         --
│    └─Conv2d: 2-1                       [1, 32, 256, 256]         128
│    └─InstanceNorm2d: 2-2               [1, 32, 256, 256]         --
│    └─ReLU: 2-3                         [1, 32, 256, 256]         --
├─MaxPool2d: 1-2                         [1, 32, 128, 128]         --
├─AttentionModule: 1-3                   [1, 32, 128, 128]         --
│    └─Conv2d: 2-4                       [1, 32, 128, 128]         1,056
│    └─Sigmoid: 2-5                      [1, 32, 128, 128]         --
├─ConvBlock: 1-4                         [1, 64, 128, 128]         --
│    └─Conv2d: 2-6                       [1, 64, 128, 128]         2,112
│    └─InstanceNorm2d: 2-7               [1, 64, 128, 128]         --
│    └─ReLU: 2-8                         [

In [52]:
patch_disc_model = PatchDiscriminator(in_channels=3)
print(summary(patch_disc_model, input_size=(batch_size, in_channels, H, W)))

Layer (type:depth-idx)                   Output Shape              Param #
PatchDiscriminator                       [1, 1, 30, 30]            --
├─Conv2d: 1-1                            [1, 64, 128, 128]         3,136
├─Conv2d: 1-2                            [1, 128, 64, 64]          131,200
├─InstanceNorm2d: 1-3                    [1, 128, 64, 64]          --
├─Conv2d: 1-4                            [1, 256, 32, 32]          524,544
├─InstanceNorm2d: 1-5                    [1, 256, 32, 32]          --
├─Conv2d: 1-6                            [1, 512, 31, 31]          2,097,664
├─InstanceNorm2d: 1-7                    [1, 512, 31, 31]          --
├─Conv2d: 1-8                            [1, 1, 30, 30]            8,193
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
Total mult-adds (G): 3.15
Input size (MB): 0.79
Forward/backward pass size (MB): 18.62
Params size (MB): 11.06
Estimated Total Size (MB): 30.47


In [53]:
disc_model = Discriminator(in_channels=3)
print(summary(disc_model, input_size=(batch_size, in_channels, H, W)))

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [1, 1, 30, 30]            --
├─Sequential: 1-1                        [1, 1, 30, 30]            --
│    └─Conv2d: 2-1                       [1, 64, 128, 128]         3,136
│    └─LeakyReLU: 2-2                    [1, 64, 128, 128]         --
│    └─Conv2d: 2-3                       [1, 128, 64, 64]          131,200
│    └─InstanceNorm2d: 2-4               [1, 128, 64, 64]          --
│    └─LeakyReLU: 2-5                    [1, 128, 64, 64]          --
│    └─Conv2d: 2-6                       [1, 256, 32, 32]          524,544
│    └─InstanceNorm2d: 2-7               [1, 256, 32, 32]          --
│    └─LeakyReLU: 2-8                    [1, 256, 32, 32]          --
│    └─Conv2d: 2-9                       [1, 512, 31, 31]          2,097,664
│    └─InstanceNorm2d: 2-10              [1, 512, 31, 31]          --
│    └─LeakyReLU: 2-11                   [1, 512, 31, 31]        

In [54]:
# ------------------ TRAINING -------------------

# Instantiate Generator and Discriminator
generator = Generator()
discriminator = Discriminator()

# Define loss function and optimizers
criterion = nn.BCEWithLogitsLoss()
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_epochs = 100
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)

        # Train Discriminator
        discriminator_optimizer.zero_grad()

        # Train with real images
        real_labels = torch.ones(batch_size, 1)
        output = discriminator(real_images)
        real_loss = criterion(output, real_labels)

        # Train with fake images
        fake_labels = torch.zeros(batch_size, 1)
        # noise = torch.randn(batch_size, input_dim, 1, 1)
        fake_images = generator(hdr_images) # NEED HDR IMAGES
        output = discriminator(fake_images.detach())
        fake_loss = criterion(output, fake_labels)

        discriminator_loss = real_loss + fake_loss
        discriminator_loss.backward()
        discriminator_optimizer.step()

        # Train Generator
        generator_optimizer.zero_grad()

        # noise = torch.randn(batch_size, noise_dim, 1, 1)
        fake_images = generator(hdr_images) # NEED HDR IMAGES
        output = discriminator(fake_images)
        generator_loss = criterion(output, real_labels)

        generator_loss.backward()
        generator_optimizer.step()

        # Print training losses
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], "
                  f"Discriminator Loss: {discriminator_loss.item():.4f}, "
                  f"Generator Loss: {generator_loss.item():.4f}")
