In [1]:
import os
import glob
import json
import numpy as np

import copick
from copick.impl.filesystem import CopickRootFSSpec
import zarr

import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import skimage

In [3]:
data_path='../input/czii-cryo-et-object-identification/'
train_path=data_path+'train/'
test_path=data_path+'test/'
experiment_runs=glob.glob(train_path+'static/ExperimentRuns/**')
experiment_runs=[os.path.split(e)[1] for e in experiment_runs]
experiment_runs

['TS_5_4', 'TS_69_2', 'TS_6_4', 'TS_6_6', 'TS_86_3', 'TS_73_6', 'TS_99_9']

In [4]:
config_blob = """{
    "name": "czii_cryoet_mlchallenge_2024",
    "description": "2024 CZII CryoET ML Challenge training data.",
    "version": "1.0.0",

    "pickable_objects": [
        {
            "name": "apo-ferritin",
            "is_particle": true,
            "pdb_id": "4V1W",
            "label": 1,
            "color": [  0, 117, 220, 128],
            "radius": 60,
            "map_threshold": 0.0418
        },
        {
            "name": "beta-amylase",
            "is_particle": true,
            "pdb_id": "1FA2",
            "label": 2,
            "color": [153,  63,   0, 128],
            "radius": 65,
            "map_threshold": 0.035
        },
        {
            "name": "beta-galactosidase",
            "is_particle": true,
            "pdb_id": "6X1Q",
            "label": 3,
            "color": [ 76,   0,  92, 128],
            "radius": 90,
            "map_threshold": 0.0578
        },
        {
            "name": "ribosome",
            "is_particle": true,
            "pdb_id": "6EK0",
            "label": 4,
            "color": [  0,  92,  49, 128],
            "radius": 150,
            "map_threshold": 0.0374
        },
        {
            "name": "thyroglobulin",
            "is_particle": true,
            "pdb_id": "6SCJ",
            "label": 5,
            "color": [ 43, 206,  72, 128],
            "radius": 130,
            "map_threshold": 0.0278
        },
        {
            "name": "virus-like-particle",
            "is_particle": true,
            "pdb_id": "6N4V",            
            "label": 6,
            "color": [255, 204, 153, 128],
            "radius": 135,
            "map_threshold": 0.201
        }
    ],

    "overlay_root": "XXXXXXXXXXtrain/overlay",

    "overlay_fs_args": {
        "auto_mkdir": true
    },

    "static_root": "XXXXXXXXXXtrain/static/"
}"""

config_blob=config_blob.replace('XXXXXXXXXX',data_path)
copick_config=json.loads(config_blob)

copick_config_path = "copick.config"
with open(copick_config_path, "w") as f:
    f.write(config_blob)

In [5]:
root = CopickRootFSSpec.from_file(copick_config_path)

In [6]:
import json

def get_particles(run,scale=10):
    path=train_path+'overlay/ExperimentRuns/'+run+'Picks/'
    path=train_path+'/overlay/ExperimentRuns/'+run+'/Picks/'
    json_files=glob.glob(path+'*.json')
    particles={}
    particles_dict={}
    for j in json_files:
        with open(j,'rt') as jf:
            data=json.load(jf)
            for j in data:
                particles_dict[data['pickable_object_name']]=np.array([[c['location']['z']/scale, c['location']['x']/scale,c['location']['y']/scale] for c in data['points']])
            particles[run]=particles_dict
    return particles

def get_points(run):
    path=train_path+'overlay/ExperimentRuns/'+run+'Picks/'
    path=train_path+'/overlay/ExperimentRuns/'+run+'/Picks/'
    json_files=glob.glob(path+'*.json')
    particles={}
    particles_dict={}
    for j in json_files:
        with open(j,'rt') as jf:
            data=json.load(jf)
            for j in data:
                particles_dict[data['pickable_object_name']]=data['points']
            particles[run]=particles_dict
    return particles

def get_radius(particle_type, scale=1):
    for po in copick_config['pickable_objects']:
        if po['name']==particle_type:
            return po['radius']/scale
    return None

