<a href="https://colab.research.google.com/github/link1697/crack_segmentation/blob/main/%E8%B4%A5%E5%8C%97_pytorch_2stage_resize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50
from keras import metrics
import random

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from sklearn.model_selection import train_test_split
import cv2
import PIL
from PIL import Image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tqdm import tqdm
from keras import backend as K
import torch.nn.functional as F

In [16]:
# from google.colab import drive
# drive.mount('/content/drive')

In [17]:
import os
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

In [18]:
def ensure_uint8(image):
    if image.dtype != np.uint8:
        image = (image * 255).clip(0, 255).astype(np.uint8)
    return image

def apply_gaussian_blur(image, kernel_size=(5, 5)):
    # Assuming the input image is a float between 0 and 1
    image = ensure_uint8(image)  # Convert to uint8
    blurred = cv2.GaussianBlur(image, kernel_size, 0)
    return blurred / 255.0  # Re-normalize to float between 0 and 1 if necessary

def apply_sobel_filter(image, ksize=3):
    # Convert to grayscale
    image = ensure_uint8(image)  # Convert to uint8
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=ksize)
    sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=ksize)
    sobel = cv2.addWeighted(np.absolute(sobelx), 0.5, np.absolute(sobely), 0.5, 0)
    sobel = np.clip(sobel, 0, 255).astype(np.uint8)  # Ensure the result is uint8
    return cv2.cvtColor(sobel, cv2.COLOR_GRAY2BGR) / 255.0  # Convert back to BGR and normalize

def apply_gaussian_noise(image, noise_level=5):
    # Assuming the input image is a float between 0 and 1
    image = ensure_uint8(image)  # Convert to uint8
    gauss_noise = np.random.normal(0, noise_level, image.shape).astype(np.uint8)
    noisy_image = cv2.add(image, gauss_noise)
    return noisy_image / 255.0  # Re-normalize to float between 0 and 1 if necessary

def color_distort_smooth(image, hue_shift=87, saturation_scale=1.3, value_scale=1.7):
    # Assuming the input image is a float between 0 and 1
    image = ensure_uint8(image)  # Convert to uint8
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    h = (h + hue_shift) % 180
    s = np.clip(s * saturation_scale, 0, 255).astype(np.uint8)
    v = np.clip(v * value_scale, 0, 255).astype(np.uint8)
    final_hsv = cv2.merge((h, s, v))
    color_distorted_image = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
    return color_distorted_image / 255.0  # Re-normalize if necessary

augmentations = [apply_gaussian_blur, apply_sobel_filter, apply_gaussian_noise, color_distort_smooth]

Pre-traning dataset

In [19]:
# Custom Dataset for Crack and Background Pair Loading
class CrackBackgroundDataset(Dataset):
    def __init__(self, crack_dir, background_dir, img_size=(112, 112)):
        self.crack_dir = crack_dir
        self.background_dir = background_dir
        self.img_size = img_size

        # List all files available in the crack directory (assuming they have a matching filename in the background directory)
        # self.filenames = [f for f in os.listdir(crack_dir) if os.path.isfile(os.path.join(background_dir, f))]
        self.filenames = [f for f in os.listdir(crack_dir)]
         
    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        try:
        # Get the filename for this index
            filename = self.filenames[idx]
            # print(filename)
            # Read crack image and background image
            crack_img_path = os.path.join(self.crack_dir, filename)
            background_img_path = os.path.join(self.background_dir, filename)

            crack_img = cv2.imread(crack_img_path)
            # background_img = cv2.imread(background_img_path)

            # Resize to the desired size
            crack_img = cv2.resize(crack_img, self.img_size) / 255.0
            # background_img = cv2.resize(background_img, self.img_size) / 255.0

            selected_augmentations = random.sample(augmentations, 2)
            crack_left = selected_augmentations[0](crack_img)
            crack_right = selected_augmentations[1](crack_img)

            # Convert to PyTorch tensors and reorder dimensions (channels, height, width)
            crack_tensor = torch.tensor(crack_left.transpose(2, 0, 1), dtype=torch.float32)
            background_tensor = torch.tensor(crack_right.transpose(2, 0, 1), dtype=torch.float32)

            return crack_tensor, background_tensor
        except Exception as e:
            print(f"Error processing file {filename}: {str(e)}")
            raise

