In [None]:
import argparse
import collections.abc
import errno
import numpy as np
import os
import pandas as pd
import pathlib
import random
import scipy
import shutil
import sys
import time
import warnings
import SimpleITK as sitk

from datetime import datetime
from glob import glob
from importlib import import_module
from os import listdir
from os.path import abspath, dirname, exists, isdir, join
from pynvml import *

from scipy.ndimage import zoom
from scipy.ndimage.morphology import binary_dilation, generate_binary_structure

import torch
from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.nn import DataParallel
from torch.utils.data import Dataset, DataLoader

RootPath = abspath('..')    # For jupyter
if RootPath not in sys.path: sys.path.insert(0, RootPath)
display(sys.path)

from models.res18_se import *
from configs.config03 import get_cfg_defaults

In [None]:
def parse_args(**kwargs):
    parser = argparse.ArgumentParser(description='Luna nodule detection by pyTorch')
    parser.add_argument("--cfg", type=str, default="configs\\config03_win.yaml", 
                        help="Configuration")
    return parser.parse_args(**kwargs)

In [None]:
def makeNotExistDir(dirPath):
    if not exists(dirPath):
        pathlib.Path(dirPath).mkdir(parents=True, exist_ok=True) 
    return "True" if exists(dirPath) else "False"

In [None]:
# From: luna_detector/utils.py

def getFreeId():
    import pynvml 

    pynvml.nvmlInit()
    def getFreeRatio(id):
        handle = pynvml.nvmlDeviceGetHandleByIndex(id)
        use = pynvml.nvmlDeviceGetUtilizationRates(handle)
        ratio = 0.5*(float(use.gpu+float(use.memory)))
        return ratio

    deviceCount = pynvml.nvmlDeviceGetCount()
    available = []
    for i in range(deviceCount):
        if getFreeRatio(i)<70:
            available.append(i)
    gpus = ''
    for g in available:
        gpus = gpus+str(g)+','
    gpus = gpus[:-1]
    return gpus

def setgpu(gpuinput):
    freeids = getFreeId()
    if gpuinput=='all':
        gpus = freeids
    else:
        gpus = gpuinput
        gpus_in_use = [g not in freeids for g in gpus.split(',')]
        if any(gpus_in_use):
            raise ValueError('gpu'+gpus_in_use+'is being used')
    os.environ['CUDA_VISIBLE_DEVICES']=gpus
    return gpus

In [None]:
# From: data_loader.py

