In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
from torch import nn
import torch
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.GELU(),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.GELU(),
        )

    def forward(self, x):
        return self.double_conv(x)
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.double_conv = DoubleConv(in_channels, out_channels)
        self.down_sample = nn.MaxPool3d(2)

    def forward(self, x):
        skip_out = self.double_conv(x)
        down_out = self.down_sample(skip_out)
        return (down_out, skip_out)
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, up_sample_mode):
        super(UpBlock, self).__init__()
        #if up_sample_mode == 'conv_transpose':
        #    self.up_sample = nn.ConvTranspose3d(in_channels-out_channels, in_channels-out_channels, kernel_size=2, stride=2)
        #elif up_sample_mode == 'trilinear':
        #    self.up_sample = nn.Upsample(size = (8, 9, 7), mode='trilinear', align_corners=False)
        #else:
        #    raise ValueError("Unsupported `up_sample_mode` (can take one of `conv_transpose` or `trilinear`)")
        self.double_conv = DoubleConv(in_channels, out_channels)

    def forward(self, down_input, skip_input):
        x = nn.Upsample(size = skip_input.shape[2:], mode = 'trilinear', align_corners=False)(down_input)
        x = torch.cat([x, skip_input], dim=1)
        return self.double_conv(x)

class MiniUnet(nn.Module):
    def __init__(self, out_classes = 1, up_sample_mode='trilinear'):
        super().__init__()
        self.up_sample_mode = up_sample_mode
        # Downsampling Path
        self.down_conv1 = DownBlock(1, 64)
        self.down_conv2 = DownBlock(64, 128)
        self.down_conv3 = DownBlock(128, 256)
        #self.down_conv4 = DownBlock(256, 512)
        # Bottleneck
        self.double_conv = DoubleConv(256, 512)
        # Upsampling Path
        #self.up_conv4 = UpBlock(512 + 1024, 512, self.up_sample_mode)
        self.up_conv3 = UpBlock(512 + 256, 256, self.up_sample_mode)
        self.up_conv2 = UpBlock(256 + 128, 128, self.up_sample_mode)
        self.up_conv1 = UpBlock(128 + 64, 64, self.up_sample_mode)
        # Final Convolution
        self.conv_last = nn.Conv3d(64, out_classes, kernel_size=1)

    def forward(self, x):
        x, skip1_out = self.down_conv1(x)
        x, skip2_out = self.down_conv2(x)
        x, skip3_out = self.down_conv3(x)
        #x, skip4_out = self.down_conv4(x)
        x = self.double_conv(x)
        #x = self.up_conv4(x, skip4_out)
        x = self.up_conv3(x, skip3_out)
        x = self.up_conv2(x, skip2_out)
        x = self.up_conv1(x, skip1_out)
        x = self.conv_last(x)
        return x


In [3]:
ckpt_path = 'model.ckpt'
checkpoint = torch.load(ckpt_path)
state_dict = {}
for key in checkpoint["state_dict"].keys():
    #print(key)
    key_new = key[6:]#.lstrip('model.')
    #print(key_new)
    state_dict[key_new] = checkpoint["state_dict"][key]

In [5]:
model_mu = MiniUnet()
model_mu.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
model_mu.to('cuda')

MiniUnet(
  (down_conv1): DownBlock(
    (double_conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv3d(1, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): GELU(approximate='none')
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (4): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (5): GELU(approximate='none')
      )
    )
    (down_sample): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down_conv2): DownBlock(
    (double_conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): GELU(approximate='none')
        (3): Conv3d(128

In [7]:
from torchsummary import summary
summary(model_mu, (1, 26,18,23))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 26, 18, 23]           1,792
    InstanceNorm3d-2       [-1, 64, 26, 18, 23]               0
              GELU-3       [-1, 64, 26, 18, 23]               0
            Conv3d-4       [-1, 64, 26, 18, 23]         110,656
    InstanceNorm3d-5       [-1, 64, 26, 18, 23]               0
              GELU-6       [-1, 64, 26, 18, 23]               0
        DoubleConv-7       [-1, 64, 26, 18, 23]               0
         MaxPool3d-8        [-1, 64, 13, 9, 11]               0
         DownBlock-9  [[-1, 64, 13, 9, 11], [-1, 64, 26, 18, 23]]               0
           Conv3d-10       [-1, 128, 13, 9, 11]         221,312
   InstanceNorm3d-11       [-1, 128, 13, 9, 11]               0
             GELU-12       [-1, 128, 13, 9, 11]               0
           Conv3d-13       [-1, 128, 13, 9, 11]         442,496
   InstanceNorm3d-14 

  return F.conv3d(
