In [None]:
%load_ext autoreload
%autoreload 2
import collections
import os
import uuid
    
import matplotlib.pyplot as plt
import torch
from patch_extraction import NegativeExtractionMode
from patch_extraction import SessionPatchExtractor as PE
from patch_extraction import get_parallel_points, get_perpendicular_points
from video_session import get_video_sessions

from cnnlearning import CNN
from cnnlearning import TrainingTracker, train
from generate_datasets import create_cell_and_no_cell_patches, create_dataset_from_patches
from plotutils import no_ticks
from plotutils import plot_images_as_grid
from IPython.display import  display

training_video_sessions = get_video_sessions(marked=True, registered=True, validation=False)
print()
print('Using cuda:', torch.cuda.is_available())
print()
# video_sessions = [vs for vs in video_sessions if 'shared-videos' in vs.video_file]
print('Training Videos')
display([vs.video_file for vs in training_video_sessions])

print('Validation Videos')
validation_video_sessions = get_video_sessions(marked=True, registered=True, validation=True)
# video_sessions = [vs for vs in video_sessions if 'shared-videos' in vs.video_file]
[vs.video_file for vs in validation_video_sessions]

In [None]:
notebook_uid = 6

# hyperparameters

patch_size = 23
translation_pixels = 0
negative_extraction_mode = NegativeExtractionMode.PERPENDICULAR
negative_search_radius = 7
npp=4

to_grayscale = False
temporal_width = 0
mixed_channel=False
model_type=0
do_preprocessing=False
limit_extraction_to_vessel_mask=False

In [None]:
vs = training_video_sessions[0]
frame_idx = list(vs.cell_positions)[0]
points = vs.cell_positions[frame_idx]

if negative_extraction_mode == NegativeExtractionMode.PARALLEL:
    x, y = get_parallel_points(points, negative_search_radius, npp)
elif negative_extraction_mode == NegativeExtractionMode.PERPENDICULAR:
    x, y = get_perpendicular_points(points, negative_search_radius, npp)
    
plt.figure(figsize=(10, 10))
no_ticks()
plt.scatter(points[:, 0], points[:, 1])
plt.scatter(x, y)

In [None]:
sc_mc = 'sc'
if mixed_channel:
    sc_mc = f'mc{cell_images.shape[-1]}'
if temporal_width > 0:
    sc_mc = 'tc'

if negative_extraction_mode == NegativeExtractionMode.CIRCLE:
    negative_extraction_str = 'circ'
elif negative_extraction_mode == NegativeExtractionMode.RECTANGLE:
    negative_extraction_str = 'rect'
elif negative_extraction_mode == NegativeExtractionMode.PERPENDICULAR:
    negative_extraction_str = 'perp'
elif negative_extraction_mode == NegativeExtractionMode.PARALLEL:
    negative_extraction_str = 'par'

use_vessel_mask_str = ''
if limit_extraction_to_vessel_mask:
    use_vessel_mask_str = '-uv-'

do_preprocessing_str = ''
if do_preprocessing:
    do_preprocessing_str = '-pr-'



output_meta = {
    "uuid": str(uuid.uuid4()),
    "notebook_uid": notebook_uid,
    "hyperparameters": {
        "patch_size": patch_size,
        "translation_pixels": translation_pixels,
        "negative_extraction_str": negative_extraction_str,
        "negatives_per_positive": npp,
        "negative_extraction_mode": negative_extraction_mode,
        "negative_search_radius": negative_search_radius,
        "temporal_width": temporal_width,
        "to_grayscale": to_grayscale,
        "model_type": model_type,
        "do_preprocessing": do_preprocessing,
        "limit_extraction_to_vessel_mask": limit_extraction_to_vessel_mask,    
    }
}

output_path = os.path.join(
    'tmp-res', 
     f'_uid{notebook_uid}-{sc_mc}-npp{npp}-tp{translation_pixels}'
     f'-ps{patch_size}-mt{model_type}-rad{negative_search_radius}'
     f'{do_preprocessing_str}{use_vessel_mask_str}-{negative_extraction_str}'
)
print('The results will be saved in: \n', output_path)
display(output_meta)

### Ensure registration on training videos