class LungNodule3Ddetector(Dataset):
    
    def __init__(self, data_dir, split_path, config, phase='train', split_comber=None):
        assert(phase == 'train' or phase == 'val' or phase == 'test')
        self.phase = phase
        
        self.max_stride = config['max_stride']
        self.stride = config['stride']       
        
        sizelim = config['sizelim']/config['reso']
        sizelim2 = config['sizelim2']/config['reso']
        sizelim3 = config['sizelim3']/config['reso']
        
        self.blacklist = config['blacklist']
        self.isScale = config['aug_scale']
        self.r_rand  = config['r_rand_crop']
        self.augtype = config['augtype']
        self.pad_value = config['pad_value']

        self.split_comber = split_comber
        #idcs = np.load(split_path)
        idcs = split_path

        if phase != 'test':
            idcs = [f for f in idcs if (f not in self.blacklist)]

        self.filenames = [os.path.join(data_dir, '%s_clean.npy' % idx) for idx in idcs]
        # display(self.filenames)
        
        labels = []
        for idx in idcs:
            l = np.load(os.path.join(data_dir, '%s_label.npy' %idx))
            if np.all(l==0):
                l=np.array([])
            labels.append(l)
        # display(labels)
            
        self.sample_bboxes = labels
        if self.phase != 'test':
            self.bboxes = []

            for i, l in enumerate(labels):
                if len(l) > 0 :
                    for t in l:
                        if t[3]>sizelim:
                            self.bboxes+=[[np.concatenate([[i],t])]]
                        if t[3]>sizelim2:
                            self.bboxes+=[[np.concatenate([[i],t])]]*2
                        if t[3]>sizelim3:
                            self.bboxes+=[[np.concatenate([[i],t])]]*4
            self.bboxes = np.concatenate(self.bboxes,axis = 0)

        self.crop = Crop(config)
        self.label_mapping = LabelMapping(config, self.phase)

    def __getitem__(self, idx,split=None):
        t = time.time()
        np.random.seed(int(str(t%1)[2:7]))#seed according to time

        isRandomImg  = False
        if self.phase !='test':
            if idx>=len(self.bboxes):
                isRandom = True
                idx = idx%len(self.bboxes)
                isRandomImg = np.random.randint(2)
            else:
                isRandom = False
        else:
            isRandom = False
        
        if self.phase != 'test':
            if not isRandomImg:
                bbox = self.bboxes[idx]
                filename = self.filenames[int(bbox[0])]
                imgs = np.load(filename)
                bboxes = self.sample_bboxes[int(bbox[0])]
                isScale = self.augtype['scale'] and (self.phase=='train')
                sample, target, bboxes, coord = self.crop(imgs, bbox[1:], bboxes,isScale,isRandom)
                if self.phase=='train' and not isRandom:
                     sample, target, bboxes, coord = augment(sample, target, bboxes, coord,
                        ifflip = self.augtype['flip'], ifrotate=self.augtype['rotate'], ifswap = self.augtype['swap'])
            else:
                randimid = np.random.randint(len(self.filenames))
                filename = self.filenames[randimid]
                imgs = np.load(filename)
                bboxes = self.sample_bboxes[randimid]
                isScale = self.augtype['scale'] and (self.phase=='train')
                sample, target, bboxes, coord = self.crop(imgs, [], bboxes, isScale=False, isRand=True)
            label = self.label_mapping(sample.shape[1:], target, bboxes)
            sample = (sample.astype(np.float32)-128)/128
            return torch.from_numpy(sample), torch.from_numpy(label), coord
        else:
            imgs = np.load(self.filenames[idx])
            bboxes = self.sample_bboxes[idx]
            nz, nh, nw = imgs.shape[1:]
            pz = int(np.ceil(float(nz) / self.stride)) * self.stride
            ph = int(np.ceil(float(nh) / self.stride)) * self.stride
            pw = int(np.ceil(float(nw) / self.stride)) * self.stride
            imgs = np.pad(
                imgs, 
                [[0,0], [0, pz-nz], [0, ph-nh], [0, pw-nw]], 
                'constant',
                constant_values = self.pad_value
            )
            
            xx, yy, zz = np.meshgrid(
                np.linspace(-0.5, 0.5, int(imgs.shape[1]/self.stride)),
                np.linspace(-0.5, 0.5, int(imgs.shape[2]/self.stride)),
                np.linspace(-0.5, 0.5, int(imgs.shape[3]/self.stride)),
                indexing ='ij'
            )
            coord = np.concatenate([xx[np.newaxis,...], yy[np.newaxis,...],zz[np.newaxis,:]],0).astype('float32')
            
            imgs, nzhw = self.split_comber.split(imgs)
            
            coord2, nzhw2 = self.split_comber.split(
                coord,
                side_len = int(self.split_comber.side_len/self.stride),
                max_stride = int(self.split_comber.max_stride/self.stride),
                margin = int(self.split_comber.margin/self.stride)
            )
            
            assert np.all(nzhw==nzhw2)
            imgs = (imgs.astype(np.float32)-128)/128
            return torch.from_numpy(imgs), bboxes, torch.from_numpy(coord2), np.array(nzhw)

    def __len__(self):
        if self.phase == 'train':
            return int(len(self.bboxes)/(1-self.r_rand))
        elif self.phase =='val':
            return len(self.bboxes)
        else:
            return len(self.sample_bboxes)
       
    
