In [11]:
import os
import skimage.io
import cv2
from matplotlib import pyplot as plt
import numpy as np
import skimage.filters
import skimage.morphology

# Paths
image_dir = r'C:\Users\k54739\Today_data\segmentation\test_img'  # Input image directory
mask_dir = r'C:\Users\k54739\Today_data\segmentation\test_mask'  # Output mask directory

os.makedirs(mask_dir, exist_ok=True)  # Ensure mask directory exists

def process_image(image_path):
    # Load the image
    image = skimage.io.imread(image_path)

    # Convert to grayscale if necessary
    if len(image.shape) == 3:
        img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        img_gray = image

    # Convert to 8-bit format
    img_gray_8bit = cv2.normalize(img_gray, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')

    # Blur using bilateral filter
    blur = cv2.bilateralFilter(img_gray_8bit, 50, 40, 40)

    # Apply Otsu thresholding
    t = skimage.filters.threshold_otsu(blur)
    binary_mask = blur > t

    # Remove small objects
    filtered_mask = skimage.morphology.remove_small_objects(binary_mask, 10000)

    return img_gray_8bit, filtered_mask

def save_mask_as_16bit_tiff(mask, save_path):
    """
    Save the binary mask as a 16-bit TIFF file with background as black (0)
    and segmented object as white (65535).
    """
    # Invert mask: Background (False) -> 0, Object (True) -> 65535
    mask_16bit = mask.astype(np.uint16) * 65535
    skimage.io.imsave(save_path, mask_16bit, check_contrast=False)  # Save as 16-bit TIFF

# Process all images
for image_name in os.listdir(image_dir):
    image_path = os.path.join(image_dir, image_name)
    
    # Skip non-image files
    if not image_name.lower().endswith(('.tiff', '.tif')):
        continue

    # Process image
    img_gray_8bit, mask = process_image(image_path)

    # Save the mask
    mask_path = os.path.join(mask_dir, os.path.splitext(image_name)[0] + '_mask.tiff')  # Save as .tiff
    save_mask_as_16bit_tiff(mask, mask_path)


In [12]:
import numpy as np
import skimage.io

# Load the TIFF mask image
mask_path = r'C:\Users\k54739\Today_data\segmentation\test_mask\cond10_ds_41_E10-T01_mask.tiff'  # Replace with your mask path
mask = skimage.io.imread(mask_path)

# Get unique pixel values and their counts
unique_values, counts = np.unique(mask, return_counts=True)

# Print unique pixel values and their count
print("Unique Pixel Values:", unique_values)
print("Number of Unique Pixel Values:", len(unique_values))
print("Counts of Each Pixel Value:", counts)


Unique Pixel Values: [    0 65535]
Number of Unique Pixel Values: 2
Counts of Each Pixel Value: [1416998 2801918]


In [2]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

Using cache found in C:\Users\k54739/.cache\torch\hub\mateuszbuda_brain-segmentation-pytorch_master


In [3]:
print(model)

UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc1norm2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (enc2conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (enc2norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
# Define dataset
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_names = [f for f in os.listdir(image_dir) if f.endswith(('.tiff', '.tif'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = os.path.join(self.image_dir, image_name)
        mask_path = os.path.join(self.mask_dir, os.path.splitext(image_name)[0] + '_mask.tiff')  # Match naming

        # Load image and mask
        image = skimage.io.imread(image_path).astype(np.float32)
        mask = skimage.io.imread(mask_path).astype(np.float32)

        # Normalize image and mask
        image = image / 65535.0  # Normalize to [0, 1]
        mask = mask / 65535.0    # Normalize mask to [0, 1]

        # Add channel dimension for grayscale images
        if len(image.shape) == 2:
            image = np.expand_dims(image, axis=0)
        else:
            image = np.transpose(image, (2, 0, 1))  # For 3-channel images

        mask = np.expand_dims(mask, axis=0)  # Add channel dimension to mask

        # Apply transformations if any
        if self.transform:
            image = self.transform(torch.tensor(image))
            mask = torch.tensor(mask)

        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)
    
    # Define transformations (if needed)
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize values
])

In [None]:
# Paths
image_dir = r'C:\path\to\images'
mask_dir = r'C:\path\to\masks'

# Split dataset into train and validation
all_images = os.listdir(image_dir)
train_images, val_images = train_test_split(all_images, test_size=0.2, random_state=42)

# Train and validation datasets
train_dataset = SegmentationDataset(image_dir, mask_dir, transform=transform)
val_dataset = SegmentationDataset(image_dir, mask_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)


In [None]:

# Load pretrained U-Net model
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=3, out_channels=1, init_features=32, pretrained=True)

# Define loss and optimizer
criterion = nn.BCEWithLogitsLoss()  # Use Binary Cross-Entropy Loss with logits
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)



In [None]:
# Training loop
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation loop
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")


In [None]:

# Save model
torch.save(model.state_dict(), 'unet_segmentation_model.pth')

# Visualize results
model.eval()
images, masks = next(iter(val_loader))
images, masks = images.to(device), masks.to(device)
outputs = model(images)
predicted_masks = (torch.sigmoid(outputs) > 0.5).float()

for i in range(len(images)):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title("Input Image")
    plt.imshow(images[i].cpu().permute(1, 2, 0).numpy(), cmap='gray')
    plt.subplot(1, 3, 2)
    plt.title("Ground Truth Mask")
    plt.imshow(masks[i].cpu().squeeze(0).numpy(), cmap='gray')
    plt.subplot(1, 3, 3)
    plt.title("Predicted Mask")
    plt.imshow(predicted_masks[i].cpu().squeeze(0).numpy(), cmap='gray')
    plt.show()
