In [1]:
import torch.nn as nn
import torch
import rawpy
import numpy as np
from torch.utils.data import DataLoader
from utils.datasets import LabeledDataset

import torchvision.transforms as transforms

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import FID, InceptionScore, RunningAverage

from torch.profiler import profile, record_function, ProfilerActivity

from torchinfo import summary

In [2]:
root_dir = "dataset"
sony_csv_files = ["dataset/Sony_train_list.txt"]
fuji_csv_files =  ["dataset/Fuji_train_list.txt"]

batch_size = 4
input_size = 512

pre_crop_transform = transforms.Compose([
    transforms.ToTensor()
])

sony_dataset = LabeledDataset(root_dir, *sony_csv_files, transform=pre_crop_transform)
sony_dataloader = idist.auto_dataloader(sony_dataset, batch_size=batch_size, num_workers=8, shuffle=True, drop_last=True, prefetch_factor=1)
print(sony_dataset[0][0].shape)
print(sony_dataset[0][1].shape)

2023-06-10 09:58:05,365 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<utils.datasets.Labe': 
	{'batch_size': 4, 'num_workers': 8, 'shuffle': True, 'drop_last': True, 'prefetch_factor': 1, 'pin_memory': True}


torch.Size([1, 2848, 4256])
torch.Size([3, 2848, 4256])


In [3]:
# sony_dataset.prime_buffer()

In [4]:
# from unet.unet_model import UNet
from torch import optim 
from ignite.handlers.param_scheduler import LRScheduler

class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel, stride=stride, padding=padding)
        self.lrelu1_1 = nn.LeakyReLU(0.2, inplace=True)
        self.conv1_2 = nn.Conv2d(out_channel, out_channel, kernel_size=kernel, stride=stride,  padding=padding)
        self.lrelu1_2 = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv1_1(x)
        x = self.lrelu1_1(x)
        x = self.conv1_2(x)
        x = self.lrelu1_2(x)
        return x
    
class UpConcatBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcatBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channel, in_channel // 2, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channel, out_channel)

    def forward(self, x1, x2):
        x1 = self.deconv(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv_block(x)

class UNet(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UNet, self).__init__()
        
        self.down1 = ConvBlock(in_feat, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = ConvBlock(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = ConvBlock(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        self.down5 = ConvBlock(256, 512)

        self.up5 = UpConcatBlock(512, 256)
        self.up4 = UpConcatBlock(256, 128)
        self.up3 = UpConcatBlock(128, 64)
        self.up2 = UpConcatBlock(64, 32)

        self.conv10 = nn.Conv2d(32, out_feat, 1)

    def forward(self, x):
        down1 = self.down1(x)
        down2 = self.down2(self.pool1(down1))
        down3 = self.down3(self.pool2(down2))
        down4 = self.down4(self.pool3(down3))
        down5 = self.down5(self.pool4(down4))

        up = self.up5(down5, down4)
        up = self.up4(up, down3)
        up = self.up3(up, down2)
        up = self.up2(up, down1)

        out = self.conv10(up)
        out = torch.nn.functional.pixel_shuffle(out, 2)
        return out
    
class UNet_D(nn.Module):
    def __init__(self, in_feat):
        super(UNet_D, self).__init__()
        
        self.down1 = ConvBlock(in_feat, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = ConvBlock(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.down4 = ConvBlock(128, 256)
        self.pool4 = nn.MaxPool2d(2)
        self.down5 = ConvBlock(256, 512)
        self.pool5 = nn.MaxPool2d(2)
        self.down6 = ConvBlock(512, 1024)
        self.pool6 = nn.MaxPool2d(2)
        self.down7 = ConvBlock(1024, 2048)
        self.pool7 = nn.MaxPool2d(2)
        self.down8 = ConvBlock(2048, 4096)

        self.fc1 = nn.Linear(4096*8*8, 1)

        self.up8 = UpConcatBlock(4096, 2048)
        self.up7 = UpConcatBlock(2048, 1024)
        self.up6 = UpConcatBlock(1024, 512)
        self.up5 = UpConcatBlock(512, 256)
        self.up4 = UpConcatBlock(256, 128)
        self.up3 = UpConcatBlock(128, 64)
        self.up2 = UpConcatBlock(64, 32)

        self.conv10 = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        down1 = self.down1(x)
        down2 = self.down2(self.pool1(down1))
        down3 = self.down3(self.pool2(down2))
        down4 = self.down4(self.pool3(down3))
        down5 = self.down5(self.pool4(down4))
        down6 = self.down6(self.pool5(down5))
        down7 = self.down7(self.pool6(down6))
        down8 = self.down8(self.pool7(down7))

        down8_ = torch.flatten(down8, 1)
        real_fake = self.fc1(down8_)

        up = self.up8(down8, down7)
        up = self.up7(up, down6)
        up = self.up6(up, down5)
        up = self.up5(up, down4)
        up = self.up4(up, down3)
        up = self.up3(up, down2)
        up = self.up2(up, down1)

        out = self.conv10(up)
        return real_fake, out

In [5]:
netG = idist.auto_model(UNet(4, 12))
netD = idist.auto_model(UNet_D(3))
optimizerG = idist.auto_optim(optim.Adam(netG.parameters(), lr=1e-4))
optimizerD = idist.auto_optim(optim.Adam(netD.parameters(), lr=1e-4))
loss = nn.L1Loss()
criterion = nn.BCEWithLogitsLoss()
# lr_scheduler = LRScheduler(optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8))

2023-06-10 09:58:07,769 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model
2023-06-10 09:58:10,226 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model


In [6]:
input_data = torch.randn(batch_size, 4, input_size, input_size)
summary(netG, input_data = input_data)

Layer (type:depth-idx)                   Output Shape              Param #
DataParallel                             [4, 3, 1024, 1024]        --
├─UNet: 1-1                              [2, 3, 1024, 1024]        7,760,748
├─UNet: 1-4                              --                        (recursive)
│    └─ConvBlock: 2-1                    [2, 32, 512, 512]         10,432
│    └─ConvBlock: 2-26                   --                        (recursive)
│    │    └─Conv2d: 3-1                  [2, 32, 512, 512]         1,184
├─UNet: 1-3                              [2, 3, 1024, 1024]        --
├─UNet: 1-4                              --                        (recursive)
│    └─ConvBlock: 2-3                    [2, 32, 512, 512]         --
│    └─ConvBlock: 2-26                   --                        (recursive)
│    │    └─Conv2d: 3-2                  [2, 32, 512, 512]         --
│    │    └─LeakyReLU: 3-3               [2, 32, 512, 512]         --
│    │    └─Conv2d: 3-4            

In [7]:
input_data = torch.randn(batch_size, 3, input_size*2, input_size*2)
summary(netD, input_data = input_data)

Layer (type:depth-idx)                   Output Shape              Param #
DataParallel                             [4, 1]                    --
├─UNet_D: 1-1                            [2, 1]                    497,994,466
├─UNet_D: 1-4                            --                        (recursive)
│    └─ConvBlock: 2-1                    [2, 32, 1024, 1024]       10,144
│    └─ConvBlock: 2-8                    --                        (recursive)
│    │    └─Conv2d: 3-1                  [2, 32, 1024, 1024]       896
├─UNet_D: 1-3                            [2, 1]                    --
├─UNet_D: 1-4                            --                        (recursive)
│    └─ConvBlock: 2-3                    [2, 32, 1024, 1024]       --
│    └─ConvBlock: 2-8                    --                        (recursive)
│    │    └─Conv2d: 3-2                  [2, 32, 1024, 1024]       --
│    │    └─LeakyReLU: 3-3               [2, 32, 1024, 1024]       --
│    │    └─Conv2d: 3-4            

In [8]:
def random_crop(image_short, image_long, size):
    H = image_short.shape[2]
    W = image_short.shape[3]
    ps = size
    xx = np.random.randint(0, W - ps)
    yy = np.random.randint(0, H - ps)
    image_short = image_short[:,:,yy:yy + ps, xx:xx + ps]
    image_long = image_long[:,:,yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2]
    return image_short, image_long

def pack_sony_raw(batch, device=None):
    if not device:
        device = idist.device()
    batch = torch.maximum(batch - 512, torch.Tensor([0]).to(device=device)) / (16383 - 512)
    H = batch.shape[2]
    W = batch.shape[3]

    out = torch.cat((batch[:,:, 0:H:2, 0:W:2], 
                     batch[:,:, 0:H:2, 1:W:2],
                     batch[:,:, 1:H:2, 1:W:2],
                     batch[:,:, 1:H:2, 0:W:2]), dim=1)
    return out

In [9]:
real_label = 1
fake_label = 0

def training_step(engine, batch):
    netG.train()
    netD.train()

    short, long, ratio, cam_model, exposure_ratio, _, _, _ = batch

    short = short.to(idist.device())
    long = long.to(idist.device())

    short = pack_sony_raw(short)

    long = long / 65535.0
    short = short * exposure_ratio.float().to(idist.device()).view(-1, 1, 1, 1)
    short, long = random_crop(short, long, input_size)

    # Train Discriminator with ground truth data
    netD.zero_grad()
    b_size = long.size(0)
    label = torch.full((b_size,), real_label, dtype=torch.float, device=idist.device())

    D_real_enc_out, D_real_dec_out = netD(long)
    D_real_enc_out = D_real_enc_out.view(-1)
    errD_real_enc = criterion(D_real_enc_out, label)
    errD_real_dec = criterion(D_real_dec_out, label.view(-1, 1, 1, 1).expand_as(D_real_dec_out))
    errD_real = errD_real_enc + errD_real_dec
    errD_real.backward()

    # Train with all-fake batch
    fake = netG(short)
    label.fill_(fake_label)

    D_fake_enc_out, D_fake_dec_out = netD(fake.detach())
    D_fake_enc_out = D_fake_enc_out.view(-1)
    errD_fake_enc = criterion(D_fake_enc_out, label)
    errD_fake_dec = criterion(D_fake_dec_out, label.view(-1, 1, 1, 1).expand_as(D_fake_dec_out))
    errD_fake = errD_fake_enc + errD_fake_dec
    errD_fake.backward()

    errD = errD_real + errD_fake
    optimizerD.step()

    # Train G
    netG.zero_grad()
    label.fill_(real_label)  # fake labels are real for generator cost

    G_D_enc_out, G_D_dec_out = netD(fake)
    
    errG_l1 = loss(fake, long)
    errG_dec = criterion(G_D_dec_out, label.view(-1, 1, 1, 1).expand_as(G_D_dec_out))
    errG = 0.5*errG_l1 + 0.5*errG_dec
    errG.backward()

    optimizerG.step()

    return {
        "Loss_G" : errG.item(),
        "Loss_D" : errD.item(),
        "D_real_enc": errD_real_enc.mean().item(),
        "D_real_dec": errD_real_dec.mean().item(),
        "D_fake_enc": errD_fake_enc.mean().item(),
        "D_fake_dec": errD_fake_dec.mean().item(),
        "D_G_L1": errG_l1.item(),
        "D_G_dec": errG_dec.mean().item(),
    }

In [10]:
trainer = Engine(training_step)
losses_key = ["Loss_G","Loss_D","D_real_enc","D_real_dec","D_fake_enc","D_fake_dec","D_G_L1","D_G_dec"]
losses = {}
for k in losses_key:
    losses[k] = []

RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, "Loss_G")
RunningAverage(output_transform=lambda x: x["Loss_D"]).attach(trainer, "Loss_D")  
RunningAverage(output_transform=lambda x: x["D_real_enc"]).attach(trainer, "D_real_enc")  
RunningAverage(output_transform=lambda x: x["D_real_dec"]).attach(trainer, "D_real_dec")  
RunningAverage(output_transform=lambda x: x["D_fake_enc"]).attach(trainer, "D_fake_enc")  
RunningAverage(output_transform=lambda x: x["D_fake_dec"]).attach(trainer, "D_fake_dec")  
RunningAverage(output_transform=lambda x: x["D_G_L1"]).attach(trainer, "D_G_L1")  
RunningAverage(output_transform=lambda x: x["D_G_dec"]).attach(trainer, "D_G_dec")  
ProgressBar().attach(trainer, metric_names=["Loss_G","Loss_D","D_real_enc","D_real_dec","D_fake_enc","D_fake_dec","D_G_L1","D_G_dec"])

@trainer.on(Events.EPOCH_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    print(o["Loss_G"])
    for k in losses_key:
        losses[k].append(o[k])

  from tqdm.autonotebook import tqdm


In [11]:
num_epoch = 100
def training(*args):
    trainer.run(sony_dataloader, max_epochs=num_epoch)

with idist.Parallel(backend='nccl') as parallel:
    parallel.run(training)

2023-06-10 09:58:14,160 ignite.distributed.launcher.Parallel INFO: Initialized processing group with backend: 'nccl'
2023-06-10 09:58:14,162 ignite.distributed.launcher.Parallel INFO: - Run '<function training at 0x7f3137fa8c10>' in 1 processes
Epoch [1/100]: [388/466]  83%|████████▎ , Loss_G=0.401, Loss_D=2.84, D_real_enc=0.722, D_real_dec=0.71, D_fake_enc=0.732, D_fake_dec=0.676, D_G_L1=0.0901, D_G_dec=0.713 [20:19<04:08] 

In [None]:
torch.save({
            'epoch': num_epoch,
            'model_state_dict': netG.state_dict(),
            'optimizer_state_dict': optimizerG.state_dict(),
            'loss': criterion,
            'modelD_state_dict': netD.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            }, 'model_seed_{}.pt'.format(torch.random.initial_seed()))