In [1]:
import torch
from torch import nn
from torch.nn import functional as F

from collections import OrderedDict


import matplotlib.pyplot as plt

from torch import Tensor

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

import numpy as np

import nibabel as nib

#from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, List

import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
print(torch.cuda.get_device_name(0))

print('PyTorch Version:', torch.__version__)

print('CuDNN Version:', torch.backends.cudnn.version())

def gpu_usage():
    print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9))

gpu_usage()

GeForce RTX 2080 Ti
PyTorch Version: 1.8.1
CuDNN Version: 8005
gpu usage (current/max): 0.00 / 0.00 GB


In [2]:
#Atrous Spatial Pyramid Pooling (Segmentation Network)
class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv3d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            #nn.AdaptiveAvgPool2d(1),
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-3:]
        x = F.adaptive_avg_pool3d(x,(1))
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='nearest')#, align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU()))

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)   


In [3]:
#Mobile-Net with depth-separable convolutions and residual connections
class ResBlock(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs
    
in_channels = torch.Tensor([1,16,24,24,32,32,32,64]).long()
mid_channels = torch.Tensor([32,96,144,144,192,192,192,384]).long()
out_channels = torch.Tensor([16,24,24,32,32,32,64,64]).long()
mid_stride = torch.Tensor([1,1,1,1,1,2,1,1])

net = []
net.append(nn.Identity())
for i in range(len(in_channels)):
    inc = int(in_channels[i]); midc = int(mid_channels[i]); outc = int(out_channels[i]); strd = int(mid_stride[i])
    layer = nn.Sequential(nn.Conv3d(inc,midc,1,bias=False),nn.BatchNorm3d(midc),nn.ReLU6(True),\
                    nn.Conv3d(midc,midc,3,stride=strd,padding=1,bias=False,groups=midc),nn.BatchNorm3d(midc),nn.ReLU6(True),\
                                   nn.Conv3d(midc,outc,1,bias=False),nn.BatchNorm3d(outc))
    if(i==0):
        layer[0] = nn.Conv3d(inc,midc,3,padding=1,stride=2,bias=False)
    if((inc==outc)&(strd==1)):
        net.append(ResBlock(layer))
    else:
        net.append(layer)

backbone = nn.Sequential(*net)

count = 0
# weight initialization
for m in backbone.modules():
    if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out')
        count += 1
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.zeros_(m.bias)

print('#CNN layer',count)

#CNN layer 24


In [4]:
#complete model: MobileNet + ASPP + head (with a single skip connection)
from torch.utils.checkpoint import checkpoint
#newer model (one more stride, no groups in head)
aspp = ASPP(64,(2,4,8,16),128)#ASPP(64,(1,),128)#
num_classes = 26#14 # 25 for verse19 and 29 for verse20
head = nn.Sequential(nn.Conv3d(128+16, 64, 1, padding=0,groups=1, bias=False),nn.BatchNorm3d(64),nn.ReLU(),\
                     nn.Conv3d(64, 64, 3, groups=1,padding=1, bias=False),nn.BatchNorm3d(64),nn.ReLU(),\
                     nn.Conv3d(64, num_classes, 1))

In [11]:
class VerSe_iso15(Dataset):
    """VerSe Dataset already preprocessed: normalisation and orientation with 1.5mm spacing"""

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        
        self.img_path = []
        self.counter = 0
        for root, dirs, files in os.walk(self.root_dir):
            for filename in files:
                    path =  root + '/' + filename
                    if "msk" not in path and 'nii.gz' in path:
                        self.counter += 1
                        self.img_path.append(path)
        

    def __len__(self):
        return self.counter

    def __getitem__(self, idx):

        img_name = self.img_path[idx]
        seg_name = self.img_path[idx].split('.nii.gz')[0] + '_msk.nii.gz'
    
        image = nib.load(img_name).get_fdata()[:, :, :]
        label = nib.load(seg_name).get_fdata()[:, :, :]
        
        image = (2*(image-image.min())/ (image.max() - image.min()) - 1)
    
        sample = {'image': image, 'label': label}

        if self.transform:
            sample = self.transform(sample)
            
        image = torch.from_numpy(image).unsqueeze(0)
        label = torch.from_numpy(label).unsqueeze(0)
        
        sample = {'image': image, 'label': label}


        return sample

In [12]:
validation_dataset = VerSe_iso15(root_dir='/share/data_zoe1/hempe/data/VerSeV2/preprocessed19/iso_15_val', \
                                 transform= None)

In [13]:
cp = torch.load('mobile_aspp2_3d_verse_edge_iso_128_patch_4k.pth')
backbone.load_state_dict(cp['backbone'])
aspp.load_state_dict(cp['aspp'])
head.load_state_dict(cp['head'])

<All keys matched successfully>

In [17]:
head.eval()
head.cuda()
backbone.eval()
backbone.cuda()
aspp.eval()
aspp.cuda()

#pretty: 11, 15, 18, 20fx
#not so pretty: 7, 8
index = 11
sample = validation_dataset[index]
image = sample["image"]
label = sample["label"]
C,D,H,W = image.shape
print(image.shape)

import time
ts = time.time()    
with torch.no_grad():
    x1 = backbone[:2](pad_img.cuda().unsqueeze(0).float())
    y = aspp(backbone[2:](x1))
    y = torch.cat((x1,F.interpolate(y,scale_factor=2)),1)
    output = F.interpolate(checkpoint(head,y),scale_factor=2,mode='trilinear')
    print('total time:', time.time() - ts)

torch.Size([1, 327, 121, 103])
total time: 0.03352236747741699
