In [1]:
import torch
import torch.nn as nn

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Encoder, self).__init__()
        # 5x5 64 conv. downsampling, BNorm, ReLu
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(64),
            nn.ReLU()
        ) # in_channelsximage_lengthximage_width -> 64x(image_length/2)x(image_width/2)
        # 5x5 128 conv. downsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> 128x(image_length/4)x(image_width/4)
        # 5x5 256 conv. downsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(256),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 256x(image_length/8)x(image_width/8)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(512),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 512x(image_length/16)x(image_width/16)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(1024),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 1024x(image_length/32)x(image_width/32)
        # 2048 fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(1024 * (image_length // 1024) * (image_width // 1024), out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        ) # 256x(image_length/64)x(image_width/64) -> out_channels

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        

        # View as 1D tensor
        x = x.view(x.size(0), -1)
        x = self.fc1(x)

        return x
    
class Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Decoder, self).__init__()
        self.image_length = image_length
        self.image_width = image_width
        # out_channels fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(in_channels, 1024 * (image_length // 1024) * (image_width // 1024)),
            nn.BatchNorm1d(1024 * (image_length // 1024) * (image_width // 1024)),
            nn.ReLU()
        ) # out_channels -> 256x(image_length/8)x(image_width/8)
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 1024x(image_length/32)x(image_width/32) -> 512x(image_length/16)x(image_width/16)
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 256x(image_length/8)x(image_width/8)
        # 5x5 256 conv. upsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 128x(image_length/4)x(image_width/4)
        # 5x5 128 conv. upsampling, BNorm, ReLu
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 64x(image_length/2)x(image_width/2)
        # 5x5 64 conv. upsampling, BNorm, ReLu
        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(64, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> out_channelsximage_lengthximage_width

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(x.size(0), 1024, (self.image_length // 1024), (self.image_width // 1024))
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
    
        return x

class Discriminator(torch.nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        # 5x5 32 conv. ReLu
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU()
        ) # in_channelsximage_lengthximage_width -> 32x(image_length/2)x(image_width/2)
        # 5x5 128 conv. downsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 128, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 32x(image_length/2)x(image_width/2) -> 128x(image_length/4)x(image_width/4)
        # 5x5 256 conv. downsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(256),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 256x(image_length/8)x(image_width/8)
        # 5x5 512 conv. downsampling, BNorm, ReLu
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(512),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 512x(image_length/16)x(image_width/16)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(1024),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 1024x(image_length/32)x(image_width/32)
        self.conv6 = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(2048),
            nn.ReLU()
        ) # 1024x(image_length/32)x(image_width/32) -> 2048x(image_length/64)x(image_width/64)

        # 512 fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(2048 , 512),
            nn.BatchNorm1d(512),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> out_channels
        # 1 fully-connected, sigmoid
        self.fc2 = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()
        ) # 512x(image_length/16)x(image_width/16) -> 1

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        # View as 1D tensor
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)

        return x

In [2]:
image_length = int(2048)
image_width = int(image_length * 1.5)
reconstruct_channel = 3
original_channel = 1
out_channels = 8*8*256
batch_size = 2

class TestModel():
    def __init__(self) -> None:
        print(f'Batches: {batch_size}, Channels: {reconstruct_channel}, Image Length: {image_length}, Image Width: {image_width}')

        # Test Encoder
        input = torch.randn(batch_size, original_channel, image_length, image_width).to('cuda')
        test_encoder = Encoder(original_channel, out_channels, image_length, image_width).to('cuda')
        output = test_encoder(input)
        print(f'Test Encoder: {output.shape}')

        # Test Decoder
        input = torch.randn(batch_size, out_channels).to('cuda')
        test_decoder = Decoder(out_channels, reconstruct_channel, image_length, image_width).to('cuda')
        output = test_decoder(input)
        print(f'Test Decoder: {output.shape}')

        # Test Discriminator
        input = torch.randn(batch_size, reconstruct_channel, image_length, image_width).to('cuda')
        test_discriminator = Discriminator(reconstruct_channel).to('cuda')
        output = test_discriminator(input)
        print(f'Test Discriminator: {output.shape}')

        # Print results
        print(f'output: {output}')
        print('Test successful!')

        # Print model summary
        from torchsummary import summary
        summary(test_encoder, input_size=(original_channel, image_length, image_width))
        summary(test_decoder, input_size=(out_channels,))
        summary(test_discriminator, input_size=(reconstruct_channel, image_length, image_width))

test = False
if test == True:
    TestModel()

In [3]:
from utils.datasets import LabeledDataset
from torch.utils.data import DataLoader
from torchvision import transforms

root_dir = "dataset"
csv_files = [
    "dataset/Sony_train_list.txt",
    # "dataset/Fuji_train_list.txt"
]

input_size = (image_length, image_width)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(input_size),
])
dataset = LabeledDataset(root_dir, *csv_files, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=True)
print(dataset[0][0].shape)
print(dataset[0][1].shape)

