In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tifffile as tiff


In [None]:
def points_to_homogeneous(points):
    ''' Converts 2D points to homogeneous coordinates.
    
    Args:
        points (array): Nx2 array of 2D points
    Returns:
        array: Nx3 array of homogeneous coordinates
    '''
    N = points.shape[0]
    ones = np.ones((N, 1), dtype=np.float32)
    points_homogeneous = np.concatenate([points, ones], axis=1)
    assert(points_homogeneous.shape == (N, 3))
    return points_homogeneous

def get_Ai(xi_vector, xi_prime_vector):
    ''' Returns the A_i matrix discussed in the lecture for input vectors.
    
    Args:
        xi_vector (array): the x_i vector in homogeneous coordinates
        xi_vector_prime (array): the x_i_prime vector in homogeneous coordinates
    '''
    assert(xi_vector.shape == (3,) and xi_prime_vector.shape == (3,))
    zero_vector = np.zeros((3,), dtype=np.float32)
    xi, yi, wi = xi_prime_vector
    
    Ai = np.stack([
        np.concatenate([zero_vector, -wi*xi_vector, yi*xi_vector]),
        np.concatenate([wi*xi_vector, zero_vector, -xi*xi_vector]),
        # np.concatenate([-yi*xi_vector, xi*xi_vector, zero_vector]) this is not needed, so we comment it out
    ])
    assert(Ai.shape == (2, 9))
    return Ai

def get_A(points_source, points_target):
    ''' Returns the A matrix discussed in the lecture.
    
    Args:
        points_source (array): 3D homogeneous points from source image
        points_target (array): 3D homogeneous points from target image
    '''
    N = points_source.shape[0]
    correspondence_pairs = zip(points_source, points_target)
    
    A = np.concatenate([get_Ai(p1, p2) for (p1, p2) in correspondence_pairs])
    assert(A.shape == (2*N, 9))
    return A

def get_homography(points_source, points_target):
    ''' Returns the homography H.
    
    Args:
        points_source (array): 3D homogeneous points from source image
        points_target (array): 3D homogeneous points from target image        
    '''
    A = get_A(points_source, points_target)
    u, s, vh = np.linalg.svd(A)
    H = vh[-1].reshape(3, 3)
    H = H / H[2, 2]
    return H

def get_affine(points_source, points_target):
    ''' Returns the affine transformation matrix M computed using least squares.

    Args:
        points_source (array): Nx3 homogeneous points from source image
        points_target (array): Nx3 homogeneous points from target image        
    '''
    N = points_source.shape[0]
    assert(points_source.shape == (N, 3))
    assert(points_target.shape == (N, 3))

    A = []
    b = []

    for i in range(N):
        x, y, _ = points_source[i]
        xp, yp, _ = points_target[i]

        A.append([x, y, 1, 0, 0, 0])
        A.append([0, 0, 0, x, y, 1])

        b.append(xp)
        b.append(yp)

    A = np.asarray(A, dtype=np.float32)   # shape (2N, 6)
    b = np.asarray(b, dtype=np.float32)   # shape (2N,)

    # Least squares solution
    theta, _, _, _ = np.linalg.lstsq(A, b, rcond=None)

    # Construct affine matrix
    M = np.array([
        [theta[0], theta[1], theta[2]],
        [theta[3], theta[4], theta[5]],
        [0,        0,        1       ]
    ], dtype=np.float32)

    return M

In [None]:
mouse = 'jm065'
sat_perc = 99.9

In [None]:
mouse_path = f'data_proc/jm/{mouse}/'

# now find all subfolders in mouse_path that start with '_a' in the end
subfolders = [f.path for f in os.scandir(mouse_path) if f.is_dir() and f.name.endswith('_a')]
subfolders.sort()
print(subfolders)

In [None]:
# now in each in each of these find the fov tseries and append them to a list

wl = '1100nm'

fov_tseries_1100 = []

for p in subfolders:
    print(p)
    wl_path = os.path.join(p, 'fov', wl)
    # now in this path find a folder starting with 'TSeries' (there should only be one)
    tseries_folders = [f.path for f in os.scandir(wl_path) if f.is_dir() and f.name.startswith('TSeries')]
    tseries_folder = tseries_folders[0] 
    # now in this path find a file ending wiith '.tif'

    tif_files = [f.path for f in os.scandir(tseries_folder) if f.is_file() and f.name.endswith('.tif')]
    tif_file = tif_files[0]

    print(tif_file)

    # now read the tiff file, compute the mean across frames and append to fov_tseries
    # TODO: motion correction! (just use the suite2p script, probably the easiest)
    tiff_data = tiff.imread(tif_file)
    tiff_mean = tiff_data.mean(axis=0)
    fov_tseries_1100.append(tiff_mean)
    

