In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from cnnlearning import *
from learning_utils import *
from cell_no_cell import *
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

SHARED_CONES = os.path.join(".", "data", "ConesShare")
SHARED_VIDEOS_PATH = os.path.join(".", "data", "Shared_Videos")
OUTPUT_FOLDER = os.path.join(".", "data", "output")
TRAINED_MODEL_FOLDER = os.path.join(OUTPUT_FOLDER, "trained_models")

In [None]:
def plot_dataset_as_grid(dataset, title=None):
    """
    Plots a stack of images in a grid.

    Arguments:
        images: The images as NxHxWxC
        title: Plot title
    """
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=60000,
        shuffle=False
    )
    
    for batch in loader:
        images = batch[0]
        labels = batch[1]
        
        print("Images:", images.shape)
        print("Labels:", labels.shape)
        batch_tensor = images.permute(0, -1, 1, 2)
        grid_img = torchvision.utils.make_grid(images, nrow=50)
    
        plt.figure(num=None, figsize=(70, 50), dpi=80, facecolor='w', edgecolor='k')
        plt.title(title)
        plt.grid(b=None)
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()

In [None]:
video_filenames = [file for file in [f for f in os.listdir(SHARED_VIDEOS_PATH) if f.endswith('avi') and 'OA790nm' in f]]
marked_video_filenames = [os.path.join(SHARED_VIDEOS_PATH, file) for file in video_filenames if 'marked' in file]
raw_video_filenames    = [os.path.join(SHARED_VIDEOS_PATH, file) for file in video_filenames if 'marked' not in file]

csv_filenames = [os.path.join(SHARED_VIDEOS_PATH, file) for file in [f for f in os.listdir(SHARED_VIDEOS_PATH) if f.endswith('csv') and 'OA790nm' in f]]

print("BLOOD CELLS")
print("-----------")
print("RAW VIDEOS:")
print(*marked_video_filenames, sep="\n")
print()

print("MARKED VIDEOS:")
print(*raw_video_filenames, sep="\n")
print()

print("CSV FILES:")
print(*csv_filenames, sep="\n")

cone_images_filenames = [os.path.join(SHARED_CONES, file) for file in [f for f in os.listdir(SHARED_CONES) if f.endswith('tif')]]
cone_csv_filenames =  [os.path.join(SHARED_CONES, file)for file in [f for f in os.listdir(SHARED_CONES) if f.endswith('txt')]]

print()
print("CONES")
print("-----")
print("TIFF:")
print(*cone_images_filenames, sep="\n")
print()

print("CSV:")
print(*cone_csv_filenames, sep="\n")

In [None]:
from scipy.spatial import Voronoi, voronoi_plot_2d

def get_random_points_in_voronoi_diagram(centroids):
    vor = Voronoi(centroids, qhull_options='Qbb Qc Qx', incremental=False)
    vor.close()

    edges = np.array(vor.ridge_vertices)

    edges_start = edges[:, 0]
    edges_end = edges[:, 1]

    vertices_start = vor.vertices[edges_start]
    vertices_end = vor.vertices[edges_end]

    t =  np.random.rand(vertices_start.shape[0])
   
    random_vertices = t[:, np.newaxis] * vertices_start + (1 - t[:, np.newaxis]) * vertices_end
    random_vertices = random_vertices[edges_start != -1]

    random_vertices = random_vertices[random_vertices[:, 0] >= 0]
    random_vertices = random_vertices[random_vertices[:, 0] <= 200]
    random_vertices = random_vertices[random_vertices[:, 1] >= 0]
    random_vertices = random_vertices[random_vertices[:, 1] <= 200]
    
    # print(edges_start == -1)
    # print(edges_end == -1)
    # print(edges_start)
    # fig = voronoi_plot_2d(vor)
    # plt.imshow(image)
    # plt.scatter(random_vertices[:, 0], random_vertices[:, 1], c="#FF0000")
    # plt.imshow(image)
    return random_vertices

