In [1]:
import sys
import os

# Add particle_detection to sys.path
sys.path.append(os.path.abspath(".."))

In [2]:
import os
import numpy as np
import cv2

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch

from sklearn.mixture import GaussianMixture

from particle_detection.data.data_pipeline import create_dataloaders
from particle_detection.autoencoder.model import create_autoencoder
from particle_detection.utils.model_utils import load_model
from particle_detection.clustering.cluster_visualization import process_and_visualize_clusters, visualize_binary_clusters, compare_original_and_clusters
from particle_detection.utils.pca_preprocessing import apply_pca, plot_explained_variance

In [3]:
data_dir = "../data"
#data_dir = "/home/blah-buttery/nanoparticles/images/normal" # gpu ws 
#dataset_dir = "/Users/blah_m4/Desktop/nanoparticle/images" # macbook
image_size = (2048, 2048)
batch_size = 8

train_loader, test_loader = create_dataloaders(data_dir=data_dir, image_size=image_size, batch_size=batch_size)

In [5]:
for img in train_loader:
    print(img.shape)

torch.Size([8, 1, 2048, 2048])


In [7]:
import os
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm  # For progress tracking

# Set output directory
patches_dir = "../patches_dataset"
os.makedirs(patches_dir, exist_ok=True)

# Define patch size
patch_size = 128
patch_idx = 0

# Loop through train_loader
for batch_idx, image_batch in enumerate(tqdm(train_loader, desc="Processing Batches")):
    # Remove batch dimension
    image_batch = image_batch.squeeze(0)  # If batch size > 1, adjust accordingly

    # Convert to NumPy
    image_batch = image_batch.cpu().numpy()

    # Handle grayscale vs RGB images
    if image_batch.shape[1] == 3:  # RGB format (B, C, H, W)
        image_batch = np.transpose(image_batch, (0, 2, 3, 1))  # Convert (B, C, H, W) → (B, H, W, C)
        mode = "RGB"
    elif image_batch.shape[1] == 1:  # Grayscale format (B, 1, H, W)
        image_batch = image_batch.squeeze(1)  # Convert (B, 1, H, W) → (B, H, W)
        mode = "L"
    else:
        raise ValueError(f"Unexpected image shape: {image_batch.shape}")

    batch_size, h, w = image_batch.shape[:3]  # Get batch size, height, width

    # Extract patches
    for img_idx in range(batch_size):  # Loop through batch images
        img = image_batch[img_idx]  # Get single image

        for y in range(0, h, patch_size):
            for x in range(0, w, patch_size):
                patch = img[y:y+patch_size, x:x+patch_size]

                # Convert to uint8 (0-255 range)
                patch = (patch * 255).astype(np.uint8)

                # Save patch as .TIF
                patch_path = os.path.join(patches_dir, f"patch_{patch_idx:06d}.tif")
                Image.fromarray(patch, mode=mode).save(patch_path)

                patch_idx += 1

print(f"Successfully saved {patch_idx} patches in {patches_dir}")

Processing Batches: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.06s/it]

Successfully saved 2048 patches in ../patches_dataset



