In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from pathlib import Path
%matplotlib qt5

In [3]:
from rtseg.cellseg.dataloaders import PhaseContrast
from rtseg.cellseg.utils.transforms import transforms
from rtseg.cellseg.networks import model_dict
from torch.utils.data import DataLoader
from rtseg.cellseg.utils.tiling import get_tiler
from rtseg.cellseg.numerics.vf_to_masks import construct_mask, construct_masks_batch
from skimage.io import imread

In [4]:
device='cuda:0'

In [5]:
model_path = Path('/home/pk/Documents/rtseg/models/cellseg/checkpoints/2024-07-22_14-44-13/model_val.pt')

In [6]:
model = model_dict['ResUnet']
model = model.parse(channels_by_scale=[1, 32, 64, 128, 256], num_outputs=[1, 2, 1],
                    upsample_type='transpose_conv', feature_fusion_type='concat',
                    skip_channel_seg=True)
model.load_state_dict(torch.load(model_path))

tiler = get_tiler("dynamic_overlap")
wrapped_model = tiler(model, device=device)
wrapped_model = wrapped_model.to(device)

def run_segnet(image):
    with torch.inference_mode():
        pred_semantic, pred_vf = wrapped_model(image)
    return pred_semantic, pred_vf

# plot the outputs of one image
def plot_inference(pred_semantic, pred_vf, device='cpu'):
    fig, ax = plt.subplots(nrows=1, ncols=3)
    ax[0].imshow(pred_semantic[0][0].cpu().numpy())
    ax[0].set_title("Semantic")
    ax[1].imshow(pred_vf[0][0].cpu().numpy())
    ax[1].set_title("vf_x")
    ax[2].imshow(pred_vf[0][1].cpu().numpy())
    ax[2].set_title("vf_y")
    plt.show()

def tensorize_image(image, device='cuda:0'):
    image_tensor = torch.from_numpy(image).float() / 65535.0
    image_tensor = image_tensor[None, None, :].to(device) # add dimension to play well with (N, C, H, W)
    return image_tensor
    
def segment(image, device='cuda:0'):
    pred_semantic, pred_vf = run_segnet(tensorize_image(image, device=device))
    segmentation_mask = construct_masks_batch(pred_vf, pred_semantic, device=device, store_solutions=False, fast=True)
    return segmentation_mask[0][0].cpu().numpy()

In [11]:
phase_img = imread('/home/pk/Documents/rtseg/data/timelapse/_1/Default/img_channel000_position000_time000000000_z000.tif')
#phase_img = imread('/mnt/sda1/REALTIME/data/seg_unet/dual/phase/img0049.tif')

In [12]:
plt.figure()
plt.imshow(phase_img, cmap='gray')
plt.title(f"{phase_img.shape}")
plt.show()

In [13]:
mask = segment(phase_img)

In [14]:
plt.figure()
plt.imshow(mask)
plt.show()

In [26]:
from rtseg.cellseg.utils.transforms import train_transform

In [27]:
train_dir = Path('/mnt/sda1/REALTIME/data/seg_unet/dual/')
train_transforms = train_transform

train_ds = PhaseContrast(phase_dir=train_dir/Path('phase'),
                labels_dir=train_dir/Path('mask'),
                vf_dir=train_dir/Path('vf11'),
                vf_at_runtime=True, # Vf are computed on the fly
                labels_delimiter='',
                vf_delimiter='_vf_11',
                transforms=train_transform,
                phase_format='.tif',
                labels_format='.tif',
                vf_format='.npy'
            )
train_dl = DataLoader(train_ds, batch_size=6)

In [28]:
train_ds.transforms

<rtseg.cellseg.utils.transforms.Compose at 0x7b60841081d0>

In [29]:
train_iter = iter(train_dl)

In [30]:
a = next(train_iter)

In [34]:
a[0].shape, a[1].shape, a[2].shape

(torch.Size([6, 1, 320, 320]),
 torch.Size([6, 1, 320, 320]),
 torch.Size([6, 2, 320, 320]))

In [19]:
from rtseg.cellseg.utils.transforms import train_transform

In [21]:
t = train_ds[0]

TypeError: Compose.__call__() takes 2 positional arguments but 3 were given

In [133]:
t[0].shape, t[1].shape

((880, 1024), (880, 1024))

In [134]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(t[0], cmap='gray')
ax[1].imshow(t[1])
plt.show()

In [135]:
import torchvision.transforms.functional  as TF

In [136]:
t1 = train_transform(t)