In [None]:
# now get the suite2p mean images at 920 nm average (in ops.npy
fov_suite2p_920 = []
for p in subfolders:
    print(p)
    suite2p_path = os.path.join(p, 'suite2p', 'plane0')
    ops_file = os.path.join(suite2p_path, 'ops.npy')
    ops = np.load(ops_file, allow_pickle=True).item()
    mean_img_920 = ops['meanImg']
    fov_suite2p_920.append(mean_img_920)

In [None]:
# now plot them
import matplotlib.pyplot as plt

ages = ['P8', 'P9', 'P10', 'P11', 'P12', 'P13']

fig, axs = plt.subplots(1, len(fov_tseries_1100), figsize=(30, 5), dpi=500)
if len(fov_tseries_1100) == 1:
    axs = [axs]  # make it iterable
for i, ax in enumerate(axs):
    img = fov_tseries_1100[i]
    img_sat = np.clip(img, 0, np.percentile(img, sat_perc))
    ax.imshow(img_sat, cmap='gray')
    ax.set_title(f'{ages[i]}')
    ax.axis('off')

# set suptitle
fig.suptitle(f'Mouse {mouse} FOVs at 1100nm', fontsize=16)


In [None]:
# now plot them
import matplotlib.pyplot as plt

ages = ['P8', 'P9', 'P10', 'P11', 'P12', 'P13']

fig, axs = plt.subplots(1, len(fov_suite2p_920), figsize=(30, 5), dpi=500)
if len(fov_suite2p_920) == 1:
    axs = [axs]  # make it iterable
for i, ax in enumerate(axs):
    img = fov_suite2p_920[i]
    img_sat = np.clip(img, 0, np.percentile(img, sat_perc))
    ax.imshow(img_sat, cmap='gray')
    ax.set_title(f'{ages[i]}')
    ax.axis('off')

# set suptitle
fig.suptitle(f'Mouse {mouse} FOVs at 1100nm', fontsize=16)

In [None]:
all_imgs_r = fov_tseries_1100
all_imgs_g = fov_suite2p_920

In [None]:
import napari


In [None]:
# now open in napari to mark keypoints
man_curate = False

if man_curate:
    viewer = napari.Viewer()
    viewer = napari.Viewer()

    for i, img_r in enumerate(all_imgs_r):
        viewer.add_image(img_r, name=f'Red channel image {i}', contrast_limits=[0, np.percentile(img_r, sat_perc)])
        viewer.add_points(name=f'Red channel keypoints {i}')

    for i, img_g in enumerate(all_imgs_g):
        viewer.add_image(img_g, name=f'Green channel image {i}')
        viewer.add_points(name=f'Green channel keypoints {i}')

    napari.run()

In [None]:
# now save points as csv files for later processing
for layer in viewer.layers:
    if isinstance(layer, napari.layers.Points):
        layer.data = np.array(layer.data)  # ensure data is numpy array
        np.savetxt(f'{layer.name}_keypoints.csv', layer.data, delimiter=',')

In [None]:
# load the keypoints using a text reader
from numpy import loadtxt
# from pandas import read_csv

all_keypoints_r = []
all_keypoints_g = []

for i in range(len(all_imgs_r)):
    keypoints_r = loadtxt(f'Red channel keypoints {i}_keypoints.csv', delimiter=',')
    # keypoints_g = read_csv(f'Green channel keypoints {i}_keypoints.csv', delimiter=',').values
    all_keypoints_r.append(keypoints_r)
    # all_keypoints_g.append(keypoints_g)

In [None]:
all_imgs_r[0].shape

In [None]:
all_keypoints_r

In [None]:
# now show images side by side with keypoints connected by line
img_stack = np.concatenate((all_imgs_r[0], all_imgs_r[1]), axis=-1)
fig, ax = plt.subplots(dpi=500)
ax.imshow(img_stack, cmap='gray', vmin=0, vmax=np.percentile(img_stack, sat_perc))
ax.axvline(x=all_imgs_r[0].shape[1], color='white', linewidth=1)  # line separating images