In [None]:
for vs in training_video_sessions:
    vs.load_vessel_masks(True)
    vs.visualize_registration(figsize=(15, 10), fontsize=10, linewidth=5)

### Ensure registration on validation videos

In [None]:
for vs in validation_video_sessions:
    vs.load_vessel_masks(True)
    vs.visualize_registration(figsize=(15, 10), fontsize=10, linewidth=5)

In [None]:
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['axes.titlesize'] = 10
for vs in training_video_sessions:
    print(vs.frames_oa790.shape)
    print(vs.frames_oa850.shape)
    print(vs.frames_confocal.shape)
    
    _, axes = plt.subplots(1, 4, figsize=(25, 7))
    no_ticks(axes)
    axes[0].imshow(vs.frames_oa790[0])
    axes[0].set_title('oa790')
    axes[1].imshow(vs.frames_oa850[0])
    axes[1].set_title('oa850')
    axes[2].imshow(vs.frames_confocal[0])
    axes[2].set_title('confocal')
    
    first_marked_frame_idx = list(vs.cell_positions)[0]
    axes[3].imshow(vs.marked_frames_oa790[0])
    cell_positions = vs.cell_positions[first_marked_frame_idx]
    axes[3].scatter(cell_positions[..., 0], cell_positions[..., 1])
    axes[3].set_title('marked oa790')
    plt.show()
    print('-----------------------')

In [None]:
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['axes.titlesize'] = 10
for vs in validation_video_sessions:
    print(vs.frames_oa790.shape)
    print(vs.frames_oa850.shape)
    print(vs.frames_confocal.shape)
    
    _, axes = plt.subplots(1, 4, figsize=(40, 10))
    no_ticks(axes)
    axes[0].imshow(vs.frames_oa790[0])
    axes[0].set_title('oa790')
    axes[1].imshow(vs.frames_oa850[0])
    axes[1].set_title('oa850')
    axes[2].imshow(vs.frames_confocal[0])
    axes[2].set_title('confocal')
    
    first_marked_frame_idx = list(vs.cell_positions)[0]
    axes[3].imshow(vs.marked_frames_oa790[0])
    cell_positions = vs.cell_positions[first_marked_frame_idx]
    axes[3].scatter(cell_positions[..., 0], cell_positions[..., 1])
    axes[3].set_title('marked oa790')
    plt.show()
    print('-----------------------')

In [None]:
cell_images, non_cell_images, cell_images_marked, non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=training_video_sessions,
    
    limit_to_vessel_mask=limit_extraction_to_vessel_mask,
    mixed_channel_patches=mixed_channel,
    
    temporal_width=temporal_width,
    
    extraction_mode=PE.ALL_MODE,
    negative_extraction_mode=negative_extraction_mode,
    negative_patch_search_radius=negative_search_radius,
    
    n_negatives_per_positive=npp,
    patch_size=patch_size + translation_pixels,
    
    v=False,
    vv=False
)

valid_cell_images, valid_non_cell_images, valid_cell_images_marked, valid_non_cell_images_marked =\
create_cell_and_no_cell_patches(
    video_sessions=validation_video_sessions,
    
    limit_to_vessel_mask=limit_extraction_to_vessel_mask,
    mixed_channel_patches=mixed_channel,
    
    temporal_width=temporal_width,
    
    extraction_mode=PE.VALIDATION_MODE,
    negative_extraction_mode=negative_extraction_mode,
    negative_patch_search_radius=negative_search_radius,
    
    n_negatives_per_positive=npp,
    patch_size=patch_size + translation_pixels,
    
    v=False,
    vv=False
)

plot_images_as_grid(cell_images[:10])
plot_images_as_grid(cell_images_marked[:10])

plot_images_as_grid(non_cell_images[:10])
plot_images_as_grid(non_cell_images_marked[:10])

plot_images_as_grid(valid_cell_images[:10])
plot_images_as_grid(valid_cell_images_marked[:10])

plot_images_as_grid(valid_non_cell_images[:10])
plot_images_as_grid(valid_non_cell_images_marked[:10])

In [None]:
print('Negatives per positive train', len(non_cell_images) / len(cell_images))
print('Negatives per positive valid', len(valid_non_cell_images) / len(valid_cell_images))
print('Shape train', cell_images.shape, non_cell_images.shape)
print('Shape valid', valid_cell_images.shape, valid_non_cell_images.shape)