#get_random_points_in_voronoi_diagram(positions)
# pass
# vor.vertices -> the vertices of the voronoi patern.
# vor.regions  -> list of list of ints. Indices of the Voronoi vertices forming each Voronoi region. 
#               The indices refer to vor.vertices.
# vor.point_region -> list of ints. Indexes vor.points where it assigns each reagion to a point
# vor.ridge_vertices -> Indexes vor.vertices. Each entry shows the two points that form the edge.
#                  I.e the edge 5 can have [3, 8] meaning it's formed from vor.vertices[3] -> vor.vertices[8]

# np.random.rand(len(vor.ridge_vertices))

# Cones

In [None]:
cone_image_size =(33, 33)
cones = np.zeros([0, cone_image_size[0], cone_image_size[1]])
non_cones = np.zeros([0, cone_image_size[0], cone_image_size[1]])

for image_filename, csv_filename in zip(cone_images_filenames, cone_csv_filenames):
    image = plt.imread(image_filename)
    positions = np.genfromtxt(csv_filename, delimiter=',')
    patches = extract_patches_at_positions(image, 
                           positions,
                           patch_size=cone_image_size,
                           visualize_patches=False,
                           padding='valid')
    non_cone_patches = extract_patches_at_positions(image,
                                                    get_random_points_in_voronoi_diagram(positions),
                                                    patch_size=cone_image_size,
                                                    visualize_patches=False,
                                                    padding='valid'
                                                   )
    
    cones = np.append(cones, patches, axis=0)
    non_cones = np.append(non_cones, non_cone_patches, axis=0)
    
    voronoi = Voronoi(positions)
    
    # no_cone_points = get_random_points_in_voronoi_diagram(positions)
    # fig = voronoi_plot_2d(voronoi)
    # plt.scatter(no_cone_points[:, 0], no_cone_points[:, 1])
    # plt.imshow(image)
cones = np.rot90(cones, 2, axes=(1, 2))
non_cones = np.rot90(non_cones, 2, axes=(1, 2))

In [None]:
fig, axes = plt.subplots(2, 6, figsize=(20, 10))
axes[0, 0].imshow(cones[0], cmap='gray')
axes[0, 1].imshow(cones[1], cmap='gray')
axes[0, 2].imshow(cones[2], cmap='gray')
axes[0, 3].imshow(cones[3], cmap='gray')
axes[0, 4].imshow(cones[4], cmap='gray')
axes[0, 5].imshow(cones[5], cmap='gray')

axes[1, 0].imshow(non_cones[0], cmap='gray')
axes[1, 1].imshow(non_cones[1], cmap='gray')
axes[1, 2].imshow(non_cones[2], cmap='gray')
axes[1, 3].imshow(non_cones[3], cmap='gray')
axes[1, 4].imshow(non_cones[4], cmap='gray')
axes[1, 5].imshow(non_cones[5], cmap='gray')

In [None]:
plot_images_as_grid(cones / 255)

In [None]:
plot_images_as_grid(non_cones / 255)

In [None]:
print("Cones shape", cones.shape)
print("Cones dtype", cones.dtype)
print("Non cones shape", non_cones.shape)
print("Non cones dtype", non_cones.dtype)

In [None]:
n = cones.shape[0]

dataset = LabeledImageDataset(
        np.concatenate((cones[:n, ...],            non_cones[:n, ...]), axis=0),
        np.concatenate((np.ones(n).astype(np.int), np.zeros(n).astype(np.int)), axis=0)
    )

trainset_size = int(len(dataset) * 0.80)
validset_size = len(dataset) - trainset_size

trainset, validset = torch.utils.data.random_split(dataset, (trainset_size, validset_size))

model = CNN(convolutional=
            nn.Sequential(
                nn.Conv2d(1, 32, padding=2, kernel_size=5),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(kernel_size=(3, 3), stride=2),
                
                nn.Conv2d(32, 32, padding=2, kernel_size=5),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
                
                nn.Conv2d(32, 64, padding=2, kernel_size=5),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
            ),
            dense=
            nn.Sequential(
                nn.Linear(1024, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(64),
                nn.Linear(64, 32),
                nn.BatchNorm1d(32),
                nn.Linear(32, 2),
             #   nn.Softmax()
            )).to(device)

params = collections.OrderedDict(
    # lr = .001,
    #optimizer=torch.optim.SGD(model.parameters(), lr=.001, weight_decay=5e-5, momentum=0.9),
    optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=5e-4),
    batch_size=1024 * 16,
    do_early_stop=True,# Optional default True
    early_stop_patience=80,
    learning_rate_scheduler_patience=100,
    epochs=2000,
    shuffle=True,
    # valid_untrunsformed_normals = valid_untrunsformed_normals,
    trainset=trainset,
    validset=validset,
)