#enumerate and zip
for i in range(len(all_keypoints_r[0])):
    p0 = all_keypoints_r[0][i]
    p1 = all_keypoints_r[1][i]

    ax.plot([p0[1], p1[1] + all_imgs_r[0].shape[1]], [p0[0], p1[0]], f'C{i}--', linewidth=0.5)  # note the offset for x in second image
    ax.scatter(p0[1], p0[0], edgecolors=f'C{i}', facecolors='none', linewidth=0.5)  # point in first image
    ax.scatter(p1[1] + all_imgs_r[0].shape[1], p1[0], edgecolors=f'C{i}', facecolors='none', linewidth=0.5)  # point in second image

# turn off axis
ax.axis('off')


In [None]:
# first test on points
pts_ref = all_keypoints_r[0]
pts_mov = all_keypoints_r[1]

pts_ref_homogeneous = points_to_homogeneous(pts_ref)
pts_mov_homogeneous = points_to_homogeneous(pts_mov)

In [None]:
transform_type = 'affine'

# if transform_type == 'homography':
#     A = get_homography(pts_mov_homogeneous, pts_ref_homogeneous)
# elif transform_type == 'affine':
#     A = get_affine(pts_mov_homogeneous, pts_ref_homogeneous)
A = get_homography(pts_mov_homogeneous, pts_ref_homogeneous) if transform_type == 'homography' else get_affine(pts_mov_homogeneous, pts_ref_homogeneous)

pts_mov_homogeneous_transformed = (A @ pts_mov_homogeneous.T).T
pts_mov_reg = pts_mov_homogeneous_transformed[:, :2] / pts_mov_homogeneous_transformed[:, 2:3]

In [None]:
# normalize H so that H[2,2] = 1
pts_mov_reg

In [None]:
plt.figure(figsize=(8, 8))
plt.scatter(pts_ref[:, 1], pts_ref[:, 0], marker='x', label='Reference points')
plt.scatter(pts_mov[:, 1], pts_mov[:, 0], marker='x', label='Moving points before registration')
plt.scatter(pts_mov_reg[:, 1], pts_mov_reg[:, 0], marker='x', label='Registered moving points')
plt.legend()
plt.xlim(0, all_imgs_r[0].shape[1])
plt.ylim(all_imgs_r[0].shape[0], 0)

In [None]:
from skimage.transform import ProjectiveTransform, AffineTransform, warp

img_ref = all_imgs_r[0]
img_mov = all_imgs_r[1]

# apply axis swap (image vs cartesian coordinates)
S = np.array([
    [0, 1, 0],
    [1, 0, 0],
    [0, 0, 1]
])
A_img =  S @ A @ S  # transform matrix in image coordinates

transform = AffineTransform(matrix=A_img) if transform_type == 'affine' else ProjectiveTransform(matrix=A_img)

img_mov_reg = warp(img_mov, inverse_map=transform.inverse, output_shape=img_ref.shape)

# now plot rgb overlay of reference and registered moving image
img_overlay = np.zeros((all_imgs_r[0].shape[0], all_imgs_r[0].shape[1], 3), dtype=np.float32)
# img_overlay[..., 0] = img_ref/ np.max(img_ref)  # red channel - reference
# img_overlay[..., 1] = img_mov_reg / np.max(img_mov_reg)  # green channel - registered moving
# same as above but with applying percentile saturation
img_overlay[..., 1] = np.clip(img_ref, 0, np.percentile(img_ref, sat_perc)) / np.percentile(img_ref, sat_perc)  # red channel - reference
img_overlay[..., 0] = np.clip(img_mov_reg, 0, np.percentile(img_mov_reg, sat_perc)) / np.percentile(img_mov_reg, sat_perc)  # green channel - registered moving
plt.figure(dpi=500, figsize=(8, 8))
plt.imshow(img_overlay)
plt.title('Overlay of reference (green) and registered moving (red) images')
plt.axis('off')


In [None]:
from cellpose import models
from tqdm import tqdm

# mn_image is your 512×512 numpy array (dtype float or uint8 etc.)

# 1. instantiate the model (use GPU if available)
model = models.CellposeModel(gpu=True, pretrained_model='cpsam')

all_mask0 = []
all_num_labels = []