def augment(sample, target, bboxes, coord, ifflip=True, ifrotate=True, ifswap=True):
    #                     angle1 = np.random.rand()*180
    if ifrotate:
        validrot = False
        counter = 0
        while not validrot:
            newtarget = np.copy(target)
            angle1 = (np.random.rand() - 0.5) * 20
            size = np.array(sample.shape[2:4]).astype('float')
            rotmat = np.array([[np.cos(angle1 / 180 * np.pi), -np.sin(angle1 / 180 * np.pi)],
                               [np.sin(angle1 / 180 * np.pi), np.cos(angle1 / 180 * np.pi)]])
            newtarget[1:3] = np.dot(rotmat, target[1:3] - size / 2) + size / 2
            if np.all(newtarget[:3] > target[3]) and np.all(newtarget[:3] < np.array(sample.shape[1:4]) - newtarget[3]):
                validrot = True
                target = newtarget
                sample = rotate(sample, angle1, axes=(2, 3), reshape=False)
                coord = rotate(coord, angle1, axes=(2, 3), reshape=False)
                for box in bboxes:
                    box[1:3] = np.dot(rotmat, box[1:3] - size / 2) + size / 2
            else:
                counter += 1
                if counter == 3:
                    break
    if ifswap:
        if sample.shape[1] == sample.shape[2] and sample.shape[1] == sample.shape[3]:
            axisorder = np.random.permutation(3)
            sample = np.transpose(sample, np.concatenate([[0], axisorder + 1]))
            coord = np.transpose(coord, np.concatenate([[0], axisorder + 1]))
            target[:3] = target[:3][axisorder]
            bboxes[:, :3] = bboxes[:, :3][:, axisorder]

    if ifflip:
        #         flipid = np.array([np.random.randint(2),np.random.randint(2),np.random.randint(2)])*2-1
        flipid = np.array([1, np.random.randint(2), np.random.randint(2)]) * 2 - 1
        sample = np.ascontiguousarray(sample[:, ::flipid[0], ::flipid[1], ::flipid[2]])
        coord = np.ascontiguousarray(coord[:, ::flipid[0], ::flipid[1], ::flipid[2]])
        for ax in range(3):
            if flipid[ax] == -1:
                target[ax] = np.array(sample.shape[ax + 1]) - target[ax]
                bboxes[:, ax] = np.array(sample.shape[ax + 1]) - bboxes[:, ax]
    return sample, target, bboxes, coord    

In [None]:
# From: data_loader.py