torch.Size([1, 2048, 3072])
torch.Size([3, 2048, 3072])


In [4]:
# dataset.prime_buffer()

In [5]:
show_images = False

# Get a batch of images and show them
import matplotlib.pyplot as plt
import numpy as np
import torchvision

def im_show(images: torch.Tensor):
    # Concatenate images
    img = torchvision.utils.make_grid(images, nrow=4, padding=1, normalize=True).to('cpu').detach().numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    # Figure size
    fig = plt.gcf()
    fig.set_size_inches(15, 15)
    plt.show()

if show_images == True:
    # Get a batch of training data
    dataiter = next(iter(dataloader))
    images_low, images, _, _, _, _, _  = dataiter
    # Show images
    im_show(images)

In [11]:
import torch.nn.functional as F

# θEnc, θDec, θDis ← initialize network parameters
encoder = Encoder(original_channel, out_channels, image_length, image_width).to('cuda')
decoder = Decoder(out_channels, reconstruct_channel, image_length, image_width).to('cuda')
discriminator = Discriminator(reconstruct_channel).to('cuda')

# Define optimizers
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
optimizer_encoder = torch.optim.Adam(encoder.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Define loss functions
criterion_reconstruction = torch.nn.KLDivLoss()
criterion_adversarial = torch.nn.BCELoss

In [22]:
num_epochs = 10
from tensorboardX import SummaryWriter
import os

# Create a directory to store TensorBoard logs
log_dir = "logs"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    discriminator.train()

    total_loss = 0
    recon_loss = 0
    prior_loss = 0
    gan_loss = 0

    for batch_idx, (images_low, images, _, _, _, _, _) in enumerate(dataloader):
        images_low = images_low.to('cuda')
        images = images.to('cuda')

        # Zero the parameter gradients
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        optimizer_discriminator.zero_grad()

        # Forward pass
        z= encoder(images_low)
        recon_images = decoder(z)

        # Discriminator forward pass
        dis_real = discriminator(images)
        dis_fake = discriminator(recon_images)
        
        # Reconstruction loss
        r_loss = F.mse_loss(recon_images, images)

        # Prior loss (KL divergence)
        kl_loss = -0.5 * torch.sum(1 + z.pow(2) - z.exp() - 1)

        # Discriminator loss
        dis_real_loss = F.binary_cross_entropy(dis_real, torch.ones_like(dis_real))
        dis_fake_loss = F.binary_cross_entropy(dis_fake, torch.zeros_like(dis_fake))
        g_loss = dis_real_loss + dis_fake_loss

        # Weighted sum of losses
        loss = r_loss + kl_loss + g_loss

        # Backward pass and optimization
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()
        optimizer_discriminator.step()

        # Accumulate losses
        total_loss += loss.item()
        recon_loss += r_loss.item()
        gan_loss += g_loss.item()

        # Send reconstructed images to TensorBoard
        if (batch_idx + 1) % 10 == 0:
            # Reshape images for tensorboard visualization
            recon_images_grid = torchvision.utils.make_grid(recon_images[:8], nrow=4, normalize=True, scale_each=True)
            images_grid = torchvision.utils.make_grid(images[:8], nrow=4, normalize=True, scale_each=True)

            # Add images to TensorBoard
            writer.add_image('Reconstructed Images', recon_images_grid, global_step=batch_idx)
            writer.add_image('Original Images', images_grid, global_step=batch_idx)


    # Print epoch summary
    print('Epoch [{}/{}], Loss: {:.4f}, Recon Loss: {:.4f}, GAN Loss: {:.4f}'
          .format(epoch + 1, num_epochs, total_loss / len(dataloader),
                  recon_loss / len(dataloader), gan_loss / len(dataloader)))

TypeError: 'LazyModule' object is not callable