for img in tqdm(all_imgs_r):
    # 2. Cellpose expects a list of images, possibly with channel dimension(s).
    #    If your image is single-channel, wrap it in a list.
    imgs = [img]  # list of one image # TODO 

    # 3. Run segmentation
    #    You can tune e.g. flow_threshold, cellprob_threshold, diameter, etc.
    masks, flows, styles = model.eval(
        imgs,
        diameter=None,
        flow_threshold=0.4,
        cellprob_threshold=0.0,
        resample=True,
        normalize=True,
        # other options you might want to adjust:
        # invert=False, rescale=None, augment=False, tile_overlap=0.1, min_size=15
    )

    # 4. masks[0] is the segmentation mask for your image
    mask0 = masks[0]

    # Example: inspect number of objects
    num_labels = mask0.max()
    print("Detected", num_labels, "objects")

    all_mask0.append(mask0)
    all_num_labels.append(num_labels)


In [None]:
# now plot the segmentation with 
def plot_segmentation_overlay(all_img, all_mask):
    fig, axs = plt.subplots(1, len(all_img), figsize=(8*len(all_img), 8), dpi=500)
    if len(all_img) == 1:
        axs = [axs]  # make it iterable
    for i, ax in enumerate(axs):
        # get random colors for each label
        ax.imshow(all_img[i], cmap='gray', vmin=0, vmax=np.percentile(all_img[i], sat_perc))
        # now plot contours using contour plot
        ax.contour(all_mask[i], colors='C0', linewidths=0.5)
        ax.set_title(f'Image {i} with segmentation contours')
        ax.axis('off')

In [None]:
plot_segmentation_overlay(all_imgs_r, all_mask0)

In [None]:
# now get centroids of ROIs and track them by doing registration and matching
from scipy.ndimage import center_of_mass
all_cent = []
for mask0 in all_mask0:
    num_labels = mask0.max()
    centroids = []
    for label in range(1, num_labels+1):
        centroid = center_of_mass(mask0 == label)
        centroids.append(centroid)
    all_cent.append(np.array(centroids))


In [None]:
from scipy.spatial import cKDTree
from scipy.optimize import linear_sum_assignment

In [None]:
# now show plot with lines connecting matched area
def plot_area_comparison(areas_ref, areas_mov, norm=False):
    if norm == True:
        # normalise each value based on the value in reference
        areas_mov = areas_mov / areas_ref
        areas_ref = areas_ref / areas_ref

    plt.figure(figsize=(2, 3), dpi=500)
    plt.scatter(np.zeros_like(areas_ref), areas_ref, s=1, c='C0', label='Reference areas')
    plt.scatter(np.ones_like(areas_mov), areas_mov, s=1, c='C2', label='Moving areas')
    for k in range(len(areas_ref)):
        plt.plot([0, 1], [areas_ref[k], areas_mov[k]], 'grey', linewidth=0.5, alpha=0.3)
    plt.xticks([0, 1], ['Reference', 'Moving'])
    plt.ylabel('Area (µm²)') if norm == False else plt.ylabel('Normalised area (au)')
    plt.title('Areas for matched cells')


In [None]:
# # now implement the tracking pipeline
# i = 5
# j = i+1
match_threshold = 10.0  # pixels
show_plot = False

px_size = 0.7 # um per pixel

# now make a variable to keep track of the trajectories (indices across frames)

