In [None]:
# Jupyter Notebook settings

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))
%autosave 1
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Plotting settings
import matplotlib.pyplot as plt
size=35
params = {'legend.fontsize': 'large',
          'figure.figsize': (20,8),
          'axes.labelsize': size,
          'axes.titlesize': size,
          'xtick.labelsize': size*0.75,
          'ytick.labelsize': size*0.75,
          'axes.titlepad': 25}
plt.rcParams.update(params)

# Accessing files

Every file under any folder in ./data is parsed and put into dictionaries that group videos of the same source.
Videos of the same source are considered videos coming from the same Subject, same Session, same OD/OS, same (x, y),
same type (Confocal, OA790, OA850).

The parsing is case insensitive with the following rules:

**unmarked videos**:
must not contain 'mask' or '_marked'

**marked videos**:
Must end with '_marked.\<<file_extension\>>'

**standard deviation images**:
Must end with
'_std.\<<file_extension\>>'

**vessel mask images**: 
Must end with
'_vessel_mask.\<<file_extension\>>'

**channel type**:
must contain one of 'OA790', 'OA850', 'Confocal' (case insensitive)

### Example to using VideoSession objects and the get_video_session() function
(useful for training and training)

In [None]:
from sharedvariables import get_video_sessions
from os.path import basename

video_sessions = get_video_sessions(should_have_marked_cells=True,
                                    should_be_registered=True)
for session in video_sessions:
        assert session.has_marked_video
        assert session.is_registered
        assert session.has_marked_cells
        print('-----------------------')
        print('Video file:', basename(session.video_file))
        print('Does video have a corresponding marked video?:', session.has_marked_video)
        print('Subject number:', session.subject_number)
        print('Session number:', session.session_number)
        print('Marked Video OA790:', basename(session.marked_video_oa790_file))
        print('Std dev image confocal:', basename(session.std_image_confocal_file))
        print('Std dev image OA850:', basename(session.std_image_oa850_file))
        print('Vessel mask OA850:', basename(session.vessel_mask_oa850_file))
        print('Vessel mask confocal:', basename(session.vessel_mask_confocal_file))
        print('Cell position csv files:', *[basename(f) for f in session.cell_position_csv_files], sep='\n')
        print()
        
print('Number of video sessions ', len(video_sessions))

# Reading frames from videos - (and cell positions for each frame)

You can get access to the frames of the video session.

In [None]:
import matplotlib.pyplot as plt
from sharedvariables import get_video_sessions
from plotutils import no_ticks

video_sessions = get_video_sessions(should_have_marked_cells=True,
                                    should_be_registered=True)
_, axes = plt.subplots(1, 2, figsize=(20, 7))
no_ticks(axes)

axes[0].imshow(session.frames_oa790[0], cmap='gray')
axes[0].set_title(f"First frame of {basename(session.video_oa790_file)}", fontsize=10)
axes[0].scatter(session.cell_positions[0][:, 0], session.cell_positions[0][:, 1], label='cell positions', s=10)
axes[0].legend()
    
axes[1].imshow(session.marked_frames_oa790[0], cmap='gray')
axes[1].set_title(f"First marked frame of {basename(session.marked_video_oa790_file)}", fontsize=10)
pass

# How to extract cell and no cell patches

## Using SessionPatchExtractor (Object oriented way) 

### Simple patch extraction

In [None]:
import matplotlib.pyplot as plt
plt.imshow(patch_extractor.marked_cell_patches_oa790[0])

In [None]:
%load_ext autoreload
%autoreload 2
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(should_have_marked_cells=True)
vs = video_sessions[5]

patch_extractor = SessionPatchExtractor(vs, patch_size=21, n_negatives_per_positive=1)

plot_images_as_grid(patch_extractor.cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.non_cell_patches_oa790[:10])
plot_images_as_grid(patch_extractor.marked_non_cell_patches_oa790[:10])