# Set up the device for training and testing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# Define paths and initialize the dataset
# crack_dir = "/content/drive/MyDrive/CEE 598 DL/archive/cropped/CFD_crack_region/"
# background_dir = "/content/drive/MyDrive/CEE 598 DL/archive/cropped/CFD_background_region/"
crack_dir = './archive/cropped/CFD_crack_region/'
background_dir = './archive/cropped/CFD_background_region/'

img_size = (112, 112)

dataset = CrackBackgroundDataset(crack_dir, background_dir, img_size=img_size)

# Create the DataLoader
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

for crack_batch, background_batch in train_loader:
    # Move batches to the specified device
    crack_batch = crack_batch.to(device)
    background_batch = background_batch.to(device)

    print("Crack Batch Shape (on GPU):", crack_batch.shape)
    print("Background Batch Shape (on GPU):", background_batch.shape)
    break  # Stop after the first batch to inspect

Using device: cpu
Crack Batch Shape (on GPU): torch.Size([32, 3, 112, 112])
Background Batch Shape (on GPU): torch.Size([32, 3, 112, 112])


U-net + MLP

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# U-Net Encoder Blocks
def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

# U-Net Encoder Class
class UNetEncoder(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super(UNetEncoder, self).__init__()
        self.enc1 = conv_block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = conv_block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = conv_block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.enc4 = conv_block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2, 2)

        # Bottleneck layer
        self.bottleneck = nn.Conv2d(features * 8, features * 16, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.enc1(x)
        x = self.pool1(x)
        x = self.enc2(x)
        x = self.pool2(x)
        x = self.enc3(x)
        x = self.pool3(x)
        x = self.enc4(x)
        x = self.pool4(x)
        x = self.bottleneck(x)

        # Flatten output
        x = x.view(x.size(0), -1)
        return x

# Projection Head Class
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=64):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Integrate U-Net Encoder and Projection Head
dummy_input = torch.randn(1, 3, 256, 256).to(device)
encoder = UNetEncoder(in_channels=3, features=32).to(device)
encoder_output = encoder(dummy_input)

# Flattened encoder output size
flattened_output = encoder_output.shape[1]

# Create the projection head with the correct input dimension
projection_head = ProjectionHead(input_dim=flattened_output).to(device)

# Test the forward pass with a dummy input
projected_output = projection_head(encoder_output)
print("Projected Output Shape:", projected_output.shape)


Using device: cpu
Projected Output Shape: torch.Size([1, 64])


train encoder+MLP

In [21]:
# Initialize the encoder and projection head
encoder = UNetEncoder(in_channels=3, features=32)
encoder_output = encoder(torch.randn(1, 3, 112, 112))
flattened_output = encoder_output.shape[1]
projection_head = ProjectionHead(input_dim=flattened_output)

# Set up the device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
projection_head.to(device)

# Initialize optimizer
optimizer = optim.Adam(list(encoder.parameters()) + list(projection_head.parameters()), lr=1e-3)

# Define a contrastive loss class and ensure the loss calculation is on the appropriate device
def contrastive_loss(pos_crack, pos_bg, all_features, temperature=0.5):
    # all_features includes all crack and background features concatenated
    # Calculate positive similarity
    pos_sim = F.cosine_similarity(pos_crack, pos_bg, dim=-1).unsqueeze(1)

    # Calculate negative similarities
    # Exclude the current positive samples from the negative set
    neg_samples = torch.cat([all_features[:pos_crack.shape[0]-1], all_features[pos_crack.shape[0]+1:]])

    neg_sim = F.cosine_similarity(pos_crack, neg_samples, dim=-1)
    neg_sim_bg = F.cosine_similarity(pos_bg, neg_samples, dim=-1)

    # Combine negative similarities from both perspectives
    neg_sim_combined = torch.cat([neg_sim, neg_sim_bg], dim=0)

    # Exponential terms for the softmax contrastive formula
    exp_pos_sim = torch.exp(pos_sim / temperature)
    sum_exp_neg_sim = torch.sum(torch.exp(neg_sim_combined / temperature), dim=0, keepdim=True)

    # Contrastive loss calculation
    loss = -torch.log(exp_pos_sim / (sum_exp_neg_sim + 1e-9))
    return loss.mean()


num_epochs = 100