class Crop(object):
    def __init__(self, config):
        self.crop_size = config['crop_size']
        self.bound_size = config['bound_size']
        self.stride = config['stride']
        self.pad_value = config['pad_value']

    def __call__(self, imgs, target, bboxes, isScale=False, isRand=False):
        if isScale:
            radiusLim = [8., 100.]
            scaleLim = [0.75, 1.25]
            scaleRange = [np.min([np.max([(radiusLim[0] / target[3]), scaleLim[0]]), 1])
                , np.max([np.min([(radiusLim[1] / target[3]), scaleLim[1]]), 1])]
            scale = np.random.rand() * (scaleRange[1] - scaleRange[0]) + scaleRange[0]
            crop_size = (np.array(self.crop_size).astype('float') / scale).astype('int')
        else:
            crop_size = self.crop_size
        bound_size = self.bound_size
        target = np.copy(target)
        bboxes = np.copy(bboxes)

        start = []
        for i in range(3):
            if not isRand:
                r = target[3] / 2
                s = np.floor(target[i] - r) + 1 - bound_size
                e = np.ceil(target[i] + r) + 1 + bound_size - crop_size[i]
            else:
                s = np.max([imgs.shape[i + 1] - crop_size[i] / 2, imgs.shape[i + 1] / 2 + bound_size])
                e = np.min([crop_size[i] / 2, imgs.shape[i + 1] / 2 - bound_size])
                target = np.array([np.nan, np.nan, np.nan, np.nan])
            if s > e:
                start.append(int(np.random.randint(e, s)))  # !
            else:
                start.append(int(target[i] - crop_size[i] / 2 + np.random.randint(-bound_size / 2, bound_size / 2)))

        normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5
        normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:])
        xx, yy, zz = np.meshgrid(
            np.linspace(normstart[0], normstart[0] + normsize[0], int(self.crop_size[0]/self.stride)),
            np.linspace(normstart[1], normstart[1] + normsize[1], int(self.crop_size[1]/self.stride)),
            np.linspace(normstart[2], normstart[2] + normsize[2], int(self.crop_size[2]/self.stride)),
            indexing='ij')
        
        coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32')

        pad = []
        pad.append([0, 0])
        for i in range(3):
            leftpad = max(0, -start[i])
            rightpad = max(0, start[i] + crop_size[i] - imgs.shape[i + 1])
            pad.append([leftpad, rightpad])
        crop = imgs[:,
               max(start[0], 0):min(start[0] + crop_size[0], imgs.shape[1]),
               max(start[1], 0):min(start[1] + crop_size[1], imgs.shape[2]),
               max(start[2], 0):min(start[2] + crop_size[2], imgs.shape[3])]
        crop = np.pad(crop, pad, 'constant', constant_values=self.pad_value)
        for i in range(3):
            target[i] = target[i] - start[i]
        for i in range(len(bboxes)):
            for j in range(3):
                bboxes[i][j] = bboxes[i][j] - start[j]

        if isScale:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                crop = zoom(crop, [1, scale, scale, scale], order=1)
            newpad = self.crop_size[0] - crop.shape[1:][0]
            if newpad < 0:
                crop = crop[:, :-newpad, :-newpad, :-newpad]
            elif newpad > 0:
                pad2 = [[0, 0], [0, newpad], [0, newpad], [0, newpad]]
                crop = np.pad(crop, pad2, 'constant', constant_values=self.pad_value)
            for i in range(4):
                target[i] = target[i] * scale
            for i in range(len(bboxes)):
                for j in range(4):
                    bboxes[i][j] = bboxes[i][j] * scale
        return crop, target, bboxes, coord

In [None]:
# From: data_loader.py

class LabelMapping(object):
    def __init__(self, config, phase):
        self.stride = np.array(config['stride'])
        self.num_neg = int(config['num_neg'])
        self.th_neg = config['th_neg']
        self.anchors = np.asarray(config['anchors'])
        self.phase = phase
        
        if phase == 'train':
            self.th_pos = config['th_pos_train']
        elif phase == 'val':
            self.th_pos = config['th_pos_val']

    def __call__(self, input_size, target, bboxes):
        stride = self.stride
        num_neg = self.num_neg
        th_neg = self.th_neg
        anchors = self.anchors
        th_pos = self.th_pos
        struct = generate_binary_structure(3, 1)

        output_size = []
        for i in range(3):
            assert (input_size[i] % stride == 0)
            output_size.append(int(input_size[i] / stride))

        label = np.zeros(output_size + [len(anchors), 5], np.float32)
        offset = ((stride.astype('float')) - 1) / 2
        oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride)
        oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride)
        ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride)

        for bbox in bboxes:
            for i, anchor in enumerate(anchors):
                iz, ih, iw = select_samples(bbox, anchor, th_neg, oz, oh, ow)
                label[iz, ih, iw, i, 0] = 1
                label[:, :, :, i, 0] = binary_dilation(label[:, :, :, i, 0].astype('bool'), structure=struct,
                                                       iterations=1).astype('float32')

        label = label - 1

        if self.phase == 'train' and self.num_neg > 0:
            neg_z, neg_h, neg_w, neg_a = np.where(label[:, :, :, :, 0] == -1)
            neg_idcs = random.sample(range(len(neg_z)), min(num_neg, len(neg_z)))
            neg_z, neg_h, neg_w, neg_a = neg_z[neg_idcs], neg_h[neg_idcs], neg_w[neg_idcs], neg_a[neg_idcs]
            label[:, :, :, :, 0] = 0
            label[neg_z, neg_h, neg_w, neg_a, 0] = -1

        if np.isnan(target[0]):
            return label
        iz, ih, iw, ia = [], [], [], []
        for i, anchor in enumerate(anchors):
            iiz, iih, iiw = select_samples(target, anchor, th_pos, oz, oh, ow)
            iz.append(iiz)
            ih.append(iih)
            iw.append(iiw)
            ia.append(i * np.ones((len(iiz),), np.int64))
        iz = np.concatenate(iz, 0)
        ih = np.concatenate(ih, 0)
        iw = np.concatenate(iw, 0)
        ia = np.concatenate(ia, 0)
        flag = True
        if len(iz) == 0:
            pos = []
            for i in range(3):
                pos.append(max(0, int(np.round((target[i] - offset) / stride))))
            idx = np.argmin(np.abs(np.log(target[3] / anchors)))
            pos.append(idx)
            flag = False
        else:
            idx = random.sample(range(len(iz)), 1)[0]
            pos = [iz[idx], ih[idx], iw[idx], ia[idx]]
        dz = (target[0] - oz[pos[0]]) / anchors[pos[3]]
        dh = (target[1] - oh[pos[1]]) / anchors[pos[3]]
        dw = (target[2] - ow[pos[2]]) / anchors[pos[3]]
        dd = np.log(target[3] / anchors[pos[3]])
        label[pos[0], pos[1], pos[2], pos[3], :] = [1, dz, dh, dw, dd]
        return label


