In [1]:
import torch.nn as nn
import torch
import rawpy
from torch.utils.data import DataLoader
from utils.datasets import LabeledDataset

import torchvision.transforms as transforms

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.contrib.handlers import ProgressBar
from ignite.metrics import FID, InceptionScore, RunningAverage

from torchinfo import summary

In [2]:
root_dir = "dataset"
csv_files = [
    "dataset/Sony_train_list.txt",
    # "dataset/Fuji_train_list.txt"
]

batch_size = 6
input_size = (2844, 4248)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(input_size)
])
dataset = LabeledDataset(root_dir, *csv_files, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=8, shuffle=True)
print(dataset[0][0].shape)

torch.Size([1, 2844, 4248])


In [3]:
from unet.unet_model import UNet
from torch import optim

class Crop(nn.Module):
    def __init__(self, scale_factor = 3, *args, **kwargs) -> None:
        super(Crop, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        block_shape = (int(x.shape[2]/(self.scale_factor-1)), int(x.shape[3]/(self.scale_factor-1)))
        out = torch.zeros((x.shape[0], self.scale_factor**2, *block_shape), dtype=x.dtype, device=idist.device())
        block_x_start = 0
        for block_x in range(self.scale_factor):
            block_y_start = 0
            for block_y in range(self.scale_factor):
                # print(block_x*self.scale_factor + block_y, '->', block_shape[0]*block_x, ':' ,block_shape[0]*(block_x+1), ',' , block_shape[1]*block_y, ':' ,block_shape[1]*(block_y+1))
                out[:,block_x*self.scale_factor + block_y,:,:] = x[:,0,block_x_start:block_x_start + block_shape[0], block_y_start:block_y_start + block_shape[1]]
                block_y_start += int(block_shape[1]/2)
            
        return out
    
class Reconstruct(nn.Module):
    def __init__(self, scale_factor = 3, *args, **kwargs) -> None:
        super(Reconstruct, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        out_shape = (int(x.shape[2]*(self.scale_factor-1)), int(x.shape[3]*(self.scale_factor-1)))
        out = torch.zeros((x.shape[0], 1, *out_shape), dtype=x.dtype, device=idist.device())
        block_x_start = 0
        for block_x in range(self.scale_factor):
            block_y_start = 0
            for block_y in range(self.scale_factor):
                out[:,0, block_x_start:block_x_start+x.shape[2],block_y_start:block_y_start+x.shape[3]] += x[:,block_x*self.scale_factor + block_y,:,:]
                block_y_start += int(x.shape[3]/2)
            block_x_start += int(x.shape[2]/2)

        return out

scale_factor = 5

net = nn.Sequential(
    Crop(scale_factor),
    UNet(scale_factor*scale_factor, scale_factor*scale_factor),
    Reconstruct(scale_factor)
)

model = idist.auto_model(net)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss = nn.MSELoss()

2023-06-02 11:04:49,651 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model


In [4]:
input_data = torch.randn(batch_size, 1, *input_size)
summary(model, input_data = input_data)

Layer (type:depth-idx)                                  Output Shape              Param #
DataParallel                                            [6, 1, 2844, 4248]        --
├─Sequential: 1-1                                       [3, 1, 2844, 4248]        31,051,865
├─Sequential: 1-4                                       --                        (recursive)
│    └─Crop: 2-1                                        [3, 25, 711, 1062]        --
├─Sequential: 1-3                                       [3, 1, 2844, 4248]        --
├─Sequential: 1-4                                       --                        (recursive)
│    └─Crop: 2-2                                        [3, 25, 711, 1062]        --
│    └─UNet: 2-3                                        [3, 25, 711, 1062]        31,051,865
│    └─UNet: 2-8                                        --                        (recursive)
│    │    └─DoubleConv: 3-1                             [3, 64, 711, 1062]        51,520
│    │    └─D

In [5]:
def training_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    short, long, cam_model, _, _ = batch
    short = short.to(idist.device())
    long = long.to(idist.device())
    output = model(short)
    g_loss = loss(output, long)
    g_loss.backward()
    optimizer.step()
    return {"Loss_G": g_loss.item()}

In [6]:
trainer = Engine(training_step)
RunningAverage(output_transform=lambda x: x["Loss_G"]).attach(trainer, 'Loss_G')
ProgressBar().attach(trainer, metric_names=['Loss_G'])

G_losses = []

@trainer.on(Events.ITERATION_COMPLETED)
def store_losses(engine):
    o = engine.state.output
    G_losses.append(o["Loss_G"])

  from tqdm.autonotebook import tqdm


In [7]:
def training(*args):
    trainer.run(dataloader, max_epochs=10)

with idist.Parallel(backend='nccl') as parallel:
    parallel.run(training)

2023-06-02 11:04:54,281 ignite.distributed.launcher.Parallel INFO: Initialized processing group with backend: 'nccl'
2023-06-02 11:04:54,282 ignite.distributed.launcher.Parallel INFO: - Run '<function training at 0x7f20952b8dc0>' in 1 processes
Epoch [1/10]: [96/311]  31%|███       , Loss_G=5.76e+6 [01:31<03:22]

In [None]:
torch.save({
            'epoch': 10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'model_seed_{}.pt'.format(torch.random.initial_seed()))