In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from pathlib import Path
%matplotlib qt5

#### Dataset and Dataloader

In [3]:
from rtseg.cellseg.dataloaders import PhaseContrast
from rtseg.cellseg.utils.transforms import transforms
from rtseg.cellseg.networks import model_dict
from torch.utils.data import DataLoader
from rtseg.cellseg.utils.tiling import get_tiler

In [4]:
#test_dir = Path('/home/pk/Documents/rtseg/data/cellseg/omnipose/')
test_dir = Path('/mnt/sda1/REALTIME/data/seg_unet/dual/')
test_transforms = transforms['eval']

test_ds = PhaseContrast(phase_dir=test_dir/Path('bacteria_test'),
                labels_dir=test_dir/Path('test_masks'),
                vf_dir=test_dir/Path('vf_test'),
                vf=False,
                labels_delimiter='_masks',
                vf_delimiter='_vf_11',
                transforms=test_transforms,
                phase_format='.png',
                labels_format='.png',
                vf_format='.npy'
            )

In [5]:
test_ds = PhaseContrast(phase_dir=test_dir/Path('phase'),
                labels_dir=test_dir/Path('mask'),
                vf_dir=test_dir/Path('vf11'),
                vf=False,
                labels_delimiter='',
                vf_delimiter='_vf_11',
                transforms=test_transforms,
                phase_format='.tif',
                labels_format='.tif',
                vf_format='.npy'
            )

##### Don't do this with more than 1 batch size

In [6]:
test_dl = DataLoader(test_ds, batch_size=1, pin_memory=False, drop_last=False, num_workers=2)

In [7]:
a, b = next(iter(test_dl))
print(a.shape, b.shape)

torch.Size([1, 1, 880, 1024]) torch.Size([1, 1, 880, 1024])


#### Model loading

In [8]:
#model_path = Path('/home/pk/Documents/rtseg/models/cellseg/checkpoints/2024-04-16_09-51-44/model_val.pt')

In [9]:
model_path = Path('/home/pk/Documents/rtseg/models/cellseg/checkpoints/2024-04-17_14-54-31_vf11MM_lr_3e-4_65535/model_val.pt')

In [10]:
model = model_dict['ResUnet']
model = model.parse(channels_by_scale=[1, 32, 64, 128, 256], num_outputs=[1, 2, 1],
                    upsample_type='transpose_conv', feature_fusion_type='concat',
                    skip_channel_seg=True)
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [11]:
model.eval()