def select_samples(bbox, anchor, th, oz, oh, ow):
    z, h, w, d = bbox
    max_overlap = min(d, anchor)
    min_overlap = np.power(max(d, anchor), 3) * th / max_overlap / max_overlap
    if min_overlap > max_overlap:
        return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)
    else:
        s = z - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = z + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mz = np.logical_and(oz >= s, oz <= e)
        iz = np.where(mz)[0]

        s = h - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = h + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mh = np.logical_and(oh >= s, oh <= e)
        ih = np.where(mh)[0]

        s = w - 0.5 * np.abs(d - anchor) - (max_overlap - min_overlap)
        e = w + 0.5 * np.abs(d - anchor) + (max_overlap - min_overlap)
        mw = np.logical_and(ow >= s, ow <= e)
        iw = np.where(mw)[0]

        if len(iz) == 0 or len(ih) == 0 or len(iw) == 0:
            return np.zeros((0,), np.int64), np.zeros((0,), np.int64), np.zeros((0,), np.int64)

        lz, lh, lw = len(iz), len(ih), len(iw)
        iz = iz.reshape((-1, 1, 1))
        ih = ih.reshape((1, -1, 1))
        iw = iw.reshape((1, 1, -1))
        iz = np.tile(iz, (1, lh, lw)).reshape((-1))
        ih = np.tile(ih, (lz, 1, lw)).reshape((-1))
        iw = np.tile(iw, (lz, lh, 1)).reshape((-1))
        centers = np.concatenate([
            oz[iz].reshape((-1, 1)),
            oh[ih].reshape((-1, 1)),
            ow[iw].reshape((-1, 1))], axis=1)

        r0 = anchor / 2
        s0 = centers - r0
        e0 = centers + r0

        r1 = d / 2
        s1 = bbox[:3] - r1
        s1 = s1.reshape((1, -1))
        e1 = bbox[:3] + r1
        e1 = e1.reshape((1, -1))

        overlap = np.maximum(0, np.minimum(e0, e1) - np.maximum(s0, s1))

        intersection = overlap[:, 0] * overlap[:, 1] * overlap[:, 2]
        union = anchor * anchor * anchor + d * d * d - intersection

        iou = intersection / union

        mask = iou >= th
        # if th > 0.4:
        #   if np.sum(mask) == 0:
        #      print(['iou not large', iou.max()])
        # else:
        #    print(['iou large', iou[mask]])
        iz = iz[mask]
        ih = ih[mask]
        iw = iw[mask]
        return iz, ih, iw


