### Guided Super-Resolution

In [None]:
import numpy as np
import os

data_dir = r"..\data\high_dose"
#data_dir = r"..\data\low_dose"
#data_dir = r"..\data\no_dose"

fluorescence = np.load(os.path.join(data_dir, "fluorescence.npy"))
fluorescence_tritc = np.load(os.path.join(data_dir, "fluorescence_tritc.npy"))
maldi_ihc = np.load(os.path.join(data_dir, "maldi_ihc.npy"))
lipid = np.load(os.path.join(data_dir, "lipid.npy"))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_msssim import ssim
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

# A small helper block: Conv -> BN -> ReLU -> Conv -> BN -> ReLU
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# Downsampling block: MaxPool -> DoubleConv
class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x):
        x = self.maxpool(x)
        x = self.conv(x)
        return x

# Upsampling block
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            # Use bilinear upsampling
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            # Then reduce channels via DoubleConv
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            # Use a transposed conv
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                         kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        x1 = feature map from the previous decoder layer
        x2 = skip connection from the encoder
        """
        # Upsample x1
        x1 = self.up(x1)
        
        # Adjust padding if needed (for odd dimension shapes)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        # Concatenate skip connection
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

# Final 1x1 output conv
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

# -------------------------------------------------------
# Two-Input U-Net
# -------------------------------------------------------
class TwoInputUNet(nn.Module):
    """
    A U-Net with two separate encoders for src and ref images,
    then merges features for the decoder.
    """
    def __init__(self, n_channels=1, n_classes=1, bilinear=True):
        super(TwoInputUNet, self).__init__()
        
        # ---------- Encoder for src ----------
        self.src_inc   = DoubleConv(n_channels, 64)
        self.src_down1 = Down(64, 128)
        self.src_down2 = Down(128, 256)
        self.src_down3 = Down(256, 512)
        self.src_down4 = Down(512, 512)  # bottom layer

        # ---------- Encoder for ref ----------
        self.ref_inc   = DoubleConv(n_channels, 64)
        self.ref_down1 = Down(64, 128)
        self.ref_down2 = Down(128, 256)
        self.ref_down3 = Down(256, 512)
        self.ref_down4 = Down(512, 512)  # bottom layer

        # ---------- Decoder ----------
        # Notice in_channels for Up is doubled because we concatenate from src + ref encoders
        self.up1  = Up(512*2, 256, bilinear)
        self.up2  = Up(256*2, 128, bilinear)
        self.up3  = Up(128*2, 64,  bilinear)
        self.up4  = Up(64*2,  64,  bilinear)
        self.outc = OutConv(64, n_classes)
        self.sigmoid = nn.Sigmoid()

    def forward(self, src, ref):
        # 1) Encode src
        src_x1 = self.src_inc(src)        
        src_x2 = self.src_down1(src_x1)   
        src_x3 = self.src_down2(src_x2)   
        src_x4 = self.src_down3(src_x3)   
        src_x5 = self.src_down4(src_x4)   
    
        # 2) Encode ref
        ref_x1 = self.ref_inc(ref)        
        ref_x2 = self.ref_down1(ref_x1)   
        ref_x3 = self.ref_down2(ref_x2)   
        ref_x4 = self.ref_down3(ref_x3)   
        ref_x5 = self.ref_down4(ref_x4)   
    
        # Instead of cat, do sums:
        bottom = src_x5 + ref_x5       # shape [B, 512, ...]
        skip4  = src_x4 + ref_x4       # shape [B, 512, ...]
        skip3  = src_x3 + ref_x3       # shape [B, 256, ...]
        skip2  = src_x2 + ref_x2       # shape [B, 128, ...]
        skip1  = src_x1 + ref_x1       # shape [B, 64,  ...]
        
        # Now decode using a standard single‐U‐Net logic
        x = self.up1(bottom, skip4)  
        x = self.up2(x, skip3)
        x = self.up3(x, skip2)
        x = self.up4(x, skip1)
        x = self.outc(x)
        return x

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import cv2
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import gc
from super_image import EdsrModel, ImageLoader
from PIL import Image
import matplotlib.pyplot as plt
from skimage.exposure import match_histograms

# Function to split an image into non-overlapping tiles
def split_into_tiles(image, tile_size=500):
    tiles = []
    for i in range(0, image.shape[0], tile_size):
        for j in range(0, image.shape[1], tile_size):
            tile = image[i:i+tile_size, j:j+tile_size]  # Extract tile
            if tile.shape[0] == tile_size and tile.shape[1] == tile_size:  # Ensure full size
                tiles.append(tile)
    return np.array(tiles)

def match_histogram(pred_np, target_np):
    """
    pred_np   : np.ndarray (H, W)  – network output in [0,1]
    target_np : np.ndarray (H, W)  – MALDI tile in [0,1]
    """
    return match_histograms(pred_np, target_np, channel_axis=None)

# Function to reconstruct the image from tiles
def reconstruct_from_tiles(tiles, original_size=(10000, 10000), tile_size=500):
    reconstructed = np.zeros(original_size, dtype=np.float32)
    index = 0
    for i in range(0, original_size[0], tile_size):
        for j in range(0, original_size[1], tile_size):
            reconstructed[i:i+tile_size, j:j+tile_size] = tiles[index]
            index += 1
    return reconstructed

def edsr(cropped_maldi):
    # Load the EDSR model (scale 2, 3, and 4 models available)
    model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=4)
    
    # Ensure proper normalization before conversion to uint8
    cropped_maldi_edsr = cropped_maldi.copy()
    
    # Normalize to [0,1] range
    min_val, max_val = cropped_maldi_edsr.min(), cropped_maldi_edsr.max()
    if max_val > min_val:  # Avoid division by zero
        cropped_maldi_edsr = (cropped_maldi_edsr - min_val) / (max_val - min_val)
    else:
        cropped_maldi_edsr = np.zeros_like(cropped_maldi_edsr)  # If all values are the same, set to zero
    
    # Convert to uint8 (0-255 range)
    cropped_maldi_edsr = (cropped_maldi_edsr * 255).astype(np.uint8)
    
    # Convert to PIL image and ensure it's RGB for model input
    low_res_img_pil = Image.fromarray(cropped_maldi_edsr).convert("RGB")
    
    # Process the image using EDSR
    inputs = ImageLoader.load_image(low_res_img_pil)
    preds = model(inputs)
    
    # Convert tensor to NumPy array
    sr_image_np = preds.cpu().detach().numpy().squeeze()  # Remove batch dimension if present
    sr_image_np = (sr_image_np * 255).clip(0, 255).astype(np.uint8)  # Scale to 0-255 and convert to uint8
    
    # Convert to PIL Image (ensure correct channel order)
    if sr_image_np.ndim == 3:  # Check if the result is multi-channel (RGB)
        sr_image_pil = Image.fromarray(np.transpose(sr_image_np, (1, 2, 0)))  # Convert CHW to HWC
    else:
        sr_image_pil = Image.fromarray(sr_image_np)
    
    # Convert to grayscale with proper normalization
    sr_image_np_gray = np.array(sr_image_pil.convert("L"), dtype=np.float32)
    
    min_gray, max_gray = sr_image_np_gray.min(), sr_image_np_gray.max()
    if max_gray > min_gray:  
        sr_image_np_gray = (sr_image_np_gray - min_gray) / (max_gray - min_gray)  # Normalize grayscale values
    else:
        sr_image_np_gray = np.zeros_like(sr_image_np_gray)  # If all values are the same, set to zero
    
    sr_image_np_gray = (sr_image_np_gray * 255).astype(np.uint8)  # Convert to uint8
    
    cropped_maldi_edsr = sr_image_np_gray
    cropped_maldi_edsr = (cropped_maldi_edsr/255.0).astype('float32')
    return cropped_maldi_edsr

In [None]:
import os
import torch.optim as optim
from tqdm import tqdm
import gc
import numpy as np
import cv2

results_dir = r"..\results\high_dose"

gc.collect()
torch.cuda.empty_cache()

ssim_weight = 0.15

for channel in range(23):
    print(f'Processing: Channel {channel}')   
        
    # Split into tiles
    fluorescence_tiles = split_into_tiles(fluorescence, 1024)
    fluorescence_tiles = (fluorescence_tiles/255.0).astype('float32')

    fluorescence_tritc_tiles = split_into_tiles(fluorescence_tritc, 1024)
    fluorescence_tritc_tiles = (fluorescence_tritc_tiles/255.0).astype('float32')
    
    maldi_ihc_channel = maldi_ihc[channel,:,:].copy()
    maldi_tiles_orig = split_into_tiles(maldi_ihc_channel, 128)
    maldi_tiles_orig = (maldi_tiles_orig/255.0).astype('float32')

    maldi_tiles = maldi_tiles_orig.copy()
 
    # Base directory for saving tiles
    base_dir = os.path.join(results_dir, f"{channel}")
    
    # Subdirectories
    output_dir = os.path.join(base_dir, "outputs")  # Processed output tiles
    fluorescence_dir = os.path.join(base_dir, "fluorescence")  # Fluorescence input tiles
    fluorescence_tritc_dir = os.path.join(base_dir, "fluorescence_tritc")  # Fluorescence cell tiles
    maldi_dir = os.path.join(base_dir, "maldi")  # MALDI input tiles
    original_maldi_dir = os.path.join(base_dir, "original_maldi")  # MALDI input tiles
    
    # Create directories if they don’t exist
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(fluorescence_dir, exist_ok=True)
    os.makedirs(fluorescence_tritc_dir, exist_ok=True)
    os.makedirs(maldi_dir, exist_ok=True)
    os.makedirs(original_maldi_dir, exist_ok=True)
    
    gc.collect()
    torch.cuda.empty_cache()
    
    # Initialize the model, optimizer, and loss function
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TwoInputUNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    def combined_loss(pred, src, upsampled):  
        mse_loss = nn.MSELoss()(pred, upsampled)
        ssim_loss = 1 - ssim(pred, src, data_range=1, size_average=True)
        return (1-ssim_weight) * mse_loss + ssim_weight * ssim_loss
    
    processed_tiles = []
    
    # Process and save each tile
    for i in range(fluorescence_tiles.shape[0]):
        print(f"Processing tile {i}...")
    
        model = TwoInputUNet().to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
        fluorescence_tile = fluorescence_tiles[i]
        fluorescence_tritc_tile = fluorescence_tritc_tiles[i]
        maldi_tile = maldi_tiles[i]
        maldi_tile_orig = maldi_tiles_orig[i]
        maldi_tile_orig = cv2.resize(maldi_tile_orig,(1024,1024),interpolation=cv2.INTER_AREA)
        maldi_tile = cv2.resize(edsr((maldi_tile*255.0).astype('uint8')),(1024,1024),interpolation=cv2.INTER_AREA)
        
        # Convert to PyTorch tensor and float32
        fluorescence_tile = torch.from_numpy(fluorescence_tile)
        maldi_tile = torch.from_numpy(maldi_tile)
    
        # Save input fluorescence and MALDI tiles as images
        fluorescence_filename = os.path.join(fluorescence_dir, f"fluorescence_tile_{i}.png")
        fluorescence_cell_filename = os.path.join(fluorescence_tritc_dir, f"fluorescence_tile_{i}.png")
        maldi_filename = os.path.join(maldi_dir, f"maldi_tile_{i}.png")
        maldi_filename_orig = os.path.join(original_maldi_dir, f"maldi_tile_{i}.png")
    
        plt.imsave(fluorescence_filename, fluorescence_tile, cmap='gray')  # Save fluorescence
        plt.imsave(fluorescence_cell_filename, fluorescence_cell_tile, cmap='gray')  # Save fluorescence cell
        plt.imsave(maldi_filename, maldi_tile)  # Save MALDI
        plt.imsave(maldi_filename_orig, maldi_tile_orig)  # Save MALDI
    
        # Ensure proper shape (B, C, H, W)
        fluorescence_tile = fluorescence_tile.to(device).unsqueeze(0).unsqueeze(0)
        maldi_tile = maldi_tile.to(device).unsqueeze(0).unsqueeze(0)
    
        num_epochs = 100
        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0
    
            outputs = model(fluorescence_tile, maldi_tile)   # <--- TWO inputs
            loss = combined_loss(outputs, fluorescence_tile, maldi_tile)
    
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            epoch_loss += loss.item()
    
            print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
        # Convert output to NumPy array
        processed_tile = torch.sigmoid(outputs).cpu().squeeze().detach().numpy()
        processed_tile = match_histogram(processed_tile, maldi_tile_orig)
        processed_tiles.append(processed_tile)
    
        # Save the processed output tile as an image
        output_filename = os.path.join(output_dir, f"output_tile_{i}.png")
        plt.imsave(output_filename, processed_tile)
    
        print(f"Saved processed tile {i} to {output_filename}")
    
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from glob import glob

def stitch_tiles_blended(tiles, grid_shape=(5, 5), tile_size=1000, overlap=100, blend=True):
    step = tile_size - overlap
    out_size = step * grid_shape[0] + overlap
    stitched = np.zeros((out_size, out_size), dtype=np.float32)
    weight_map = np.zeros_like(stitched)

    if blend and overlap > 0:
        def blend_window(size, overlap):
            w = np.ones(size)
            ramp = np.linspace(0, 1, overlap)
            w[:overlap] = ramp
            w[-overlap:] = ramp[::-1]
            return w

        blend_x = blend_window(tile_size, overlap)
        blend_y = blend_window(tile_size, overlap)
        blend_mask = np.outer(blend_y, blend_x)
    else:
        blend_mask = np.ones((tile_size, tile_size), dtype=np.float32)

    for idx, tile in enumerate(tiles):
        row = idx // grid_shape[1]
        col = idx % grid_shape[1]
        y_start = row * step
        x_start = col * step

        stitched[y_start:y_start+tile_size, x_start:x_start+tile_size] += tile * blend_mask
        weight_map[y_start:y_start+tile_size, x_start:x_start+tile_size] += blend_mask

    stitched /= np.maximum(weight_map, 1e-6)
    return stitched

def load_tiles(tile_dir, prefix, grid_shape, tile_size=1000):
    total_tiles = grid_shape[0] * grid_shape[1]
    tiles = []
    for i in range(total_tiles):
        tile_path = os.path.join(tile_dir, f"{prefix}_tile_{i}.png")
        if os.path.exists(tile_path):
            img = cv2.imread(tile_path, cv2.IMREAD_GRAYSCALE)
            img = img.astype(np.float32) / 255.0
        else:
            img = np.zeros((tile_size, tile_size), dtype=np.float32)
        tiles.append(img)
    return tiles

# === Main Processing ===
for channel in range(23):
    print(f"Stitching Channel {channel}...")

    base_dir = os.path.join(results_dir, f"{channel}")
    output_dir = os.path.join(base_dir, "outputs")
    fluorescence_dir = os.path.join(base_dir, "fluorescence")
    fluorescence_tritc_dir = os.path.join(base_dir, "fluorescence_tritc")
    maldi_dir = os.path.join(base_dir, "maldi")
    original_maldi_dir = os.path.join(base_dir, "original_maldi")

    grid_shape = (5, 5)
    tile_size = 1000

    # --- Processed Tiles ---
    processed_tiles = load_tiles(output_dir, "output", grid_shape, tile_size)
    stitched_processed = stitch_tiles_blended(processed_tiles, grid_shape, tile_size)
    plt.imsave(os.path.join(output_dir, "stitched.png"), stitched_processed, cmap='viridis')

    # --- Fluorescence Tiles ---
    fluorescence_tiles = load_tiles(fluorescence_dir, "fluorescence", grid_shape, tile_size)
    stitched_fluorescence = stitch_tiles_blended(fluorescence_tiles, grid_shape, tile_size)
    plt.imsave(os.path.join(fluorescence_dir, "stitched.png"), stitched_fluorescence, cmap='gray')

    # --- Fluorescence Cell Tiles ---
    fluorescence_cell_tiles = load_tiles(fluorescence_tritc_dir, "fluorescence", grid_shape, tile_size)
    stitched_fluorescence_cell = stitch_tiles_blended(fluorescence_cell_tiles, grid_shape, tile_size)
    plt.imsave(os.path.join(fluorescence_tritc_dir, "stitched.png"), stitched_fluorescence_cell, cmap='gray')

    # --- MALDI Tiles ---
    maldi_tiles = load_tiles(maldi_dir, "maldi", grid_shape, tile_size)
    stitched_maldi = stitch_tiles_blended(maldi_tiles, grid_shape, tile_size)
    plt.imsave(os.path.join(maldi_dir, "stitched.png"), stitched_maldi, cmap='viridis')

    # --- Original MALDI Tiles ---
    orig_maldi_tiles = load_tiles(original_maldi_dir, "maldi", grid_shape, tile_size)
    stitched_orig_maldi = stitch_tiles_blended(orig_maldi_tiles, grid_shape, tile_size)
    plt.imsave(os.path.join(original_maldi_dir, "stitched.png"), stitched_orig_maldi, cmap='viridis')

In [None]:
import os
import cv2
import numpy as np

# List of channels used
channels = range(23)

# Initialize list to collect stitched arrays
stitched_arrays = []

for channel in channels:
    stitched_path = os.path.join(results_dir, str(channel), "outputs", "stitched.png")
    
    if os.path.exists(stitched_path):
        img = cv2.imread(stitched_path, cv2.IMREAD_GRAYSCALE)
        stitched_arrays.append(img.astype(np.float32) / 255.0)
    else:
        print(f"Missing stitched image for channel {channel}")

# Stack into a (C, H, W) array
stitched_stack = np.stack(stitched_arrays, axis=0)
print("Final stacked shape:", stitched_stack.shape)  # (C, H, W)
np.save(os.path.join(results_dir, "maldi_ihc_gsr.npy"), stitched_stack)

In [None]:
from tqdm import tqdm

def stitch_tiles_blended(tiles, grid_shape=(5, 5), tile_size=1000, overlap=100, blend=True):
    step = tile_size - overlap
    out_size = step * grid_shape[0] + overlap
    stitched = np.zeros((out_size, out_size), dtype=np.float32)
    weight_map = np.zeros_like(stitched)

    if blend and overlap > 0:
        def blend_window(size, overlap):
            w = np.ones(size)
            ramp = np.linspace(0, 1, overlap)
            w[:overlap] = ramp
            w[-overlap:] = ramp[::-1]
            return w

        blend_x = blend_window(tile_size, overlap)
        blend_y = blend_window(tile_size, overlap)
        blend_mask = np.outer(blend_y, blend_x)
    else:
        blend_mask = np.ones((tile_size, tile_size), dtype=np.float32)

    for idx, tile in enumerate(tiles):
        row = idx // grid_shape[1]
        col = idx % grid_shape[1]
        y_start = row * step
        x_start = col * step

        stitched[y_start:y_start+tile_size, x_start:x_start+tile_size] += tile * blend_mask
        weight_map[y_start:y_start+tile_size, x_start:x_start+tile_size] += blend_mask

    stitched /= np.maximum(weight_map, 1e-6)
    return stitched

def split_into_tiles(image, tile_size=500):
    tiles = []
    for i in range(0, image.shape[0], tile_size):
        for j in range(0, image.shape[1], tile_size):
            tile = image[i:i+tile_size, j:j+tile_size]  # Extract tile
            if tile.shape[0] == tile_size and tile.shape[1] == tile_size:  # Ensure full size
                tiles.append(tile)
    return np.array(tiles)

lipids_stitched = []
for channel in tqdm(range(lipids_square.shape[0])):
    grid_shape = (5, 5)
    tile_size = 1000
    lipid_square = lipids_square[channel,:,:]
    lipid_tiles = split_into_tiles(lipid_square, 1000)
    lipid_tiles = (lipid_tiles/255.0).astype('float32')
    stitched_lipid = stitch_tiles_blended(lipid_tiles, grid_shape, tile_size)
    lipids_stitched.append(stitched_lipid)
lipids_stitched = np.stack(lipids_stitched,axis=0)
np.save(os.path.join(results_dir,'lipids_stitched.npy'), lipids_stitched)