In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
from plotutils import plot_images_as_grid

from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_patches

import image_processing
import video_session
import patch_extraction

from cnnlearning import CNN

from learning_utils import ImageDataset
from classificationutils import create_probability_map

from cnnlearning import TrainingTracker, train
import os
import collections

import scipy
import skimage
from skimage.morphology import binary_dilation as bd
from skimage.exposure import equalize_adapthist

import numpy as np
import torch
import cv2
import copy

training_video_sessions = video_session.get_video_sessions(marked=True, validation=False)
validation_video_sessions = video_session.get_video_sessions(marked=True, validation=True)

# video_sessions = [vs for vs in video_sessions if 'shared-videos' in vs.video_file]
print('training videos')
display([vs.video_file for vs in training_video_sessions])

print('validation videos')
display([vs.video_file for vs in validation_video_sessions])

## Patch extraction helpers

In [None]:
vs = training_video_sessions[1]
points = vs.cell_positions[0]

plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.scatter(points[:, 0], points[:, 1])
x, y = patch_extraction.get_parallel_points(points, 21)
plt.scatter(x, y)
plt.title('parallel points')

vs = training_video_sessions[5]
points = vs.cell_positions[0]

plt.subplot(1, 2, 2)
plt.scatter(points[:, 0], points[:, 1])
x, y = patch_extraction.get_perpendicular_points(points, 9)
plt.scatter(x, y)
plt.title('perpendicular points')

# Perpendicular search negative patch extraction

In [None]:
%load_ext autoreload
%autoreload 2
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patch_extraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

negative_extraction_mode = patch_extraction.NegativeExtractionMode.PERPENDICULAR
patch_extractor = patch_extraction.SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    limit_to_vessel_mask=False, 
    negative_extraction_mode = negative_extraction_mode,
    negative_extraction_radius=32, # Extraction radius here is the length of the line between the cell positions
    v=True,
    extraction_mode=PE.ALL_MODE
)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)
plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0, figsize=(20, 20))

#### Limit to vessel mask

In [None]:
%load_ext autoreload
%autoreload 2
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patch_extraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

negative_extraction_mode = patch_extraction.NegativeExtractionMode.PERPENDICULAR
patch_extractor = patch_extraction.SessionPatchExtractor(
    vs, 
    patch_size=25, 
    n_negatives_per_positive=32,
    limit_to_vessel_mask=True, 
    negative_extraction_mode = negative_extraction_mode,
    negative_extraction_radius=32, # Extraction radius here is the length of the line between the cell positions
    v=True,
    extraction_mode=PE.ALL_MODE
)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)
plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0, figsize=(20, 20))

In [None]:
import seaborn as sns
penguins = sns.load_dataset("penguins")
sns.displot(penguins, x="flipper_length_mm")

In [None]:
mean = patch_extractor.with_session(video_session.get_video_sessions(validation=False)[0]).cell_patches_oa790.flatten().mean()
std =  patch_extractor.with_session(video_session.get_video_sessions(validation=False)[0]).cell_patches_oa790.flatten().std()

# newImage1 = (image1-mean1)*std2/std1 + mean2;

In [None]:
def match_target_distribution(values, target_mean, target_std):
    # https://www.mathworks.com/matlabcentral/answers/236286-image-normalization-same-mean-and-same-std
    return (values - values.mean()) * target_std / values.std() + target_mean
    

In [None]:
video_session.get_video_sessions(validation=False)[0].frames_oa790[0].min(), video_session.get_video_sessions(validation=False)[0].frames_oa790[0].max()

In [None]:
from skimage import data
from skimage import exposure
from skimage.exposure import match_histograms

reference = data.coffee()
image = data.chelsea()

matched = match_histograms(image, reference, multichannel=True)

fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(8, 3),
                                    sharex=True, sharey=True)
for aa in (ax1, ax2, ax3):
    aa.set_axis_off()
plt.rcParams['font.size'] = 34
ax1.imshow(image)
ax1.set_title('Source', fontdict={'fontsize' :34})
ax2.imshow(reference)
ax2.set_title('Reference', fontdict={'fontsize' :34})
ax3.imshow(matched)
ax3.set_title('Matched', fontdict={'fontsize' :34})

plt.tight_layout()
plt.show()

fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(8, 3),
                                    sharex=True, sharey=True)
