In [1]:
# Hack to import helper packages
%cd /workspaces/segment_vasculature/models/

/workspaces/segment_vasculature/models


In [2]:
from helpers.loss_functions import DiceLoss
from helpers.train_test import train_and_test
from helpers.dataset_setup import train_test_split, augment_image, DataLoader, TRAIN_FOLDER, preprocess_image, preprocess_mask

from resnet import resnet50

import os
import torch
import numpy as np
from torch.utils.data.dataset import Dataset
import cv2

In [3]:
image_folder = f'{TRAIN_FOLDER}/kidney_2/images/'
image_files = [os.path.join(image_folder, img) for img in os.listdir(image_folder) if img.endswith(".tif")]
image_files.sort()
labels_folder = f'{TRAIN_FOLDER}/kidney_2/labels/'
label_files = [os.path.join(labels_folder, img) for img in os.listdir(labels_folder) if img.endswith(".tif")]
label_files.sort()

In [4]:
def preprocess_image(path):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    # img = np.tile(img[...,None],[1,1,3])
    img = img.astype('float32')
    mx = np.max(img)
    if mx:
        img/=mx
    # img = np.transpose(img,(2,0,1))
    img_ten = torch.tensor(img)
    return img_ten

def preprocess_mask(path):
    
    msk = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    msk = msk.astype('float32')
    msk/=255.0
    msk_ten = torch.tensor(msk)
    
    return msk_ten

