In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR

Encoder and decoder


In [3]:
class PDN(nn.Module):
  def __init__(self):
    super(PDN, self).__init__()

    # Encoder layers
    self.encoder = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # Input channels: 3 (RGB), output: 32 channels
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),
        # nn.MaxPool2d(2, 2),

        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),

        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),

        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2, 2),
    )

    # Decoder layers (use transposed convolutions for upsampling)
    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        nn.Sigmoid()  # Output layer between 0 and 1 for image reconstruction
    )

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded


In [None]:

# # Example usage
# model = PDN()
# input_image = torch.randn(1, 3, 256, 256)  # Sample input image with batch size 1
# output = model(input_image)
# print(output.shape)  # Output: torch.Size([1, 3, 256, 256])


In [None]:
model = PDN()

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define learning rate scheduler with exponential decay
scheduler = ExponentialLR(optimizer, gamma=0.97)

# Training loop (example)
for epoch in range(20):  # Train feature extractor for 20 epochs
  # ... training code for feature extractor
  scheduler.step()  # Update learning rate after each epoch

# Adjust parameters for PDN training (potentially modify learning rate again)
# ... modify model parameters or optimizer settings (e.g., new learning rate)

for epoch in range(10):  # Train PDN for 10 epochs
  # ... training code for PDN
  scheduler.step()  # Continue learning rate decay


Feature Extractor ResnNet model

In [3]:
class ResNet18FeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNet18FeatureExtractor, self).__init__()
        resnet50 = torchvision.models.resnet50(pretrained=True)
        # Freeze the weights of the pre-trained model
        for param in resnet18.parameters():
            param.requires_grad = False
        # Use layers up to avgpool layer (exclusive)
        self.features = nn.Sequential(*list(resnet50.children())[:-2])

    def forward(self, x):
        # Pass the image through the feature extraction layers
        x = self.features(x)
        return x