results = train(model, params, criterion=torch.nn.CrossEntropyLoss(), device=device)

In [None]:
sample_cone_image = plt.imread(cone_images_filenames[0]).astype(np.float32) / 255
plt.imshow(sample_cone_image)
results.save("cone_model")

In [None]:
model = results.model
model = model.eval()

probability_map = get_frame_probability_map(sample_cone_image, model)

# Blood Cells

In [None]:
raw_videos = {}
for video in raw_video_filenames:
    vidcap = cv2.VideoCapture( SHARED_VIDEOS_PATH + video)

    success,image = vidcap.read()
    count = 0
    while success:
        video_key = video[:video.index('nm')]+"nm"
        raw_videos[video_key] = [image]
        success,image = vidcap.read()
        count += 1
        
marked_videos = {}
for video in marked_video_filenames:
    vidcap = cv2.VideoCapture( SHARED_VIDEOS_PATH + video)

    success,image = vidcap.read()
    count = 0
    while success:
        video_key = video[:video.index('nm')]+"nm"
        marked_videos[video_key] = [image]
        success,image = vidcap.read()
        count += 1

In [None]:
height, width = 19, 19
cell_images_1, non_cell_images_1 = get_cell_and_no_cell_patches_from_video(raw_video_filenames[0],
                                                                           csv_filenames[0],
                                                                           height=height,
                                                                           width=width,
                                                                           normalise=True,
                                                                           )

cell_images_2, non_cell_images_2 = get_cell_and_no_cell_patches_from_video(raw_video_filenames[1],
                                                                           csv_filenames[1],
                                                                           height=height,
                                                                           width=width,
                                                                           normalise=True,
                                                                           )

cell_images = np.concatenate((cell_images_1, cell_images_2), axis=0).astype(np.float32)
non_cell_images = np.concatenate((non_cell_images_1, non_cell_images_2), axis=0).astype(np.float32)


print(cell_images.shape)
print(non_cell_images.shape)

In [None]:
fig, axes = plt.subplots(2, 6, figsize=(20, 10))
axes[0, 0].imshow(cell_images[0], cmap='gray')
axes[0, 1].imshow(cell_images[1], cmap='gray')
axes[0, 2].imshow(cell_images[2], cmap='gray')
axes[0, 3].imshow(cell_images[3], cmap='gray')
axes[0, 4].imshow(cell_images[4], cmap='gray')
axes[0, 5].imshow(cell_images[5], cmap='gray')

axes[1, 0].imshow(non_cell_images[0], cmap='gray')
axes[1, 1].imshow(non_cell_images[1], cmap='gray')
axes[1, 2].imshow(non_cell_images[2], cmap='gray')
axes[1, 3].imshow(non_cell_images[3], cmap='gray')
axes[1, 4].imshow(non_cell_images[4], cmap='gray')
axes[1, 5].imshow(non_cell_images[5], cmap='gray')

In [None]:
plot_images_as_grid(cell_images)

In [None]:
plot_images_as_grid(non_cell_images)

In [None]:
n = cell_images.shape[0]

dataset = LabeledImageDataset(
        np.concatenate((cell_images[:n, ...],      non_cell_images[:n, ...]), axis=0),
        np.concatenate((np.ones(n).astype(np.int), np.zeros(n).astype(np.int)), axis=0)
    )

trainset_size = int(len(dataset) * 0.80)
validset_size = len(dataset) - trainset_size