for epoch in range(num_epochs):
    running_loss = 0.0
    for crack_batch, background_batch in train_loader:
        crack_batch, background_batch = crack_batch.to(device), background_batch.to(device)

        # Get features for both crack and background images
        crack_features = projection_head(encoder(crack_batch))
        background_features = projection_head(encoder(background_batch))

        # Concatenate all features for negative sampling
        all_features = torch.cat([crack_features, background_features], dim=0)

        batch_loss = 0.0
        for i in range(crack_batch.size(0)):
            pos_crack = crack_features[i].unsqueeze(0)
            pos_bg = background_features[i].unsqueeze(0)

            # Calculate contrastive loss for each pair with all other as negatives
            loss = contrastive_loss(pos_crack, pos_bg, all_features)
            batch_loss += loss.item()

        batch_loss /= crack_batch.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += batch_loss

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss / len(train_loader):.4f}")


Epoch 1/100 - Loss: 4.9877
Epoch 2/100 - Loss: 4.7123
Epoch 3/100 - Loss: 4.7119


KeyboardInterrupt: 

save encoder (for reusing weights and bias in fine-tuning)

In [None]:
# After training with contrastive loss
torch.save(encoder.state_dict(), "./stage1_unet_encoder.pth")

Fine-tuning dataset

In [None]:
# Define a Dataset Class for Crack Images and Masks
class CrackSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        self.image_dir = image_dir
        self.mask_dir = mask_dir

        # List all image files in the directory (assuming each has a corresponding mask)
        # self.image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
        self.image_files = [f for f in os.listdir(image_dir)]
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get the image file name and corresponding mask file path
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)
        mask_path = os.path.join(self.mask_dir, image_file)  # Assuming masks have the same name

        # Load the image and mask
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # Load mask as grayscale

        # Normalize both to the range [0, 1] without resizing
        image = image / 255.0  # Normalize image
        mask = mask / 255.0  # Normalize mask

        # Convert to PyTorch tensors and change shape to [C, H, W]
        image_tensor = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)
        mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # Add a channel dimension

        return image_tensor, mask_tensor


# Ensure we use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Example Usage
# Replace these paths with the appropriate directories containing your crack images and masks
image_dir = './archive/crack_segmentation_dataset/test/images/'
mask_dir = './archive/crack_segmentation_dataset/test/masks/'
dataset = CrackSegmentationDataset(image_dir, mask_dir)

# Check the total size of the dataset
total_size = len(dataset)
train_size = 300
test_size = 200

# Verify that the total matches your desired split
if train_size + test_size != total_size:
    print(f"Warning: Adjusting train/test split to match dataset length ({total_size})")
    train_size = int(0.6 * total_size)  # Example: 60% for training
    test_size = total_size - train_size

# Split the dataset into training and testing sets
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders for training and testing
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Move batches to the appropriate device (GPU/CPU) for training set
for image_batch, mask_batch in train_loader:
    image_batch = image_batch.to(device)
    mask_batch = mask_batch.to(device)
    print("Training Image Batch Shape:", image_batch.shape)
    print("Training Mask Batch Shape:", mask_batch.shape)
    break

# Move batches to the appropriate device (GPU/CPU) for testing set
for image_batch, mask_batch in test_loader:
    image_batch = image_batch.to(device)
    mask_batch = mask_batch.to(device)
    print("Testing Image Batch Shape:", image_batch.shape)
    print("Testing Mask Batch Shape:", mask_batch.shape)
    break


Training Image Batch Shape: torch.Size([16, 3, 448, 448])
Training Mask Batch Shape: torch.Size([16, 1, 448, 448])
Testing Image Batch Shape: torch.Size([16, 3, 448, 448])
Testing Mask Batch Shape: torch.Size([16, 1, 448, 448])


fine-tuning

用前面unet的encoder

In [None]:

# Residual Block with Debugging
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, num_filters):
        super(ResidualBlock, self).__init__()
        self.projection = None
        if in_channels != num_filters:
            self.projection = nn.Conv2d(in_channels, num_filters, kernel_size=1, stride=1)

        self.conv1 = nn.Conv2d(in_channels, num_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)

    def forward(self, x):
        #print(f"Input shape to ResidualBlock: {x.shape}")
        shortcut = x
        if self.projection:
            shortcut = self.projection(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += shortcut
        return self.relu(out)

# Attention Gate
class AttentionGate(nn.Module):
    def __init__(self, skip_channels, gate_channels, out_channels):
        super(AttentionGate, self).__init__()
        # Match the number of channels correctly
        self.conv_gate = nn.Conv2d(gate_channels, out_channels, kernel_size=1)
        self.bn_gate = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv_skip = nn.Conv2d(skip_channels, out_channels, kernel_size=1)
        self.bn_skip = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, skip, gate):
         
        # Adjust the gate and skip to have the same dimensions
        gate_resized = self.conv_gate(gate)
        gate_resized = self.bn_gate(gate_resized)
        gate_resized = self.relu(gate_resized)
        #print(f"Shape of gate_resized in attentiongate: {gate_resized.shape}")
        skip_resized = self.conv_skip(skip)
        skip_resized = self.bn_skip(skip_resized)
        #print(f"Shape of skip_resized in attentiongate: {skip_resized.shape}")
        # Combine the skip and gate tensors
        attention = self.sigmoid(gate_resized + skip_resized)
        out = skip * attention
        #print(f"Shape of out in attentiongate: {skip_resized.shape}")
        # Multiply by the skip tensor
        return out


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.attention_gate = AttentionGate(skip_channels, out_channels, out_channels)
        self.residual_block = ResidualBlock(out_channels + skip_channels, out_channels)

    def forward(self, x, skip):
        #print(f"Shape of x before upconv: {x.shape}")
        x = self.upconv(x)
        #print(f"Shape of x after upconv: {x.shape}")

        # Apply the attention gate to align the skip connection channels
        skip = self.attention_gate(skip, x)
        #print(f"Shape of skip after attention gate: {skip.shape}")

        # Concatenate the skip connection and the upsampled feature map
        x = torch.cat([x, skip], dim=1)
        #print(f"Shape of x after concatenation: {x.shape}")

        # Pass through the residual block
        x = self.residual_block(x)
        #print(f"Shape of x after residual block: {x.shape}")
        return x