In [None]:
def train(data_loader, net, loss, epoch, optimizer, get_lr, save_dir):
    start_time = time.time()
    
    net.train()
    lr = get_lr(epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    metrics = []
    for i, (data, target, coord) in enumerate(data_loader):
        data = Variable(data.cuda())
        target = Variable(target.cuda())
        coord = Variable(coord.cuda())

        output = net(data, coord)
        
        loss_output = loss(output, target)
        optimizer.zero_grad()
        loss_output[0].backward()
        optimizer.step()

        loss_output[0] = loss_output[0].item()
        metrics.append(loss_output)
        
        print("finished iteration {} with loss {}.".format(i, loss_output[0]))
                
    end_time = time.time()

    metrics = np.asarray(metrics, np.float32)
    print('\nEpoch %03d (lr %.5f)' % (epoch, lr))
    print('Train:      tpr %3.2f, tnr %3.2f, total pos %d, total neg %d, time %3.2f' % (
        100.0 * np.sum(metrics[:, 6]) / np.sum(metrics[:, 7]),
        100.0 * np.sum(metrics[:, 8]) / np.sum(metrics[:, 9]),
        np.sum(metrics[:, 7]),
        np.sum(metrics[:, 9]),
        end_time - start_time))
    print('loss %2.4f, classify loss %2.4f, regress loss %2.4f, %2.4f, %2.4f, %2.4f' % (
        np.mean(metrics[:, 0]),
        np.mean(metrics[:, 1]),
        np.mean(metrics[:, 2]),
        np.mean(metrics[:, 3]),
        np.mean(metrics[:, 4]),
        np.mean(metrics[:, 5])))
    return

In [None]:
def validate(data_loader, net, loss):
    start_time = time.time()
    
    net.eval()

    metrics = []
    for i, (data, target, coord) in enumerate(data_loader):
        data = Variable(data.cuda())
        target = Variable(target.cuda())
        coord = Variable(coord.cuda())

        output = net(data, coord)
        loss_output = loss(output, target, train = False)

        loss_output[0] = loss_output[0].item()
        metrics.append(loss_output)    
    end_time = time.time()

    metrics = np.asarray(metrics, np.float32)
    print('\nValidation: tpr %3.2f, tnr %3.8f, total pos %d, total neg %d, time %3.2f' % (
        100.0 * np.sum(metrics[:, 6]) / np.sum(metrics[:, 7]),
        100.0 * np.sum(metrics[:, 8]) / np.sum(metrics[:, 9]),
        np.sum(metrics[:, 7]),
        np.sum(metrics[:, 9]),
        end_time - start_time))
    print('loss %2.4f, classify loss %2.4f, regress loss %2.4f, %2.4f, %2.4f, %2.4f' % (
        np.mean(metrics[:, 0]),
        np.mean(metrics[:, 1]),
        np.mean(metrics[:, 2]),
        np.mean(metrics[:, 3]),
        np.mean(metrics[:, 4]),
        np.mean(metrics[:, 5])))
    return np.mean(metrics[:, 0])


In [None]:
# Main
args = parse_args(args=[])
print(args)
    
path_root = RootPath
path_conf = join(path_root, args.cfg)
    
CFG = get_cfg_defaults()
CFG.merge_from_file(path_conf)
CFG.freeze()
print(); print(CFG)

In [None]:
path_ckpt = join(path_root, CFG.TRAIN.DIR_CKPT)
isExist = makeNotExistDir(path_ckpt)
print("'{}' exist? {}".format(path_ckpt, isExist))

path_save = join(path_root, CFG.TRAIN.DIR_SAVE)
isExist = makeNotExistDir(path_save)
print("'{}' exist? {}".format(path_save, isExist))

path_data = join(path_root, CFG.TRAIN.DIR_DATA)
isExist = makeNotExistDir(path_data)
print("'{}' exist? {}".format(path_data, isExist))

In [None]:
# Load Model
model = import_module(CFG.TRAIN.MODEL)
config, net, loss, get_pbb = model.get_model()

print(type(config))
display(config)

In [None]:
# Resume
start_epoch = CFG.TRAIN.START_EPOCH

if CFG.TRAIN.RESUME:
    abspath_ckpt = join(path_ckpt, 'detector_'+CFG.TRAIN.RESUME)
    if exists(abspath_ckpt):
        print("\nLoad checkpoint: '{}'".format(abspath_ckpt))
        checkpoint = torch.load(abspath_ckpt)
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['state_dict'])
    else:
        print("\n'{}' not exist.".format(abspath_ckpt))
    
