In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
from plotutils import plot_images_as_grid

from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_patches
from imageprosessing import hist_match_images, enhance_motion_contrast, normalize_data, gaussian_blur_stack
from imageprosessing import enhance_motion_contrast_de_castro, enhance_motion_contrast_j_tam, SessionPreprocessor
from sharedvariables import get_video_sessions
from video_session import VideoSession
from cnnlearning import CNN
from patchextraction import SessionPatchExtractor as PE
from learningutils import ImageDataset
from classificationutils import create_probability_map

from cnnlearning import TrainingTracker, train
import os
import collections

import scipy
import skimage
from skimage.morphology import binary_dilation as bd
from skimage.exposure import equalize_adapthist

import numpy as np
import torch
import cv2
import copy

from patchextraction import extract_patches, SessionPatchExtractor
from patchextraction import SessionPatchExtractor as PE
from imageprosessing import ImageRegistrator

video_sessions = get_video_sessions(marked=True, registered=True)
# video_sessions = [vs for vs in video_sessions if 'shared-videos' in vs.video_file]
[vs.video_file for vs in video_sessions]

# Accessing files

Every file under any folder in ./data is parsed and put into dictionaries that group videos of the same source.
Videos of the same source are considered videos coming from the same Subject, same Session, same OD/OS, same (x, y),
same type (Confocal, OA790, OA850).

The parsing is case insensitive with the following rules:

**unmarked videos**:
must not contain 'mask' or '_marked'

**marked videos**:
Must end with '_marked.\<<file_extension\>>'

**standard deviation images**:
Must end with
'_std.\<<file_extension\>>'

**vessel mask images**: 
Must end with
'_vessel_mask.\<<file_extension\>>'

**channel type**:
must contain one of 'OA790', 'OA850', 'Confocal' (case insensitive)

### Example to using VideoSession objects and the get_video_session() function
(useful for training and testing)

### Get all marked and registered videos for training

In [None]:
from sharedvariables import get_video_sessions
from os.path import basename

train_uids = []
video_sessions = get_video_sessions(marked=True, registered=True)
for session in video_sessions:
        assert session.has_marked_video
        assert session.is_registered
        assert session.has_marked_cells
        assert session.uid not in train_uids
        train_uids.append(session.uid)
        print('-----------------------')
        print('Video file:', basename(session.video_file))
        print('Uid', session.uid)
        print('Does video have a corresponding marked video?:', session.has_marked_video)
        print('Subject number:', session.subject_number)
        print('Session number:', session.session_number)
        print('Marked Video OA790:', basename(session.marked_video_oa790_file))
        print('Std dev image confocal:', basename(session.std_image_confocal_file))
        print('Std dev image OA850:', basename(session.std_image_oa850_file))
        print('Vessel mask OA850:', basename(session.vessel_mask_oa850_file))
        print('Vessel mask confocal:', basename(session.vessel_mask_confocal_file))
        print('Cell position csv files:', *[basename(f) for f in session.cell_position_csv_files], sep='\n')
        print()
        
print('Number of video sessions ', len(video_sessions))

In [None]:
from sharedvariables import get_video_sessions
from os.path import basename

train_uids = [vs.uid for vs in get_video_sessions(marked=True, validation=False)]
valid_uids =[]
video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
for session in video_sessions:
        assert session.has_marked_video
        assert session.is_registered
        assert session.has_marked_cells
        assert session.uid not in valid_uids, f'Warining, Duplicate video session {session.uid}'
        assert session.uid not in train_uids, 'Warning, validation video exist in training videos as well'
        
        valid_uids.append(session.uid)
        print('-----------------------')
        print('Video file:', basename(session.video_file))
        print('Uid', session.uid)
        print('Does video have a corresponding marked video?:', session.has_marked_video)
        print('Subject number:', session.subject_number)
        print('Session number:', session.session_number)
        print('Marked Video OA790:', basename(session.marked_video_oa790_file))
        print('Std dev image confocal:', basename(session.std_image_confocal_file))
        print('Std dev image OA850:', basename(session.std_image_oa850_file))
        print('Vessel mask OA850:', basename(session.vessel_mask_oa850_file))
        print('Vessel mask confocal:', basename(session.vessel_mask_confocal_file))
        print('Cell position csv files:', *[basename(f) for f in session.cell_position_csv_files], sep='\n')
        print()
        
print('Number of video sessions ', len(video_sessions))

# Reading frames from videos - (and cell positions for each frame)

You can get access to the frames of the video session.

In [None]:
import matplotlib.pyplot as plt
from sharedvariables import get_video_sessions
from plotutils import no_ticks

video_sessions = get_video_sessions(marked=True, registered=True)
_, axes = plt.subplots(1, 2, figsize=(20, 7))
no_ticks(axes)

axes[0].imshow(session.frames_oa790[0], cmap='gray')
axes[0].set_title(f"First frame of {basename(session.video_oa790_file)}", fontsize=10)
axes[0].scatter(session.cell_positions[0][:, 0], session.cell_positions[0][:, 1], label='cell positions', s=10)
axes[0].legend()
    
axes[1].imshow(session.marked_frames_oa790[0], cmap='gray')
axes[1].set_title(f"First marked frame of {basename(session.marked_video_oa790_file)}", fontsize=10)
pass

# How to extract cell and no cell patches

# Using SessionPatchExtractor (Object oriented way) 

## Simple patch extraction

#### Circle search negative patch extraction

In [None]:
%load_ext autoreload
%autoreload 2
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patchextraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    use_vessel_mask=False, 
    extraction_mode=PE.ALL_MODE)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0)

####  Rectancle search negative patch extraction

In [None]:
%load_ext autoreload
%autoreload 2
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from generate_datasets import create_cell_and_no_cell_patches
from patchextraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    use_vessel_mask=False, 
    negative_extraction_mode=PE.RECTANGLE,
    negative_patch_extraction_radius=33,
    extraction_mode=PE.ALL_MODE)

print(patch_extractor.cell_patches_oa790.shape)
print(patch_extractor.non_cell_patches_oa790.shape)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0)

### Restricting negatives within vessel mask

In [None]:
patch_extractor = SessionPatchExtractor(
    vs, 
    patch_size=21, 
    n_negatives_per_positive=32,
    use_vessel_mask=True)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])
patch_extractor.visualize_patch_extraction(linewidth=2, s=100, frame_idx=0)

## Temporal patches

Temporal patches include the patches from the same positions from the next and previous frames.

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, validation=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(vs, patch_size=21, temporal_width=1, n_negatives_per_positive=7)

plot_images_as_grid(patch_extractor.temporal_cell_patches_oa790[:10], title='Temporal cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.temporal_non_cell_patches_oa790[:10], title='Temporal non cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_non_cell_patches_oa790[:10])

