# Modules

In [100]:
import torch
from torch.utils.data import Dataset, DataLoader

import os
from tqdm import tqdm

import cv2 as cv
import numpy as np

import matplotlib.pyplot as plt

from h5py import File as h5File

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# GLOBAL PARAMETERS

In [89]:
DEBUG = False

IMG_OUTPUT_DIM = 128

IMAGING_TYPE = 'BACKGROUND_BRIGHT_WORM_DARK'
CROP_SIZE = 400
WINDOWSIZE_MIN_INCLUDED = 120
WINDOWSIZE_MIN_INCLUDED2 = WINDOWSIZE_MIN_INCLUDED//2
THETA_MAX_DEVIATION = 360.0
GAMMA_MAX_DEVIATION = 0.2
# TODO: add scale changing which should change the bounds for x/y_min/max (sice size of included window should change)

if torch.cuda.is_available():
    DEVICE = 'cuda:0'
else:
    DEVICE = 'cpu'

## Methods

In [90]:
def data_annotations_to_npz(fp_read_data, fp_read_annotations, fp_write_npz):
    file_data = h5File(fp_read_data)
    file_annotations = h5File(fp_read_annotations)
    
    data = file_data['data']
    nt, nx, ny = data.shape

    t_idxs = file_annotations['t_idx'][:]
    xs = file_annotations['x'][:]
    ys = file_annotations['y'][:]
    
    xs_idx = (xs*nx).astype(np.int64)
    ys_idx = (ys*ny).astype(np.int64)

    images = np.array([
        data[t] for t in t_idxs
    ])
    coords = np.stack((xs_idx, ys_idx)).T
    np.savez_compressed(
        fp_write_npz,
        images = images,
        ts = t_idxs,
        xs_idx = xs_idx, ys_idx = ys_idx,
        coords = coords,
    )
    file_data.close()
    file_annotations.close()
    return images, coords

# Data Classes

In [153]:
def rotate_xy(M, x_idx, y_idx):
    return np.matmul(
        M,
        np.array([x_idx, y_idx, 1])
    )