for i in range(len(all_imgs_r)-1):
    j = i+1

    # compute affine for matched keypoints
    pts_ref = all_keypoints_r[i]
    pts_mov = all_keypoints_r[j]
    pts_ref_homogeneous = points_to_homogeneous(pts_ref)
    pts_mov_homogeneous = points_to_homogeneous(pts_mov)

    # compute transform
    A = get_affine(pts_mov_homogeneous, pts_ref_homogeneous)

    # apply to centroids 
    cent_ref = all_cent[i]
    cent_mov = all_cent[j]

    cent_mov_homogeneous = points_to_homogeneous(cent_mov)
    cent_mov_homogeneous_transformed = (A @ cent_mov_homogeneous.T).T
    cent_mov_reg = cent_mov_homogeneous_transformed[:, :2] / cent_mov_homogeneous_transformed[:, 2:3]

    # calculate pairwise distances and use linear sum assignment to find matches
    dist_mat = np.linalg.norm(cent_ref[:, np.newaxis, :] - cent_mov_reg[np.newaxis, :, :], axis=2)
    row_ind, col_ind = linear_sum_assignment(dist_mat)
    # filter matches by threshold
    valid_matches = dist_mat[row_ind, col_ind] < match_threshold
    matched_ref_indices = row_ind[valid_matches]
    matched_mov_indices = col_ind[valid_matches]

    ##### AREAS CALCULATION
        # now calculate for the matched cells the area of the ROIs
    areas_ref = []
    areas_mov = []
    for idx in matched_ref_indices:
        mask = all_mask0[i] == (idx + 1)  # labels start from 1
        area = np.sum(mask)
        areas_ref.append(area)
    for idx in matched_mov_indices:
        mask = all_mask0[j] == (idx + 1)  # labels start from 1
        area = np.sum(mask)
        areas_mov.append(area)
    areas_ref = np.array(areas_ref)
    areas_mov = np.array(areas_mov)

    # now calculate area in um^2
    areas_ref = areas_ref * (px_size ** 2)
    areas_mov = areas_mov * (px_size ** 2)

    ##### TRAJECTORY ACROSS ALL DAYS CALCULATION

    if i == 0:
        n_match_d0 = len(matched_ref_indices)
        traj_indices = np.nan * np.ones((n_match_d0, len(all_imgs_r)))
        traj_indices[:, 0] = matched_ref_indices
        traj_indices[:, 1] = matched_mov_indices
        traj_areas = np.nan * np.ones((n_match_d0, len(all_imgs_r)))
        traj_areas[:, 0] = areas_ref
        traj_areas[:, 1] = areas_mov
    else:
        # now check which of the matched_ref_indices are in the previous matched_mov_indices
        prev_matched_mov_indices = traj_indices[:, i]
        # now find these and add entries tro traj_indices
        for k in range(len(matched_ref_indices)):
            ref_idx = matched_ref_indices[k]
            mov_idx = matched_mov_indices[k]
            if ref_idx in prev_matched_mov_indices:
                # find the row in traj_indices where prev matched mov index is located
                row_idx = np.where(traj_indices[:, i] == ref_idx)[0][0]
                traj_indices[row_idx, j] = mov_idx
                traj_areas[row_idx, j] = areas_mov[k]


########## PLOTTING

    if show_plot:
        # scatter before and after
        plt.figure(figsize=(8, 8), dpi=500)
        plt.scatter(cent_ref[:, 1], cent_ref[:, 0], marker='x', c='C0', s=10, alpha=0.5, label='Reference centroids')
        # plt.scatter(cent_mov[:, 1], cent_mov[:, 0], marker='x', c='C1', s=10, alpha=0.5, label='Moving centroids before registration')
        plt.scatter(cent_mov_reg[:, 1], cent_mov_reg[:, 0], marker='x', c='C2', s=10, alpha=0.5, label='Registered moving centroids')
        plt.legend()
        plt.xlim(0, all_imgs_r[0].shape[1])
        plt.ylim(all_imgs_r[0].shape[0], 0)

        # now plot matches in without alpha and non-matches with alpha=0.3
        plt.figure(figsize=(8, 8), dpi=500)
        plt.scatter(cent_ref[:, 1], cent_ref[:, 0], marker='x', c='C0', s=10, alpha=0.3, label='Reference centroids')
        plt.scatter(cent_mov_reg[:, 1], cent_mov_reg[:, 0], marker='x', c='C2',s=10, alpha=0.3, label='Registered moving centroids')
        plt.scatter(cent_ref[matched_ref_indices, 1], cent_ref[matched_ref_indices, 0], c='C0', marker='x', s=10, alpha=1.0, label='Matched reference centroids')
        plt.scatter(cent_mov_reg[matched_mov_indices, 1], cent_mov_reg[matched_mov_indices, 0], c='C2', marker='x', s=10, alpha=1.0, label='Matched registered moving centroids')
        plt.legend()
        plt.xlim(0, all_imgs_r[0].shape[1])
        plt.ylim(all_imgs_r[0].shape[0], 0)




        fig, axs = plt.subplots(2, 1, figsize=(6, 8), dpi=500)
        # make sure both histograms use the same bins
        bins = np.linspace(min(min(areas_ref), min(areas_mov)), max(max(areas_ref), max(areas_mov)), 20)
        axs[0].hist(areas_ref, bins=bins, alpha=0.7, label='Reference areas', color='C0')
        axs[0].hist(areas_mov, bins=bins, alpha=0.7, label='Moving areas', color='C2')
        axs[0].set_title('Histogram of ROI areas (matched cells)')
        axs[0].set_xlabel('Area (µm²)')
        axs[0].set_ylabel('Count')
        axs[0].legend()

        bins = np.linspace(min(areas_mov - areas_ref), max(areas_mov - areas_ref), 20)
        axs[1].hist(areas_mov - areas_ref, bins=bins, alpha=0.7, color='C3')
        axs[1].axvline(0, color='k', linestyle='--')
        axs[1].set_title('Histogram of ROI area differences (Moving - Reference)')
        axs[1].set_xlabel('Area difference (µm²)')
        axs[1].set_ylabel('Count')
        plt.tight_layout()
        plt.show()

        plot_area_comparison(areas_ref, areas_mov)
        # now plot normalised to reference
        plot_area_comparison(areas_ref, areas_mov, norm=True)