for aa in (ax1, ax2, ax3):
    aa.set_axis_off()
plt.rcParams['font.size'] = 34
sns.histplot(image.flatten(), ax=ax1, color='r')
ax1.set_title('Source', fontdict={'fontsize' :34})
sns.histplot(reference.flatten(), ax=ax2, color='g')
ax2.set_title('Reference', fontdict={'fontsize' :34})
sns.histplot(matched.flatten(), ax=ax3, color='b')
ax3.set_title('Matched', fontdict={'fontsize' :34})

plt.tight_layout()
plt.show()

In [None]:
plt.rcParams['font.size'] = 35
plt.subplot(1, 2, 1)
im1 = video_session.get_video_sessions(validation=False)[0].frames_oa790[0]
plt.imshow(im1, cmap='gray')
plt.subplot(1, 2, 2)
im2 = video_session.get_video_sessions(validation=False)[1].frames_oa790[0]
plt.imshow(im2, cmap='gray')

plt.figure()
plt.subplot(1, 2, 1)
im1 = video_session.get_video_sessions(validation=False)[0].frames_oa790[0]
sns.histplot(im1.flatten(), color='r')
sns.histplot(im2.flatten(), color='g')
matched = match_histograms(im1, im2, multichannel=False)
sns.histplot(matched.flatten(), color='b')

plt.subplot(1, 2, 2)
im2 = video_session.get_video_sessions(validation=False)[1].frames_oa790[0]
sns.histplot(im2.flatten(), color='g')

plt.figure()
plt.subplot(1, 2, 1)
im1 = video_session.get_video_sessions(validation=False)[0].frames_oa790[0]
plt.imshow(im1, cmap='gray')

plt.subplot(1, 2, 2)
im2 = video_session.get_video_sessions(validation=False)[1].frames_oa790[0]
matched = match_histograms(im1, im2, multichannel=False)
plt.imshow(matched, cmap='gray')
print(matched.dtype, im1.dtype)

In [None]:
import matplotlib.pyplot as plt

colors = ['b', 'g', 'r', 'c', 'm', 'y', "tab:brown", 'tab:pink', "tab:olive"]
fig, ax = plt.subplots(1, 2, figsize=(30, 20))
target_dis
for i, vs in enumerate(video_session.get_video_sessions(validation=False)):
    values = patch_extractor.with_session(vs).cell_patches_oa790.flatten() 
    neg_values = patch_extractor.with_session(vs).non_cell_patches_oa790.flatten()
    sns.histplot(values, ax=ax[0], color=colors[i])
    sns.histplot(neg_values, ax=ax[1], color=colors[i])

    
plt.show()

In [None]:
import matplotlib.pyplot as plt

colors = ['b', 'g', 'r', 'c', 'm', 'y', "tab:brown", 'tab:pink', "tab:olive"]
fig, ax = plt.subplots(1, 2, figsize=(30, 20))
for i, vs in enumerate(video_session.get_video_sessions(validation=False)):
    values = patch_extractor.with_session(vs).cell_patches_oa790.flatten() 
    neg_values = patch_extractor.with_session(vs).non_cell_patches_oa790.flatten()
    sns.histplot(values, ax=ax[0], color=colors[i])
    sns.histplot(neg_values, ax=ax[1], color=colors[i])

    
plt.show()

In [None]:
from image_processing import SessionPreprocessor

vs = video_session.get_video_sessions(validation=False)[0]

sp = SessionPreprocessor(vs)
reference = video_session.get_video_sessions(validation=False)[1].frames_oa790[0]

# sp.with_session(vs).with_preprocess([lambda x: match_histograms(x, reference), lambda x: x / 255]).map()

In [None]:

plt.figure()
ax = plt.gca()
colors = ['r', 'g', 'b']
for i in range(3):
    sns.histplot(vs.frames_oa790[i].flatten(), ax=ax, color=colors[i])

In [None]:

plt.figure()
ax = plt.gca()
colors = ['r', 'g', 'b']
for i in range(3):
    sns.histplot(vs.frames_oa790[i].flatten(), ax=ax, color=colors[i])

In [None]:
plt.imshow(vs.frames_oa790[0], cmap='gray')

In [None]:
plt.imshow(vs.frames_oa790[0], cmap='gray')

In [None]:
import matplotlib.pyplot as plt

