In [None]:
import torch
import torch.nn as nn
import numpy as np
from mlframework.base.vistools import plots
from mlframework.torch import Trainer
from mlframework.torch import DataStreamer
from mlframework.torch import RegressionEvaluator as Evaluator

In [None]:
class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convolution1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.activation1 = nn.ReLU()

        self.convolution2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.activation2 = nn.ReLU()

    def forward(self, x):
        x = self.activation1(self.convolution1(x))
        x = self.activation2(self.convolution2(x))
        return x

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConvolution(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        skip = self.conv(x)
        x = self.pool(skip)
        return x, skip


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = DoubleConvolution(out_channels+out_channels, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class Unet(nn.Module):
    def __init__(self, filters):
        super().__init__()
        num_c = filters
        self.down1 = DownSample(1, num_c)
        self.down2 = DownSample(num_c, num_c*2)
        self.down3 = DownSample(num_c*2, num_c*4)
        self.down4 = DownSample(num_c*4, num_c*8)

        self.z = DoubleConvolution(num_c*8, num_c*16)

        self.up1 = UpSample(num_c*16, num_c*8)
        self.up2 = UpSample(num_c*8, num_c*4)
        self.up3 = UpSample(num_c*4, num_c*2)
        self.up4 = UpSample(num_c*2, num_c*1)

        self.outputs = nn.Conv2d(num_c, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x1, skip1 = self.down1(x)
        x2, skip2 = self.down2(x1)
        x3, skip3 = self.down3(x2)
        x4, skip4 = self.down4(x3)
        z = self.z(x4)
        y1 = self.up1(z, skip4)
        y2 = self.up2(y1, skip3)
        y3 = self.up3(y2, skip2)
        y4 = self.up4(y3, skip1)

        return self.outputs(y4)

## Prepare DataStreamer

In [None]:
TRAIN_DATA_X = './condition.npy'
TRAIN_DATA_Y = './solution.npy'

train_data_x = np.expand_dims(np.load(TRAIN_DATA_X), 1)[:10]
train_data_y = np.expand_dims(np.load(TRAIN_DATA_Y), 1)[:10]

datastreamer = DataStreamer(input=train_data_x,
                            output=train_data_y,
                            shuffle=True,
                            batch_size=512,
                            split=(8, 2, 0))
train_data, val_data, _ = datastreamer.train_data, datastreamer.val_data, datastreamer.test_data

## Trainer

In [None]:
output_dir = "output_" + "Unet"
num_filter = 64
model = Unet(num_filter)
my_trainer = Trainer(model=model,
                    run_number=1,
                    train_dataloader=train_data,
                    validation_dataloader=val_data,
                    n_epochs=1500,
                    output_dir=output_dir,
                    loss=nn.MSELoss(),
                    )

In [None]:
my_trainer.train()