In [None]:
# now make traj_indices a numpy array
plt.matshow(traj_indices, cmap='viridis', aspect='auto')
plt.matshow(traj_areas, cmap='viridis', aspect='auto')

In [None]:
x_stack

In [None]:
# now get the areas that are there for all days
plt.figure(figsize=(4, 8), dpi=300)
valid_area_mask = ~np.isnan(traj_areas).any(axis=1)
traj_areas_valid = traj_areas[valid_area_mask, :]

for i in range(traj_areas_valid.shape[0]):
    plt.plot(traj_areas_valid[i, :], alpha=0.5, color=f'C{i}', linewidth=0.5)
    plt.scatter(np.arange(traj_areas_valid.shape[1]), traj_areas_valid[i, :], s=3, color=f'C{i}')
plt.xlabel('Day index')
plt.ylabel('Area (µm²)')
plt.title('Cell area trajectories for cells tracked across all days')   
plt.show()

# now calculate correlation coefficient and linear fit 
x_traj = np.arange(traj_areas_valid.shape[1])
x_stack = np.tile(x_traj, (traj_areas_valid.shape[0], 1))
x = x_stack.flatten()
y = traj_areas_valid.flatten()
from scipy.stats import linregress
slope, intercept, r_value, p_value, std_err = linregress(x, y)
print(f'Linear fit slope: {slope}, intercept: {intercept}, R^2: {r_value**2}, p-value: {p_value}')

r = np.corrcoef(x, y)[0, 1]
print(f'Correlation coefficient: {r}')

plt.figure(figsize=(4, 8), dpi=300)
plt.scatter(x+np.random.normal(scale=0.05, size=y.shape), y, s=3, alpha=1)
plt.plot(x, intercept + slope * x, 'k--', alpha=0.5, label='Fitted line')
plt.xlabel('Day index')
plt.ylabel('Area (µm²)')
plt.title(f'Area vs Day index for tracked cells (R={r:.2f})')
plt.legend()
plt.show()

norm_areas_mat = traj_areas_valid / traj_areas_valid[:, 0:1]
y_norm = norm_areas_mat.flatten()

slope_norm, intercept_norm, r_value_norm, p_value_norm, std_err_norm = linregress(x, y_norm)
print(f'Normalised Linear fit slope: {slope_norm}, intercept: {intercept_norm}, R^2: {r_value_norm**2}, p-value: {p_value_norm}')

r= np.corrcoef(x, y_norm)[0, 1]
print(f'Normalised Correlation coefficient: {r}')
plt.figure(figsize=(4, 8), dpi=300)
plt.scatter(x+np.random.normal(scale=0.05, size=y_norm.shape), y_norm, s=3, alpha=1)
plt.plot(x, intercept_norm + slope_norm * x, 'k--', alpha=0.5, label='Fitted line')
plt.xlabel('Day index')
plt.ylabel('Normalised area (au)')
plt.title(f'Normalised Area vs Day index for tracked cells (R={r:.2f})')
plt.legend()
plt.show()



# now normalised to first day
plt.figure(figsize=(4, 8), dpi=300)
for i in range(traj_areas_valid.shape[0]):
    norm_areas = traj_areas_valid[i, :] / traj_areas_valid[i, 0]
    plt.plot(norm_areas, alpha=0.5, color=f'C{i}', linewidth=0.5)
    plt.scatter(np.arange(traj_areas_valid.shape[1]), norm_areas, s=3, color=f'C{i}')
plt.xlabel('Day index')
plt.ylabel('Normalised area (au)')
plt.title('Normalised cell area trajectories for cells tracked across all days')
plt.show()


In [None]:
# now get the areas that are there for all days
plt.figure(figsize=(4, 8))
valid_area_mask = ~np.isnan(traj_areas).any(axis=1)
traj_areas_valid = traj_areas[valid_area_mask, :]

for i in range(traj_areas_valid.shape[0]):
    plt.plot(traj_areas_valid[i, :], '-o', alpha=0.5)
plt.xlabel('Day index')
plt.ylabel('Area (µm²)')
plt.title('Cell area trajectories for cells tracked across all days')   