trainset, validset = torch.utils.data.random_split(dataset, (trainset_size, validset_size))
model = CNN(convolutional=
            nn.Sequential(
                nn.Conv2d(3, 32, padding=2, kernel_size=5),
                # PrintLayer("1"),
                nn.BatchNorm2d(32),
                # PrintLayer("2"),
                nn.MaxPool2d(kernel_size=(3, 3), stride=2),
                # PrintLayer("3"),

                nn.Conv2d(32, 32, padding=2, kernel_size=5),
                # PrintLayer("4"),
                nn.BatchNorm2d(32),
                # PrintLayer("5"),
                nn.ReLU(),
                # PrintLayer("6"),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
                # PrintLayer("7"),
                
                nn.Conv2d(32, 64, padding=2, kernel_size=5),
                #PrintLayer("9"),
                nn.BatchNorm2d(64),
                #PrintLayer("11"),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
                # PrintLayer("12"),
            ),
            dense=
            nn.Sequential(
                nn.Linear(576, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(64),
                nn.Linear(64, 32),
                nn.BatchNorm1d(32),
                nn.Linear(32, 2),
             #   nn.Softmax()
            )).to(device)

params = collections.OrderedDict(
    # lr = .001,
    #optimizer=torch.optim.SGD(model.parameters(), lr=.001, weight_decay=5e-5, momentum=0.9),
    optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=5e-4),
    batch_size=1024 * 16,
    do_early_stop=True,# Optional default True
    early_stop_patience=80,
    learning_rate_scheduler_patience=100,
    epochs=2000,
    shuffle=True,
    # valid_untrunsformed_normals = valid_untrunsformed_normals,
    trainset=trainset,
    validset=validset,
)

results = train(model, params,  criterion=torch.nn.CrossEntropyLoss(), device=device)

In [None]:
results.save("cell_trained")

In [None]:
def classify(images, model, device="cuda"):
    """ Classify images.
    
    Arguments:
        images -- NxHxWxC or NxHxW. The images
        model  -- The model to do the prediction
    
    Returns:
        N predictions. A prediction for each image
    """ 
    if len(images.shape) == 3:
        # Add channel dimension if grayscale
        images = images[..., None]
        
    loader = torch.utils.data.DataLoader(
        ImageDataset(images),
        batch_size=50000,
    )
    
    c = 0
    predictions = torch.zeros(images.shape[0])
    for batch in loader:
        pred = results.model(batch.to(device))
        pred = torch.nn.functional.softmax(pred, dim=1)
        pred = torch.argmax(pred, axis=1)
        predictions[c:pred.shape[0]] = pred
        
        c += pred.shape[0]
        
    return predictions

print("Positive accuracy", classify(cell_images, results.model).sum() / cell_images.shape[0])
print("Negative accuracy", 1 - classify(non_cell_images, results.model).sum() / non_cell_images.shape[0])

In [None]:
model = CNN(convolutional=
            nn.Sequential(
                nn.Conv2d(3, 32, padding=2, kernel_size=5),
                # PrintLayer("1"),
                nn.BatchNorm2d(32),
                # PrintLayer("2"),
                nn.MaxPool2d(kernel_size=(3, 3), stride=2),
                # PrintLayer("3"),

                nn.Conv2d(32, 32, padding=2, kernel_size=5),
                # PrintLayer("4"),
                nn.BatchNorm2d(32),
                # PrintLayer("5"),
                nn.ReLU(),
                # PrintLayer("6"),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
                # PrintLayer("7"),
                
                nn.Conv2d(32, 64, padding=2, kernel_size=5),
                #PrintLayer("9"),
                nn.BatchNorm2d(64),
                #PrintLayer("11"),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=3, padding=1, stride=2),
                # PrintLayer("12"),
            ),
            dense=
            nn.Sequential(
                nn.Linear(576, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(64),
                nn.Linear(64, 32),
                nn.BatchNorm1d(32),
                nn.Linear(32, 2),
             #   nn.Softmax()
            )).to(device)

model.load_state_dict(torch.load('cell_trained.pt'))
model = model.eval()

In [None]:
height, width = 19, 19

frames = get_frames_from_video(raw_video_filenames[0])
print(frames.dtype)
sample_frame = frames[0, ...]
plt.imshow(sample_frame)

probability_map = get_frame_probability_map(sample_frame, model, height=height, width=width)

In [None]:
plt.imshow(probability_map > 0.5,cmap="gray")

In [None]:
probability_map > 0.5

In [None]:
results.model