ResUnet(
  (down_layers): Sequential(
    (0): ResConvBlock(
      (block): Sequential(
        (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
        (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (skip_connection): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (nonlinearity): ReLU(inplace=True)
    )
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ResConvBlock(
      (block): Sequential(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU(inplace=True)
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), paddin

#### Adding tiling and merging wrapper around the model.

In [12]:
device = 'cuda:0'

In [13]:
tiler = get_tiler("dynamic_overlap")

In [14]:
wrapped_model = tiler(model, device=device)
wrapped_model = wrapped_model.to(device)

In [15]:
a = a.to(device)

In [16]:
a.shape

torch.Size([1, 1, 880, 1024])

In [17]:
double_a = torch.cat([a, a], dim=0)

In [18]:
double_a.shape

torch.Size([2, 1, 880, 1024])

In [19]:
plt.figure()
plt.imshow(a.cpu()[0][0], cmap='gray')
plt.show()

In [20]:
def infer(a):
    with torch.inference_mode():
        pred_semantic, pred_vf = wrapped_model(a)
    return pred_semantic, pred_vf

In [21]:
pred_semantic, pred_vf = infer(a)

In [21]:
%timeit pred_semantic, pred_vf = infer(a)

32.5 ms ± 606 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [22]:
pred_semantic.shape, pred_vf.shape

(torch.Size([1, 1, 880, 1024]), torch.Size([1, 2, 880, 1024]))

In [23]:
%timeit pred_semantic_double, pred_vf_double = infer(double_a)

64.8 ms ± 689 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [24]:
pred_semantic_double, pred_vf_double = infer(double_a)

In [25]:
pred_semantic_double.shape, pred_vf_double.shape

(torch.Size([2, 1, 880, 1024]), torch.Size([2, 2, 880, 1024]))

In [25]:
mask_constructed = construct_masks_batch(pred_vf, pred_semantic, device='cuda:0', store_solutions=False, fast=True)

In [26]:
plt.figure()
plt.imshow(mask_constructed[0][0])
plt.show()

In [23]:
def plot_inference(pred_semantic, pred_vf, device='cpu'):
    fig, ax = plt.subplots(nrows=1, ncols=3)
    ax[0].imshow(pred_semantic[0][0].cpu().numpy())
    ax[0].set_title("Semantic")
    ax[1].imshow(pred_vf[0][0].cpu().numpy())
    ax[1].set_title("vf_x")
    ax[2].imshow(pred_vf[0][1].cpu().numpy())
    ax[2].set_title("vf_y")
    plt.show()

In [24]:
plot_inference(pred_semantic, pred_vf)

In [25]:
plot_inference(pred_semantic_double[0][None, :], pred_vf_double[0][None,:])
plot_inference(pred_semantic_double[1][None, :], pred_vf_double[1][None,:])

NameError: name 'pred_semantic_double' is not defined

### Read from file and infer

In [24]:
from rtseg.cellseg.numerics.vf_to_masks import construct_mask, construct_masks_batch
from skimage.io import imread

In [47]:
image = imread(Path('/home/pk/Documents/waveletCode/data/img_000000000_phase.tiff')).astype('float32')

In [48]:
image_tensor = torch.from_numpy(image)[None, None, :].to('cuda:0') / 35000.0

In [49]:
image_tensor.shape

torch.Size([1, 1, 1404, 3200])

In [50]:
image_tensor_batch = torch.concatenate([image_tensor, image_tensor, image_tensor], axis=0)

In [51]:
image_tensor_batch.shape

torch.Size([3, 1, 1404, 3200])

In [52]:
%timeit pred_semantic, pred_vf = infer(image_tensor_batch)

367 ms ± 4.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
pred_semantic, pred_vf = infer(image_tensor_batch)

In [54]:
pred_semantic.shape, pred_vf.shape

(torch.Size([3, 1, 1404, 3200]), torch.Size([3, 2, 1404, 3200]))

In [55]:
plot_inference(pred_semantic, pred_vf)

In [56]:
%timeit mask_constructed = construct_masks_batch(pred_vf, pred_semantic, device='cuda:0', store_solutions=False, fast=True)

994 ms ± 9.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [57]:
mask_constructed = construct_masks_batch(pred_vf, pred_semantic, device='cuda:0', store_solutions=False, fast=True)

In [58]:
plt.figure()
plt.imshow(mask_constructed[0][0])
plt.show()

In [70]:
save_image_path = Path('/home/pk/Documents/waveletCode/data/img_000000000_mask.tiff')

In [71]:
from skimage.io import imsave

In [72]:
imsave(save_image_path, mask_constructed[0][0].numpy(), plugin='tifffile')

### Reconstruction of masks in batches

In [45]:
from rtseg.cellseg.numerics.vf_to_masks import construct_mask, construct_masks_batch

In [33]:
pred_vf.device, pred_semantic.device

(device(type='cuda', index=0), device(type='cuda', index=0))

In [34]:
mask_constructed = construct_mask(pred_vf, pred_semantic, device='cuda:0', store_solutions=False, fast=True)

In [35]:
mask_constructed.shape

torch.Size([1, 1, 880, 1024])

In [36]:
mask_constructed.dtype

torch.float32

In [37]:
mask_constructed = construct_masks_batch(pred_vf_double, pred_semantic_double, device='cuda:0', store_solutions=False, fast=True)

In [38]:
mask_constructed.shape

torch.Size([2, 1, 880, 1024])

In [39]:
plt.figure()
plt.imshow(mask_constructed[0][0])
plt.show()

In [37]:
plt.figure()
plt.imshow(mask_constructed[1][0])
plt.show()

In [38]:
pred_vf.device

device(type='cuda', index=0)

In [39]:
pred_semantic.device

device(type='cuda', index=0)

In [40]:
mask_constructed.shape

torch.Size([2, 1, 880, 1024])

### Evals