In [57]:
# Inference
import sys
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/preprocessing')
sys.path.insert(1, 'H:/Projects/Kaggle/CZII-CryoET-Object-Identification/postprocessing')
import load
import augment
import os
import torch
import numpy as np
from monai.networks.nets import UNet
import metrics

In [2]:
root = load.get_root()

picks = load.get_picks_dict(root)

runs = os.listdir('H:/Projects/Kaggle/CZII-CryoET-Object-Identification/data/train/static/ExperimentRuns')
run = 'TS_6_4'

In [3]:
vol, coords, scales = load.get_run_volume_picks(root, run=run, level=0)
mask = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]))
points = load.get_picks_mask(vol.shape, picks, coords, int(scales[0]), pts=True)


In [4]:
params = augment.aug_params
params["final_size"] = (104,104,104)
params["flip_prob"] = 0.0
params["patch_size"] = (104,104,104)
params["rot_prob"] = 0.0

In [37]:
from monai.networks.nets import UNet
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(np.unique(points))

sample = augment.random_augmentation(vol, 
                            mask,
                            points,
                            num_samples=1, 
                            aug_params=params,
                            save=False)
src = sample[0]["source"].unsqueeze(0).unsqueeze(0).to(device)
tgt = sample[0]["target"]
pts = sample[0]["points"]
print(src.shape)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=7,
    channels=(64, 128, 256, 512),
    strides=(2, 2, 2),
    num_res_units=2,
    dropout=0.1,
).to(device)
model.load_state_dict(torch.load("UNet_v1-2.pth"))


model.eval()
pred = model(src).to('cpu')
pred_viewable = pred.argmax(1).squeeze()
src = src.to('cpu').squeeze()
prediction = torch.softmax(pred.squeeze(), dim=0)
print(prediction.shape)
pts.unique()

[0 1 2 3 4 5 6]
torch.Size([1, 1, 104, 104, 104])
torch.Size([7, 104, 104, 104])


metatensor([0., 4.])

In [69]:
def extract_points_from_target(target: torch.Tensor):
    target = np.array(target, dtype=np.int16)
    
    label_points = {label: [] for label in range(1, 7)}  # Initialize a dict for each label

    points = np.array(target.nonzero(), dtype=np.int16).T
    for point in points:
        label_points[target[tuple(point)]].append(list(point))
    return label_points

In [68]:
def extract_points_from_prediction(prediction, threshold = 0.):
    prediction = prediction.detach().numpy()
    
    labels = range(1,7)
    
    label_points = {label: [] for label in labels}  # Initialize a dict for each label

    for label in labels:
        channel = prediction[label]
        points = np.array(np.where(channel > threshold)).T
        for p in range(points.shape[0]):
            label_points[label].append(list(points[p]))

    return label_points

In [82]:
ref_pts = extract_points_from_target(pts.to(torch.int16))
cand_pts = extract_points_from_prediction(prediction=prediction, threshold=0.99999997)
print(ref_pts)
print(cand_pts)