print('Start epoch: {}'.format(start_epoch))

In [None]:
# GPU 
gpus = setgpu(CFG.TRAIN.GPU)
net  = net.cuda()
loss = loss.cuda()
cudnn.benchmark = True
net = DataParallel(net)

n_gpu = len(gpus.split(','))
print('Using GPUs:', n_gpu)

In [None]:
# train and test list
name_series = pd.Series(
    glob(path_data + "\\*label.npy")
).apply(lambda x: x.split("_label.npy")[0].split("\\")[-1])

name_list = name_series.values.tolist()

frac_for_train = CFG.TRAIN.FRAC_FOR_TRAIN
len_train = int(len(name_list) * frac_for_train)

luna_train = name_list[:len_train]
luna_valid = name_list[len_train:]

print('Train dset:', luna_train)
print('Valid dset:', luna_valid)

In [None]:
trainSet = LungNodule3Ddetector(path_data, luna_train, config, "train")

trainSize = len(trainSet)
print('Dataset size:', trainSize)

imgs, label, coord = trainSet[0]
print('\nImage:', imgs.shape)
print('Label:', label.shape)
print('coord:', coord.shape)

In [None]:
batch_size = CFG.TRAIN.BATCH_SIZE
num_workers = CFG.TRAIN.NUM_WORKERS

train_loader = DataLoader(
    trainSet, 
    batch_size = batch_size, 
    shuffle = True, 
    num_workers = num_workers, 
    pin_memory=True
)

imgs, lbls, cords = next(iter(train_loader))
print('\nImage type:', type(imgs))
print('      size: ', imgs.size())

print('\nLabel type:', type(lbls))
print('      size: ', lbls.size())

print('\nCoord type:', type(cords))
print('      size: ', cords.size())

#img = imgs[0]
#print('\nImage shape:', img.shape)
#print(); print(img)

In [None]:
validSet = LungNodule3Ddetector(path_data, luna_valid, config, phase='val')

valid_loader = DataLoader(
    validSet, 
    batch_size = batch_size, 
    shuffle = False, 
    num_workers = num_workers, 
    pin_memory=True
)

In [None]:
bestLoss = 1000
    
torch.manual_seed(0)
torch.cuda.set_device(0)

NumEpochs = CFG.TRAIN.NUM_EPOCHS
LR = float(CFG.TRAIN.LEARNING_RATE)
mo = float(CFG.TRAIN.MOMENTUM)
wd = float(CFG.TRAIN.WEIGHT_DECAY)

optimizer = torch.optim.SGD( net.parameters(), LR, momentum=mo, weight_decay=wd)

In [None]:
def get_lr(epoch):
    if epoch <= NumEpochs * 0.2:
        lr = LR
    elif epoch <= NumEpochs * 0.4:
        lr = 0.1 * LR
    elif epoch <= NumEpochs * 0.6:
        lr = 0.05 * LR
    else:
        lr = 0.01 * LR
    return lr

for epoch in range(start_epoch, NumEpochs+1):
    train(train_loader, net, loss, epoch, optimizer, get_lr, path_save)
    print("finsihed epoch {}".format(epoch))
    
    valiloss = validate(valid_loader, net, loss)
    
    if bestLoss > valiloss:   
        bestLoss = valiloss
        state_dict = net.module.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].cpu()
                
        torch.save({
            'epoch': epoch + 1,
            'save_dir': path_save,
            'state_dict': state_dict,
            'args': args},
            os.path.join(path_save, 'detector_%03d.ckpt' % epoch))
        print("\nSave model on epoch %d" % epoch)

In [None]:
print('done.')