In [5]:
class SenNetDataset(Dataset):

    def __init__(self, image_files, mask_files, input_size=(16, 256, 256), augmentation_transforms=None):
        self.image_files=image_files
        self.mask_files=mask_files
        self.input_D = input_size[0]
        self.input_H = input_size[1]
        self.input_W = input_size[2]
        self.augmentation_transforms=augmentation_transforms
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Get 16 images instead of 1
        N = len(self.image_files)
        if idx + self.input_D > N:
            diff = idx + self.input_D - N
            start_idx = idx - diff
            image_paths = self.image_files[start_idx:N]
            mask_paths = self.mask_files[start_idx:N]
        elif idx <= self.input_D:
            diff = self.input_D - idx
            end_range = self.input_D + diff
            image_paths = self.image_files[idx:end_range]
            mask_paths = self.mask_files[idx:end_range]

        # Extract images into tensor
        images = torch.stack(list(map(preprocess_image, image_paths)), dim=0)
        masks = torch.stack(list(map(preprocess_mask, mask_paths)), dim=0)
        img_size = list(images.size())
        mask_size = list(masks.size())
        print(img_size)

        images = images.reshape((1, img_size[0], img_size[1], img_size[2]))
        masks = masks.reshape((1, mask_size[0], mask_size[1], mask_size[2]))

        # image, mask = self.__training_data_process__(img, mask)

        if self.augmentation_transforms:
            image, mask=self.augmentation_transforms(image, mask, self.input_size)

        # make sure it is tensor of shape: [1, z, y, x]
        assert images.shape == masks.shape, "img shape:{} is not equal to mask shape:{}".format(images.shape, masks.shape)
        return images, masks

    def __drop_invalid_range__(self, volume, label=None):
        """
        Cut off the invalid area
        """
        zero_value = volume[0, 0, 0]
        non_zeros_idx = np.where(volume != zero_value)
        
        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)
        
        if label is not None:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
        else:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w]


    def __random_center_crop__(self, data, label):
        from random import random
        """
        Random crop
        """
        target_indexs = np.where(label>0)
        [img_d, img_h, img_w] = data.shape
        [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
        [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
        [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
        Z_min = int((min_D - target_depth*1.0/2) * random())
        Y_min = int((min_H - target_height*1.0/2) * random())
        X_min = int((min_W - target_width*1.0/2) * random())
        
        Z_max = int(img_d - ((img_d - (max_D + target_depth*1.0/2)) * random()))
        Y_max = int(img_h - ((img_h - (max_H + target_height*1.0/2)) * random()))
        X_max = int(img_w - ((img_w - (max_W + target_width*1.0/2)) * random()))
       
        Z_min = np.max([0, Z_min])
        Y_min = np.max([0, Y_min])
        X_min = np.max([0, X_min])

        Z_max = np.min([img_d, Z_max])
        Y_max = np.min([img_h, Y_max])
        X_max = np.min([img_w, X_max])
 
        Z_min = int(Z_min)
        Y_min = int(Y_min)
        X_min = int(X_min)
        
        Z_max = int(Z_max)
        Y_max = int(Y_max)
        X_max = int(X_max)

        return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]



    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """
        
        pixels = volume[volume > 0]
        mean = pixels.mean()
        std  = pixels.std()
        out = (volume - mean)/std
        out_random = np.random.normal(0, 1, size = volume.shape)
        out[volume == 0] = out_random[volume == 0]
        return out

    def __resize_data__(self, data):
        """
        Resize the data to the input size
        """ 
        [depth, height, width] = data.shape
        scale = [self.input_D*1.0/depth, self.input_H*1.0/height, self.input_W*1.0/width]  
        # data = ndimage.interpolation.zoom(data, scale, order=0)
        raise NotImplementedError("Please add scipy as a dep")
        #return data


    def __crop_data__(self, data, label):
        """
        Random crop with different methods:
        """ 
        # random center crop
        data, label = self.__random_center_crop__ (data, label)
        
        return data, label

    def __training_data_process__(self, data, label): 
        # crop data according net input size
        data = data.get_data()
        label = label.get_data()
        
        # drop out the invalid range
        data, label = self.__drop_invalid_range__(data, label)
        
        # crop data
        data, label = self.__crop_data__(data, label) 

        # resize data
        data = self.__resize_data__(data)
        label = self.__resize_data__(label)

        # normalization datas
        data = self.__itensity_normalize_one_volume__(data)

        return data, label

In [6]:
train_image_files, val_image_files, train_mask_files, val_mask_files = train_test_split(
    image_files, label_files, test_size=0.1, random_state=42)

train_dataset = SenNetDataset(train_image_files, train_mask_files, input_size=(8, 256, 256))
val_dataset = SenNetDataset(val_image_files, val_mask_files, input_size=(8, 256, 256))

train_dataloader= DataLoader(train_dataset, batch_size=1, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

dataloaders = {
    'train': train_dataloader,
    'val': val_dataloader
}

In [7]:
kw = {'sample_input_D': 8, 'sample_input_H': 256, 'sample_input_W': 256, 'num_seg_classes': 1}
model = resnet50(**kw)
st_dict = torch.load('tencent_model/trails/models/resnet_50_epoch_110_batch_0.pth.tar')

new_state_dict = dict()
for k, v in st_dict["state_dict"].items():
        # Strip module. from start of keys
        new_key = k.strip("module").strip(".")
        if "num_batches_track" in new_key:
            continue

        new_state_dict[new_key] = v
for k in st_dict["state_dict"].keys():
    print(k)

  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')


module.conv1.weight
module.bn1.weight
module.bn1.bias
module.bn1.running_mean
module.bn1.running_var
module.bn1.num_batches_tracked
module.layer1.0.conv1.weight
module.layer1.0.bn1.weight
module.layer1.0.bn1.bias
module.layer1.0.bn1.running_mean
module.layer1.0.bn1.running_var
module.layer1.0.bn1.num_batches_tracked
module.layer1.0.conv2.weight
module.layer1.0.bn2.weight
module.layer1.0.bn2.bias
module.layer1.0.bn2.running_mean
module.layer1.0.bn2.running_var
module.layer1.0.bn2.num_batches_tracked
module.layer1.0.conv3.weight
module.layer1.0.bn3.weight
module.layer1.0.bn3.bias
module.layer1.0.bn3.running_mean
module.layer1.0.bn3.running_var
module.layer1.0.bn3.num_batches_tracked
module.layer1.0.downsample.0.weight
module.layer1.0.downsample.1.weight
module.layer1.0.downsample.1.bias
module.layer1.0.downsample.1.running_mean
module.layer1.0.downsample.1.running_var
module.layer1.0.downsample.1.num_batches_tracked
module.layer1.1.conv1.weight
module.layer1.1.bn1.weight
module.layer1.1.

In [8]:
epochs = 5
def train():
    kw = {'sample_input_D': 1, 'sample_input_H': 256, 'sample_input_W': 256, 'num_seg_classes': 2}
    model = resnet50(**kw)
    st_dict = torch.load('tencent_model/trails/models/resnet_50_epoch_110_batch_0.pth.tar')

    new_state_dict = dict()
    for k, v in st_dict["state_dict"].items():
            # Strip module. from start of keys
            new_key = k.strip("module").strip(".")
            if "num_batches_track" in new_key:
                continue

            new_state_dict[new_key] = v

    model.load_state_dict(new_state_dict)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = DiceLoss()
    trained_model, train_epoch_losses, test_epoch_losses = train_and_test(model, dataloaders, optimizer, criterion, num_epochs=epochs, show_images=True)
    return trained_model, train_epoch_losses, test_epoch_losses


trained_model, train_epoch_losses, test_epoch_losses = train()
#torch.save(trained_model.state_dict(), 'lower_learning_rate_100.pth')

Epoch 1/5
----------
[16, 1041, 1511]
torch.Size([1, 1, 16, 1041, 1511])
torch.Size([1, 2, 4, 262, 378])
torch.Size([1, 1, 16, 1041, 1511])


RuntimeError: The size of tensor a (792288) must match the size of tensor b (25167216) at non-singleton dimension 0