In [None]:
deps_path = '/kaggle/input/czii-cryoet-dependencies'

In [None]:
! cp -r /kaggle/input/czii-cryoet-dependencies/asciitree-0.3.3/ asciitree-0.3.3/

In [None]:
! pip wheel asciitree-0.3.3/asciitree-0.3.3/

In [None]:
!pip install asciitree-0.3.3-py3-none-any.whl

In [None]:
! pip install -q --no-index --find-links {deps_path} --requirement {deps_path}/requirements.txt

In [None]:
# pip installがされたかの確認
!pip show monai

In [None]:
from typing import List, Tuple, Union
import numpy as np
import torch
from monai.data import DataLoader, Dataset, CacheDataset, decollate_batch
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    Orientationd,
    AsDiscrete,
    RandFlipd,
    RandRotate90d,
    NormalizeIntensityd,
    RandCropByLabelClassesd,
)

In [None]:
# 指定された次元を完全に覆うために最小限の重複でパッチの開始位置を計算する
def calculate_patch_starts(dimension_size: int, patch_size: int) -> List[int]:
    """
    Calculate the starting positions of patches along a single dimension
    with minimal overlap to cover the entire dimension.

    Parameters:
    -----------
    dimension_size : int
        Size of the dimension
    patch_size : int
        Size of the patch in this dimension

    Returns:
    --------
    List[int]
        List of starting positions for patches
    """
    if dimension_size <= patch_size:
        return [0]

    # Calculate number of patches needed
    n_patches = np.ceil(dimension_size / patch_size)

    # ここのコードはいらない気もするが一用残しておく
    if n_patches == 1:
        return [0]

    # Calculate overlap
    total_overlap = (n_patches * patch_size - dimension_size) / (n_patches - 1)

    # Generate starting positions
    positions = []
    for i in range(int(n_patches)):
        pos = int(i * (patch_size - total_overlap))
        if pos + patch_size > dimension_size:
            pos = dimension_size - patch_size
        if pos not in positions:  # Avoid duplicates
            positions.append(pos)

    return positions

def extract_3d_patches_minimal_overlap(arrays: List[np.ndarray], patch_size: int) -> Tuple[List[np.ndarray], List[Tuple[int, int, int]]]:
    """
    Extract 3D patches from multiple arrays with minimal overlap to cover the entire array.
    複数の3D配列から最小限の重複を持つバッチを抽出し、全体をカバーします
    
    Parameters:
    -----------
    arrays : List[np.ndarray]
        List of input arrays, each with shape (m, n, l)
        抽出する立方体パッチのサイズ(a x a)
    patch_size : int
        Size of cubic patches (a x a x a)
        抽出する立方体パッチサイズ
        
    Returns:
    --------
    patches : List[np.ndarray]
        List of all patches from all input arrays
        全ての入力配列からちゅしゅつされたパッチのリスト
    coordinates : List[Tuple[int, int, int]]
        List of starting coordinates (x, y, z) for each patch
        各パッチの開始位置
    """
    # 入力が非空のリストであることを確認
    if not arrays or not isinstance(arrays, list):
        raise ValueError("Input must be a non-empty list of arrays")

    # 全ての配列が同じ形状を持つ配列であることを確認
    # Verify all arrays have the same shape
    shape = arrays[0].shape
    if not all(arr.shape == shape for arr in arrays):
        raise ValueError("All input arrays must have the same shape")

    # パッチサイズが各次元の最小サイズより小さいことを確認
    if patch_size > min(shape):
        raise ValueError(f"patch_size ({patch_size}) must be smaller than smallest dimension {min(shape)}")
    
    m, n, l = shape
    patches = [] # 抽出されたパッチを格納するリスト
    coordinates = [] # 各パッチの開始座標を格納するリスト
    
    # Calculate starting positions for each dimension
    # 各次元に対するパッチの開始位置を計算
    x_starts = calculate_patch_starts(m, patch_size)
    y_starts = calculate_patch_starts(n, patch_size)
    z_starts = calculate_patch_starts(l, patch_size)
    
    # Extract patches from each array
    # 各配列からパッチを抽出
    for arr in arrays:
        for x in x_starts:
            for y in y_starts:
                for z in z_starts:
                    # 配列からパッチを切り出し
                    patch = arr[
                        x:x + patch_size,
                        y:y + patch_size,
                        z:z + patch_size
                    ]
                    patches.append(patch)
                    coordinates.append((x, y, z))
    
    return patches, coordinates # パッチのリストと座標のリストを返す