<rtseg.cellseg.utils.transforms.changedToPIL object at 0x7e849e41d010>
<rtseg.cellseg.utils.transforms.RandomCrop object at 0x7e849e41d390>
<rtseg.cellseg.utils.transforms.RandomRotation object at 0x7e849e69c690>
10.957586288452148
<rtseg.cellseg.utils.transforms.RandomAffine object at 0x7e849e6360d0>
<rtseg.cellseg.utils.transforms.VerticalFlip object at 0x7e849e668dd0>
<rtseg.cellseg.utils.transforms.HorizontalFlip object at 0x7e849e3ffad0>
<rtseg.cellseg.utils.transforms.AddVectorField object at 0x7e849e60e290>
<rtseg.cellseg.utils.transforms.ToFloat object at 0x7e849e68ed90>


In [137]:
t1[0].shape, t1[1].shape, t1[2].shape

(torch.Size([1, 320, 320]),
 torch.Size([1, 320, 320]),
 torch.Size([2, 320, 320]))

In [144]:
plt.figure()
plt.imshow(t1[2][0])
plt.show()

In [81]:
from skimage.measure import label

In [41]:
from rtseg.cellseg.numerics.sdf_vf import sdf_vector_field

In [42]:
a = sdf_vector_field(torch.tensor(t1[1]), 21)

RuntimeError: Could not infer dtype of Image

In [33]:
a.shape

torch.Size([2, 880, 1024])

In [34]:
fig, ax = plt.subplots(nrows=1, ncols=2)
ax[0].imshow(a[0].numpy())
ax[1].imshow(a[1].numpy())
plt.show()

In [None]:
class RandomRotateScale:

    def __init__(self):
        pass

    def __call__(self, image, mask , vf = None):
        pass

In [43]:
test_dir = Path('/mnt/sda1/REALTIME/data/seg_unet/dual/')
test_transforms = transforms['train']

In [44]:
test_ds = PhaseContrast(phase_dir=test_dir/Path('phase'),
                labels_dir=test_dir/Path('mask'),
                vf_dir=test_dir/Path('vf11'),
                vf=False,
                labels_delimiter='',
                vf_delimiter='_vf_11',
                transforms=test_transforms,
                phase_format='.tif',
                labels_format='.tif',
                vf_format='.npy'
            )

In [45]:
test_dl = DataLoader(test_ds, batch_size=1, pin_memory=False, drop_last=False, num_workers=2)

In [46]:
test_iter = iter(test_dl)

In [49]:
phase, mask = next(test_iter)

print(phase.shape, mask.shape)

ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pk/Documents/rtseg/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/pk/Documents/rtseg/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pk/Documents/rtseg/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/pk/Documents/rtseg/rtseg/cellseg/dataloaders.py", line 110, in __getitem__
    return self._getitem(idx)
           ^^^^^^^^^^^^^^^^^^
  File "/home/pk/Documents/rtseg/rtseg/cellseg/dataloaders.py", line 105, in _get_image_mask
    image, mask = self.transforms(image, mask)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/pk/Documents/rtseg/rtseg/cellseg/utils/transforms.py", line 55, in __call__
    image, mask = layer(image, mask)
    ^^^^^^^^^^^
ValueError: too many values to unpack (expected 2)


In [39]:
vf.shape

torch.Size([1, 2, 320, 320])

In [40]:
def plot(phase, mask):
    fig, ax = plt.subplots(nrows=1, ncols=2)
    ax[0].imshow(phase[0][0].numpy(), cmap='gray')
    ax[1].imshow(mask[0][0].numpy())
    plt.show()

In [41]:
plot(phase, mask)

In [58]:
import torchvision.transforms.functional as TF

In [59]:
phase.dtype

torch.float32

In [60]:
from skimage.util import random_noise
from skimage.exposure import adjust_gamma, rescale_intensity
import random

In [61]:
phase_np = phase.numpy()
gamma_factor = random.uniform(0.7, 1.4)
print(gamma_factor)
brightness_fator = random.uniform(-2500.0, 5000.0)
print(brightness_fator)
phase_adjusted = adjust_gamma(phase_np, gamma=gamma_factor)
phase_adjusted += brightness_fator
phase_adjust = rescale_intensity(phase_adjusted, in_range='image', out_range='uint16')

0.8210256985135094
1467.453322170123


In [62]:
fig, ax = plt.subplots(nrows=1, ncols=2)
im = ax[0].imshow(phase_np[0])
fig.colorbar(im, ax=ax[0])
im = ax[1].imshow(phase_adjust[0])
fig.colorbar(im, ax=ax[1])
plt.show()

TypeError: Invalid shape (1, 320, 320) for image data