### Temporal patches

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(should_have_marked_cells=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(vs, patch_size=21, temporal_width=1, n_negatives_per_positive=7)

plot_images_as_grid(patch_extractor.temporal_cell_patches_oa790[:10], title='Temporal cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_cell_patches_oa790[:10])

plot_images_as_grid(patch_extractor.temporal_non_cell_patches_oa790[:10], title='Temporal non cell patches temporal width 1')
plot_images_as_grid(patch_extractor.temporal_marked_non_cell_patches_oa790[:10])

# A higher temporal width will give patches with more channells
patch_extractor.temporal_width = 1
print(f'Temporal patches shape with temporal width = 1: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 4
print(f'Temporal patches shape with temporal width = 4: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 5
print(f'Temporal patches shape with temporal width = 5: {patch_extractor.temporal_cell_patches_oa790.shape}')
patch_extractor.temporal_width = 6
print(f'Temporal patches shape with temporal width = 6: {patch_extractor.temporal_cell_patches_oa790.shape}')
print(f'As temporal window becomes bigger notice that there are less patches.')

### Mixed channel patches
 
Mixed channel patches give patches with 3 channels, the first channel is confocal video patch, second channel is from the oa780 channel,
third channel is from the oa850 channel.

The confocal video and the oa790 channel have the capillaries at the same position. The oa850 video has a vertical displacement, the video is registered before extracting the patches.

In [None]:
from sharedvariables import get_video_sessions
from patchextraction import SessionPatchExtractor
from plotutils import plot_images_as_grid

video_sessions = get_video_sessions(should_have_marked_cells=True)
vs = video_sessions[0]

patch_extractor = SessionPatchExtractor(vs, patch_size=21)

plot_images_as_grid(patch_extractor.mixed_channel_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_cell_patches[:10])

plot_images_as_grid(patch_extractor.mixed_channel_non_cell_patches[:10])
plot_images_as_grid(patch_extractor.mixed_channel_marked_non_cell_patches[:10])

In [None]:
patch_extractor.visualize_mixed_channel_patch_extraction(frame_idx=20)

# Create dataset convenience functions

Extract all patches (marked and unmarkded) for all the video session given.

If no video sessions given then all the video sessions automatically created from the videos with cell position csv in the data folder are used.

## Normal patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)

cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=reg_video_sessions,                                                                  
    n_negatives_per_positive=1,                                                                                                
    v=True,
    vv=False
)

plot_images_as_grid(cell_images[:10])
plot_images_as_grid(cell_images_marked[:10])

## Temporal patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    temporal_width=1,
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

plot_images_as_grid(cell_images[:10], title='Temporal width 1')
plot_images_as_grid(cell_images_marked[:10])

## Mixed channel patches

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from plotutils import plot_images_as_grid

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    mixed_channel_patches=True,
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

plot_images_as_grid(cell_images[:10], title='Mixed channel patches')
plot_images_as_grid(cell_images_marked[:10])

# Training Experiments

## Histogram matching

In [None]:
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_cell_and_no_cell_images
from imageprosessing import hist_match_images
from sharedvariables import get_video_sessions
from plotutils import plot_images_as_grid

reg_video_sessions = get_video_sessions(should_have_marked_cells=True, should_be_registered=True)
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=reg_video_sessions,                                                                                                                                                                
)

In [None]:
from imageprosessing import hist_match_images

def get_highest_contrast_frame(video_sessions):
    max_diff = 0
    max_diff_idx = 0
    for i, vs in enumerate(video_sessions):
        vs.mask_frames_oa790 = crop_mask(vs.mask_frames_oa790, 15)
        diff = vs.masked_frames_oa790[0].max() - vs.masked_frames_oa790[0].min()
        if diff > max_diff:
            max_diff_idx = i
            max_diff = diff
            
    highest_contrast_frame = video_sessions[max_diff_idx].masked_frames_oa790[0]
    highest_contrast_frame = highest_contrast_frame.filled(highest_contrast_frame.mean())
    
    return highest_contrast_frame

template_frame = get_highest_contrast_frame(reg_video_sessions)

hist_matched_cell_images = hist_match_images(cell_images, template_frame)
hist_matched_non_cell_images = hist_match_images(non_cell_images, template_frame)

In [None]:
import collections
import torch
from cnnlearning import CNN, train, TrainingTracker 

standardize_dataset = True
trainset, validset = create_dataset_from_cell_and_no_cell_images(hist_matched_cell_images, 
                                                                 hist_matched_non_cell_images,
                                                                 validset_ratio=0.2,
                                                                 standardize=True)

model = CNN(dataset_sample=trainset, output_classes=2).to('cuda')
train_params = collections.OrderedDict(
    optimizer=torch.optim.Adam(model.parameters(), lr=.001, weight_decay=0.01),
    batch_size=256,
    do_early_stop=True,  # Optional default True
    early_stop_patience=40,
    learning_rate_scheduler_patience=20,
    epochs=200,
    shuffle=True,
    # valid_untrunsformed_normals = valid_untrunsformed_normals,
    evaluation_epochs=5,
    trainset=trainset,
    validset=validset,
)
results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(),
                                 device='cuda')

In [None]:
from classificationutils import classify_labeled_dataset, classify_images

model = results.recorded_model
model.eval()

_, train_accuracy = classify_labeled_dataset(trainset, model)
_, valid_accuracy = classify_labeled_dataset(validset, model)
positive_accuracy = classify_images(cell_images, model, standardize_dataset=standardize_dataset).sum().item() / len(cell_images)
negative_accuracy = (1 - classify_images(non_cell_images, model, standardize_dataset=standardize_dataset)).sum().item() / len(non_cell_images)

print()
print(f'Model trained on {len(cell_images)} cell patches and {len(non_cell_images)} non cell patches.')
print()
print('Brief evaluation - best validation accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')

train_model = results.recorded_train_model
train_model.eval()

_, train_accuracy = classify_labeled_dataset(trainset, train_model)
_, valid_accuracy = classify_labeled_dataset(validset, train_model)
positive_accuracy = classify_images(cell_images, train_model, standardize_dataset=standardize_dataset).sum().item() / len(cell_images)
negative_accuracy = (1 - classify_images(non_cell_images, train_model, standardize_dataset=standardize_dataset)).sum().item() / len(non_cell_images)

print()
print('Brief evaluation - best training accuracy model')
print('----------------')
print(f'Epoch:\t', results.recorded_train_model_epoch)
print('Training accuracy:\t', f'{train_accuracy:.3f}')
print('Validation accuracy:\t', f'{valid_accuracy:.3f}')
print()
print('Positive accuracy:\t', f'{positive_accuracy:.3f}')
print('Negative accuracy:\t', f'{negative_accuracy:.3f}')