# A higher temporal width will give patches with more channells
patch_extractor.temporal_width = 1
print(f'Temporal patches shape with temporal width = 1: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 4
print(f'Temporal patches shape with temporal width = 4: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 5
print(f'Temporal patches shape with temporal width = 5: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 6
print(f'Temporal patches shape with temporal width = 6: {patch_extractor.temporal_cell_patches_oa790.shape}')
print(f'As temporal window becomes bigger notice that there are less patches.')
patch_extractor.temporal_width = 1
patch_extractor.visualize_temporal_patch_extraction()

## Mixed channel patches
 
Mixed channel patches give patches with 3 channels, the first channel is confocal video patch, second channel is from the oa780 channel,
third channel is from the oa850 channel.

The confocal video and the oa790 channel have the capillaries at the same position. The oa850 video has a vertical displacement, the video is registered before extracting the patches.

#### Registration process

The vessel mask for the 790nm and 850nm video is created and then registered vertically by maximising Dice's coefficient
which is a similarity measure usually used to evaluate segmenation results

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(marked=True, registered=True)
vs = video_sessions[1]

vs.visualize_registration()
pass

In [None]:
patch_extractor = SessionPatchExtractor(vs, patch_size=21)

plot_images_as_grid(patch_extractor.mixed_channel_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_cell_patches[:10])

plot_images_as_grid(patch_extractor.mixed_channel_non_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_non_cell_patches[:10])

#### Showing the patches for each channel

The first image is the patches extracted from the confocal video.
Second is the patches extracted from the 790nm video.
Third is the patches extracted from the 850nm video.

In [None]:
_, axes = plt.subplots(1, 3, figsize=(150, 150))
for i, ax in enumerate(axes):
    patch_extractor.visualize_mixed_channel_patch_extraction(frame_idx=20, channel=i, ax=axes[i])

## Changing Patch Extractor mode - Training, Validation, and All patches

Each session has a frame that is assigned for validation.
Usually this is the frame with the most cell positions but it can be changed with vs.validation_frame_idx = index

Patch extractor has mode for extracting only training and validation patches but the default mode extracts all the patches.

In [None]:
%load_ext autoreload
%autoreload 2
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor as PE
from plotutils import plot_images_as_grid, no_ticks
import matplotlib.pyplot as plt

video_sessions = get_video_sessions(should_have_marked_cells=True)
vs = video_sessions[2]

patch_extractor = PE(vs, patch_size=21, extraction_mode=PE.VALIDATION_MODE)

_, ax = plt.subplots()
print('Validation frame index', vs.validation_frame_idx)
ax.imshow(vs.frames_oa790[vs.validation_frame_idx], cmap='gray')
ax.set_title(f'Validation frame {vs.validation_frame_idx}')
cell_positions = vs.cell_positions[vs.validation_frame_idx]
ax.scatter(cell_positions[:, 0], cell_positions[:, 1])
no_ticks()

plot_images_as_grid(patch_extractor.cell_patches_oa790, title=f'Validation patches. {patch_extractor.cell_patches_oa790.shape}')
patch_extractor.visualize_patch_extraction()

In [None]:
patch_extractor.extraction_mode = PE.TRAINING_MODE

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10], title=f'Training patches {patch_extractor.cell_patches_oa790.shape}')
patch_extractor.visualize_patch_extraction()

patch_extractor.extraction_mode = PE.ALL_MODE
print('All mode patches:', patch_extractor.cell_patches_oa790.shape)

# Create dataset convenience functions

Extract all patches (marked and unmarkded) for all the video session given.

If no video sessions given then all the video sessions automatically created from the videos with cell position csv in the data folder are used.

## Normal patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=reg_video_sessions,                                                                  
    n_negatives_per_positive=1,                                                                                                
    v=True,
    vv=False
)

plot_images_as_grid(cell_images[:10])
plot_images_as_grid(cell_images_marked[:10])

## Temporal patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    temporal_width=1,
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

plot_images_as_grid(cell_images[:10], title='Temporal width 1')
plot_images_as_grid(cell_images_marked[:10])

## Mixed channel patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from plotutils import plot_images_as_grid

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    mixed_channel_patches=True,
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

plot_images_as_grid(cell_images[:10], title='Mixed channel patches')
plot_images_as_grid(cell_images_marked[:10])

## Different Modes - Training, Validation, All

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor as SE

video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=video_sessions,                                                                  
    n_negatives_per_positive=1,
    
    extraction_method=SE.ALL_MODE, # Default mode
    
    v=True,
    vv=False
)

plot_images_as_grid(cell_images[:10], title=f'All mode patches. {cell_images.shape}')
plot_images_as_grid(cell_images_marked[:10])

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor as SE

video_sessions = get_video_sessions(marked=True, registered=True, validation=True)

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=video_sessions,
    use_vessel_mask=True,
    patch_size=21,
    n_negatives_per_positive=32,
    
    extraction_mode=SE.VALIDATION_MODE,
    
    v=True,
    vv=False
)

plot_images_as_grid(cell_images[:10], title=f'Validation mode patches. {cell_images.shape}')
plot_images_as_grid(cell_images_marked[:10])

In [None]:
summage = 0
for vs in video_sessions:
    summage += len(np.concatenate([vs.cell_positions[idx] for idx in vs.cell_positions], axis=0))
summage, len(cell_images), len(non_cell_images)

In [None]:
5674 / 623

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor as SE

video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=video_sessions,                                                                  
    n_negatives_per_positive=1,
    
    extraction_method=SE.TRAINING_MODE,
    
    v=True,
    vv=False
)

plot_images_as_grid(cell_images[:10], title=f'Training mode patches. {cell_images.shape}')
plot_images_as_grid(cell_images_marked[:10])

# Training Experiments

## Histogram matching

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from plotutils import plot_images_as_grid

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

In [None]:
from imageprosessing import hist_match_images

def get_highest_contrast_frame(video_sessions):
    max_diff = 0
    max_diff_idx = 0
    for i, vs in enumerate(video_sessions):
        vs.mask_frames_oa790 = crop_mask(vs.mask_frames_oa790, 15)
        diff = vs.masked_frames_oa790[0].max() - vs.masked_frames_oa790[0].min()
        if diff > max_diff:
            max_diff_idx = i
            max_diff = diff
            
    highest_contrast_frame = video_sessions[max_diff_idx].masked_frames_oa790[0]
    highest_contrast_frame = highest_contrast_frame.filled(highest_contrast_frame.mean())
    
    return highest_contrast_frame

template_frame = get_highest_contrast_frame(reg_video_sessions)

hist_matched_cell_images = hist_match_images(cell_images, template_frame)
hist_matched_non_cell_images = hist_match_images(non_cell_images, template_frame)

