In [1]:

import os

import numpy as np

from make_dataset import DatasetDigitalStaining

dir = r"D:\Matsusaka\data_mito\HeLa_Su9-mSG"
train_folders = ["2", "3", "4"]
val_folders = ["1"]
test_folders = ["5"]
img_folders = [os.path.join(dir, f) for f in train_folders]
train_datasets = [DatasetDigitalStaining(img_folders[i], augmentation=None) for i in range(len(train_folders))]
val_datasets = [DatasetDigitalStaining(img_folders[i], augmentation=None) for i in range(len(val_folders))]

for ph1,ph2, mito in train_datasets[0]:
    print(ph1.shape, ph2.shape, mito.shape)
    break

torch.Size([1, 1024, 1224]) torch.Size([1, 1024, 1224]) torch.Size([1, 1024, 1224])


In [2]:
for ph1,ph2, mito in train_datasets[0]:
    print(ph1.shape, ph2.shape, mito.shape)
    
    print(ph1.max(), ph2.max(), mito.max())
    break

torch.Size([1, 1024, 1224]) torch.Size([1, 1024, 1224]) torch.Size([1, 1024, 1224])
tensor(0.0119) tensor(0.3455) tensor(19787.)


In [3]:
import os
import json
import torch
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image # Import PIL

def scale_to_uint8(img_np, p_min, p_max):
    # 1. Clip the image to the global min/max range
    img_clipped = np.clip(img_np, p_min, p_max)
    
    # 2. Normalize to [0, 1]
    img_normalized = (img_clipped - p_min) / (p_max - p_min)
    
    # 3. Scale to [0, 255] and convert to uint8
    img_uint8 = (img_normalized * 255).astype(np.uint8)
    
    return img_uint8

def process_dataset(datasets_list, split_name, num_patches_per_image, patch_size, 
                    base_dir, prompt, ph1_stats, mito_stats):
    """
    Iterates over a list of datasets, extracts random patches,
    scales them to uint8 using GLOBAL stats, and saves as PNG.
    
    Args:
        ...
        ph1_stats (tuple): (global_p01, global_p99_9) for ph1
        mito_stats (tuple): (global_p01, global_p99_9) for mito
    """
    
    output_path = os.path.join(base_dir, split_name)
    os.makedirs(output_path, exist_ok=True)
    
    metadata = []
    patch_counter = 0
    
    ph1_p_min, ph1_p_max = ph1_stats
    mito_p_min, mito_p_max = mito_stats
    
    print(f"Processing {split_name} split...")
    
    for dataset in datasets_list:
        for ph1, ph2, mito in tqdm(dataset):
            
            if ph1.dim() != 3 or mito.dim() != 3:
                print(f"Skipping image: unexpected dimensions. Got {ph1.shape} and {mito.shape}")
                continue
                
            _, h, w = ph1.shape
            
            if ph1.shape[1:] != mito.shape[1:]:
                print(f"Skipping image: ph1 and mito shape mismatch. {ph1.shape[1:]} vs {mito.shape[1:]}")
                continue
                
            if h < patch_size or w < patch_size:
                print(f"Skipping image: image is smaller than patch size. {h}x{w} vs {patch_size}")
                continue

            for _ in range(num_patches_per_image):
                y = torch.randint(0, h - patch_size + 1, (1,)).item()
                x = torch.randint(0, w - patch_size + 1, (1,)).item()
                
                ph1_patch = ph1[:, y : y + patch_size, x : x + patch_size]
                mito_patch = mito[:, y : y + patch_size, x : x + patch_size]
                
                ph1_np = ph1_patch.squeeze(0).numpy()
                mito_np = mito_patch.squeeze(0).numpy()
                
                # --- Apply global scaling ---
                ph1_uint8 = scale_to_uint8(ph1_np, ph1_p_min, ph1_p_max)
                mito_uint8 = scale_to_uint8(mito_np, mito_p_min, mito_p_max)
                
                # --- Convert to PIL Image ---
                # 'L' mode is for 8-bit grayscale
                ph1_pil = Image.fromarray(ph1_uint8, mode='L')
                mito_pil = Image.fromarray(mito_uint8, mode='L')

                # 4. Define filenames with .png extension
                base_filename = f"{patch_counter:06d}"
                condition_filename = f"{base_filename}_condition.png"
                target_filename = f"{base_filename}_target.png"
                
                condition_path = os.path.join(output_path, condition_filename)
                target_path = os.path.join(output_path, target_filename)
                
                # 5. Save the images as PNGs
                ph1_pil.save(condition_path)
                mito_pil.save(target_path)
                
                # 6. Add entry to metadata
                metadata.append({
                    "file_name": target_filename,
                    "conditioning_image": condition_filename,
                    "text": prompt
                })
                
                patch_counter += 1

    # --- Save Metadata File ---
    metadata_path = os.path.join(output_path, "metadata.jsonl")
    with open(metadata_path, 'w') as f:
        for entry in metadata:
            f.write(json.dumps(entry) + "\n")
            
    print(f"\nFinished processing {split_name} split.")
    print(f"Saved {len(metadata)} patches to: {output_path}")
    print(f"Saved metadata file to: {metadata_path}\n")

    # --- Save Metadata File ---
    metadata_path = os.path.join(output_path, "metadata.jsonl")
    with open(metadata_path, 'w') as f:
        for entry in metadata:
            f.write(json.dumps(entry) + "\n")
            
    print(f"\nFinished processing {split_name} split.")
    print(f"Saved {len(metadata)} patches to: {output_path}")
    print(f"Saved metadata file to: {metadata_path}\n")