In [7]:
#Annexed from copick source
#Creates 3D spherical annotations based on point data
def from_picks(points, 
               seg_volume,
               radius: float = 10.0, 
               label_value: int = 1,
               voxel_spacing: float = 10):
    """
    Paints picks into a segmentation volume as spheres.

    Parameters:
    -----------
    pick : copick.models.CopickPicks
        Copick object containing `points`, where each point has a `location` attribute with `x`, `y`, `z` coordinates.
    seg_volume : numpy.ndarray
        3D segmentation volume (numpy array) where the spheres are painted. Shape should be (Z, Y, X).
    radius : float, optional
        The radius of the spheres to be inserted in physical units (not voxel units). Default is 10.0.
    label_value : int, optional
        The integer value used to label the sphere regions in the segmentation volume. Default is 1.
    voxel_spacing : float, optional
        The spacing of voxels in the segmentation volume, used to scale the radius of the spheres. Default is 10.

    Returns:
    --------
    numpy.ndarray
        The modified segmentation volume with spheres inserted at pick locations.
    """
        
    def create_sphere(shape, center, radius, val):
        """Creates a 3D sphere within the given shape, centered at the given coordinates."""
        zc, yc, xc = center
        z, y, x = np.indices(shape)
        
        # Compute squared distance from the center
        distance_sq = (x - xc)**2 + (y - yc)**2 + (z - zc)**2
        
        # Create a mask for points within the sphere
        sphere = np.zeros(shape, dtype=np.float32)
        sphere[distance_sq <= radius**2] = val
        return sphere

    def get_relative_target_coordinates(center, delta, shape):
        """
        Calculate the low and high index bounds for placing a sphere within a 3D volume, 
        ensuring that the indices are clamped to the valid range of the volume dimensions.
        """

        low = max(int(np.floor(center) - delta), 0)
        high = min(int(np.ceil(center) + delta + 1), shape)

        return low, high

    # Adjust radius for voxel spacing
    radius_voxel = radius / voxel_spacing
    delta = int(np.ceil(radius_voxel))

    # Get volume dimensions
    vol_shape_x, vol_shape_y, vol_shape_z = seg_volume.shape

    # Paint each pick as a sphere
    for pick in points:
        
        # Adjust the pick's location for voxel spacing
        cx, cy, cz = pick['location']['z'] / voxel_spacing, pick['location']['y'] / voxel_spacing, pick['location']['x'] / voxel_spacing

        # Calculate subarray bounds, clamped to the valid volume dimensions
        xLow, xHigh = get_relative_target_coordinates(cx, delta, vol_shape_x)
        yLow, yHigh = get_relative_target_coordinates(cy, delta, vol_shape_y)
        zLow, zHigh = get_relative_target_coordinates(cz, delta, vol_shape_z)

        # Subarray shape
        subarray_shape = (xHigh - xLow, yHigh - yLow, zHigh - zLow)

        # Compute the local center of the sphere within the subarray
        local_center = (cx - xLow, cy - yLow, cz - zLow)

        # Create the sphere
        sphere = create_sphere(subarray_shape, local_center, radius_voxel, label_value)

        # Assign Sphere to Segmentation Target Volume
        seg_volume[xLow:xHigh, yLow:yHigh, zLow:zHigh] = np.maximum(seg_volume[xLow:xHigh, yLow:yHigh, zLow:zHigh], sphere)

    return seg_volume

In [8]:
def get_hex_color(c):#utility function to get hexadecimal HTML color code
    return f'#{c[0]:02x}{c[1]:02x}{c[2]:02x}'

In [9]:
axis='xy'

def process_run(): 
    #extract a tomogram and points from the selected expiriment run 
    global volume, particles, points, run_str, run
    with output:
        run_str=runs_dropdown.value
        run = root.get_run(run_str)
        tomogram = run.get_voxel_spacing(10).get_tomogram("denoised")
        group = zarr.open(tomogram.zarr())
        arrays = list(group.arrays())
        _, volume = arrays[0]
        particles=get_particles(run_str)
        points=get_points(run_str)

def on_run_change(change):
    #called when expiriment run is changed
    global axis
    with output:
        process_run()
        process_volume_and_mask()
        reset_widgets()
        update()

