In [1]:
import torch.nn as nn
import torch
import rawpy
from torch.utils.data import DataLoader
from utils.datasets import LabeledDataset

import torchvision.transforms as transforms

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.contrib.handlers import ProgressBar

from torchinfo import summary

In [2]:
root_dir = "dataset"
csv_files = [
    "dataset/Sony_train_list.txt",
    "dataset/Fuji_train_list.txt"
]

input_size = (3024, 2016)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop(input_size)
])
dataset = LabeledDataset(root_dir, *csv_files, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, num_workers=8, shuffle=True)
print(dataset[0][0].shape)

torch.Size([1, 3024, 2016])


In [3]:
from unet.unet_model import UNet
from torch import optim

model = idist.auto_model(UNet(1, 1).half())
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss = nn.MSELoss()

2023-06-02 08:21:20,370 ignite.distributed.auto.auto_model INFO: Apply torch DataParallel on model


In [4]:
input_data = torch.randn(2, 1, *input_size).half()
summary(model, input_data = input_data)

Layer (type:depth-idx)                             Output Shape              Param #
DataParallel                                       [2, 1, 3024, 2016]        --
├─UNet: 1-1                                        [1, 1, 3024, 2016]        31,036,481
├─UNet: 1-4                                        --                        (recursive)
│    └─DoubleConv: 2-1                             [1, 64, 3024, 2016]       37,696
│    └─DoubleConv: 2-23                            --                        (recursive)
│    │    └─Sequential: 3-1                        [1, 64, 3024, 2016]       37,696
│    │    └─Sequential: 3-26                       --                        (recursive)
├─UNet: 1-3                                        [1, 1, 3024, 2016]        --
├─UNet: 1-4                                        --                        (recursive)
│    └─DoubleConv: 2-3                             [1, 64, 3024, 2016]       --
│    └─DoubleConv: 2-23                            --          

In [5]:
def training_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    short, long, cam_model, _, _ = batch
    short = short.half().to(idist.device())
    long = long.half().to(idist.device())
    output = model(short)
    g_loss = loss(output, long)
    g_loss.backward()
    optimizer.step()
    return {"Loss_G": g_loss.item()}

In [6]:
trainer = Engine(training_step)
ProgressBar().attach(trainer, metric_names=['Loss_G'])

  from tqdm.autonotebook import tqdm


In [7]:
def training(*args):
    trainer.run(dataloader, max_epochs=10)

with idist.Parallel(backend='nccl') as parallel:
    parallel.run(training)

2023-06-02 08:21:24,125 ignite.distributed.launcher.Parallel INFO: Initialized processing group with backend: 'nccl'
2023-06-02 08:21:24,126 ignite.distributed.launcher.Parallel INFO: - Run '<function training at 0x7fcbb5234a60>' in 1 processes
Epoch [1/10]: [699/1760]  40%|███▉       [10:31<15:41]

In [None]:
torch.save({
            'epoch': 10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, 'model_seed_{}.pt'.format(torch.random.initial_seed()))