In [4]:

# --- Parameters to configure ---
BASE_OUTPUT_DIR = "..//controlnet_dataset_uint8" # New output dir
PATCH_SIZE = 512
NUM_PATCHES_PER_IMAGE = 1
PROMPT = "fluorescence microscopy image of mitochondria"

# Percentiles to use for normalization
# (0.1, 99.9) is a good start, ignores 0.2% of outliers
P_LOW = 0.1 
P_HIGH = 99.9

BASE_DATA_DIR = dir

# --- Your Dataset Initialization Code ---
train_folders = ["2", "3", "4"]
train_folders = ["1"]
val_folders = ["1"]
test_folders = ["5"]

train_img_folders = [os.path.join(BASE_DATA_DIR, f) for f in train_folders]
val_img_folders = [os.path.join(BASE_DATA_DIR, f) for f in val_folders]

print("Loading datasets...")
train_datasets = [DatasetDigitalStaining(folder, augmentation=None) for folder in train_img_folders]
val_datasets = [DatasetDigitalStaining(folder, augmentation=None) for folder in val_img_folders]
print("Dataset loading complete.")

# --- NEW: Calculate Global Statistics ---
print("\nCalculating global statistics from training set (this may take a moment)...")
ph1_all_pixels = []
mito_all_pixels = []

for dataset in train_datasets:
    for ph1, ph2, mito in tqdm(dataset):
        ph1_all_pixels.append(ph1.numpy().ravel())
        mito_all_pixels.append(mito.numpy().ravel())

# Concatenate all pixel values
ph1_global_dist = np.concatenate(ph1_all_pixels)
mito_global_dist = np.concatenate(mito_all_pixels)

# Calculate percentiles
ph1_stats = np.percentile(ph1_global_dist, [P_LOW, P_HIGH])
mito_stats = np.percentile(mito_global_dist, [P_LOW, P_HIGH])

print(f"Global ph1 stats ({P_LOW}%, {P_HIGH}%): {ph1_stats[0]:.4f}, {ph1_stats[1]:.4f}")
print(f"Global mito stats ({P_LOW}%, {P_HIGH}%): {mito_stats[0]:.2f}, {mito_stats[1]:.2f}\n")
# --- End Statistics Calculation ---


# --- Run the processing ---
# Process the training datasets
process_dataset(
    datasets_list=train_datasets,
    split_name="train",
    num_patches_per_image=NUM_PATCHES_PER_IMAGE,
    patch_size=PATCH_SIZE,
    base_dir=BASE_OUTPUT_DIR,
    prompt=PROMPT,
    ph1_stats=ph1_stats,   # Pass global stats
    mito_stats=mito_stats  # Pass global stats
)

# Process the validation datasets
# IMPORTANT: Use the *same stats from the training set*
process_dataset(
    datasets_list=val_datasets,
    split_name="val",
    num_patches_per_image=NUM_PATCHES_PER_IMAGE,
    patch_size=PATCH_SIZE,
    base_dir=BASE_OUTPUT_DIR,
    prompt=PROMPT,
    ph1_stats=ph1_stats,   # Pass global stats from train
    mito_stats=mito_stats  # Pass global stats from train
)