class AnnotatedDataLoader(Dataset):
    # Constructor
    def __init__(self, images, coordinates, factor_augmentations = 1):
        super().__init__()
        self.images = images
        self.nrecords, self.nx, self.ny = self.images.shape
        self.coordinates = coordinates
        self.factor_augmentations = factor_augmentations
        self.n = self.nrecords * self.factor_augmentations
        return
    # Length
    def __len__(self):
        return self.n
    # Get Item
    def __getitem__(self, i):
        if self.factor_augmentations == 1:
            return self.images[i].copy(), self.coordinates[i].copy()
        # Sample
        nx, ny = self.nx, self.ny
        idx = i//self.factor_augmentations
        img_processed = self.images[idx]
        x_idx, y_idx = self.coordinates[idx]
        # Gamma
        gamma = 1.0 + (np.random.rand()-0.5)*2*GAMMA_MAX_DEVIATION
        img_processed = ( (img_processed/255)**gamma * 255 ).astype(np.uint8)
        # TODO add passing conditions and handle missing annotation inside the frame
        for _ in range(3):
            # Rotation
            theta = np.random.rand()*THETA_MAX_DEVIATION
            M_rotation = cv.getRotationMatrix2D((nx//2, ny//2), theta, 1.0)
            assert IMAGING_TYPE == 'BACKGROUND_BRIGHT_WORM_DARK', \
                "`borderValue` for rotation is set only for dark worm on bright background."
            img_processed = cv.warpAffine(
                img_processed,
                M_rotation,
                (self.nx, self.ny),
                borderValue=255
            )
            # Offset
            coords_new = rotate_xy(M_rotation, x_idx, y_idx).astype(np.int32)
            # Annotation Outside after Rotation
            if np.any(coords_new >= np.array([nx, ny])) or np.any(coords_new < np.zeros(2)):
                if DEBUG:
                    print("Rotation@annotation outside image: {}-{} , shape:({},{})".format(
                        *coords_new,
                        nx, ny
                    ))
                continue
            if DEBUG:
                print(f"COORD NEW: {coords_new}")
            x_idx_new, y_idx_new = coords_new
            x_idx_min = max(
                x_idx_new + WINDOWSIZE_MIN_INCLUDED2 - CROP_SIZE,
                0
            )
            x_idx_max = min(
                x_idx_new - WINDOWSIZE_MIN_INCLUDED2,
                nx - CROP_SIZE
            )
            y_idx_min = max(
                y_idx_new + WINDOWSIZE_MIN_INCLUDED2 - CROP_SIZE,
                0
            )
            y_idx_max = min(
                y_idx_new - WINDOWSIZE_MIN_INCLUDED2,
                ny - CROP_SIZE
            )
            # Empty Cropping Area
            if x_idx_max < x_idx_min or y_idx_max < y_idx_min:
                if DEBUG:
                    print("Empty crop area: {}-{} , {}-{}".format(
                        x_idx_min, x_idx_max,
                        y_idx_min, y_idx_max
                    ))
                continue
            if DEBUG:
                print(f"{x_idx_min}-{x_idx_max} , {y_idx_min}-{y_idx_max}")
            crop_topleft = np.random.randint(
                (x_idx_min, y_idx_min),
                (x_idx_max+1,y_idx_max+1)
            )  # so `x` can be used for index `i` in slicing and `y` for column indexing
            if DEBUG:
                print(f"{crop_topleft}")
            # Apply
            imin, jmin = crop_topleft
            imax, jmax = crop_topleft+CROP_SIZE
            img_processed = img_processed[imin:imax, jmin:jmax]
            coords_new -= crop_topleft[::-1]
            # Return
            return img_processed, coords_new
        print("DEBUG: Attempts to Augment failed!")
        return self.images[idx].copy(), self.coordinates[idx].copy()
def collate_fn_3d_input(data):
    images, coords = zip(*data)
    coords = np.array(coords)
    images = np.repeat(
        np.array(images)[:,None,:,:], 3, axis=1
    )
    images_channeled = torch.tensor( images, dtype=torch.float32 )
    coords = torch.tensor( coords, dtype=torch.float32 )
    return images_channeled, coords
def collate_fn_heatmap(data):
    images, coords = zip(*data)
    _img = images[0]
    images = np.array(images)[:,None,:,:]
    images = torch.tensor( images, dtype=torch.float32 )
    heatmaps = []
    for j,i in coords:
        img_annotated = np.zeros_like( _img, dtype=np.float32 )
        img_annotated[i,j] = 100.0
        img_annotated = cv.GaussianBlur(img_annotated,(11,11), 0)
        img_annotated = cv.resize(img_annotated, (IMG_OUTPUT_DIM, IMG_OUTPUT_DIM))
        img_annotated /= img_annotated.max()
        heatmaps.append(img_annotated.flatten())
    heatmaps = torch.tensor( heatmaps, dtype=torch.float32 )
    return images, heatmaps

# Convert Data

In [154]:
fp_data = "../data/20230414_N2_pharyngeal_and_pumping_000003.h5"
fp_annotations = "../data/20230414_N2_pharyngeal_and_pumping_000003_annotations.h5"
fp_labelled_data = "../data/labelled_data.npz"

In [155]:
if os.path.exists(fp_labelled_data):
    with np.load(fp_labelled_data) as in_file:
        images = in_file['images']
        coords = in_file['coords']
else:
    images, coords = data_annotations_to_npz(fp_data, fp_annotations, fp_labelled_data)

# Data Loader

In [156]:
dataset = AnnotatedDataLoader(
    images, coords,
    factor_augmentations = 1
)

In [157]:
dataloader = DataLoader(
    dataset,
    batch_size=32, shuffle=True,
    collate_fn=collate_fn_heatmap
)

# Model

In [158]:
class Model(nn.Module):
    def __init__(self, input_image_shape = (400, 400)):
        super().__init__()
        self.input_image_shape = input_image_shape
        self.input_nx, self.input_ny = self.input_image_shape
        # Convolutions
        self.conv1_nchannels = 1
        self.conv1_nconvs = 4
        self.conv1_convsize = 25
        self.conv1 = nn.Conv2d(
            self.conv1_nchannels,
            self.conv1_nconvs,
            self.conv1_convsize
        )
        self.conv1_activation = nn.ReLU()
        self.conv1_npooling = 5
        self.conv1_poolingstride = 2
        self.conv1_pooling = nn.MaxPool2d(
            self.conv1_npooling,
            stride=self.conv1_poolingstride
        )
        # TODO: add max_pooling layers
        # https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
        # Flatten
        self.flatten = nn.Flatten()
        # Denses
        self.linear1 = nn.Linear(
            in_features=234256,  # TODO calculate this based on parameters above, e.g. self.conv1_convsize, ...
            out_features=64
        )
        self.linear1_activation = nn.ReLU()
        self.dense = nn.Linear(
            in_features=64,
            out_features=IMG_OUTPUT_DIM**2
        )
        self.to_probability = nn.Sigmoid()
        return

    def forward(self, x):
        # Convolutions
        x = self.conv1_activation(self.conv1(x))
        x = self.conv1_pooling(x)
        # Flattern
        x = self.flatten(x)
        # Dense
        x = self.linear1_activation(
            self.linear1(x)
        )
        x = self.dense(x)
        return self.to_probability(x)

# Training

In [159]:
model = Model().to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
logs = []

In [161]:
epochs = tqdm( range(1), desc=f'Loss: {0.0:>7.3f}' )
for i_epoch in epochs:
    losses_epoch = []
    steps = tqdm(dataloader, desc=f'Epoch Steps - Loss: {0.0:>7.3f}')
    for x_train, y_train in steps:
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        y_train_pred = model(x_train)

        # Compute the loss and its gradients
        loss = loss_fn(y_train_pred, y_train)
        loss.backward()
        loss_value = loss.cpu().item()

        # Adjust learning weights
        optimizer.step()
        
        # Log
        losses_epoch.append(loss_value)
        steps.set_description(
            'Epoch Steps - Loss: {:>7.3f}'.format(loss_value)
        )
    logs.append([
        np.mean(losses_epoch),
        losses_epoch.copy()
    ])
    # Report
    epochs.set_description(
        'Loss: {:>7.3f}'.format( logs[-1][0] )
    )

Loss:   0.000:   0%|                                                                                                                                                                  | 0/1 [00:00<?, ?it/s]
Epoch Steps - Loss:   0.000:   0%|                                                                                                                                                    | 0/7 [00:00<?, ?it/s][A
Epoch Steps - Loss:  20.546:   0%|                                                                                                                                                    | 0/7 [00:03<?, ?it/s][A
Epoch Steps - Loss:  20.546:  14%|████████████████████                                                                                                                        | 1/7 [00:03<00:22,  3.73s/it][A
Epoch Steps - Loss:  21.752:  14%|████████████████████                                                                                                                     