{1: [], 2: [], 3: [], 4: [[34, 75, 94], [92, 73, 97]], 5: [], 6: []}
{1: [[7, 100, 90], [7, 100, 91], [8, 99, 90], [8, 99, 91], [8, 99, 92], [8, 100, 89], [8, 100, 90], [8, 100, 91], [8, 100, 92], [8, 101, 90], [8, 101, 91], [9, 99, 91], [9, 100, 90], [9, 100, 91], [9, 100, 92], [9, 101, 90], [9, 101, 91]], 2: [], 3: [[25, 96, 36], [26, 95, 36], [26, 96, 34], [26, 96, 35], [26, 96, 36], [26, 97, 36], [26, 98, 36], [27, 94, 36], [27, 95, 36], [27, 96, 34], [27, 96, 35], [27, 96, 36], [27, 96, 37], [27, 97, 36], [27, 98, 36], [28, 93, 36], [28, 94, 34], [28, 94, 35], [28, 94, 36], [28, 94, 37], [28, 95, 34], [28, 95, 35], [28, 95, 36], [28, 95, 37], [28, 96, 33], [28, 96, 34], [28, 96, 35], [28, 96, 36], [28, 96, 37], [28, 96, 38], [28, 97, 34], [28, 97, 35], [28, 97, 36], [28, 97, 37], [28, 98, 34], [28, 98, 35], [28, 98, 36], [28, 98, 37], [29, 94, 36], [29, 95, 36], [29, 96, 34], [29, 96, 35], [29, 96, 36], [29, 96, 37], [30, 94, 36], [30, 95, 36], [30, 96, 35], [30, 96, 36]], 4: [], 

In [86]:
from scipy.spatial import KDTree
def compute_metrics(reference_points, reference_radius, candidate_points):
    num_reference_particles = len(reference_points)
    num_candidate_particles = len(candidate_points)

    if len(reference_points) == 0:
        return 0, num_candidate_particles, 0

    if len(candidate_points) == 0:
        return 0, 0, num_reference_particles

    ref_tree = KDTree(reference_points)
    candidate_tree = KDTree(candidate_points)
    raw_matches = candidate_tree.query_ball_tree(ref_tree, r=reference_radius)
    matches_within_threshold = []
    for match in raw_matches:
        matches_within_threshold.extend(match)
    # Prevent submitting multiple matches per particle.
    # This won't be be strictly correct in the (extremely rare) case where true particles
    # are very close to each other.
    matches_within_threshold = set(matches_within_threshold)
    tp = int(len(matches_within_threshold))
    fp = int(num_candidate_particles - tp)
    fn = int(num_reference_particles - tp)
    return tp, fp, fn

In [71]:
radii = { 1:60,
          2:65,
          3:90,
          4:150,
          5:130,
          6:135 }

In [88]:
for label in range(1,7):
    print(radii[label])
    print(compute_metrics(ref_pts[label], radii[label] / 10, cand_pts[label]))

60
(0, 17, 0)
65
(0, 0, 0)
90
(0, 48, 0)
150
(0, 0, 2)
130
(0, 3, 0)
135
(0, 0, 0)


In [8]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact


print(f'# Particles Types Represented: {len(np.unique(tgt)) - 1}')
print(f'# Particles Types Predicted: {len(np.unique(prediction)) - 1}')



def plot_cross_section(i):
    vol1 = np.zeros(pts.shape)
    vol2 = pts

        
    plt.figure(figsize=(15, 5))
    alpha1 = 0.3
    alpha2 = 0.3
    

    # Slice at x-coordinate
    plt.subplot(131)
    plt.imshow(vol1[i, :, :], cmap="viridis", alpha=alpha1)
    plt.imshow(vol2[i, :, :], cmap="Reds", alpha=alpha2)  # Overlay mask with transparency
    plt.title(f'Slice at x={i}')

    # Slice at y-coordinate
    plt.subplot(132)
    plt.imshow(vol1[:, i, :], cmap="viridis", alpha=alpha1)
    plt.imshow(vol2[:, i, :], cmap="Reds", alpha=alpha2)
    plt.title(f'Slice at y={i}')

    # Slice at z-coordinate
    plt.subplot(133)
    plt.imshow(vol1[:, :, i], cmap="viridis", alpha=alpha1)
    plt.imshow(vol2[:, :, i], cmap="Reds", alpha=alpha2)
    plt.title(f'Slice at z={i}')

    plt.show()

# Interactive Slider for scrolling through slices
interact(plot_cross_section, i=(0, prediction.shape[0] - 1))

# Particles Types Represented: 4
# Particles Types Predicted: 6


interactive(children=(IntSlider(value=51, description='i', max=103), Output()), _dom_classes=('widget-interact…

<function __main__.plot_cross_section(i)>