In [None]:
plot_images_as_grid(hist_matched_cell_images[:10])
plot_images_as_grid(hist_matched_non_cell_images[:10])

In [None]:
import collections
import torch
from cnnlearning import CNN, train, TrainingTracker 

standardize_dataset = True
trainset, validset = create_dataset_from_cell_and_no_cell_images(hist_matched_cell_images, 
                                                                 hist_matched_non_cell_images,
                                                                 validset_ratio=0.2,
                                                                 standardize=True)

model = CNN(dataset_sample=trainset, output_classes=2).to('cuda')
train_params = collections.OrderedDict(
    optimizer=torch.optim.Adam(model.parameters(), lr=.001, weight_decay=0.01),
    batch_size=256,
    do_early_stop=True,  # Optional default True
    early_stop_patience=40,
    learning_rate_scheduler_patience=20,
    epochs=200,
    shuffle=True,
    # valid_untrunsformed_normals = valid_untrunsformed_normals,
    evaluation_epochs=5,
    trainset=trainset,
    validset=validset,
)
results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(),
                                 device='cuda')

In [None]:
from classificationutils import classify_labeled_dataset, classify_images

model = results.recorded_model
model.eval()

_, train_accuracy = classify_labeled_dataset(trainset, model)
_, valid_accuracy = classify_labeled_dataset(validset, model)
positive_accuracy = classify_images(cell_images, model, standardize_dataset=standardize_dataset).sum().item() / len(cell_images)
negative_accuracy = (1 - classify_images(non_cell_images, model, standardize_dataset=standardize_dataset)).sum().item() / len(non_cell_images)

print()
print(f'Model trained on {len(cell_images)} cell patches and {len(non_cell_images)} non cell patches.')
print()
print('Brief evaluation - best validation accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')

train_model = results.recorded_train_model
train_model.eval()

_, train_accuracy = classify_labeled_dataset(trainset, train_model)
_, valid_accuracy = classify_labeled_dataset(validset, train_model)
positive_accuracy = classify_images(cell_images, train_model, standardize_dataset=standardize_dataset).sum().item() / len(cell_images)
negative_accuracy = (1 - classify_images(non_cell_images, train_model, standardize_dataset=standardize_dataset)).sum().item() / len(non_cell_images)

print()
print('Brief evaluation - best training accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_train_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')

## Motion contrast enhanced

In [None]:
%load_ext autoreload
%autoreload 2
from plotutils import plot_images_as_grid
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from video_session import SessionPreprocessor
from imageprosessing import enhance_motion_contrast
import numpy as np

video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)[:3]

for vs in video_sessions:
    pr = SessionPreprocessor(vs, [
        lambda frames: enhance_motion_contrast(frames, adapt_hist=False, normalize=True, sigma = 0.75),
        lambda frames: np.uint8(frames * 255)
    ])
    pr.apply_preprocessing()

In [None]:
cell_images, non_cell_images, _, _ = create_cell_and_no_cell_patches(
    video_sessions=video_sessions,                                                                           
    n_negatives_per_positive=2,                     
    patch_size=29,
    v=True,
    vv=True
)

plot_images_as_grid(cell_images)
plot_images_as_grid(non_cell_images)

In [None]:
import collections
import torch
from cnnlearning import CNN, train, TrainingTracker 
from generate_datasets import create_dataset_from_cell_and_no_cell_images
import torchvision
import PIL


trainset, validset = create_dataset_from_cell_and_no_cell_images(cell_images, 
                                                                 non_cell_images, 
                                                                 standardize=True,
                                                                 
                                                                 apply_data_augmentation_transformations=True,
#                                                                  translation_pixels=2,
                                                                 patch_size=21)

print('Trainset size', len(trainset), 'Validset size', len(validset))

In [None]:
import matplotlib.pyplot as plt
loader = torch.utils.data.DataLoader(validset, batch_size=len(validset))
for i, (img, lbl) in enumerate(loader):
    print('Dataset min, max', img.min().item(), img.max().item())
    print('Dataset shape', img.shape)
    idx = 0
    plt.imshow(img[idx].squeeze(), cmap='gray')
    plt.suptitle(f'Frame index {idx}')
    print('Image min max', img[idx].min().item(), img[idx].max().item())
    print('label', lbl[idx].item()) 
    break
    

In [None]:
model = CNN(dataset_sample=trainset, output_classes=2).to('cuda')
model.train()
train_params = collections.OrderedDict(
    optimizer=torch.optim.Adam(model.parameters(), lr=.001, weight_decay=.001),
    
    batch_size=512,
    do_early_stop=True,  # Optional default True
    early_stop_patience=30,
    learning_rate_scheduler_patience=10,
    epochs=250,
    shuffle=True,
    evaluation_epochs=5,
    
    trainset=trainset,
    validset=validset,
)

results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(),
#                                  criterion=torch.nn.CrossEntropyLoss(torch.tensor([.2, 1.]).cuda()),
                                 device='cuda')

## Single channel training

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from sharedvariables import get_video_sessions
from plotutils import plot_images_as_grid
from generate_datasets import create_cell_and_no_cell_patches

video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)[:3]\

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked = \
    create_cell_and_no_cell_patches(patch_size=21,
                                    temporal_width=0,
                                    
                                    mixed_channel_patches=False,
                                    
                                    video_sessions=video_sessions,
                                    n_negatives_per_positive=1,
                                    v=False, vv=False)

plot_images_as_grid(cell_images[:10])
plot_images_as_grid(cell_images_marked[:10])
plot_images_as_grid(non_cell_images[:10])
plot_images_as_grid(non_cell_images_marked[:10])
cell_images.shape, non_cell_images.shape

In [None]:
import collections
import torch
from cnnlearning import CNN, train, TrainingTracker 
from generate_datasets import create_dataset_from_cell_and_no_cell_images
import torchvision
import PIL

# From the mixed channels we pick only the oa780 and oa850 channels
trainset, validset = create_dataset_from_cell_and_no_cell_images(cell_images, 
                                                                 non_cell_images, 
                                                                 standardize=True,
                                                                 
#                                                                  apply_data_augmentation_transformations=True,
#                                                                  translation_pixels=2,
                                                                 patch_size=21,
                                                                )
# positive_dataset = LabeledImageDataset(cell_images_marked,     np.ones(len(cell_images), dtype=np.int))
# negative_dataset = LabeledImageDataset(non_cell_images_marked, np.zeros(len(non_cell_images), dtype=np.int))


print('Trainset size', len(trainset), 'Validset size', len(validset))

