In [76]:
import os
import shutil
import kagglehub

import numpy as np
import pandas as pd

from pathlib import Path
import matplotlib.pyplot as plt
import imageio.v2 as iio
from scipy import ndimage
from scipy.ndimage import zoom, rotate, shift

In [77]:
download_path = "./isbi-2012-chanllenge"

if not os.path.exists(download_path):
    path = kagglehub.dataset_download("hamzamohiuddin/isbi-2012-challenge")
    os.makedirs(download_path, exist_ok=True)
    dest = os.path.join(download_path, os.path.basename(path))
    shutil.move(path, dest)

img_root = download_path + "/1/unmodified-data/train/imgs"
lab_root = download_path + "/1/unmodified-data/train/labels"

img_paths = sorted([os.path.join(img_root, f) for f in os.listdir(img_root) if f.endswith(".png")])
lab_paths = sorted([os.path.join(lab_root, f) for f in os.listdir(lab_root) if f.endswith(".png")])

# # Create pairs of image and label paths
pair = list(zip(img_paths, lab_paths))

print(f"Found {len(pair)} image-label pairs")

np.random.seed(1234)
for _ in range(4):
    np.random.shuffle(pair)

split = int(30*.78)
train_paths = pair[:split]
val_paths  = pair[split:]

Found 30 image-label pairs


In [None]:
#https://www.kaggle.com/code/hamzamohiuddin/u-net-implementation-part-5#Training-(Optional)
dim_orig = 512
pad = 171
padded_shape = dim_orig + pad * 2
dim_out = 512

def load_img(path):
    image = iio.imread(path)
    
    if image.ndim == 2:
        image = image[:, :, np.newaxis]
        
    image = image.astype(np.float32)
    
    paddings = ((pad, pad), (pad, pad), (0, 0)) # (H, W, C)
    image = np.pad(image, paddings, mode='reflect')
    return image

def load_lab(path):
    # 逻辑与 load_img 相同
    image = iio.imread(path)
    if image.ndim == 2:
        image = image[:, :, np.newaxis]
    image = image.astype(np.float32)
    paddings = ((pad, pad), (pad, pad), (0, 0))
    image = np.pad(image, paddings, mode='reflect')
    return image

In [79]:
def augment(img, label):
    output = np.concatenate([img, label], axis=-1) # (H, W, 2)
    
    if np.random.rand() > 0.5:
        output = np.fliplr(output)
        
    angle = np.random.uniform(0, 45)
    output = rotate(output, angle, reshape=False, order=1, mode='constant', cval=0)

    h, w, _ = output.shape
    h_shift = np.random.uniform(-15, 15)
    w_shift = np.random.uniform(-15, 15)
    output = shift(output, [h_shift, w_shift, 0], order=1, mode='constant', cval=0)
    
    img_out, label_out = output[..., 0:1], output[..., 1:2]
    
    # 原始代码: .12
    brightness_factor = np.random.uniform(1.0 - 0.12, 1.0 + 0.12)
    img_out = img_out * brightness_factor
    
    img_out = np.clip(img_out, 0, 255)
    
    return img_out, label_out

In [80]:
def center_crop(img, crop_h, crop_w):
    h, w = img.shape[:2]
    start_h = (h - crop_h) // 2
    start_w = (w - crop_w) // 2
    return img[start_h : start_h + crop_h, start_w : start_w + crop_w, :]

def process(image, label):
    image = image.reshape((padded_shape, padded_shape, 1))
    label = label.reshape((padded_shape, padded_shape, 1))
    
    image = center_crop(image, dim_orig, dim_orig)
    label = center_crop(label, dim_orig, dim_orig)
    
    zoom_factor = dim_out / dim_orig
    label = zoom(label, (zoom_factor, zoom_factor, 1), order=1)
    
    mid = 255 // 2
    label = (label >= mid).astype(np.uint8) * 255
    
    image = image.astype(np.uint8)
    
    return image, label

In [81]:
def get_deformed(img,label):
    h,w,c = img.shape
    grid_size=5

    img, label = np.squeeze(img),np.squeeze(label)#,np.squeeze(weight)
    #Basically, we will stretch this grid out to image size to get displacement vector for each pixel.
    # This is done by first getting the coordinates
    
    # Following are the coordinates between
    xi = np.linspace(start=0,stop=grid_size-1,num=w)
    yi = np.linspace(start=0,stop=grid_size-1,num=h)
    xy_grid = np.meshgrid(yi,xi, indexing='ij') # returns 512 by 512 grid of coordinates with values between 0 and 2. These will be used to interpolate values between grid values
    
    grid_x = np.random.normal(0,10,size=(grid_size,grid_size)) # 3 by 3 grid
    grid_y = np.random.normal(0,10,size=(grid_size,grid_size)) # 3 by 3 grid
    
    dx = ndimage.map_coordinates(grid_x, xy_grid)
    dy = ndimage.map_coordinates(grid_y, xy_grid)
    
    output_coord_y, output_coord_x = np.meshgrid(np.arange(h),np.arange(w),indexing='ij') # output image coordinate
    output_coord_y, output_coord_x = output_coord_y+dy, output_coord_x+dx

    deformed_img    = ndimage.map_coordinates(img,[output_coord_y,output_coord_x])
    deformed_label  = ndimage.map_coordinates(label,[output_coord_y,output_coord_x], order=1)

    return deformed_img[...,None],deformed_label[...,None]

In [82]:
import time 


train_examples = 100
img_format = ".tif"
out_img = "data_isbi/train/images/"
out_labs = "data_isbi/train/labels/"

if os.path.exists(out_img):
    shutil.rmtree(out_img)
if os.path.exists(out_labs):
    shutil.rmtree(out_labs)
    
os.makedirs(out_img)
os.makedirs(out_labs)

count = 1

print("Writing original images...")
for img_path, lab_path in train_paths:
    img = load_img(img_path)
    lab = load_lab(lab_path)
    
    img_p, lab_p = process(img, lab)
    
    path_img = out_img + '{:0>5}'.format(count) + img_format
    path_lab = out_labs + '{:0>5}'.format(count) + img_format
    
    iio.imwrite(path_img, img_p.squeeze())
    iio.imwrite(path_lab, lab_p.squeeze())
    
    if count % 20 == 0:
        print(count)
    count += 1

print("Writing augmented images...")
while count <= train_examples:
    idx = np.random.randint(0, len(train_paths))
    img_path, lab_path = train_paths[idx]
    
    img = load_img(img_path)
    lab = load_lab(lab_path)
    
    img_aug, lab_aug = augment(img, lab)
    img_def, lab_def = get_deformed(img_aug, lab_aug)
    img_proc, lab_proc = process(img_def, lab_def)
    
    path_img = out_img + '{:0>5}'.format(count) + img_format
    path_lab = out_labs + '{:0>5}'.format(count) + img_format
    
    iio.imwrite(path_img, img_proc.squeeze())
    iio.imwrite(path_lab, lab_proc.squeeze())
    
    if count % 20 == 0:
        print(count)
        
    count += 1

print("Data generation complete.")

Writing original images...
20
Writing augmented images...
40
60
80
100
Data generation complete.
