# Label and Track Voids
Our first step in gathering the 3D positions of voids is to identify their location in each image of the tilt series, then track their movement across each image.

In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
from rtdefects.drift import compute_drift_from_images
from rtdefects.segmentation.pytorch import PyTorchSemanticSegmenter
from rtdefects.analysis import analyze_defects, convert_to_per_particle, compile_void_tracks
from rtdefects.io import load_file
from collections import defaultdict
from skimage.transform import AffineTransform, warp
from pathlib import Path
from tqdm import tqdm
import trackpy as tp
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


TODO: Convert from pixels to nm

## Load the Images
Get the names of positions of each image

In [2]:
images = []
for image in Path('images/').glob('tilt*.png'):
    images.append({
        'path': str(image),
        'frame': int(image.name[4:-4])
    })
images = pd.DataFrame(images).sort_values('frame')
print(f'Loaded {len(images)} from tilt series')

Loaded 10 from tilt series


## Segment them
Use the latest model from the void segmentation approach. The procedure for analyzing a single image is to:

1. Load image from disk into a standard representation: grayscale represented as a floating point between 0-1
2. Convert image into the form needed by a particular model
3. Run segmentation to get the pixels for each void
4. Run analysis to get a summary of the positions, sizes, etc for each void

In [3]:
segmenter = PyTorchSemanticSegmenter()
print(f'Loaded the {segmenter.model_path.name} segmentation model.')

Loaded the small_voids_031023.pth segmentation model.


In [4]:
results = []
for path in tqdm(images['path']):
    img = load_file(path)
    std_img = segmenter.transform_standard_image(img)
    labeled_img = segmenter.perform_segmentation(std_img)
    result = analyze_defects(labeled_img)
    result['labeled_img'] = labeled_img
    results.append(result)
results = pd.DataFrame(results)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:34<00:00,  3.41s/it]


In [5]:
images = pd.concat([images, results], axis=1)

We now have the locations and sizes of voids for each frame

## Use FFT-based Drift Correction
The drift between frames in a tilt series is large and FFTs provide a robust way to determine a drift between frames

In [6]:
def convert_labeled_image_to_mask(labeled_img: np.ndarray) -> np.ndarray:
    """Convert a labeled image to a 2D boolean array

    Args:
        labeled_img: Image to be converted
    Returns:
        A simple mask
    """
    return labeled_img.any(axis=0).astype(float)

Get the drift between all pairs of frames

In [7]:
n_pairs = (len(results) - 1) * len(results) // 2

In [8]:
drifts_in = np.zeros((n_pairs, len(results)))
drifts_out = np.zeros((n_pairs, 2))

In [9]:
pos = 0
for i in range(len(results)):
    image_i = convert_labeled_image_to_mask(results['labeled_img'].iloc[i])
    for j in range(i):
        image_j = convert_labeled_image_to_mask(results['labeled_img'].iloc[j])
        drift = compute_drift_from_images(image_i, image_j)
        drifts_in[pos, i] = 1
        drifts_in[pos, j] = -1
        drifts_out[pos, :] = drift
        pos += 1

Estimate the drift using least squares

In [10]:
drifts, _, _, _ = np.linalg.lstsq(drifts_in, drifts_out, rcond=None)

Plot them in a progressive series

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 3.5))

for frame, drift in zip(results['labeled_img'], drifts):
    # Make mask then adjust with the drift
    mask = convert_labeled_image_to_mask(frame)
    mask_shifted = warp(mask, AffineTransform(translation=drift))

    # Make a image read
    image = np.zeros((*mask.shape, 3), dtype=np.uint8) + 255
    image[:, :, :2] -= np.array(mask_shifted * 255, dtype=np.uint8)[:, :, None]

    ax.imshow(image, cmap='Blues', alpha=0.4)
ax.set_yticks(ax.set_xticks([]))

Adjust the positions using the computed drift

In [None]:
images['positions-no-drift'] = images.apply(lambda x: np.subtract(x['positions'], drifts[x['frame'] - 1, :]), axis=1)

## Run the Particle Tracking
We use [trackpy](https://soft-matter.github.io/trackpy/dev/), which expects each row in the dataframe to be a particle rather than a frame

In [None]:
particles = pd.concat(list(convert_to_per_particle(images, position_col='positions-no-drift'))).query('not touches_side')
particles.head(5)

Run the tracking, using a wide search range for the drift of a single void and no memory for voids being lost between frames.

Rationale: We are only looking for a few easy-to-track particles to use when determining the tilt axis

In [None]:
tracks = tp.link_df(particles, search_range=20, memory=2)
print(f'Found a total of {len(tracks.particle.value_counts())} unique particles out of {len(particles)} labelled.')

The output is the void in each frame assigned with a global ID, "particle"

In [None]:
tracks

We'll next produce a summary where we group the same particule into each row

In [None]:
void_tracks = compile_void_tracks(tracks)
void_tracks.sort_values('total_frames', ascending=False).head()

Plot the tracks for the voids

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

ax.set_yticks(ax.set_xticks([]))
for p in void_tracks.query('total_frames > 6')['positions']: 
    ax.plot(p[:, 0], p[:, 1])

fig.tight_layout()

## Save for later use
Let's save a few things separately.

First, the data for each frame

In [None]:
results.drop(columns=['labeled_img']).to_json('frame-data.json', orient='records', lines=True)

Then the summary of void tracks, in full detail

In [None]:
void_tracks.to_json('track-data.json', orient='records', lines=True)

Now the CSV of voids that are tracked across many frames coordinates in 2D

In [None]:
tracked_coords = defaultdict(list)
for rid, row in void_tracks.query('total_frames >= 8').iterrows():
    for i, (x, y) in enumerate(row['positions']):
        tracked_coords['id'].append(rid)
        tracked_coords['frame'].append(i + row['start_frame'])
        tracked_coords['x'].append(x)
        tracked_coords['y'].append(y)
    tracked_coords['r'].extend(row['radii'])
tracked_coords = pd.DataFrame(tracked_coords)

In [None]:
tracked_coords.to_csv('void-2d-coordinates.csv', index=False)