In [1]:
import os
import glob
import time
import sys
import warnings
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import torch
from tqdm import tqdm
from ultralytics import YOLO
import zarr
from scipy.spatial import cKDTree
from collections import defaultdict

In [2]:
import lightning.pytorch as pl
from czii_helper import *
from dataset import *
from model2 import *
from datetime import datetime
import pytz
import sys
from typing import List, Tuple, Union
import zarr


In [10]:
class Model(pl.LightningModule):
    def __init__(
        self, 
        spatial_dims: int = 3,
        in_channels: int = 1,
        out_channels: int = 7,
        channels: Union[Tuple[int, ...], List[int]] = (48, 64, 80, 80),
        strides: Union[Tuple[int, ...], List[int]] = (2, 2, 1),
        num_res_units: int = 1,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = UNet(
            spatial_dims=self.hparams.spatial_dims,
            in_channels=self.hparams.in_channels,
            out_channels=self.hparams.out_channels,
            channels=self.hparams.channels,
            strides=self.hparams.strides,
            num_res_units=self.hparams.num_res_units,
        )
    def forward(self, x):
        return self.model(x)

channels = (48, 64, 80, 80)
strides_pattern = (2, 2, 1)
num_res_units = 1
def extract_3d_patches_minimal_overlap(arrays: List[np.ndarray], patch_size: int) -> Tuple[List[np.ndarray], List[Tuple[int, int, int]]]:
    if not arrays or not isinstance(arrays, list):
        raise ValueError("Input must be a non-empty list of arrays")
    
    # Verify all arrays have the same shape
    shape = arrays[0].shape
    if not all(arr.shape == shape for arr in arrays):
        raise ValueError("All input arrays must have the same shape")
    
    if patch_size > min(shape):
        raise ValueError(f"patch_size ({patch_size}) must be smaller than smallest dimension {min(shape)}")
    
    m, n, l = shape
    patches = []
    coordinates = []
    
    # Calculate starting positions for each dimension
    x_starts = calculate_patch_starts(m, patch_size)
    y_starts = calculate_patch_starts(n, patch_size)
    z_starts = calculate_patch_starts(l, patch_size)
    
    # Extract patches from each array
    for arr in arrays:
        for x in x_starts:
            for y in y_starts:
                for z in z_starts:
                    patch = arr[
                        x:x + patch_size,
                        y:y + patch_size,
                        z:z + patch_size
                    ]
                    patches.append(patch)
                    coordinates.append((x, y, z))
    
    return patches, coordinates
def reconstruct_array(patches: List[np.ndarray], 
                     coordinates: List[Tuple[int, int, int]], 
                     original_shape: Tuple[int, int, int]) -> np.ndarray:
    reconstructed = np.zeros(original_shape, dtype=np.int64)  # To track overlapping regions
    
    patch_size = patches[0].shape[0]
    
    for patch, (x, y, z) in zip(patches, coordinates):
        reconstructed[
            x:x + patch_size,
            y:y + patch_size,
            z:z + patch_size
        ] = patch
        
    
    return reconstructed
def calculate_patch_starts(dimension_size: int, patch_size: int) -> List[int]:
    if dimension_size <= patch_size:
        return [0]
        
    # Calculate number of patches needed
    n_patches = np.ceil(dimension_size / patch_size)
    
    if n_patches == 1:
        return [0]
    
    # Calculate overlap
    total_overlap = (n_patches * patch_size - dimension_size) / (n_patches - 1)
    
    # Generate starting positions
    positions = []
    for i in range(int(n_patches)):
        pos = int(i * (patch_size - total_overlap))
        if pos + patch_size > dimension_size:
            pos = dimension_size - patch_size
        if pos not in positions:  # Avoid duplicates
            positions.append(pos)
    
    return positions
import pandas as pd

def dict_to_df(coord_dict, experiment_name):
    # Create lists to store data
    all_coords = []
    all_labels = []
    
    # Process each label and its coordinates
    for label, coords in coord_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))
    
    # Concatenate all coordinates
    all_coords = np.vstack(all_coords)
    
    df = pd.DataFrame({
        'experiment': experiment_name,
        'particle_type': all_labels,
        'x': all_coords[:, 0],
        'y': all_coords[:, 1],
        'z': all_coords[:, 2]
    })

    
    return df