In [None]:
trainset, validset = create_dataset_from_patches(
    cell_images, 
    non_cell_images,
    
    valid_cell_patches=valid_cell_images,
    valid_non_cell_patches=valid_non_cell_images,
    
    random_translation_pixels=translation_pixels,
    random_rotation_degrees=0,
    center_crop_patch_size=patch_size,
    
    validset_ratio=0.0000001,
    
    to_grayscale=to_grayscale,
    standardize=True,
    standardize_mean=mean,
    standardize_std=std,
    v=True,
)

trainset_marked, validset_marked = create_dataset_from_patches(
    cell_images_marked, 
    non_cell_images_marked,
    
    valid_cell_patches=valid_cell_images_marked,
    valid_non_cell_patches=valid_non_cell_images_marked,
    
    random_translation_pixels=translation_pixels,
    random_rotation_degrees=0,
    center_crop_patch_size=patch_size,
    
    validset_ratio=0.0000001,
    
    to_grayscale=to_grayscale,
    standardize=True,
    v=True
)

loader = torch.utils.data.DataLoader(trainset, batch_size=1)
loader_marked = torch.utils.data.DataLoader(trainset_marked, batch_size=1)

In [None]:
import numpy as np

mins = []
maxs = []
values = []
for i, (img, lbl) in enumerate(loader):
    mins.append(img.min().item())
    maxs.append(img.max().item())
    values.append(img.flatten().numpy())
np.min(mins), np.max(maxs)

In [None]:
import patch_extraction
cell_images.min(), cell_images.max()

In [None]:
np.mean(values), np.std(values)

In [None]:
plt.subplot(1, 2, 1)
for i, (img, lbl) in enumerate(loader_marked):
    print('Dataset min, max', img.min().item(), img.max().item())
    print('Dataset shape', img.shape)
    idx = 0

    plt.imshow(img[idx].permute(1, 2, 0)[..., 0].squeeze(), cmap='gray')
    plt.suptitle(f'Frame index {idx}')
    print('Image min max', img[idx].min().item(), img[idx].max().item())
    print('label', lbl[idx].item()) 
    break
    
plt.subplot(1, 2, 2)
for i, (img, lbl) in enumerate(loader):
    print('Dataset min, max', img.min().item(), img.max().item())
    print('Dataset shape', img.shape)
    idx = 0

    plt.imshow(img[idx].permute(1, 2, 0)[..., 0].squeeze(), cmap='gray')
    plt.suptitle(f'Frame index {idx}')
    print('Image min max', img[idx].min().item(), img[idx].max().item())
    print('label', lbl[idx].item()) 
    break

In [None]:
from test_cuda import test_cuda
test_cuda()

In [None]:
model = CNN(dataset_sample=trainset, model_type=model_type, output_classes=2).to('cuda')
model.train()
train_params = collections.OrderedDict(
    lr=.001,
    weight_decay=.001,
    
    batch_size=512,
    do_early_stop=True,  # Optional default True
    
    early_stop_patience=30,
    learning_rate_scheduler_patience=10,
    
    epochs=250,
    shuffle=True,
    evaluation_epochs=5,
    
    trainset=trainset,
    validset=validset,
)

results: TrainingTracker = train(model,
                                 train_params,
                                 criterion=torch.nn.CrossEntropyLoss(),
                                 device='cuda')

In [None]:
print('Negatives per positive train', len(non_cell_images) / len(cell_images))
print('Negatives per positive valid', len(valid_non_cell_images) / len(valid_cell_images))
print('Shape train', cell_images.shape, non_cell_images.shape)
print('Shape valid', valid_cell_images.shape, valid_non_cell_images.shape)

In [None]:
results.save(output_path, v=True)
print(cell_images.shape, non_cell_images.shape)

# Load

In [None]:
print("Output file", output_path)

print('Loading results...')
results = TrainingTracker.from_file(os.path.join(output_path, 'results.pkl'))
print('Done')
print('Best balanced validation performance')
display(results.recorded_models['best_valid_balanced_accuracy']['valid_classification_results'])

print('Best balanced training performance')
display(results.recorded_models['best_train_balanced_accuracy']['train_classification_results'])