In [1]:
import os
import tifffile 
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook as tqdm
from shutil import copyfile
import numpy as np
from torchvision.datasets import DatasetFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
working_dir = '/media/jswaney/SSD EVO 860/organoid_phenotyping/ventricle_segmentation'

In [4]:
data_dir = 'data_pz'

files = os.listdir(os.path.join(working_dir, data_dir))
len(files)

607

In [5]:
def load_tiff_seg(path):
    data = tifffile.imread(path)[np.newaxis]
#     data = np.stack([data[0], data[1], np.zeros(data[0].shape, data.dtype)])
    data = np.stack([data[0], np.zeros(data[0].shape, data.dtype), np.zeros(data[0].shape, data.dtype)])
    return data.transpose((1, 2, 0))

In [6]:
degrees = 45
scale = (0.8, 1.2)
size = 256

data_dir = 'V1_1_class'

dataset = DatasetFolder(os.path.join(working_dir, data_dir),
                             loader=load_tiff_seg,
                             extensions=['.tif'],
                             transform=transforms.ToTensor())

print(dataset)

Dataset DatasetFolder
    Number of datapoints: 607
    Root Location: /media/jswaney/SSD EVO 860/organoid_phenotyping/ventricle_segmentation/V1_1_class
    Transforms (if any): ToTensor()
    Target Transforms (if any): None


In [7]:
import torch

In [8]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
device

device(type='cuda')

In [9]:
batch_size = 1
n_workers = 1 if use_cuda else 0
pin_memory = True if use_cuda else False

loader = DataLoader(dataset, 
                   batch_size, 
                   shuffle=False, 
                   num_workers=n_workers, 
                   pin_memory=pin_memory)

In [10]:
import sys
sys.path.append('/home/jswaney/Pytorch-UNet/')
from unet.unet_model import UNet

In [11]:
model_path = 'unet_pz_200.pt'

model = UNet(1, 1)
model.load_state_dict(torch.load(os.path.join(working_dir, model_path)))
model.to(device)

UNet(
  (inc): inconv(
    (conv): double_conv(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace)
      )
    )
  )
  (down1): down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-0

In [13]:
imgs = []
segs = []
outputs = []

with torch.no_grad():
    for x, _ in tqdm(loader):
        img = x[:, 0].unsqueeze(1).to(device)
#         seg = x[:, 1].unsqueeze(1).to(device)

        output = model(img)

        imgs.append(img.detach().cpu().numpy())
#         segs.append(seg.detach().cpu().numpy())        
        outputs.append(output.detach().cpu().numpy())

HBox(children=(IntProgress(value=0, max=607), HTML(value='')))




In [None]:
output_path = 'V1_1_output'

os.makedirs(os.path.join(working_dir, output_path))

for i, output in tqdm(enumerate(outputs), total=len(outputs)):
    output = output[0, 0]
    filename = f'{i:04d}.tif'
    path = os.path.join(working_dir, output_path, filename)
    tifffile.imsave(path, output, compress=1)

HBox(children=(IntProgress(value=0, max=607), HTML(value='')))