In [None]:
import numpy as np
from torch import nn
import torchvision as tv
import torch
import h5py
import os

In [None]:
assert torch.cuda.is_available()
device = torch.cuda.device(0)
print(torch.cuda.get_device_name(device))

In [None]:
data = []
for root, dirs, filenames in os.walk('challenge_dataset/'):  # adapt path
    for file in filenames:
        data.append(h5py.File(f'{root}{file}'))

In [None]:
# Building blocks for U-Net
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels=in_ch, out_channels=out_ch, kernel_size=3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv3d(in_channels=out_ch, out_channels=out_ch, kernel_size=3)

    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))

In [None]:
class Encoder(nn.Module):
    def __init__(self, channels=(1, 64, 128, 256)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(channels[i], channels[i+1]) for i in range(len(channels)-1)])
        self.pool       = nn.MaxPool3d(2, stride=2)

    def forward(self, x):
        ftrs = []
        for i, block in enumerate(self.enc_blocks):
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs

In [None]:
encoder = Encoder()
x = torch.Tensor(np.expand_dims(np.array(data[0]['raw']), (0, 1)))
print(x.shape)

In [None]:
ftrs = encoder(x)
for ftr in ftrs:
    print(ftr.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, channels=(256, 128, 64)):
        super().__init__()
        self.channels         = channels
        self.upconvs    = nn.ModuleList([nn.ConvTranspose3d(channels[i], channels[i+1], 2, 2) for i in range(len(channels)-1)])
        self.dec_blocks = nn.ModuleList([Block(channels[i], channels[i+1]) for i in range(len(channels)-1)])

    def forward(self, x, encoder_features):
        for i in range(len(self.channels)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x

    @staticmethod
    def crop(enc_ftrs, x):
        _, _, H, W, D = x.shape
        enc_ftrs   = tv.transforms.CenterCrop([H, W, D])(enc_ftrs)
        return enc_ftrs

In [None]:
decoder = Decoder()
t = torch.randn(1, 256, 9, 41, 41)
decoder(t, ftrs[::-1][1:])