colors = ['b', 'g', 'r', 'c', 'm', 'y', "tab:brown", 'tab:pink', "tab:olive"]
fig, ax = plt.subplots(1, 2, figsize=(30, 20))
for i, vs in enumerate(video_session.get_video_sessions(validation=False)):
    values = patch_extractor.with_session(vs).cell_patches_oa790.flatten() 
    neg_values = patch_extractor.with_session(vs).non_cell_patches_oa790.flatten()
    sns.histplot(values, ax=ax[0], color=colors[i])
    sns.histplot(neg_values, ax=ax[1], color=colors[i])

    
plt.show()

In [None]:
import matplotlib.pyplot as plt

colors = ['b', 'g', 'r', 'c', 'm', 'y', "tab:brown", 'tab:pink', "tab:olive"]
fig, ax = plt.subplots(1, 2, figsize=(30, 20))
for i, vs in enumerate(video_session.get_video_sessions(validation=False)):
    values = patch_extractor.with_session(vs).cell_patches_oa790.flatten() 
    neg_values = patch_extractor.with_session(vs).non_cell_patches_oa790.flatten()
    sns.histplot(match_target_distribution(values, mean, std), ax=ax[0], color=colors[i])
    sns.histplot(match_target_distribution(neg_values, mean, std), ax=ax[1], color=colors[i])

    
plt.show()

In [None]:
sns.displot(patch_extractor.cell_patches_oa790.flatten())
sns.displot(patch_extractor.cell_patches_oa790.flatten())

# Circle search negative patch extraction

In [None]:
%load_ext autoreload
%autoreload 2
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patch_extraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    limit_to_vessel_mask=False, 
    extraction_mode=PE.ALL_MODE)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0, figsize=(20, 20))

#### Limit to vessel mask

In [None]:
%load_ext autoreload
%autoreload 2
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patch_extraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    limit_to_vessel_mask=True, 
    extraction_mode=PE.ALL_MODE)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0, figsize=(20, 20))

# Temporal patches

Temporal patches include the patches from the same positions from the next and previous frames of the oa790 channel.

All the spatial patch extraction methods described earlier apply here as well

In [None]:
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(vs, patch_size=21, temporal_width=1, n_negatives_per_positive=7)

plot_images_as_grid(patch_extractor.temporal_cell_patches_oa790[:10], title='Temporal cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.temporal_non_cell_patches_oa790[:10], title='Temporal non cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_non_cell_patches_oa790[:10])

# A higher temporal width will give patches with more channells
patch_extractor.temporal_width = 1
print(f'Temporal patches shape with temporal width = 1: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 4
print(f'Temporal patches shape with temporal width = 4: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 5
print(f'Temporal patches shape with temporal width = 5: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 6
print(f'Temporal patches shape with temporal width = 6: {patch_extractor.temporal_cell_patches_oa790.shape}')
print(f"As temporal window becomes bigger notice that there are less patches becasue we can't use some frames at the begining and the end")
patch_extractor.temporal_width = 1
patch_extractor.visualize_temporal_patch_extraction(figsize=(20, 20))

# Mixed channel patches
 
Mixed channel patches give patches with 3 channels, the first channel is confocal video patch, second channel is from the oa780 channel,
third channel is from the oa850 channel.

The confocal video and the oa790 channel have the capillaries at the same position. The oa850 video has a vertical displacement, the oa850 frames are registered before extracting the patches.

#### Registration process

The vessel mask for the 790nm and 850nm video is created and then registered vertically by maximising Dice's coefficient
which is a similarity measure usually used to evaluate segmenation results

In [None]:
from video_session import get_video_sessions
from patch_extraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True)
vs = video_sessions[1]

vs.visualize_registration()
pass

In [None]:
patch_extractor = SessionPatchExtractor(vs, patch_size=21)

plot_images_as_grid(patch_extractor.mixed_channel_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_cell_patches[:10])

plot_images_as_grid(patch_extractor.mixed_channel_non_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_non_cell_patches[:10])

#### Showing the patches for each channel

The first image is the patches extracted from the confocal video.
Second is the patches extracted from the 790nm video.
Third is the patches extracted from the 850nm video.

In [None]:
_, axes = plt.subplots(1, 3, figsize=(150, 150))
for i, ax in enumerate(axes):
    patch_extractor.visualize_mixed_channel_patch_extraction(frame_idx=20, channel=i, ax=axes[i])