def process_volume_and_mask():
    #create volume from tomogram data, do some minimal processing
    #create 3d mask data from points
    global mask, volume, vol
    with output:#necessary for debugging or else ipwidgets supress all exceptions
        vol = np.clip(volume, np.percentile(volume, 0.5), np.percentile(volume, 99.5))
        vol = (vol - vol.min()) / (vol.max() - vol.min())
        
        mask=np.zeros(vol.shape+(3,), dtype=np.uint8)
        
        for p in copick_config['pickable_objects']:#create colored mask
            particle_type=p['name']
            color=np.array(p['color'],dtype=np.uint8)[0:3].reshape(1,1,1,3)
            m=np.zeros_like(volume, dtype=np.uint8)
            m=from_picks(points=points[run_str][particle_type], 
                           seg_volume=m,
                           radius=get_radius(particle_type), 
                           label_value = 1,
                           voxel_spacing = 10)
            m=np.tile(m[:,:,:,None], (1,1,1,3))
            mask+=m*color
    
        #rotated views
        if axis=='xz':
            print(vol.shape, mask.shape)
            vol=np.transpose(np.swapaxes(vol, 0,2),(0,2,1))
            mask=np.transpose(np.swapaxes(mask, 0,2),(0,2,1,3))
            print(vol.shape, mask.shape)
    
        if axis=='zy':
            vol=np.swapaxes(vol, 0,1)
            mask=np.swapaxes(mask, 0,1)

def process_image():
    with output:
        n=slice_slider.value
        image=vol[n]
        image=image[:,:,None]
        image=np.tile(image,(1,1,3))
        
        a=alpha_slider.value
        msk=mask[n]/255.0
        image=image*(msk>0)*(1-a)+msk*a+image*(msk==0)#needed to overlay colorful mask over BW image without clipping
        #image=image.clip(0,1)
    return image
    
def update():#this function is called when slice number is changed
    global image
    with output:
        image=process_image()
        redraw_plot()

def reset_widgets():#this function is called when slice slider max value needs to be adjusted
    slice_slider.max=len(vol)-1


def on_value_change(change):#This is the ipwidgets callback
    update()

def prev_slide(instance):
    slice_slider.value=max(slice_slider.value-1,0)
    update()

def next_slide(instance):
    slice_slider.value=min(slice_slider.value+1,slice_slider.max-1)
    update()

def on_view_change(change):
    global axis
    with output:
        axis=view_radiobutton.value
        process_volume_and_mask()
        reset_widgets()
        update()

def redraw_plot():#redraw the image
    global fig, imshow, image
    with output:
        clear_output(True)
        fig=plt.figure(figsize=(10,10))
        imshow=plt.imshow(image)
        plt.xticks([]), plt.yticks([])
        plt.axis("off")
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.show()

In [10]:
#slice number
slice_slider=widgets.IntSlider(
    value=0,
    min=0,
    max=1,
    step=1,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d')
slice_slider.observe(on_value_change, names='value')

buttonPrev = widgets.Button(
    description='Previous',
    disabled=False,
    button_style='', 
    tooltip='Previous',)
buttonPrev.on_click(prev_slide)

buttonNext = widgets.Button(
    description='Next',
    disabled=False,
    button_style='', 
    tooltip='Next',)
buttonNext.on_click(next_slide)

#mask transparency
alpha_slider=widgets.FloatSlider(
    value=0.5,
    min=0.0,
    max=1.0,
    step=0.05,
    description='Transparency',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',)
alpha_slider.observe(on_value_change, names='value')

#Horizontal radiobuttons requires some mad html slillz
_style = widgets.HTML(
    "<style>.widget-radio-box {flex-direction: row !important;}.widget-radio-box"
    " label{margin:5px !important;width: 120px !important;}</style>",
    layout=widgets.Layout(display="none"),)
view_radiobutton=widgets.RadioButtons(
    options=['xy', 'xz', 'zy'],
    description='View:',
    disabled=False,)
h_radiobutton=widgets.HBox([view_radiobutton, _style])
view_radiobutton.observe(on_view_change, names='value')

#labels to show color codings for particles
labels=[]
for (n,c) in [(chr(0x2B24)+p['name'],get_hex_color(p['color'][0:3])) for p in copick_config['pickable_objects']]:
    l=widgets.Label(value=n, layout=widgets.Layout(height='50%'))
    l.style.text_color=c
    l.style.height=1
    labels.append(l)

labels=widgets.HBox(labels)

#to select experiment run
runs_dropdown=widgets.Dropdown(options=[(er,er) for er in experiment_runs],description='Run:',)
runs_dropdown.observe(on_run_change, names='value')


In [None]:
output = widgets.Output()
process_run()
process_volume_and_mask()
image=process_image()
update()
reset_widgets()

HBox1=widgets.HBox([buttonPrev, slice_slider,buttonNext])
HBox2=widgets.HBox([h_radiobutton, alpha_slider])
Box=widgets.VBox([runs_dropdown, HBox2, labels,HBox1])
display(Box, output)