In [1]:
import torch
import torch.nn as nn
import torch.optim as opts
import h5py as h5
import model
# import model2 as model
import cv2

from plasma.modules import *
from plasma.training import StandardDataset, Trainer, callbacks
from albumentations import ShiftScaleRotate, Compose, HorizontalFlip

# repo

In [2]:
aug = Compose([
    HorizontalFlip(p=0.5),
    ShiftScaleRotate(shift_limit=0.1, rotate_limit=25, border_mode=cv2.BORDER_CONSTANT, p=0.3)
], p=1)

In [3]:
class Repo(StandardDataset):
    
    def __init__(self, file, train=True):
        super().__init__()
        
        self.h5_file = h5.File(file, mode="r")
        self.train = train
    
    def get_len(self):
        return len(self.h5_file["image"])
    
    def get_item(self, idx):
        img = self.h5_file["image"][idx][None]
        
        if self.train:
            img = aug(image=img)["image"]
        
        return img, img

# callback

In [4]:
from torchvision.utils import save_image

class RenderImage(callbacks.Callback):
    
    def on_validation_batch_end(self, batch, x, y, pred):
        if batch % 50 == 0:
            img = torch.stack([y, pred], dim=1).flatten(start_dim=0, end_dim=1)
        
            save_image(img, f"check/{batch}.png", normalize=True, nrow=2) 

# train

In [5]:
encoder = model.DenseCap()

encoder.cuda(0)

DenseCap(
  (con1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (con2): Sequential(
    (0): Upsample(scale_factor=0.5, mode=bilinear)
    (1): DenseBlock(
      (skip): Identity()
      (con): Sequential(
        (0): BN_ReLU_Conv(
          (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU(inplace=True)
          (con): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU(inplace=True)
          (2): ChannelAttention(
            in_channels=16, ratio=0.5, axes=[2, 3], groups=32
            (attention): Sequential(
              (0): GlobalAverage(axes=[2, 3], keepdims=True)
              (1):

In [6]:
decoder = model.Decoder()

decoder.cuda(0)

Decoder(
  (con1): Sequential(
    (0): Upsample(scale_factor=2.0, mode=bilinear)
    (1): BN_ReLU_Conv(
      (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (con): Conv2d(1024, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): BN_ReLU_Conv(
      (norm): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (con): Conv2d(768, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
  (con2): Sequential(
    (0): Upsample(scale_factor=2.0, mode=bilinear)
    (1): BN_ReLU_Conv(
      (norm): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
      (con): Conv2d(768, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): BN_ReLU_Conv(
      (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, 

In [7]:
ae = nn.Sequential(encoder, decoder)

In [8]:
loss = nn.MSELoss()

In [9]:
opt = opts.SGD(ae.parameters(), lr=0.2, momentum=0.9, nesterov=True)

In [10]:
trainer = Trainer(ae, opt, loss, x_device="cuda:0", y_type=torch.float, y_device="cuda:0")

In [11]:
cbs = [
    callbacks.WarmRestart(1e-5, model_name="ae_256x256"),
    callbacks.CSVLogger("train.csv", append=True),
    RenderImage()
]

In [None]:
trainer.fit(Repo("data.h5"), test=Repo("test.h5", train=False), callbacks=cbs, batch_size=8, val_batch_size=8)

epoch 1


HBox(children=(FloatProgress(value=0.0, description='train', max=1505.0, style=ProgressStyle(description_width…