from typing import List, Tuple, Union
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
from monai.transforms import (
    Compose, 
    EnsureChannelFirstd, 
    Orientationd,  
    AsDiscrete,  
    RandFlipd, 
    RandRotate90d, 
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)
TRAIN_DATA_DIR = "../input/mask"
import json
copick_config_path = TRAIN_DATA_DIR + "/copick.config"

with open(copick_config_path) as f:
    copick_config = json.load(f)

copick_config['static_root'] = '../input/czii-cryo-et-object-identification/train/static'

copick_test_config_path = 'copick_test.config'

with open(copick_test_config_path, 'w') as outfile:
    json.dump(copick_config, outfile)
import copick

root = copick.from_file(copick_test_config_path)

copick_user_name = "copickUtils"
copick_segmentation_name = "paintedPicks"
voxel_size = 10
tomo_type = "denoised"
inference_transforms = Compose([
    EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
    NormalizeIntensityd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS")
])
import cc3d

id_to_name = {1: "apo-ferritin", 
              2: "beta-amylase",
              3: "beta-galactosidase", 
              4: "ribosome", 
              5: "thyroglobulin", 
              6: "virus-like-particle"}
BLOB_THRESHOLD = 200
CERTAINTY_THRESHOLD = 0.05

classes = [1, 2, 3, 4, 5, 6]
import torch
import numpy as np
import pandas as pd
import cc3d
from monai.data import CacheDataset
from monai.transforms import Compose, EnsureType
from torch import nn
from tqdm import tqdm
from monai.networks.nets import UNet
from monai.losses import TverskyLoss
from monai.metrics import DiceMetric

def load_models(model_paths):
    models = []
    for model_path in model_paths:
        channels = (48, 64, 80, 80)
        strides_pattern = (2, 2, 1)       
        num_res_units = 1
        learning_rate = 1e-3
        num_epochs = 100
        model = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units)
        
        weights =torch.load(model_path)['state_dict']
        model.load_state_dict(weights)
        model.to('cuda')
        model.eval()
        models.append(model)
    return models


model_paths = [
    #'../model/UNet-Model-val_metric0.450.ckpt',
    '../model/model_3dunet.ckpt',
]


models = load_models(model_paths)
def ensemble_prediction_tta(models, input_tensor, threshold=0.5):
    probs_list = []
    data_copy0 = input_tensor.clone()
    data_copy0=torch.flip(data_copy0, dims=[2])
    data_copy1 = input_tensor.clone()
    data_copy1=torch.flip(data_copy1, dims=[3])
    data_copy2 = input_tensor.clone()
    data_copy2=torch.flip(data_copy2, dims=[4])
    data_copy3 = input_tensor.clone()
    data_copy3 = data_copy3.rot90(1, dims=[3, 4])
    with torch.no_grad():
        model_output0 = model(input_tensor)
        model_output1 = model(data_copy0)
        model_output1=torch.flip(model_output1, dims=[2])
        model_output2 = model(data_copy1)
        model_output2=torch.flip(model_output2, dims=[3])
        model_output3 = model(data_copy2)
        model_output3=torch.flip(model_output3, dims=[4])
        probs0 = torch.softmax(model_output0[0], dim=0)
        probs1 = torch.softmax(model_output1[0], dim=0)
        probs2 = torch.softmax(model_output2[0], dim=0)
        probs3 = torch.softmax(model_output3[0], dim=0)
        probs_list.append(probs0)
        probs_list.append(probs1)
        probs_list.append(probs2)
        probs_list.append(probs3)
    avg_probs = torch.mean(torch.stack(probs_list), dim=0)
    thresh_probs = avg_probs > threshold
    _, max_classes = thresh_probs.max(dim=0)
    return max_classes
