In [8]:
import sys
sys.path.append('../../image_processing/')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
    
# plt.style.use("seaborn")
# sns.set(font_scale=1)

from pathlib import Path
import os
import gc
import functools 
from shutil import copyfile
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from apex import amp
import dill

import os, glob
import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

from resize import *
# from unet3d.model import ResidualUNet3D
# from unet3d.losses import *
# from unet3d.utils import create_feature_maps
# from unet3d.buildingblocks import Encoder, Decoder, FinalConv, ExtResNetBlock, SingleConv

# from fastai.vision import *


In [9]:
data_path = Path('./')
train_img_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/training/fake_skull_stripped_1x1x3'
train_mask_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/training/fake_mets_masks_1x1x3'
valid_img_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/validation/skull_stripped_1x1x3'
valid_mask_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/validation/mets_masks_1x1x3'
test_img_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/testing/skull_stripped_1x1x3'
test_mask_path = data_path/'./manuscript_1_datasets/first_tx_allmets_0-0.5cc/testing/mets_masks_1x1x3'

cuda1 = torch.device('cuda:3')

In [10]:
train_img_files = sorted([str(train_img_path/file) for file in os.listdir(train_img_path)])
train_mask_files = sorted([str(train_mask_path/mask) for mask in os.listdir(train_mask_path)])
valid_img_files = sorted([str(valid_img_path/file) for file in os.listdir(valid_img_path)])
valid_mask_files = sorted([str(valid_mask_path/mask) for mask in os.listdir(valid_mask_path)])
test_img_files = sorted([str(test_img_path/file) for file in os.listdir(test_img_path)])
test_mask_files = sorted([str(test_mask_path/mask) for mask in os.listdir(test_mask_path)])
img_files = sorted(train_img_files+valid_img_files+test_img_files)
mask_files = sorted(train_mask_files+valid_mask_files+test_mask_files)
img_names = ['_'.join(file.split('/')[-1].split('_')[0:2]) for file in img_files]
mask_names = ['_'.join(file.split('/')[-1].split('_')[0:2]) for file in mask_files]
assert img_names==mask_names

In [11]:
def read_and_crop_to_tensor(file,target_d=None,target_h=None,target_w=None):
    img = np.load(file)
    d, h, w = img.shape
    if (target_d == None):
        target_d = d
    if (target_h == None):
        target_h = h
    if (target_w == None):
        target_w = w
    img = torch.from_numpy(xyz_pad(np.load(file),target_d,target_h,target_w)).type(torch.float)
    return img

class MetDataSet(Dataset):
    def __init__(self,img_files,mask_files):
        self.img_files = img_files
        self.mask_files = mask_files
        self.img_names = ['_'.join(file.split('/')[-1].split('_')[0:2]) for file in self.img_files]
        self.mask_names = ['_'.join(file.split('/')[-1].split('_')[0:2]) for file in self.mask_files]
    def __len__(self):
        return len(self.img_files)
    def __getitem__(self,idx):
        img = read_and_crop_to_tensor(self.img_files[idx],64,256,256).unsqueeze(0)
        mask = read_and_crop_to_tensor(self.mask_files[idx],64,256,256).unsqueeze(0)
        return img.to(cuda1), mask.to(cuda1)
    def get_name(self,idx):
        img_name = self.img_names[idx]
        mask_name = self.mask_names[idx]
        return img_name,mask_name
    
def show_single_pair(img,mask,index):
    figs,axes = plt.subplots(1,2)
    axes[0].imshow(img.cpu()[index])
    axes[1].imshow(mask.cpu()[index])
    plt.show()

In [12]:
# aug_img_files = [file for file in img_files if '/aug' in file]
# aug_mask_files = [file for file in mask_files if '/aug' in file]

### 2 slices with mets

In [13]:
train_dataset = MetDataSet(img_files,mask_files)

In [15]:
for i in range(len(train_dataset)):
    img = train_dataset[i][0][0]
    mask = train_dataset[i][1][0]
    print(train_dataset.get_name(i)[0])
    show_single_pair(img,mask,mask.argmax()//(256*256)+1)
    show_single_pair(img,mask,mask.argmax()//(256*256)+2)
#     plt.title(train_dataset.get_name(i)[0])

### Specific Patient

In [22]:
def show_single_pair(img,mask,index):
    figs,axes = plt.subplots(1,2)
    axes[0].imshow(img[index])
    axes[1].imshow(mask[index])
    plt.show()

In [26]:
train_img_files[1]

'manuscript_1_datasets/first_tx_allmets_0-0.5cc/training/fake_skull_stripped_1x1x3/BrainMets-UCSF-00274_19990927.npy48.npy'

In [None]:
for i in range(len(img)):
    show_single_pair(read_and_crop_to_tensor(train_img_files[1],64,256,256),read_and_crop_to_tensor(train_mask_files[1],64,256,256),i)