# 分割されたパッチとその開始座標から元の3D配列を再構築する
def reconstruct_array(patches: List[np.ndarray], 
                     coordinates: List[Tuple[int, int, int]], 
                     original_shape: Tuple[int, int, int]) -> np.ndarray:
    """
    Reconstruct array from patches.
    
    Parameters:
    -----------
    patches : List[np.ndarray]
        List of patches to reconstruct from
    coordinates : List[Tuple[int, int, int]]
        Starting coordinates for each patch
    original_shape : Tuple[int, int, int]
        Shape of the original array
        
    Returns:
    --------
    np.ndarray
        Reconstructed array
    """
    # 原始配列を再構築するためのゼロ配列を作成
    reconstructed = np.zeros(original_shape, dtype=np.int64)  # To track overlapping regions

    # パッチのサイズを取得(立方体パッチとして最初の次元のみを使用)
    patch_size = patches[0].shape[0]

    # 各パッチとその開始座標を順に処理
    for patch, (x, y, z) in zip(patches, coordinates):
        # 再構築配列の対応する位置にパッチを配置
        reconstructed[
            x:x + patch_size,
            y:y + patch_size,
            z:z + patch_size
        ] = patch # パッチの値で上書き

    # 再構築された配列を返す
    return reconstructed




In [None]:
import pandas as pd
import numpy as np

# 辞書をデータフレームに変換
def dict_to_df(coords_dict, experiment_name):
    # Create lists to store data
    all_coords = []
    all_labels = []

    for label, coords in coords_dict.items():
        all_coords.append(coords)
        all_labels.extend([label] * len(coords))

    # Concatenate all coordinates
    # すべての座標を連結
    # .vstack()は、配列を垂直方向に連結する
    all_coords = np.vstack(all_coords)

    df = pd.DataFrame({
        'experiment' : experiment_name,
        'particle_type' : all_labels,
        'x' : all_coords[:, 0],
        'y' : all_coords[:, 1],
        'z' : all_coords[:, 2]
    })

    return df

In [None]:
TRAIN_DATA_DIR = "/kaggle/input/create-numpy-dataset-exp-name"
TEST_DATA_DIR = "/kaggle/input/czii-cryo-et-object-identification"

In [None]:
train_names = ['TS_5_4', 'TS_69_2', 'TS_6_6', 'TS_73_6', 'TS_86_3', 'TS_99_9']
valid_names = ['TS_6_4']

train_files = []
valid_files = []

for name in train_names:
    # 画像データとラベルデータを読み込む
    image = np.load(f"{TRAIN_DATA_DIR}/train_image_{name}.npy")
    label = np.load(f"{TRAIN_DATA_DIR}/train_label_{name}.npy")

    train_files.append({"image": image, "label": label})

for name in valid_names:
    image = np.load(f"{TRAIN_DATA_DIR}/train_image_{name}.npy")
    label = np.load(f"{TRAIN_DATA_DIR}/train_label_{name}.npy")

    valid_files.append({"image": image, "label": label})

In [None]:
# Non-random transforms to be cached

# トランスフォームの定義
non_random_transforms = Compose([
    # チャンネル次元を先頭に配置。画像とラベルのデータに適用
    EnsureChannelFirstd(keys=["image", "label"], channel_dim="no_channel"),
    # 画像データの強度値を正規化(標準化)する
    NormalizeIntensityd(keys="image"),
    # 画像とラベルのオリエンテーションを"RAS"(右、前、上)に統一する
    Orientationd(keys=["image", "label"], axcodes="RAS")
])

# データの前処理結果をキャッシュする
raw_train_ds = CacheDataset(data=train_files, transform=non_random_transforms, cache_rate=1.0)

my_num_samples = 16
train_batch_size = 1

# Random transforms to be applied during training
# トレーニング中に適用されるランダムなトランスフォームの定義
random_transforms = Compose([
    # ラベルのクラスごとにランダムに切り取るトランスフォーム
    RandCropByLabelClassesd(
        keys=["image", "label"],
        label_key="label",
        spatial_size=[98, 98, 98], # 切り取り後の空間サイズ(深さ、高さ、幅)
        num_samples=my_num_samples # 生成するサンプル数
    ),
    # 画像およびラベルを90度単位でランダムに回転させるトランスフォーム
    RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
    # 画像およびラベルを指定した軸に沿ってランダムに反転させるトランスフォーム
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
])

train_ds = Dataset(data=raw_train_ds, transform=random_transforms)

# DataLoader remains the same
train_loader = DataLoader(
    train_ds, # トレーニングデータセット
    batch_size=train_batch_size, # バッチサイズ
    shuffle=True, # データをシャッフル
    num_workers=4, # 使用するワーカーの数
    pin_memory=torch.cuda.is_available() # CPUが利用可能な場合、ピンメモリを使用
)

# データローダーの確認
print(f"Number of workers: {train_loader.num_workers}")
print(f"Pin memory: {train_loader.pin_memory}")
print(f"Number of samples in raw_train_ds: {len(raw_train_ds)}")
print(f"Number of samples in train_ds: {len(train_ds)}")