In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
import torch.nn as nn
import torch.optim as optim
os.environ['CUDA_VISIBLE_DEVICES']='5,6'
from unet3d import UNet
import tqdm
from torch.utils.data import DataLoader, Dataset
from datasets import LiverDatasetRandom, LiverDatasetFixed
import progressbar

In [None]:
trndataset = LiverDatasetFixed()
valdataset = LiverDatasetFixed()
trnloader = DataLoader(trndataset, batch_size = 4, shuffle = True, num_workers = 2)
valloader = DataLoader(valdataset, batch_size = 2, shuffle = True, num_workers = 2)

device = torch.device('cuda')
net = UNet(n_class = 2);
n_gpu = torch.cuda.device_count()
net = nn.DataParallel(net, device_ids = list(range(n_gpu)))
net.to(device);

optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=5e-5)
criterion = nn.CrossEntropyLoss()

## training

In [None]:
for param in net.parameters():
    param.requires_grad = True

lenloader = len(trnloader)
print('Total : ', lenloader)
dataloader_iterator = iter(trnloader)
tr_loss = 0
epoch = 0
for step in range(lenloader):
    batch = next(dataloader_iterator)
    images, labels  = batch['image'], batch['label']
    inputs = images.to(device, dtype = torch.float)
    labels = labels.to(device)
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    if step % 5 == 0:
        print('Step {} loss'.format(step) + ': {}'.format(loss.item()))
    tr_loss += loss.item()
    optimizer.step()
    del inputs, labels, outputs
epoch_loss = tr_loss/lenloader
print('Training loss: {:.4f}'.format(epoch_loss))
for param in net.parameters():
    param.requires_grad = False
model_save_file = os.path.join('./models', 'model_epoch{}.bin'.format(epoch))
torch.save(net.module.state_dict(), model_save_file)


## data

In [None]:
each thread loads an image. 
for some time, they query patches from the current image
when the patches from the image are over, they switch the image and go to the next image

In [None]:
import numpy as np
from torch.utils.data import DataLoader, Dataset
class OrderTestDataset(Dataset):    
    def __len__(self):
        return 50
    
    def __getitem__(self, idx):
        print(idx)
        return np.zeros([5,5])

    
trndset = OrderTestDataset()
trnloader = DataLoader(trndset, batch_size = 2, shuffle = False, num_workers = 2)

for step, batch in enumerate(trnloader):
    print(batch.shape)
    break

In [None]:
each image has a patch count
an image wont load a new image until all patches are extracted from this image

In [None]:
## make patchlocs.npy file

In [None]:
import numpy as np
listids = [102, 103, 104, 105, 107, 108]
patchlocdict = {}
for imgid in listids:
    patchlocs = []
    for idx in range(50):
        x, y, z = np.random.randint(100), np.random.randint(100), np.random.randint(100)
        patchlocs.append([200+x,200+y,200+z])
    patchlocdict[imgid] = patchlocs
np.save('patchlocs.npy', patchlocdict)

In [None]:
class LiverDataset(Dataset):
    def __init__(self, lits_id_list, nperimage, patchlocs, imgpshape, lblpshape):
        self.lits_id_list = lits_id_list
        self.nperimage = nperimage
        self.notover = False
        self.image = None
        self.label = None
        self.patchlocs = patchlocs
        self.current_pidx = 0
        self.imgpshape = imgpshape
        self.lblpshape = lblpshape
        
    def __len__(self):
        return self.nperimage*self.id_list

    def _load_new_image(self, pidx):
        listidx = pidx%self.nperimage
        imgidx = self.lits_id_list[listidx]
        lits_root = '/mnt/data/LiverCT/Parenchyma/LITS/train'
        img_path = os.path.join(lits_root, 'volume-' + str(imgidx) + '.nii')
        label_path = os.path.join(lits_root, 'segmentation-' + str(imgidx) + '.nii')
        self.image = sitk.GetImageFromArray(sitk.ReadImage(img_path))
        self.label = sitk.GetImageFromArray(sitk.ReadImage(label_path))
        self.current_idx = 0
        
    def next_patch_for_current_image(self):
        ploc = self.patchlocs[self.current_pidx]
        self.current_idx += 1
        ix, iy, iz = self.imgshape
        _ix, ix_, _iy, iy_, _iz, iz_ = ix//2, ix - ix//2, iy//2, iy - iy//2, iz//2, iz - iz//2
        lx, ly, lz = self.lblshape
        _lx, lx_, _ly, ly_, _lz, lz_ = lx//2, lx - lx//2, ly//2, ly - ly//2, lz//2, lz - lz//2
        imgpatch = self.image[ploc[0]-_ix:ploc[0]+ix_, ploc[1]-_iy:ploc[1]+iy_, ploc[2]-_iz:ploc[2]+iz_]
        lblpatch = self.label[ploc[0]-_lx:ploc[0]+lx_, ploc[1]-_ly:ploc[1]+ly_, ploc[2]-_lz:ploc[2]+lz_]
        return imgpatch, lblpatch
    
    def __getitem__(self, pidx):
        if self.current_pidx == self.nperimage:
            return next_patch_for_current_image()
        else:
            self._load_new_image(pidx) ## this idx should be mapped to the image
            return next_patch_for_current_image()

litsids = [102, 103, 104, 105, 107, 108]
patchlocs = np.load('patchlocs.npy', allow_pickle = True)[()]
trndset = LiverDataset(litsids, 50, patchlocs)
trndset