In [None]:
import matplotlib.pyplot as plt
loader = torch.utils.data.DataLoader(validset, batch_size=len(validset))
for i, (img, lbl) in enumerate(loader):
    print(img.min().item(), img.max().item())
    print(img.shape)
    idx = -5
    plt.imshow(img[idx].squeeze(), cmap='gray')
    print(img[idx].min(), img[idx].max())
    print('label', lbl[idx].item())
#     print(lbl)
    

In [None]:
model = CNN(dataset_sample=trainset, output_classes=2).to('cuda')
model.train()
train_params = collections.OrderedDict(
    optimizer=torch.optim.Adam(model.parameters(), lr=.001, weight_decay=.001),
    
    batch_size=512,
    do_early_stop=True,  # Optional default True
    early_stop_patience=30,
    learning_rate_scheduler_patience=10,
    epochs=250,
    shuffle=True,
    evaluation_epochs=5,
    
    trainset=trainset,
    validset=validset,
)

results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(torch.tensor([.2, 1.]).cuda()),
                                 device='cuda')

In [None]:
model

In [None]:
from classificationutils import classify_labeled_dataset, classify_images
from learningutils import LabeledImageDataset, ImageDataset
from imageprosessing import center_crop_images
import numpy as np
model = results.recorded_model
model.eval()

positive_dataset = LabeledImageDataset(center_crop_images(cell_images, 21),     np.ones(len(cell_images), dtype=np.int))
negative_dataset = LabeledImageDataset(center_crop_images(non_cell_images, 21), np.zeros(len(non_cell_images), dtype=np.int))

_, train_accuracy, positive_accuracy, negative_accuracy = \
classify_labeled_dataset(trainset, model, ret_pos_and_neg_acc=True)
_, valid_accuracy = classify_labeled_dataset(validset, model)

print()
print(f'Model trained on {len(cell_images)} cell patches and {len(non_cell_images)} non cell patches.')
print()
print('Brief evaluation - best validation accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')


# Training Convenience Function

This function can load a trained model from cache with the required parameters

In [None]:
from train_model import train_model_demo
from sharedvariables import get_video_sessions
import collections
%load_ext autoreload
%autoreload 2
    
video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)

train_params = collections.OrderedDict(
    lr=.001,
    weight_decay=.001,
    
    batch_size=512,
    do_early_stop=True,  # Optional default True
    
    early_stop_patience=30,
    learning_rate_scheduler_patience=10,
    
    n_negatives_per_positive=5,
    
    epochs=250,
    shuffle=True,
    evaluation_epochs=5,
)

model, results = train_model_demo(
        video_sessions=video_sessions[:2], # The video sessions the data will be created from

        patch_size=27,
        temporal_width=0,
    
        mixed_channel_patches=True,
        drop_confocal_channel=True,
    
        do_hist_match=False,
        n_negatives_per_positive=5,

        standardize_dataset=True,
        apply_data_augmentation_to_dataset=False,
        valid_ratio=0.2,
    
        try_load_data_from_cache=False, # If true attemps to load data from cache. If false just created new (and overwrites old if exist)
        try_load_model_from_cache=False, # Attemps to load model from cache. If false creates new
    
        train_params=train_params, # The training train parameters, uses default if not specified
        additional_displays=None,
        v=True, vv=True,
)

In [None]:
import copy
model_tmp = copy.deepcopy(model)

## Create probability map from mixed channel patches

In [None]:
from cnnlearning import TrainingTracker
import os
from sharedvariables import CACHED_MODELS_FOLDER, get_video_sessions
import matplotlib.pyplot as plt

video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)
vs = video_sessions[0]


plt.subplot(131)
plt.imshow(vs.vessel_mask_confocal)
plt.subplot(132)
plt.imshow(vs.vessel_mask_oa850)
plt.subplot(133)
plt.imshow(vs.registered_vessel_mask_oa850)

In [None]:
import scipy
import skimage
from skimage.morphology import binary_dilation as bd
plt.imshow(scipy.ndimage.morphology.binary_dilation(vs.registered_vessel_mask_oa850, iterations=15))

final_image = vs.registered_vessel_mask_oa850
for i in range(2):
    final_image = bd(final_image)
plt.imshow(final_image)
scipy.ndimage.morphology.binary_dilation(vs.registered_vessel_mask_oa850, iterations=15) == final_image

In [None]:
from patchextraction import extract_patches
from imageprosessing import ImageRegistator
import matplotlib.pyplot as plt
import numpy as np
import cv2

video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)
vs = video_sessions[0]
padding = cv2.BORDER_REPLICATE

mask = np.bool8(final_image)
mask &= np.bool8(vs.registered_mask_frames_oa850[0])
# mask = np.bool8(np.ones_like(vs.frames_oa790[0]))
mask_flattened = mask.reshape(-1)
vessel_pixel_indices = np.where(mask_flattened)[0]

patches_oa790 = extract_patches(vs.frames_oa790[0], patch_size=27, padding=padding)[vessel_pixel_indices]
patches_oa850 = extract_patches(vs.registered_frames_oa850[0], patch_size=27, padding=padding)[vessel_pixel_indices]
patches = np.empty_like(patches_oa790, shape=(*patches_oa850.shape[:-1], 2))
patches[..., 0] = patches_oa790.squeeze()
patches[..., 1] = patches_oa850.squeeze()
plt.imshow(mask)

In [None]:
import torch
from learningutils import ImageDataset
@torch.no_grad()
def get_label_probability(images, model, standardize=True, to_grayscale=False, n_output_classes=2, device='cuda'):
    """ Make a prediction for the images giving probabilities for each labels.

    Arguments:
        images -- NxHxWxC or NxHxW. The images
        model  -- The model to do the prediction

    Returns:
        Returns the probability per label for each image.
    """
    model = model.to(device)
    model = model.eval()

    if len(images.shape) == 3:
        # Add channel dimension when images are single channel grayscale
        # i.e (Nx100x123 -> Nx100x123x1)
        images = images[..., None]

    image_dataset = ImageDataset(images, standardize=standardize, to_grayscale=to_grayscale)
    loader = torch.utils.data.DataLoader(
        image_dataset,
        batch_size=1024 * 3,
    )

    c = 0
    predictions = torch.empty((len(image_dataset), n_output_classes), dtype=torch.float32)
    for images in loader:
        images = images.to(device)

        print('_-_-_-_-_-_-_-_-_')
        print(images.shape)
        pred = model(images)
        pred = torch.nn.functional.softmax(pred, dim=1)
#         print('_-_-_-_-_-_-_-_-_')
#         print(images.shape)
#         print(pred.shape)
#         print(predictions[c:c + len(pred), ...].shape)
#         print(c, len(pred))
#         print('_-_-_-_-_-_-_-_-_')
        predictions[c:c + len(pred), ...] = pred
        c += pred.shape[0]

    return predictions