sub=[]
for model in models:
    with torch.no_grad():
        location_df = []
        for run in root.runs:
            tomo = run.get_voxel_spacing(10)
            tomo = tomo.get_tomogram(tomo_type).numpy()
            tomo_patches, coordinates = extract_3d_patches_minimal_overlap([tomo], 96)
            tomo_patched_data = [{"image": img} for img in tomo_patches]
            tomo_ds = CacheDataset(data=tomo_patched_data, transform=inference_transforms, cache_rate=1.0)
            pred_masks = []
            for i in tqdm(range(len(tomo_ds))):
                input_tensor = tomo_ds[i]['image'].unsqueeze(0).to("cuda")
                max_classes = ensemble_prediction_tta(models, input_tensor, threshold=CERTAINTY_THRESHOLD)
                pred_masks.append(max_classes.cpu().numpy())
            reconstructed_mask = reconstruct_array(pred_masks, coordinates, tomo.shape)
            location = {}
            for c in classes:
                cc = cc3d.connected_components(reconstructed_mask == c)
                stats = cc3d.statistics(cc)
                zyx = stats['centroids'][1:] * 10.012444  # 转换单位
                zyx_large = zyx[stats['voxel_counts'][1:] > BLOB_THRESHOLD]
                xyz = np.ascontiguousarray(zyx_large[:, ::-1])
                location[id_to_name[c]] = xyz
            df = dict_to_df(location, run.name)
            location_df.append(df)
        location_df = pd.concat(location_df)
        location_df.insert(loc=0, column='id', value=np.arange(len(location_df)))

get_tomogram is deprecated, use get_tomograms instead. Results may be incomplete
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 394.63it/s]
100%|██████████| 98/98 [00:05<00:00, 19.40it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 366.41it/s]
100%|██████████| 98/98 [00:04<00:00, 19.74it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 352.36it/s]
100%|██████████| 98/98 [00:05<00:00, 19.11it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 374.99it/s]
100%|██████████| 98/98 [00:05<00:00, 19.11it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 364.47it/s]
100%|██████████| 98/98 [00:05<00:00, 19.51it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 401.14it/s]
100%|██████████| 98/98 [00:05<00:00, 19.58it/s]
Loading dataset: 100%|██████████| 98/98 [00:00<00:00, 354.63it/s]
100%|██████████| 98/98 [00:05<00:00, 19.57it/s]


In [12]:
location_df

Unnamed: 0,id,experiment,particle_type,x,y,z
0,0,TS_5_4,thyroglobulin,4552.042678,175.278626,221.391600
1,1,TS_5_4,thyroglobulin,3592.242408,324.003760,219.067449
2,2,TS_5_4,thyroglobulin,5627.594689,4866.047784,267.392371
3,3,TS_5_4,thyroglobulin,2055.706996,503.311447,409.365607
4,4,TS_5_4,thyroglobulin,2162.278628,3096.197053,435.081619
...,...,...,...,...,...,...
55,343,TS_99_9,thyroglobulin,4481.115315,2714.801127,1105.844673
56,344,TS_99_9,thyroglobulin,4366.031959,4131.991443,1104.551584
57,345,TS_99_9,thyroglobulin,3744.348943,4572.760194,1092.200641
58,346,TS_99_9,thyroglobulin,5193.082416,6274.865713,1083.910652


In [11]:
from helpers import *
valid_dir = '../input/czii-cryo-et-object-identification/train'
compute_lb(location_df, f'{valid_dir}/overlay/ExperimentRuns')

['TS_5_4', 'TS_69_2', 'TS_6_4', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9']
 TS_99_9 virus-like-particle


(         particle_type    P    T  hit  miss   fp  precision    recall  \
 0         apo-ferritin    0  375    0   375    0   0.000000  0.000000   
 1         beta-amylase    0   87    0    87    0   0.000000  0.000000   
 2   beta-galactosidase    0  112    0   112    0   0.000000  0.000000   
 3             ribosome    0  331    0   331    0   0.000000  0.000000   
 4        thyroglobulin  348  251  173    78  175   0.497126  0.689243   
 5  virus-like-particle    0  113    0   113    0   0.000000  0.000000   
 
     f-beta4  weight  
 0  0.000000       1  
 1  0.000000       0  
 2  0.000000       2  
 3  0.000000       1  
 4  0.673923       2  
 5  0.000000       1  ,
 0.19254943040460917)