# ImprovedUNet Model
class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=32):
        super(ImprovedUNet, self).__init__()
        self.enc1 = conv_block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = conv_block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = conv_block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.enc4 = conv_block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2, 2)

        # Bottleneck layer
        self.bottleneck = nn.Conv2d(features * 8, features * 16, kernel_size=3, padding=1)   #padding =1 keeps feature map height & width unchanged

        self.dec4 = DecoderBlock(features * 16, features * 8, features * 8)
        self.dec3 = DecoderBlock(features * 8, features * 4, features * 4)
        self.dec2 = DecoderBlock(features * 4, features * 2, features * 2)
        self.dec1 = DecoderBlock(features * 2, features, features)

        self.final = nn.Conv2d(features, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.enc1(x)
        #print(f"Shape of enc1: {enc1.shape}")
        p1 = self.pool1(enc1)
        enc2 = self.enc2(p1)
        #print(f"Shape of enc2: {enc2.shape}")
        p2 = self.pool2(enc2)
        enc3 = self.enc3(p2)
        #print(f"Shape of enc3: {enc3.shape}")
        p3 = self.pool3(enc3)
        enc4 = self.enc4(p3)
        #print(f"Shape of enc4: {enc4.shape}")
        p4 = self.pool4(enc4)

        bottleneck = self.bottleneck(p4)
        #print(f"Shape of bottleneck: {bottleneck.shape}")  #  expect 512, [batch_size, num_channels, height, width]-[16, 512, 7, 7]

        dec4 = self.dec4(bottleneck, enc4)
        #print(f"Shape of dec4: {dec4.shape}")
        dec3 = self.dec3(dec4, enc3)
        #print(f"Shape of dec3: {dec3.shape}")
        dec2 = self.dec2(dec3, enc2)
        #print(f"Shape of dec2: {dec2.shape}")
        dec1 = self.dec1(dec2, enc1)
        #print(f"Shape of dec1: {dec1.shape}")

        final_output = torch.sigmoid(self.final(dec1))
        #print(f"Shape of final output: {final_output.shape}")
        return final_output



# Model instantiation and device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet(in_channels=3, out_channels=1, features=32)
model.to(device)


ImprovedUNet(
  (enc1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fa

如果把decoder block 也直接放金ImprovedUnet

In [None]:
import torch
import torch.nn as nn
class ImprovedUNet(nn.Module):
    @staticmethod
    def conv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def __init__(self, in_channels=3, out_channels=1, features=64):
        super(ImprovedUNet, self).__init__()
        # Encoder
      

        self.enc1 = conv_block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.enc2 = conv_block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.enc3 = conv_block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.enc4 = conv_block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2, 2)

        # Bottleneck layer
        self.bottleneck = nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1)

        # Decoder upsampling layers
        self.upconv4 = nn.ConvTranspose2d(features * 8, features * 8, kernel_size=2, stride=2)
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)

        # Residual blocks (ensure correct input sizes after concatenation)
        self.residual4 = ResidualBlock(features * 8 + features * 8, features * 8)
        self.residual3 = ResidualBlock(features * 4 + features * 8, features * 4)
        self.residual2 = ResidualBlock(features * 2 + features * 4, features * 2)
        self.residual1 = ResidualBlock(features + features * 2, features)

        # Final output layer
        self.final = nn.Conv2d(features, out_channels, kernel_size=1)

    @staticmethod
    def attention_gate(input_tensor, gate_tensor, num_filters):
    # Resize gate tensor to match the input_tensor shape using a 1x1 convolution
        gate_resized = layers.Conv2D(num_filters, 1, padding='same')(gate_tensor)
        gate_resized = layers.BatchNormalization()(gate_resized)
        gate_resized = layers.ReLU()(gate_resized)

        # Add the gate to the input_tensor
        x = layers.add([input_tensor, gate_resized])
        attention = layers.Conv2D(num_filters, 1, padding='same', activation='sigmoid')(x)
        x = layers.multiply([input_tensor, attention])

        return x

    @staticmethod
    def decoder_block(input_tensor, skip_features, num_filters):
        x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input_tensor)
        attention_skipped = attention_gate(skip_features, x, num_filters)  # Apply attention gate
        x = layers.concatenate([x, attention_skipped])
        x = conv_block(x, num_filters)
        return x

    def forward(self, x, features=64):
        try:
            # Encoder
            enc1 = self.enc1(x)
            print(f"Shape of enc1: {enc1.shape}")
            p1 = self.pool1(enc1)
            enc2 = self.enc2(p1)
            print(f"Shape of enc2: {enc2.shape}")
            p2 = self.pool2(enc2)
            enc3 = self.enc3(p2)
            print(f"Shape of enc3: {enc3.shape}")
            p3 = self.pool3(enc3)
            enc4 = self.enc4(p3)
            print(f"Shape of enc4: {enc4.shape}")
            p4 = self.pool4(enc4)

            # Bottleneck
            bottleneck = self.bottleneck(p4)
            print(f"Shape of bottleneck: {bottleneck.shape}")

            # # Decoder Stage 1
            # up4 = self.upconv4(bottleneck)
            # print(f"Shape of up4: {up4.shape}")
            # merge4 = torch.cat([up4, enc4], dim=1)
            # print(f"Shape of merge4: {merge4.shape}")
            # dec4 = self.residual4(merge4)
            # print(f"Shape of dec4: {dec4.shape}")

            # # Decoder Stage 2
            # up3 = self.upconv3(dec4)
            # print(f"Shape of up3: {up3.shape}")
            # merge3 = torch.cat([up3, enc3], dim=1)
            # print(f"Shape of merge3: {merge3.shape}")
            # dec3 = self.residual3(merge3)
            # print(f"Shape of dec3: {dec3.shape}")

            # # Decoder Stage 3
            # up2 = self.upconv2(dec3)
            # print(f"Shape of up2: {up2.shape}")
            # merge2 = torch.cat([up2, enc2], dim=1)
            # print(f"Shape of merge2: {merge2.shape}")
            # dec2 = self.residual2(merge2)
            # print(f"Shape of dec2: {dec2.shape}")

            # # Decoder Stage 4
            # up1 = self.upconv1(dec2)
            # print(f"Shape of up1: {up1.shape}")
            # merge1 = torch.cat([up1, enc1], dim=1)
            # print(f"Shape of merge1: {merge1.shape}")
            # dec1 = self.residual1(merge1)
            # print(f"Shape of dec1: {dec1.shape}")

            # final_output = torch.sigmoid(self.final(dec1))
            # print(f"Shape of final output: {final_output.shape}")
            # return final_output
        
            up4 = self.decoder_block(bottleneck, enc4, features * 8)
            up3 = self.decoder_block(up4, enc3, features * 4)
            up2 = self.decoder_block(up3, enc2, features * 2)
            up1 = self.decoder_block(up2, enc1, features * 1)
            # Final layer
            outputs = layers.Conv2D(1, (1,1), padding="same", activation="sigmoid")(up1)
            return outputs

        except Exception as e:
            print(f"Error encountered: {e}")
            raise e