device = 'cuda'
model = results.recorded_model.to(device)
model = model.train()
label_probabilities = get_label_probability(patches, model, standardize=True,
                                            to_grayscale=False, device='cuda')
probability_map = np.zeros(vs.frames_oa790[0].shape[:2], dtype=np.float32)
rows, cols = np.unravel_index(vessel_pixel_indices, probability_map.shape[:2])
probability_map[rows, cols] = label_probabilities[:, 1]
plt.imshow(probability_map, cmap='hot')
plt.scatter(vs.cell_positions[0][..., 0], vs.cell_positions[0][..., 1])

In [None]:
plt.imshow(probability_map)

In [None]:
64 * 7 * 7

In [None]:
image_dataset = ImageDataset(patches)
loader = torch.utils.data.DataLoader(
    image_dataset,
    batch_size=1,
)
model.eval()
for images in loader:
    print(images.shape)
    print(model.convolutional(images.to(device)).shape)
    break

### Picking Negative Points


In [None]:
from sharedvariables import get_video_sessions
import matplotlib.pyplot as plt
import matplotlib
from nearest_neighbors import get_nearest_neighbor, get_nearest_neighbor_distances
# https://matplotlib.org/api/_as_gen/matplotlib.patches.Circle.html

video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)
vs = video_sessions[5]

_, ax = plt.subplots(figsize=(11, 11))
cx, cy = vs.cell_positions[0][0]
r = 19
ax.set_aspect(1)
ax.scatter(vs.cell_positions[0][..., 0], vs.cell_positions[0][..., 1])
ax.add_artist(matplotlib.patches.Circle((cx, cy), r, fill=False, edgecolor='r', linestyle=':'))

In [None]:
def get_random_points_on_circles(points, n_points_per_circle=1, ret_radii=False):
    from nearest_neighbors import get_nearest_neighbor, get_nearest_neighbor_distances
   
    assert 1 <= n_points_per_circle <= 7, f'Points per circle must be between 1 and 7 not {n_points_per_circle}'
  
    neighbor_distances, nearest_neighbor_idxs = get_nearest_neighbor(points, 2)
    
    dist_flat = neighbor_distances.flatten()
    dist_flat = np.delete(dist_flat, np.where(dist_flat > (dist_flat.mean() + 0 * dist_flat.std()))[0])
    mean_distance = dist_flat.mean()

    nnp = n_points_per_circle
    uniform_angle_displacements = np.array([0, np.math.pi, np.math.pi * 0.5, np.math.pi * 3 * 0.5,  np.math.pi * 0.25, 
                                            3 * np.math.pi * 0.25, 5 * np.math.pi * 0.25, 7 * np.math.pi * 0.25]).squeeze()
    c = 0
    rxs = np.empty(len(points) * nnp)
    rys = np.empty(len(points) * nnp)
    radii = np.empty(len(points))
    for centre_point_idx, (distances, neighbor_idxs) in enumerate(zip(neighbor_distances, nearest_neighbor_idxs)):
        centre_point = points[centre_point_idx]
        closest_point_1, closest_point_2 = points[neighbor_idxs]

        r = (np.min(distances) *1.3) / 2
        radii[centre_point_idx] = r
        cx, cy = centre_point

        angle = np.random.rand() * np.math.pi * 2;
        random_angles = np.array(angle + uniform_angle_displacements).squeeze()

        rx = np.array([cx + np.cos(random_angles[:nnp]) * r]).squeeze()
        ry = np.array([cy + np.sin(random_angles[:nnp]) * r]).squeeze()

        rxs[c:c + nnp] = rx
        rys[c:c + nnp] = ry

        c += nnp
    
    if ret_radii:
        return rxs, rys, radii
    else:
        return rxs, rys

In [None]:
_, ax = plt.subplots(figsize=(11, 11))
ax.scatter(points[..., 0], points[..., 1], c='b')
ax.scatter(rxs, rys, c='r')

for idx, r in enumerate(radii):
    cx, cy = points[idx]
    ax.add_artist(matplotlib.patches.Circle((cx, cy), r, fill=False, edgecolor='r', linestyle='--'))   

In [None]:
points = vs.cell_positions[0]
neighbor_distances, nearest_neighbor_idxs = get_nearest_neighbor(points, 2)

_, ax = plt.subplots(figsize=(11, 11))
cx, cy = vs.cell_positions[0][0]
r = 19
ax.set_aspect(1)
ax.scatter(points[..., 0], points[..., 1])

centre_point_idx = 6

closest_point_1, closest_point_2 = points[[closest_point_1_idx, closest_point_2_idx]]
ax.scatter(centre_point[0], centre_point[1], label='centre point')
ax.scatter(closest_point_1[0], closest_point_1[1], label='other point')
ax.scatter(closest_point_2[0], closest_point_2[1], label='other point')
ax.legend()

dist_flat = distances.flatten()
dist_flat = np.delete(dist_flat, np.where(dist_flat > (dist_flat.mean() + 0 * dist_flat.std()))[0])
mean_distance = dist_flat.mean()

nnp = 3
uniform_angle_displacements = np.array([0, np.math.pi, np.math.pi * 0.5, np.math.pi * 3 * 0.5,  np.math.pi * 0.25, 
                                         3 * np.math.pi * 0.25, 5 * np.math.pi * 0.25, 7 * np.math.pi * 0.25]).squeeze()

rxs = np.empty(len(points) * nnp)
rys = np.empty(len(points) * nnp)

c = 0
for centre_point_idx, (distances, neighbor_idxs) in enumerate(zip(neighbor_distances, nearest_neighbor_idxs)):
    centre_point = points[centre_point_idx]
    closest_point_1, closest_point_2 = points[neighbor_idxs]
    
    r = (np.min(distances) *1.3) / 2
    cx, cy = centre_point
    
    angle = np.random.rand() * np.math.pi * 2;
    random_point_angles = np.array(angle + uniform_angle_displacements).squeeze()
    
    rx = np.array([cx + np.cos(random_point_angles[:nnp]) * r]).squeeze()
    ry = np.array([cy + np.sin(random_point_angles[:nnp]) * r]).squeeze()

    rxs[c:c + nnp] = rx
    rys[c:c + nnp] = ry
    
    c += nnp
    
    ax.scatter(rx, ry, c='r')
    ax.scatter(rx_1, ry_1, c='r')

    ax.add_artist(matplotlib.patches.Circle((centre_point[0], centre_point[1]), r, fill=False, edgecolor='b', linestyle=':'))   
#     ax.add_artist(matplotlib.patches.Circle((centre_point[0], centre_point[1]), distances[0], fill=False, edgecolor='r', linestyle=':'))
#     ax.add_artist(matplotlib.patches.Circle((centre_point[0], centre_point[1]), distances[1], fill=False, edgecolor='r', linestyle=':'))    

