In [1]:
import numpy as np
from scipy import ndimage
from sklearn.preprocessing import StandardScaler
from skimage.transform import resize
from sklearn.decomposition import NMF
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.metrics.pairwise import cosine_similarity

In [2]:

def upscale_msi_data(msi_data, optical_image, target_shape, patch_size=5):
    """
    Upscale MSI data using structural information from a high-resolution optical image.
    
    Parameters:
    -----------
    msi_data : ndarray
        Low-resolution MSI data of shape (height, width, n_channels)
    optical_image : ndarray
        High-resolution optical image of shape (target_height, target_width)
    target_shape : tuple
        Desired output shape (height, width)
    patch_size : int
        Size of patches for local feature extraction
        
    Returns:
    --------
    ndarray
        Upscaled MSI data of shape (target_height, target_width, n_channels)
    """
    
    # Step 1: Initial upscaling using bicubic interpolation
    upscaled_msi = np.zeros((target_shape[0], target_shape[1], msi_data.shape[2]))
    for channel in range(msi_data.shape[2]):
        upscaled_msi[:,:,channel] = resize(msi_data[:,:,channel], 
                                         target_shape, 
                                         order=3,  # bicubic
                                         mode='reflect')
    
    # Step 2: Extract structural features from optical image
    optical_patches = extract_patches_2d(optical_image, 
                                       patch_size=(patch_size, patch_size))
    optical_features = optical_patches.reshape(optical_patches.shape[0], -1)
    
    # Step 3: Normalize features
    scaler = StandardScaler()
    optical_features_normalized = scaler.fit_transform(optical_features)
    
    # Step 4: Learn local patterns using NMF
    n_components = min(20, optical_features_normalized.shape[1])
    nmf = NMF(n_components=n_components, random_state=42)
    optical_patterns = nmf.fit_transform(np.abs(optical_features_normalized))
    
    # Step 5: Build similarity matrix
    similarity_matrix = cosine_similarity(optical_patterns)
    
    # Step 6: Guided filtering using structural similarity
    filtered_msi = np.zeros_like(upscaled_msi)
    for channel in range(upscaled_msi.shape[2]):
        channel_data = upscaled_msi[:,:,channel]
        filtered_channel = guided_filter(channel_data, 
                                      optical_image, 
                                      similarity_matrix,
                                      radius=patch_size//2)
        filtered_msi[:,:,channel] = filtered_channel
    
    return filtered_msi

def guided_filter(msi_channel, guide_image, similarity_matrix, radius=2):
    """
    Apply guided filtering using structural information from the optical image.
    
    Parameters:
    -----------
    msi_channel : ndarray
        Single channel of upscaled MSI data
    guide_image : ndarray
        High-resolution optical image
    similarity_matrix : ndarray
        Pre-computed structural similarity matrix
    radius : int
        Filtering radius
        
    Returns:
    --------
    ndarray
        Filtered MSI channel
    """
    
    # Apply edge-preserving smoothing
    filtered = np.zeros_like(msi_channel)
    
    # Pad images for border handling
    padded_msi = np.pad(msi_channel, radius, mode='reflect')
    padded_guide = np.pad(guide_image, radius, mode='reflect')
    
    # Iterate through each pixel
    for i in range(radius, padded_msi.shape[0] - radius):
        for j in range(radius, padded_msi.shape[1] - radius):
            # Extract local windows
            msi_window = padded_msi[i-radius:i+radius+1, j-radius:j+radius+1]
            guide_window = padded_guide[i-radius:i+radius+1, j-radius:j+radius+1]
            
            # Compute weights based on guide image similarity
            weights = compute_weights(guide_window, similarity_matrix)
            
            # Apply weighted filtering
            filtered[i-radius, j-radius] = np.sum(weights * msi_window)
    
    return filtered

def compute_weights(window, similarity_matrix, epsilon=1e-6):
    """
    Compute filtering weights based on structural similarity.
    """
    window_flat = window.flatten()
    weights = np.dot(similarity_matrix, window_flat)
    weights = np.maximum(weights, epsilon)
    weights = weights / np.sum(weights)
    return weights.reshape(window.shape)

In [3]:
import joblib
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
msi_data = joblib.load(r'E:\CloudDrive\OneDrive - genseccoltd\Projects\vscode-jupyter\gancaofusion-data\size_list, data, data_dr, Profile_list_dr')[2]

In [5]:
msi_data.shape

(10, 205, 255)

In [6]:
msi_data = np.moveaxis(msi_data, 0, -1)
msi_data.shape

(205, 255, 10)

In [7]:
optical_image = np.array(Image.open(r"E:\CloudDrive\OneDrive - genseccoltd\Projects\ij\gancao30\precise regi\aufn moved.jpg"))

In [11]:
# Load your data
msi_data = msi_data[100:150, 0:100, :] # Your low-res MSI data (height, width, channels)
optical_image = optical_image[2000:3000, 0:2000, :] # Your high-res optical image
target_shape = optical_image.shape[:2]
print ("Target shape: ", target_shape)
print ("MSI data shape: ", msi_data.shape)
print ("Optical image shape: ", optical_image.shape)

Target shape:  (1000, 2000)
MSI data shape:  (50, 100, 10)
Optical image shape:  (1000, 2000, 3)


In [13]:
# Apply the upscaling
upscaled_msi = upscale_msi_data(msi_data, optical_image, target_shape)



MemoryError: Unable to allocate 28.8 TiB for an array with shape (1988016, 1988016) and data type float64

In [14]:
import numpy as np
from scipy import ndimage
from skimage.transform import resize
import math

def upscale_msi_data_efficient(msi_data, optical_image, target_shape, tile_size=100):
    """
    Memory-efficient upscaling of MSI data using tiled processing.
    
    Parameters:
    -----------
    msi_data : ndarray
        Low-resolution MSI data of shape (height, width, channels)
    optical_image : ndarray
        High-resolution optical image of shape (target_height, target_width)
    target_shape : tuple
        Desired output shape (height, width)
    tile_size : int
        Size of tiles for processing
        
    Returns:
    --------
    ndarray
        Upscaled MSI data of shape (target_height, target_width, channels)
    """
    
    # Calculate scaling factors
    scale_y = target_shape[0] / msi_data.shape[0]
    scale_x = target_shape[1] / msi_data.shape[1]
    
    # Calculate number of tiles needed
    n_tiles_y = math.ceil(target_shape[0] / tile_size)
    n_tiles_x = math.ceil(target_shape[1] / tile_size)
    
    # Initialize output array
    upscaled_msi = np.zeros((*target_shape, msi_data.shape[2]))
    
    # Process each tile
    for ty in range(n_tiles_y):
        for tx in range(n_tiles_x):
            # Calculate tile boundaries in target space
            y_start = ty * tile_size
            x_start = tx * tile_size
            y_end = min((ty + 1) * tile_size, target_shape[0])
            x_end = min((tx + 1) * tile_size, target_shape[1])
            
            # Calculate corresponding source boundaries
            src_y_start = int(y_start / scale_y)
            src_x_start = int(x_start / scale_x)
            src_y_end = min(math.ceil(y_end / scale_y), msi_data.shape[0])
            src_x_end = min(math.ceil(x_end / scale_x), msi_data.shape[1])
            
            # Extract source tile with padding
            src_tile = msi_data[src_y_start:src_y_end, 
                              src_x_start:src_x_end, 
                              :]
            
            # Extract corresponding optical tile
            opt_tile = optical_image[y_start:y_end, 
                                   x_start:x_end]
            
            # Calculate local target size
            local_target = (y_end - y_start, x_end - x_start)
            
            # Process tile
            tile_result = process_tile(src_tile, opt_tile, local_target)
            
            # Place result in output array
            upscaled_msi[y_start:y_end, x_start:x_end, :] = tile_result
    
    return upscaled_msi

def process_tile(msi_tile, optical_tile, target_shape):
    """
    Process individual tile using guided upscaling.
    
    Parameters:
    -----------
    msi_tile : ndarray
        Source MSI tile data
    optical_tile : ndarray
        Corresponding optical image tile
    target_shape : tuple
        Desired tile output shape
        
    Returns:
    --------
    ndarray
        Upscaled MSI tile
    """
    # Initialize output tile
    upscaled_tile = np.zeros((*target_shape, msi_tile.shape[2]))
    
    # Process each channel
    for channel in range(msi_tile.shape[2]):
        # Basic upscaling using bicubic interpolation
        upscaled_channel = resize(msi_tile[:, :, channel],
                                target_shape,
                                order=3,
                                mode='edge')
        
        # Apply edge-preserving filter guided by optical image
        filtered_channel = edge_preserving_filter(upscaled_channel, 
                                                optical_tile)
        
        upscaled_tile[:, :, channel] = filtered_channel
    
    return upscaled_tile

def edge_preserving_filter(msi_channel, guide_image, sigma_spatial=2, sigma_range=0.1):
    """
    Apply edge-preserving filtering using bilateral filter.
    
    Parameters:
    -----------
    msi_channel : ndarray
        Single channel of upscaled MSI data
    guide_image : ndarray
        High-resolution optical image tile
    sigma_spatial : float
        Spatial sigma for bilateral filter
    sigma_range : float
        Range sigma for bilateral filter
        
    Returns:
    --------
    ndarray
        Filtered MSI channel
    """
    # Normalize guide image
    guide_norm = (guide_image - guide_image.min()) / (guide_image.max() - guide_image.min())
    
    # Apply bilateral filtering
    filtered = ndimage.gaussian_filter(msi_channel, sigma_spatial)
    
    # Edge preservation using guide image gradients
    guide_grad = np.gradient(guide_norm)
    edge_weights = np.exp(-np.sum([g**2 for g in guide_grad], axis=0) / (2 * sigma_range**2))
    
    return filtered * (1 - edge_weights) + msi_channel * edge_weights