In [1]:
import torch
import torch.nn as nn

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Encoder, self).__init__()
        # 5x5 64 conv. downsampling, BNorm, ReLu
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(64),
            nn.ReLU()
        ) # in_channelsximage_lengthximage_width -> 64x(image_length/2)x(image_width/2)
        # 5x5 128 conv. downsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> 128x(image_length/4)x(image_width/4)
        # 5x5 256 conv. downsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(256),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 256x(image_length/8)x(image_width/8)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(512),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 512x(image_length/16)x(image_width/16)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(1024),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 1024x(image_length/32)x(image_width/32)
        # 2048 fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(1024 * (image_length // 1024) * (image_width // 1024), out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU()
        ) # 256x(image_length/64)x(image_width/64) -> out_channels

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)

        # View as 1D tensor
        x = x.view(x.size(0), -1)
        x = self.fc1(x)

        return x
    
class Decoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, image_length=128, image_width=128):
        super(Decoder, self).__init__()
        self.image_length = image_length
        self.image_width = image_width
        # out_channels fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(in_channels, 1024 * (image_length // 1024) * (image_width // 1024)),
            nn.BatchNorm1d(1024 * (image_length // 1024) * (image_width // 1024)),
            nn.ReLU()
        ) # out_channels -> 256x(image_length/8)x(image_width/8)
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 1024x(image_length/32)x(image_width/32) -> 512x(image_length/16)x(image_width/16)
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 256x(image_length/8)x(image_width/8)
        # 5x5 256 conv. upsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 128x(image_length/4)x(image_width/4)
        # 5x5 128 conv. upsampling, BNorm, ReLu
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 64x(image_length/2)x(image_width/2)
        # 5x5 64 conv. upsampling, BNorm, ReLu
        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(64, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.ReLU()
        ) # 64x(image_length/2)x(image_width/2) -> out_channelsximage_lengthximage_width

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(x.size(0), 1024, (self.image_length // 1024), (self.image_width // 1024))
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
    
        return x

class Discriminator(torch.nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        # 5x5 32 conv. ReLu
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU()
        ) # in_channelsximage_lengthximage_width -> 32x(image_length/2)x(image_width/2)
        # 5x5 128 conv. downsampling, BNorm, ReLu
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 128, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(128),
            nn.ReLU()
        ) # 32x(image_length/2)x(image_width/2) -> 128x(image_length/4)x(image_width/4)
        # 5x5 256 conv. downsampling, BNorm, ReLu
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(256),
            nn.ReLU()
        ) # 128x(image_length/4)x(image_width/4) -> 256x(image_length/8)x(image_width/8)
        # 5x5 512 conv. downsampling, BNorm, ReLu
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(512),
            nn.ReLU()
        ) # 256x(image_length/8)x(image_width/8) -> 512x(image_length/16)x(image_width/16)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(1024),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> 1024x(image_length/32)x(image_width/32)
        self.conv6 = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=5, stride=2, padding=2),
            nn.MaxPool2d(kernel_size=2, stride=2), # TODO: Check if this is correct
            nn.BatchNorm2d(2048),
            nn.ReLU()
        ) # 1024x(image_length/32)x(image_width/32) -> 2048x(image_length/64)x(image_width/64)

        # 512 fully-connected, BNorm, ReLu
        self.fc1 = nn.Sequential(
            nn.Linear(2048 , 512),
            nn.BatchNorm1d(512),
            nn.ReLU()
        ) # 512x(image_length/16)x(image_width/16) -> out_channels
        # 1 fully-connected, sigmoid
        self.fc2 = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()
        ) # 512x(image_length/16)x(image_width/16) -> 1

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        # View as 1D tensor
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)

        return x