In [None]:
_, ax = plt.subplots(figsize=(11, 11))
ax.scatter(rxs, rys, c='r')
ax.scatter(points[..., 0], points[..., 1], c='b')

In [None]:
random_point_angles = [angle + np.array([0, np.math.pi, np.math.pi * 0.5, np.math.pi * 3 * 0.5 ])]
cx + np.cos(random_point_angles[:4]) +  np.math.pi * r

In [None]:
np.cos(random_point_angles[:2]).shape

In [None]:
plt.plot(distances_no_outliers.flatten())

In [None]:
import numpy as np


In [None]:
distances.flatten().shape, distances_no_outliers.flatten().shape

In [None]:
cx, 

In [None]:
import os
from cnnlearning import TrainingTracker

results = TrainingTracker.from_file(os.path.join('tmp', 'results_file.pkl'))

In [None]:
import os
import pickle
with open(os.path.join('cache', 'models', 'blood_cell_classifier_va_0.847_ps_27_tw_0_mc_true_hm_false_npp_1_st_true_da_false', 'results.pkl'), 'rb') as results_file:
    results = pickle.load(results_file)

In [None]:
from generate_datasets import get_cell_and_no_cell_patches
from sharedvariables import get_video_sessions
video_sessions = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)

trainset, validset, cell_images, non_cell_images, _, _, _ =\
get_cell_and_no_cell_patches(
        video_sessions=video_sessions,
        patch_size=27,
        temporal_width=0,
        mixed_channel_patches=True,
    
        do_hist_match=False,
        n_negatives_per_positive=1,

        standardize_dataset=True,
        apply_data_augmentation_to_dataset=False,

        try_load_from_cache=True,

        v=True, vv=True
)

In [None]:
import numpy as np
from classificationutils import classify_labeled_dataset, classify_images
from learningutils import LabeledImageDataset, ImageDataset

model = results.recorded_model
model.eval()

positive_dataset = LabeledImageDataset(cell_images[..., 1:],     np.ones(len(cell_images), dtype=np.int))
negative_dataset = LabeledImageDataset(non_cell_images[..., 1:], np.zeros(len(non_cell_images), dtype=np.int))

_, train_accuracy = classify_labeled_dataset(trainset, model)
_, valid_accuracy = classify_labeled_dataset(validset, model)

_, positive_accuracy = classify_labeled_dataset(positive_dataset, model)
_, negative_accuracy = classify_labeled_dataset(negative_dataset, model)

print()
print(f'Model trained on {len(cell_images)} cell patches and {len(non_cell_images)} non cell patches.')
print()
print('Brief evaluation - best validation accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')

train_model = results.recorded_train_model
train_model.eval()

_, train_accuracy = classify_labeled_dataset(trainset, train_model)
_, valid_accuracy = classify_labeled_dataset(validset, train_model)


_, positive_accuracy = classify_labeled_dataset(positive_dataset, train_model)
_, negative_accuracy = classify_labeled_dataset(negative_dataset, train_model)

print()
print('Brief evaluation - best training accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_train_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')

# Flow quantification

In [None]:
from classificationutils import create_probability_map
from patchextraction import extract_patches
from imageprosessing import std_image

frame = vs.frames_oa790[0]
extract_patches(frame).shape

std_image_oa850    = std_image(vs.frames_oa850, vs.mask_frames_oa850, sigma=0, method='j_tam', adapt_hist=True)
std_image_confocal = std_image(vs.frames_confocal, vs.mask_frames_confocal, sigma=0, method='j_tam', adapt_hist=True)

plt.subplot(121)
plt.imshow(std_image_oa850)
plt.subplot(122)
plt.imshow(std_image_confocal)

In [None]:
from imageprosessing import ImageRegistator
from vesseldetection import create_vessel_mask

confocal_mask = create_vessel_mask(std_image_confocal)
oa850_mask = create_vessel_mask(std_image_oa850)
plt.subplot(121)
plt.imshow(confocal_mask, label='Confocal mask')
plt.subplot(122)
plt.imshow(oa850_mask, label='oa850 mask')

In [None]:
from sharedvariables import get_video_sessions
from matplotlib import pyplot as plt

video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
vs = video_sessions[0]
plt.imshow(vs.frames_oa790[0], cmap='gray')
plt.imshow(vs.registered_frames_oa850[0], cmap='gray')

In [None]:
from sharedvariables import get_video_sessions
from matplotlib import pyplot as plt

video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
vs = video_sessions[0]

plt.subplot(121)
plt.imshow(vs.vessel_mask_confocal)
vs.vessel_mask_confocal_file = ''
vs.vessel_mask_confocal = None
plt.subplot(122)
plt.imshow(vs.vessel_mask_confocal)

In [None]:
from vesseldetection import create_vessel_image

In [None]:
ir = ImageRegistator(source=oa850_mask, target=confocal_mask)
ir.register_vertically()

plt.subplot(121)
plt.imshow(confocal_mask, label='Confocal mask')
plt.subplot(122)
plt.imshow(ir.registered_source, label='oa850 mask')

In [None]:
img = frame
patch_size = (21, 21)
patch_height, patch_width = patch_size

kernel_height, kernel_width = patch_height, patch_width

inp = torch.from_numpy(img)

if len(inp.shape) == 3:
    inp = inp.permute(-1, 0, 1)
elif len(inp.shape) == 2:
    inp = inp[None, ...]
inp = inp[None, ...]

print("Inp.shape", inp.shape)
patches = inp.unfold(0, 0, 1).unfold(2, kernel_height, 1).unfold(3, kernel_width, 1)
# Shape -> 1 x 1 x H x W x Hpatch x Wpatch
print("Patches shape 1", patches.shape)

patches = patches.permute(2, 3, 1, -2, -1, 0)[..., 0]
print(patches.shape)

### Applying preprocessing to session frames

In [None]:
import PIL.Image
video_sessions = get_video_sessions(marked=True, validation=True)

PIL.Image.fromarray(vs.frames_oa790[0]).save('frame_790_registered.png')
PIL.Image.fromarray(vs.frames_oa790[0] * vs.mask_frames_oa790[0]).save('frame_790_registered_masked.png')
PIL.Image.fromarray(vs.mask_frames_oa790[0].astype(np.uint8) * 255).save('frame_790_registered_mask.png')

In [None]:
from sharedvariables import get_video_sessions
from plotutils import *
from imageprosessing import SessionPreprocessor, enhance_motion_contrast, normalize_data


video_sessions = get_video_sessions(marked=True, registered=True)
vs = video_sessions[3]

