In [48]:
import torch
from torch import nn
import torchvision.transforms as t
import torch.nn.functional as F
from dataset import FkDataset

In [143]:
class Downsample:
    def __init__(self, size, mode="bicubic"):
        self.size = size
        self.mode = mode
    
    def __call__(self, x):
        return torch.nn.functional.interpolate(x, self.size)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
    
class Unflatten(nn.Module):
    def __init__(self, size=256, h=None, w=None):
        self.size = size
        self.h = h if h is not None else 1
        self.w = w if w is not None else 1
        super(Unflatten, self).__init__()
        
    def forward(self, input):
        return input.view(input.size(0), self.size, self.w, self.h)

class Elu(nn.Module):
    def forward(self, x):
        return nn.functional.elu(x)
    
class UNetConvBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, padding, batch_norm):
        super(UNetConvBlock, self).__init__()

        # add convolution 1
        self.add_module("conv1",
                        nn.Conv2d(in_channels=in_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu1", nn.ReLU())

        # add batchnorm 1
        if batch_norm:
            self.add_module("batchnorm1", nn.BatchNorm2d(out_channels))

        # add convolution 2
        self.add_module("conv2",
                        nn.Conv2d(in_channels=out_channels,
                                  out_channels=out_channels,
                                  kernel_size=3,
                                  padding=int(padding)))
        self.add_module("relu2", nn.ReLU())

        # add batchnorm 2
        if batch_norm:
            self.add_module("batchnorm2", nn.BatchNorm2d(out_channels))

    def forward(self, x):
        return super().forward(x)


class UNetUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, padding, batch_norm):
        super(UNetUpBlock, self).__init__()

        # upsample
#         self.up = nn.ConvTranspose2d(in_channels=in_channels,
#                                      out_channels=out_channels,
#                                      kernel_size=2,
#                                      stride=2)
        self.up = nn.Upsample(mode="bilinear",
                              align_corner=True,
                              scale_factor=3)

        # add convolutions block
        self.conv_block = UNetConvBlock(in_channels=in_channels,
                                        out_channels=out_channels,
                                        padding=padding,
                                        batch_norm=batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, skip_connection):
        up = self.up(x)
        crop1 = self.center_crop(skip_connection, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

    
class UNet(nn.Module):
    def __init__(self, filters, hidden_dim):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Args:
            hyperpara,s (dh.learning.Hyperparams): todo
            unet_hyperparams (dh.learning.ConvNetHyperparams): stores the hyperparameters of the downsampling branch of the unet
        """
        super(UNet, self).__init__()
        
        depth = len(filters)
        
        # downsampling
        self.downsample = nn.ModuleList()
        in_channels = 3  # v, w, u
        for i in range(depth):
            print(filters[i])
            self.downsample.append(UNetConvBlock(in_channels=in_channels,
                                                 out_channels=filters[i],
                                                 padding=1,
                                                 batch_norm=False))
            in_channels = filters[i]
        
#         # latent
#             self.latent_cont = nn.Conv2d(in_channels, in_channels, 1, 1)
#             self.flatten = Flatten()
#             self.latent_in = nn.Linear(32, hidden_dim)
#             self.unflatten = Unflatten(32, 1, 1)
#             self.latent_out = nn.Linear(hidden_dim, 32)
        
        # upsample
        self.upsample = nn.ModuleList()
        out_filter = [3] + filters
        for i in reversed(range(1, depth)):
            print(out_filter[i])
            self.upsample.append(UNetUpBlock(in_channels=in_channels,
                                             out_channels=out_filter[i],
                                             padding=1,
                                             batch_norm=False))
            in_channels = out_filter[i]
        return

    def forward(self, X, Y=None):
        skip_connections = []
        for i, down in enumerate(self.downsample):
            print("down", X.shape)
            X = down(X)
            if i != len(self.downsample) - 1:
                skip_connections.append(X)
                X = F.max_pool2d(X, 4)
                
#         z = self.flatten(X)
#         print(z.shape)
#         z = self.latent_in(z)
#         print(z.shape)
#         z = self.latent_out(z)
#         print(z.shape)
#         z = self.unflatten(z)
#         print(z.shape)

        for i, up in enumerate(self.upsample):
            print("up", X.shape)
            X = up(X, skip_connections[-i - 1])
        print("output", X.shape)
        return X
    
    def parameters_count(self):
        return sum(p.numel() for p in self.parameters())

In [144]:
if __name__ == "__main__":
    ## HYPERPARAMS
    root = "/home/ep119/repos/fenton_karma_jax/data/train_dev_set/"
    epochs = 100000
    device = torch.device("cuda")
    input_size = 256
    hidden_dim = 128
    loss_coeff = {
        "mse": 10000.,
        "kld": 1.,
        "grad": 0.
    }

    net = UNet([8, 16, 32, 64, 128], 128)
    print(net)
    print(net.parameters_count())
    fkset = FkDataset(root, 1, 0, 1, transforms=t.Compose([Downsample((input_size, input_size))]), squeeze=True)
    net(fkset[0].unsqueeze(0)).shape

8
16
32
64
128
64
32
16
8
UNet(
  (downsample): ModuleList(
    (0): UNetConvBlock(
      (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
    )
    (1): UNetConvBlock(
      (conv1): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
    )
    (2): UNetConvBlock(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
    )
    (3): UNetConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu1): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu2): ReLU()
    )
 