In [1]:
import os
import re
import glob
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, Grayscale
from torchvision.models import models
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm

In [2]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=1,
        depth=4,
        wf=3,
        padding=True,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597

        Using the default arguments will yield the exact version used
        in the original paper

        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

In [9]:
# Regular expression for numerical sorting
numbers = re.compile(r'(\d+)')

def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

# Function to load and preprocess the image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Adjust size as needed
        Grayscale(num_output_channels=1),  # Convert to grayscale with one channel
        transforms.ToTensor(),
        # Add other necessary transformations (normalization, etc.)
    ])
    image = Image.open(image_path)
    return transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to device

# Load your trained model
model = UNet().to(device)
model.load_state_dict(torch.load('VPAT_model_checkpoint_6h.pth', map_location=device))
model.eval()

# List all image files in the directory
df = pd.read_csv('dataset.csv')
image_files = list(df['Path'])
images_directory = r'D:\CV_Project\results\integrals'
image_files = sorted(glob.glob(os.path.join(images_directory, '*.png')), key=numericalSort)

# Create a DataLoader for batch processing
batch_size = 16  # Change batch size according to memory constraints
data_loader = DataLoader(image_files, batch_size=batch_size, shuffle=False)

# Output directory to save edited images
output_directory = r'D:\OneDrive - Johannes Kepler Universität Linz\Artificial_Intelligence\Computer_Vision\predictions'
os.makedirs(output_directory, exist_ok=True)

# Define a counter for naming the outputs
counter = 0

# Loop through the DataLoader and process images using the model
for batch in tqdm(data_loader):
    input_images = torch.cat([preprocess_image(img_path) for img_path in batch])
    
    with torch.no_grad():
        output_images = model(input_images)
    
    # Process and save each output image in the batch
    for i, output_image in enumerate(output_images):
        output_image = output_image.permute(1, 2, 0).cpu().detach().numpy()
        output_image = np.clip(output_image, 0, 1)  # Clip values if necessary
    
        # Check image dimensions and adjust if needed
        if output_image.shape[2] == 1:  # Convert single channel to 3 channels (grayscale to RGB)
            output_image = np.repeat(output_image, 3, axis=2)
        elif output_image.shape[2] != 3 and output_image.shape[2] != 4:
            raise ValueError("Invalid image shape")
    
        # Save the edited image
        image_name = f"output_{str(counter).zfill(5)}.png"
        image_path = os.path.join(output_directory, image_name)
        plt.imsave(image_path, output_image)
        counter += 1

100%|██████████████████████████████████████████████████████████████████████████████| 2006/2006 [39:42<00:00,  1.19s/it]