preprocessor = SessionPreprocessor(vs, [
    lambda frames: enhance_motion_contrast(frames, sigma=1.2),
    lambda frames: normalize_data(frames, (0, 255)).astype(np.uint8)
])

plt
plt.subplot(121)
plt.imshow(vs.frames_oa790[0])
plt.scatter(vs.cell_positions[0][:, 0], vs.cell_positions[0][:, 1], s=5)
plt.title('Before applying preprocessing', fontsize=15)

preprocessor.apply_preprocessing()

plt.subplot(122)
plt.imshow(vs.frames_oa790[0])
plt.scatter(vs.cell_positions[0][:, 0], vs.cell_positions[0][:, 1],  s=5)

plt.title('After applying preprocessing', fontsize=15)
pass

In [None]:
preprocessor.apply_preprocessing()

In [None]:
plt.figure(figsize=(75, 75))
plt.subplot(121)
plt.imshow(vs.frames_oa790[0], cmap='gray')
plt.scatter(vs.cell_positions[0][:, 0], vs.cell_positions[0][:, 1], s=5)
plt.title('Before applying preprocessing', fontsize=15)



plt.subplot(122)
plt.imshow(vs.frames_oa790[0], cmap='gray')
plt.scatter(vs.cell_positions[0][:, 0], vs.cell_positions[0][:, 1],  s=5)

plt.title('After applying preprocessing', fontsize=15)
pass

### Apply motion contrast enhancement to all video sessions

In [None]:
from sharedvariables import get_video_sessions
from imageprosessing import SessionPreprocessor, enhance_motion_contrast
import tqdm
video_sessions_enhanced = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)

for vs in tqdm.tqdm(video_sessions_enhanced):
    preprocessor = SessionPreprocessor(vs, lambda frames: enhance_motion_contrast(frames, sigma=1))
    preprocessor.apply_preprocessing()

In [None]:
from generate_datasets import create_cell_and_no_cell_patches
cell_images_enhanced, non_cell_images_enhanced, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=video_sessions_enhanced
)

In [None]:
from plotutils import plot_images_as_grid
plot_images_as_grid(cell_images_enhanced[:10])
plot_images_as_grid(cell_images_marked[:10])
plot_images_as_grid(non_cell_images_enhanced[:10])
plot_images_as_grid(non_cell_images_marked[:10])

In [None]:
import collections
import torch
from cnnlearning import CNN, train, TrainingTracker 
from generate_datasets import create_dataset_from_cell_and_no_cell_images

trainset, validset = create_dataset_from_cell_and_no_cell_images(cell_images_enhanced, 
                                                                 non_cell_images_enhanced,
                                                                 standardize=True)
model = CNN(dataset_sample=trainset, output_classes=2).to('cuda')
train_params = collections.OrderedDict(
    epochs=250,
    lr = .001,
    
    weight_decay=0.01,
    batch_size='all', # can be a number or None/'all' to train all trainset at once. 
    do_early_stop=True,  # Optional default True
    early_stop_patience=60, # How many epochs with no validation accuracy improvement before stopping early
    learning_rate_scheduler_patience=20, # How many epochs with no validation accuracy improvement before lowering learning rate
    evaluate_epochs=10,
    
    trainset=trainset,
    validset=validset,
    shuffle=True)
results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(),
                                 device='cuda')

In [None]:
from sharedvariables import get_video_sessions
from imageprosessing import SessionPreprocessor, enhance_motion_contrast, normalize_data, frame_differencing
import matplotlib.pyplot as plt
import numpy as np

video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]


preprocessor = SessionPreprocessor(vs, [
    lambda frames: frame_differencing(frames, sigma=1.2),
    lambda frames: np.uint8(normalize_data(frames, (0, 255))),
    lambda frames: enhance_motion_contrast(frames, sigma=1.2)]
                                  )

plt.subplot(121)
plt.imshow(vs.frames_oa790[0])
plt.title('Before applying preprocessing', fontsize=15)

preprocessor.apply_preprocessing_to_oa790()

plt.subplot(122)
plt.imshow(vs.frames_oa790[0])
plt.title('After applying preprocessing', fontsize=15)
pass

In [None]:
video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]
plt.imshow(filters.gaussian(vs.masked_frames_oa790[0]))

In [None]:
blur_images(vs.masked_frames_oa790).max()

In [None]:
from sharedvariables import get_video_sessions
from imageprosessing import SessionPreprocessor, equalize_adapthist_images
import matplotlib.pyplot as plt
import numpy as np
from skimage import filters

video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]

def blur_images(frames, sigma=1):
    blurred_images = np.empty_like(frames, dtype=np.float64)
    for i, im in enumerate(frames):
        blurred_images[i] = filters.gaussian(im, sigma)
    return blurred_images

preprocessor = SessionPreprocessor(vs, [
    lambda frames: blur_images(frames, sigma=2),
    lambda frames: np.uint8(frames * 255),
    equalize_adapthist_images,
    lambda frames: np.ma.array([filters.unsharp_mask(f, radius=5, amount=1, preserve_range=True) for f in frames], dtype=np.uint8),
])

_, axes = plt.subplots(1, 2, figsize=(60, 70))
axes[0].imshow(vs.frames_oa790[0], cmap='gray')
axes[0].set_title('Before applying preprocessing', fontsize=15)

preprocessor.apply_preprocessing_to_oa790()

axes[1].imshow(vs.frames_oa790[0], cmap='gray')
axes[1].set_title('After applying preprocessing', fontsize=15)
pass

In [None]:
for f in [
    lambda frames: frame_differencing(frames, sigma=1.2),
    lambda frames: enhance_motion_contrast(frames, sigma=1.2),
                                       ]:
    print(f)

In [None]:
from sharedvariables import get_video_sessions
from imageprosessing import SessionPreprocessor, enhance_motion_contrast
import mahotas as mh
video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]

plt.imshow(mh.gaussian_filter(vs.masked_frames_oa790.mean(0), sigma=2))
# preprocessor = SessionPreprocessor(vs, lambda frames, masks: np.ma.mean(frames, masks))

# plt.subplot(121)
# plt.imshow(vs.frames_oa790[0])
# plt.title('Before applying preprocessing', fontsize=15)

# preprocessor.apply_preprocessing()

# plt.subplot(122)
# plt.imshow(vs.frames_oa790[0])
# plt.title('After applyging preprocessing', fontsize=15)
# pass

In [None]:
string = 'hey'
import collections
print(isinstance(string, List))
for s in string:
    print(s)

In [None]:
string = 'Hey'
list(string)

In [None]:
plt.imshow(frame_differencing(vs.masked_frames_oa790, 2)[3])
frame_differencing(vs.masked_frames_oa790, 2).max()

In [None]:
video_sessions = get_video_sessions()