class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Adjust the projection layer to align channels correctly
        if in_channels != out_channels:
            # Projection layer with correct input-to-output mapping
            self.projection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.projection = nn.Identity()  # Pass-through if the channels match

    def forward(self, x):
        print(f"Input shape to ResidualBlock: {x.shape}")
        # Adjust the skip connection via the projection layer if necessary
        # residual = self.projection(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Add the skip connection
        out += residual
        print(f"Output shape from ResidualBlock: {out.shape}")
        return self.relu(out)




# Model instantiation and device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet(in_channels=3, out_channels=1, features=32)
model.to(device)

ImprovedUNet(
  (enc1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fa

fine-tuning train

In [None]:
# Initialize the improved U-Net model
model = ImprovedUNet(in_channels=3, out_channels=1, features=64)
 
# Load pre-trained weights into the encoder section
# pretrained_weights = torch.load( "/content/drive/MyDrive/CEE 598 DL/pretrained_unet_encoder.pth")
pretrained_weights = torch.load("./stage1_unet_encoder.pth")

# Map pre-trained encoder weights to fine-tuning model's encoder layers
model_state = model.state_dict()

# Adjust this mapping according to your encoder's specific layer names
encoder_weights_mapping = {
    'enc1.0.weight': 'enc1.conv.weight',
    'enc1.0.bias': 'enc1.conv.bias',
    'enc2.0.weight': 'enc2.conv.weight',
    # Add mappings for other layers (adjust names based on your pre-trained encoder)
}

# Update the model's state dictionary with the pre-trained weights
for pre_key, fine_key in encoder_weights_mapping.items():
    if pre_key in pretrained_weights and fine_key in model_state:
        model_state[fine_key] = pretrained_weights[pre_key]

# Load the updated state dictionary back into the model
model.load_state_dict(model_state)
 
# Move the model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, masks in train_loader:
            # Move images and masks to the GPU
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss / len(train_loader):.4f}")

# Testing/Prediction function
def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, _ in test_loader:  # Ignore the ground-truth masks here
            # Move images to the GPU
            images = images.to(device)
            outputs = model(images)

            # Apply thresholding to get binary predictions
            predicted_masks = (outputs > 0.5).float()

            # Store predictions (back to CPU if necessary)
            predictions.extend(predicted_masks.cpu())

    return predictions

# Define a loss function and optimizer
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Example usage
# Fine-tune the model
num_epochs = 10  # Adjust as needed
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Evaluate the model on the test dataset
# predicted_masks = evaluate_model(model, test_loader)

# Example: Print the shape of the first predicted mask
if predicted_masks:
    print("First Predicted Mask Shape:", predicted_masks[0].shape)
 

Shape of enc1: torch.Size([16, 64, 448, 448])
Shape of enc2: torch.Size([16, 128, 224, 224])
Shape of enc3: torch.Size([16, 256, 112, 112])
Shape of enc4: torch.Size([16, 512, 56, 56])
Shape of bottleneck: torch.Size([16, 512, 28, 28])
Error encountered: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
# Set up the device for training and testing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the model and move it to the device
model = ImprovedUNet(in_channels=3, out_channels=1, features=32)
model.to(device)

# Define a loss function and optimizer
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss / len(train_loader):.4f}")

# Testing/Prediction function
def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for images, _ in test_loader:  # Ignore the ground-truth masks here
            images = images.to(device)
            outputs = model(images)

            # Apply thresholding to get binary predictions
            predicted_masks = (outputs > 0.5).float()

            # Store predictions (back to CPU if necessary)
            predictions.extend(predicted_masks.cpu())

    return predictions

# Example usage
# Fine-tune the model
num_epochs = 10  # Adjust as needed
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Evaluate the model on the test dataset
predicted_masks = evaluate_model(model, test_loader)

# Example: Print the shape of the first predicted mask
if predicted_masks:
    print("First Predicted Mask Shape:", predicted_masks[0].shape)


Shape of enc1: torch.Size([16, 32, 112, 112])
Shape of enc2: torch.Size([16, 64, 56, 56])
Shape of enc3: torch.Size([16, 128, 28, 28])
Shape of enc4: torch.Size([16, 256, 14, 14])
Shape of bottleneck: torch.Size([16, 256, 7, 7])
Input shape to ResidualBlock: torch.Size([16, 512, 14, 14])
Shape of dec4: torch.Size([16, 256, 14, 14])


RuntimeError: Given groups=1, weight of size [128, 256, 1, 1], expected input[16, 128, 28, 28] to have 256 channels, but got 128 channels instead

improved unet for fine tuning?

encoder + MLP?

contrastive_loss