In [1]:
%reload_ext autoreload
%autoreload 2
%reload_ext notexbook
%texify

In [2]:
import numpy as np
import edt
import pathlib
from pathlib import Path
from skimage import io
import matplotlib.pyplot as plt
import torch
%matplotlib qt

In [3]:
DATA_DIR = Path('/home/pk/Documents/rtseg/data/cellseg/omnipose/')
phase_dir = DATA_DIR / Path('bacteria_train')
labels_dir = DATA_DIR / Path('train_masks')
vf_dir = DATA_DIR / Path('vf_train')

In [4]:
from rtseg.cellseg.dataloaders import PhaseContrast
from rtseg.cellseg.utils.transforms import transforms

In [5]:
dataset = PhaseContrast(phase_dir, labels_dir, vf_dir, vf = True, vf_delimiter='_vf_11', transforms=transforms['train'])
test_dataset = PhaseContrast(phase_dir, labels_dir, vf_dir, vf = False, transforms=transforms['eval'])

In [6]:
len(dataset)

249

In [7]:
len(test_dataset)

249

In [8]:
image, mask, vf = dataset[0]

<rtseg.cellseg.utils.transforms.AddDimension object at 0x72426236d6d0>
<rtseg.cellseg.utils.transforms.RandomCrop object at 0x72426236d710>
<rtseg.cellseg.utils.transforms.VerticalFlip object at 0x72426236d750>
<rtseg.cellseg.utils.transforms.HorizontalFlip object at 0x72426236d790>
<rtseg.cellseg.utils.transforms.ToFloat object at 0x72426236d7d0>


In [9]:
image.shape, mask.shape, vf.shape

((1, 320, 320), (1, 320, 320), (2, 320, 320))

In [13]:
image_test, mask_test = test_dataset[15]

<rtseg.cellseg.utils.transforms.AddDimension object at 0x72426236d850>
<rtseg.cellseg.utils.transforms.ToFloat object at 0x72426236d890>


In [14]:
image_test.shape, mask_test.shape,

((1, 599, 579), (1, 599, 579))

In [83]:
dataset.plot_item(100)

In [84]:
vf_tensor = torch.from_numpy(vf)[None,:]
semantic = torch.from_numpy(mask)[None, :]

In [85]:
vf_tensor.shape, semantic.shape

(torch.Size([1, 2, 320, 320]), torch.Size([1, 1, 320, 320]))

In [86]:
from rtseg.cellseg.numerics.vf_to_masks import construct_mask

In [87]:
labels = construct_mask(vf_tensor, semantic, store_solutions = False)

In [88]:
labels.shape

torch.Size([1, 1, 320, 320])

In [89]:
plt.figure()
plt.imshow(labels[0][0].numpy())
plt.show()

In [90]:

vf_tensor_test = torch.from_numpy(vf_test)[None, :]
semantic_test = torch.from_numpy(mask_test)[None, :]

In [91]:
labels_test = construct_mask(vf_tensor_test, semantic_test, store_solutions = False)

In [92]:
plt.figure()
plt.imshow(labels_test[0][0].numpy())
plt.show()

### Batched things

In [20]:
from torch.utils.data import DataLoader

In [48]:
dataloader = DataLoader(dataset, batch_size=2, pin_memory=True, drop_last=True, shuffle=False, num_workers=2)

In [53]:
iter_data = iter(dataloader)
image_batch, mask_batch, vf_batch = next(iter_data)

In [57]:
image_batch, mask_batch, vf_batch = next(iter_data)

In [58]:
image_batch.shape, mask_batch.shape, vf_batch.shape

(torch.Size([2, 1, 320, 320]),
 torch.Size([2, 1, 320, 320]),
 torch.Size([2, 2, 320, 320]))

In [59]:
def plot_batch(image_batch, mask_batch, vf_batch):
    B, _, H, W = image_batch.shape

    for i in range(B):
        image, mask, vf = image_batch[i].numpy(), mask_batch[i].numpy(), vf_batch[i].numpy()
        #print(image.shape, mask.shape, vf.shape)
        nrows, ncols = 2, 2
        fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
        ax[0, 0].imshow(image[0], cmap='gray')
        ax[0, 0].set_title('Phase contrast')
        ax[0, 1].imshow(mask[0])
        ax[0, 1].set_title('Mask')
        ax[1, 0].imshow(vf[0])
        ax[1, 0].set_title('vf_x')
        ax[1, 1].imshow(vf[1])
        ax[1, 1].set_title('vf_y')
        #fig.suptitle(f'{self.phase_filenames[idx].name}')
        plt.show()

In [60]:
plot_batch(image_batch, mask_batch, vf_batch)

In [26]:
labels = construct_mask(vf_batch[0][None,:], mask_batch[0][None, :] > 0)

In [27]:
labels.shape

torch.Size([1, 1, 320, 320])

In [28]:
plt.figure()
plt.imshow(labels[0][0].numpy())
plt.show()

In [29]:
plt.figure()
plt.imshow(image_batch[0][0].numpy())
plt.show()

In [30]:
from rtseg.cellseg.numerics.interpolation.interpolate_vf import interpolate_vf
from rtseg.cellseg.numerics.integration.utils import init_values_mesh_batched, init_values_semantic
from rtseg.cellseg.numerics.integration.integrate_vf import ivp_solver
from rtseg.cellseg.numerics.cluster import cluster

### Batched integration 

In [31]:
continuous_vf_batch = interpolate_vf(vf_batch, mode = 'bilinear_batched')

In [32]:
continuous_vf_batch

<function rtseg.cellseg.numerics.interpolation.interpolate_vf._vf_bilinear_batched.<locals>._vf(p)>

In [33]:
semantic_batch = mask_batch.clone() > 0

In [34]:
semantic_batch.shape

torch.Size([2, 1, 320, 320])

In [35]:
initial_values = init_values_mesh_batched(2, 320, 320)

In [36]:
solutions = ivp_solver(continuous_vf_batch, initial_values, dx=0.1, n_steps=10)[-1]

In [37]:
solutions.shape, semantic_batch.shape

(torch.Size([2, 2, 320, 320]), torch.Size([2, 1, 320, 320]))

In [38]:
plt.figure()
plt.imshow(solutions[][0].numpy())
plt.show()

In [None]:
solutions[0]