vs = [vs for vs in video_sessions if 'shared-videos' in vs.video_file][0]
print(vs)
plt.imshow(vs.frames_oa790[0])
plt.show()
extractor = SessionPatchExtractor(vs, patch_size=21)
preprocessor = SessionPreprocessor(vs, lambda frames, masks: enhance_motion_contrast(frames, masks, sigma=0.125, mask_crop_pixels=0))
preprocessor.apply_preprocessing_to_oa790()
plt.imshow(vs.frames_oa790[0])

In [None]:
plt.imshow(vs.frames_oa790[23])

In [None]:
from sharedvariables import get_video_sessions
from imageprosessing import SessionPreprocessor, enhance_motion_contrast
from patchextraction import SessionPatchExtractor

video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]

extractor = SessionPatchExtractor(vs, patch_size=21)
preprocessor = SessionPreprocessor(vs, lambda frames, masks: enhance_motion_contrast(frames, masks, sigma=1.2, mask_crop_pixels=0))

cell_patches_before_preprocessing = extractor.cell_patches_oa790
non_cell_patches_before_preprocessing = extractor.non_cell_patches_oa790

plt.subplot(121)
plt.imshow(vs.frames_oa790[0])
plt.title('Before applying preprocessing', fontsize=15)
plt.scatter(vs.cell_positions[0][:, 0], vs.cell_positions[0][:, 1])

preprocessor.apply_preprocessing()

extractor._reset_patches()
cell_patches_after_preprocessing = extractor.cell_patches_oa790
non_cell_patches_after_preprocessing = extractor.non_cell_patches_oa790

plt.subplot(122)
plt.imshow(vs.frames_oa790[0])
plt.title('After applying preprocessing', fontsize=15)
pass

In [None]:
plot_images_as_grid(cell_patches_before_preprocessing[:10])
plot_images_as_grid(non_cell_patches_before_preprocessing[:10])

plot_images_as_grid(cell_patches_after_preprocessing[:10])
plot_images_as_grid(non_cell_patches_after_preprocessing[:10])

In [None]:
print('Channel oa850')
plot_images_as_grid(patch_extractor.mixed_channel_cell_patches[:10][..., 1])
# plot_images_as_grid(patch_extractor.mixed_channel_marked_cell_patches[:10][..., 1])

plot_images_as_grid(patch_extractor.mixed_channel_non_cell_patches[:10][..., 1])
# plot_images_as_grid(patch_extractor.mixed_channel_marked_non_cell_patches[:10][..., 1])

In [None]:
 patch_extractor.temporal_marked_cell_patches_oa790[1].transpose(2, 0, 1).shape

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid, no_ticks
import numpy as np

video_sessions = get_video_sessions(should_have_marked_video=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(vs, patch_size=37)
patch_extractor.patch_size =31
print(patch_extractor.temporal_cell_patches_oa790.max())
print(patch_extractor.temporal_marked_cell_patches_oa790.max())

_, axes = plt.subplots(1, 3)
no_ticks(axes)
for ax, channel_patch in zip(axes, patch_extractor.temporal_marked_cell_patches_oa790[29].transpose(2, 0, 1)):
    ax.imshow(channel_patch, cmap='gray')
    
# _, axes = plt.subplots(1, 3)
# no_ticks(axes)
# for ax, channel_patch in zip(axes, patch_extractor.temporal_non_cell_patches_oa790[0].transpose(2, 0, 1)):
#     ax.imshow(channel_patch, cmap='gray')

In [None]:
 patch_extractor.temporal_non_cell_patches_oa790

In [None]:
patch_extractor.temporal_marked_cell_patches_oa790[1].transpose(2, 0, 1).shape

In [None]:
# To get all the patches extracted
plot_images_as_grid(patch_extractor.cell_patches_oa790)
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790)

plot_images_as_grid(patch_extractor.non_cell_patches_oa790)
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790)

### How to get standard deviation image and vessel masks

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(40, 10))

# If empty string then file not found
if session.std_image_oa850_file != "":
    axes[0].imshow(session.std_image_oa850)
    axes[0].set_title('Std image oa850', fontsize=20)

if session.std_image_confocal_file != "":
    axes[1].imshow(session.std_image_confocal)
    axes[1].set_title('Std image confocal', fontsize=20)

if session.vessel_mask_oa850_file != '':
    axes[2].imshow(session.vessel_mask_oa850)
    axes[2].set_title('Vessel mask oa850', fontsize=20)

if session.vessel_mask_confocal_file != '':
    axes[3].imshow(session.vessel_mask_confocal)
    axes[3].set_title('Vessel mask confocal', fontsize=20)
pass

### Voronoi

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import extract_patches
from scipy.spatial import Voronoi, voronoi_plot_2d
import numpy as np
vs = get_video_sessions(should_be_registered=True, should_have_marked_cells=True)[2]
points = vs.cell_positions[0]
vor = Voronoi(points)
fig = voronoi_plot_2d(vor)
def get_random_points_in_voronoi_diagram(centroids, limits):
    
    vor = Voronoi(centroids, qhull_options='Qbb Qc Qx', incremental=False)
    vor.close()

    edges = np.array(vor.ridge_vertices)

    edges_start = edges[:, 0]
    edges_end = edges[:, 1]

    vertices_start = vor.vertices[edges_start]
    print(vertices_start.shape)
    vertices_end = vor.vertices[edges_end]

    t = np.random.rand(vertices_start.shape[0])

    random_vertices = t[:, np.newaxis] * vertices_start + (1 - t[:, np.newaxis]) * vertices_end
    random_vertices = random_vertices[edges_start != -1]

    random_vertices = random_vertices[random_vertices[:, 0] >= 0]
    random_vertices = random_vertices[random_vertices[:, 0] <= limits[1]]
    random_vertices = random_vertices[random_vertices[:, 1] >= 0]
    random_vertices = random_vertices[random_vertices[:, 1] <= limits[0]]

    return random_vertices
# random_points = get_random_points_in_voronoi_diagram(points, vs.frames_oa790[0].shape[:2])
plt.imshow(np.uint8(~vs.vessel_mask_confocal)*255, cmap='gray')
# plt.scatter(random_points[:, 0], random_points[:, 1], c='r', s=50)

In [None]:
plt.imshow(np.uint8(~vs.vessel_mask_confocal)*0)

In [None]:

vs.mask_frames_oa790[0].shape, vs.frames_oa790.shape, vs.masked_frames_oa790.shape

In [None]:
from videoutils import get_frames_from_video

get_frames_from_video(vs.mask_video_oa790_file).shape, get_frames_from_video(vs.video_oa790_file).shape

In [None]:
# np.ma.masked_array(self.frames_oa790, ~self.mask_frames_oa790)
vs.mask_frames_oa790.shape, vs.frames_oa790.shape