In [2]:
class TestModel():
    def __init__(self) -> None:
        # Define test parameters
        image_length = 2048
        image_width = int(image_length * 1.5)
        in_channels = 3
        out_channels = 8*8*256
        batch_size = 16

        print(f'Batches: {batch_size}, Channels: {in_channels}, Image Length: {image_length}, Image Width: {image_width}')

        # Test Encoder
        input = torch.randn(batch_size, in_channels, image_length, image_width).to('cuda')
        test_encoder = Encoder(in_channels, out_channels, image_length, image_width).to('cuda')
        output = test_encoder(input)
        print(f'Test Encoder: {output.shape}')

        # Test Decoder
        input = torch.randn(batch_size, out_channels).to('cuda')
        test_decoder = Decoder(out_channels, in_channels, image_length, image_width).to('cuda')
        output = test_decoder(input)
        print(f'Test Decoder: {output.shape}')

        # Test Discriminator
        input = torch.randn(batch_size, in_channels, image_length, image_width).to('cuda')
        test_discriminator = Discriminator(in_channels).to('cuda')
        output = test_discriminator(input)
        print(f'Test Discriminator: {output.shape}')

        # Print results
        print(f'output: {output}')
        print('Test successful!')

        # Print model summary
        from torchsummary import summary
        summary(test_encoder, input_size=(in_channels, image_length, image_width))
        summary(test_decoder, input_size=(out_channels,))
        summary(test_discriminator, input_size=(in_channels, image_length, image_width))

test = True
if test == True:
    TestModel()

Batches: 16, Channels: 3, Image Length: 2048, Image Width: 3072
Test Encoder: torch.Size([16, 16384])
Test Decoder: torch.Size([16, 3, 2048, 3072])
Test Discriminator: torch.Size([16, 1])
output: tensor([[0.5779],
        [0.3704],
        [0.5226],
        [0.6457],
        [0.5071],
        [0.5853],
        [0.5079],
        [0.5327],
        [0.4873],
        [0.4141],
        [0.4710],
        [0.4283],
        [0.4938],
        [0.4936],
        [0.3423],
        [0.3072]], device='cuda:0', grad_fn=<SigmoidBackward0>)
Test successful!
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1       [-1, 64, 1024, 1536]           4,864
         MaxPool2d-2         [-1, 64, 512, 768]               0
       BatchNorm2d-3         [-1, 64, 512, 768]             128
              ReLU-4         [-1, 64, 512, 768]               0
            Conv2d-5        [-1, 128, 256, 384]         204,928
    

In [3]:
from utils.datasets import LabeledDataset
from torch.utils.data import DataLoader
from torchvision import transforms

root_dir = "dataset"
csv_files = [
    "dataset/Sony_train_list.txt",
    # "dataset/Fuji_train_list.txt"
]

batch_size = 6
input_size = (2844, 4248)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(input_size)
])
dataset = LabeledDataset(root_dir, *csv_files, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=True)
print(dataset[0][0].shape)
print(dataset[0][1].shape)

torch.Size([1, 2844, 4248])
torch.Size([3, 2844, 4248])


In [4]:
dataset.prime_buffer()

In [5]:
show_images = True
if show_images == True:
    # Get a batch of training data
    dataiter = next(iter(dataloader))
    images_low, images, _, _, _, _, _  = dataiter

    # Get a batch of images and show them
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision

    def imshow(images: torch.Tensor):
        # Concatenate images
        img = torchvision.utils.make_grid(images, nrow=3, padding=1, normalize=True).to('cpu').detach().numpy()
        plt.imshow(np.transpose(img, (1, 2, 0)))
        # Figure size
        fig = plt.gcf()
        fig.set_size_inches(15, 15)
        plt.show()

    # Show images
    imshow(images)

In [None]:
# Define test parameters
image_length = 128
image_width = int(image_length * 1.5)
in_channels = 3
out_channels = 8*8*256
batch_size = 2

encoder = Encoder(in_channels, out_channels, image_length, image_width).to('cuda')
decoder = Decoder(out_channels, in_channels, image_length, image_width).to('cuda')
discriminator = Discriminator(in_channels).to('cuda')
# θEnc, θDec, θDis ← initialize network parameters

# repeat
# X ← random mini-batch from dataset
# Z ← Enc(X)
# Lprior ← DKL(q(Z|X)‖p(Z))
#  ̃X ← Dec(Z)
# LDisl
# llike ← −Eq(Z|X) [p(Disl(X)|Z)]
# Zp ← samples from prior N (0, I)
# Xp ← Dec(Zp)
# LGAN ← log(Dis(X)) + log(1 − Dis(  ̃X))
# + log(1 − Dis(Xp))
# // Update parameters according to gradients
# θEnc
# +
# ← −∇θEnc (Lprior + LDisl
# llike )
# θDec
# +
# ← −∇θDec (γLDisl
# llike − LGAN)
# θDis
# +
# ← −∇θDis LGAN
# until deadline