print("All processing complete.")
print(f"Your ControlNet uint8 dataset is ready in: {BASE_OUTPUT_DIR}")


Loading datasets...
Dataset loading complete.

Calculating global statistics from training set (this may take a moment)...


100%|██████████| 31/31 [00:00<00:00, 139.81it/s]


Global ph1 stats (0.1%, 99.9%): -0.0079, 0.0092
Global mito stats (0.1%, 99.9%): 0.00, 25959.00

Processing train split...


100%|██████████| 31/31 [00:02<00:00, 14.24it/s]



Finished processing train split.
Saved 31 patches to: ..//controlnet_dataset_uint8\train
Saved metadata file to: ..//controlnet_dataset_uint8\train\metadata.jsonl


Finished processing train split.
Saved 31 patches to: ..//controlnet_dataset_uint8\train
Saved metadata file to: ..//controlnet_dataset_uint8\train\metadata.jsonl

Processing val split...


100%|██████████| 31/31 [00:02<00:00, 14.58it/s]


Finished processing val split.
Saved 31 patches to: ..//controlnet_dataset_uint8\val
Saved metadata file to: ..//controlnet_dataset_uint8\val\metadata.jsonl


Finished processing val split.
Saved 31 patches to: ..//controlnet_dataset_uint8\val
Saved metadata file to: ..//controlnet_dataset_uint8\val\metadata.jsonl

All processing complete.
Your ControlNet uint8 dataset is ready in: ..//controlnet_dataset_uint8





In [5]:
# accelerate launch train_controlnet.py --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" --output_dir="D://Matsusaka//sd_outputs//mito_cn" --train_data_dir="C://Users//Matsusaka//PycharmProjects//DDPM//controlnet_dataset_uint8//train" --resolution=512 --learning_rate=1e-5 --train_batch_size=4 --num_train_epochs=100 --report_to="wandb"

In [6]:
# code from https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md for out data
import json
import cv2
import numpy as np

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, folder):
        self.data = []
        self.folder = folder
        fname = os.path.join(folder, "metadata.jsonl")
        with open(fname, 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        image_f = item['file_name']
        condition_f = item['conditioning_image']
        prompt = item['text']

        image = cv2.imread(os.path.join(self.folder, image_f))
        condition_img = cv2.imread(os.path.join(self.folder, condition_f))

        # Do not forget that OpenCV read images in BGR order.
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        condition_img = cv2.cvtColor(condition_img, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        condition_img = condition_img.astype(np.float16) / 255.0

        # Normalize target images to [-1, 1].
        image = (image.astype(np.float16) / 127.5) - 1.0

        return dict(jpg=image, txt=prompt, hint=condition_img)

In [7]:
folder = "C://Users//Matsusaka//PycharmProjects//DDPM//controlnet_dataset_uint8//train"
dataset = MyDataset(folder)
print(len(dataset))

item = dataset[0]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)

31
fluorescence microscopy image of mitochondria
(512, 512, 3)
(512, 512, 3)


In [10]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


model_f = "D://Matsusaka//sd_outputs//control_v11f1p_sd15_depth.yaml"
# Configs
resume_path = "D://Matsusaka//sd_outputs//control_v11f1p_sd15_depth.pth"
batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False

# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model(model_f).cpu()
# model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control

# Misc
dataset = MyDataset(folder)
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(precision=16, callbacks=[])


# Train!
trainer.fit(model, dataloader)

ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded model config from [D://Matsusaka//sd_outputs//control_v11f1p_sd15_depth.yaml]



  | Name              | Type               | Params | Mode 
-----------------------------------------------------------------
0 | model             | DiffusionWrapper   | 859 M  | train
1 | first_stage_model | AutoencoderKL      | 83.7 M | eval 
2 | cond_stage_model  | FrozenCLIPEmbedder | 123 M  | eval 
3 | control_model     | ControlNet         | 361 M  | train
-----------------------------------------------------------------
1.2 B     Trainable params
206 M     Non-trainable params
1.4 B     Total params
5,710.058 Total estimated model params size (MB)
1266      Modules in train mode
365       Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1