In [3]:
import re
import os
import cv2
import numpy as np
from typing import List, Dict, Tuple

def natural_sort_key(s: str) -> List:
    """Natural sort key to correctly sort filenames with numerical values."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def collect_filenames_from_directory(base_dir: str, max_level: int = 7) -> Dict[int, List[str]]:
    """Collect all image filenames from each level."""
    level_filenames = {}
    for level in range(max_level + 1):
        level_dir = os.path.join(base_dir, f'FusedImages_Level_{level}')
        if os.path.exists(level_dir):
            filenames = [f for f in os.listdir(level_dir) if f.endswith(('.tif', '.tiff'))]
            filenames = sorted(filenames, key=natural_sort_key)
            level_filenames[level] = filenames
    return level_filenames

def load_image(image_path: str) -> np.ndarray:
    """Load an image using OpenCV in grayscale mode."""
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        raise ValueError(f"Failed to load image: {image_path}")
    return image / 255.0  # Normalize to [0, 1]

def create_pairs_with_fallback(level_filenames: Dict[int, List[str]], base_dir: str, max_level: int = 7) -> Dict[int, List[Tuple[Tuple[str, str], Tuple[np.ndarray, np.ndarray], np.ndarray]]]:
    """Create pairs of images with their corresponding fused outputs."""
    pairs_dict = {}

    for current_level in range(max_level):
        if current_level not in level_filenames:
            break
            
        current_files = level_filenames[current_level]
        pairs = []
        next_level_files = []

        # Use sliding window to create pairs
        for i in range(len(current_files) - 1):
            output_filename = f'Fused_Image_Level_{current_level + 1}_{i}.tif'
            
            img1_filename = current_files[i]
            img2_filename = current_files[i + 1]
            
            # Construct paths for input images
            img1_path = os.path.join(base_dir, f'FusedImages_Level_{current_level}', img1_filename)
            img2_path = os.path.join(base_dir, f'FusedImages_Level_{current_level}', img2_filename)
            
            # Load the input images
            try:
                img1 = load_image(img1_path)
                img2 = load_image(img2_path)
            except ValueError as e:
                print(f"Error loading images: {e}")
                continue

            # Construct path for the fused output
            next_level_dir = os.path.join(base_dir, f'FusedImages_Level_{current_level + 1}')
            fused_img_path = os.path.join(next_level_dir, output_filename)

            # Check if the fused output image exists
            if not os.path.exists(fused_img_path):
                continue

            # Load the fused image
            try:
                fused_img = load_image(fused_img_path)
            except ValueError as e:
                print(f"Error loading fused image: {e}")
                continue

            # Add the pair to the list
            pair = ((img1_filename, img2_filename, output_filename), (img1, img2, fused_img))
            pairs.append(pair)
            next_level_files.append(output_filename)

        # Store pairs for the current level
        if pairs:
            pairs_dict[current_level] = pairs
            level_filenames[current_level + 1] = next_level_files

            # Print pairs for verification
            print(f"\nLevel {current_level} pairs:")
            for idx, ((fname1, fname2, fused_fname), (img1, img2, fused_img)) in enumerate(pairs):
                print(f"  - Pair: {fname1} with {fname2} -> {fused_fname}")
                print(f"    Shapes: img1 {img1.shape}, img2 {img2.shape}, fused {fused_img.shape}")
    
    return pairs_dict

base_dir = "../FusedDataset"
max_level = 7

collected_filenames = collect_filenames_from_directory(base_dir, max_level=max_level)

fusion_pairs = create_pairs_with_fallback(collected_filenames, base_dir, max_level=max_level)


Level 0 has 69 images:
  - Fused_Image_Level_0_0.tif
  - Fused_Image_Level_0_1.tif
  - Fused_Image_Level_0_10.tif
  - Fused_Image_Level_0_11.tif
  - Fused_Image_Level_0_12.tif
  - Fused_Image_Level_0_13.tif
  - Fused_Image_Level_0_14.tif
  - Fused_Image_Level_0_15.tif
  - Fused_Image_Level_0_16.tif
  - Fused_Image_Level_0_17.tif
  - Fused_Image_Level_0_18.tif
  - Fused_Image_Level_0_19.tif
  - Fused_Image_Level_0_2.tif
  - Fused_Image_Level_0_20.tif
  - Fused_Image_Level_0_21.tif
  - Fused_Image_Level_0_22.tif
  - Fused_Image_Level_0_23.tif
  - Fused_Image_Level_0_24.tif
  - Fused_Image_Level_0_25.tif
  - Fused_Image_Level_0_26.tif
  - Fused_Image_Level_0_27.tif
  - Fused_Image_Level_0_28.tif
  - Fused_Image_Level_0_29.tif
  - Fused_Image_Level_0_3.tif
  - Fused_Image_Level_0_30.tif
  - Fused_Image_Level_0_31.tif
  - Fused_Image_Level_0_32.tif
  - Fused_Image_Level_0_33.tif
  - Fused_Image_Level_0_34.tif
  - Fused_Image_Level_0_35.tif
  - Fused_Image_Level_0_36.tif
  - Fused_Image_Leve

In [16]:
fusion_pairs[0]

[('Fused_Image_Level_0_0.tif',
  'Fused_Image_Level_0_1.tif',
  'Fused_Image_Level_1_0.tif'),
 ('Fused_Image_Level_0_1.tif',
  'Fused_Image_Level_0_2.tif',
  'Fused_Image_Level_1_1.tif'),
 ('Fused_Image_Level_0_2.tif',
  'Fused_Image_Level_0_3.tif',
  'Fused_Image_Level_1_2.tif'),
 ('Fused_Image_Level_0_3.tif',
  'Fused_Image_Level_0_4.tif',
  'Fused_Image_Level_1_3.tif'),
 ('Fused_Image_Level_0_4.tif',
  'Fused_Image_Level_0_5.tif',
  'Fused_Image_Level_1_4.tif'),
 ('Fused_Image_Level_0_5.tif',
  'Fused_Image_Level_0_6.tif',
  'Fused_Image_Level_1_5.tif'),
 ('Fused_Image_Level_0_6.tif',
  'Fused_Image_Level_0_7.tif',
  'Fused_Image_Level_1_6.tif'),
 ('Fused_Image_Level_0_7.tif',
  'Fused_Image_Level_0_8.tif',
  'Fused_Image_Level_1_7.tif'),
 ('Fused_Image_Level_0_8.tif',
  'Fused_Image_Level_0_9.tif',
  'Fused_Image_Level_1_8.tif'),
 ('Fused_Image_Level_0_9.tif',
  'Fused_Image_Level_0_10.tif',
  'Fused_Image_Level_1_9.tif'),
 ('Fused_Image_Level_0_10.tif',
  'Fused_Image_Level_0_11.t

In [78]:
def get_next_pair(fusion_pairs, level, init_img=0):
    print(fusion_pairs[level][init_img][2])

    level += 1
    if level < 7:
        get_next_pair(fusion_pairs, level)

init_image = fusion_pairs[0][0]
print(init_image[0])
get_next_pair(fusion_pairs, 0)

Fused_Image_Level_0_0.tif
Fused_Image_Level_1_0.tif
Fused_Image_Level_2_0.tif
Fused_Image_Level_3_0.tif
Fused_Image_Level_4_0.tif
Fused_Image_Level_5_0.tif
Fused_Image_Level_6_0.tif
Fused_Image_Level_7_0.tif


In [90]:
def get_next_pair(fusion_pairs, level, init_img=1):
    if init_img % 2 == 1:
        print(fusion_pairs[level][init_img-1][2])

    level += 1
    if level < 7:
        get_next_pair(fusion_pairs, level)

init_image = fusion_pairs[0][0]
print(init_image[1])
get_next_pair(fusion_pairs, 0)

Fused_Image_Level_0_1.tif
Fused_Image_Level_1_0.tif
Fused_Image_Level_2_0.tif
Fused_Image_Level_3_0.tif
Fused_Image_Level_4_0.tif
Fused_Image_Level_5_0.tif
Fused_Image_Level_6_0.tif
Fused_Image_Level_7_0.tif


In [94]:
class FusionTracker:
    def __init__(self, num_base_images, max_level=7):
        self.num_base_images = num_base_images
        self.max_level = max_level
        
    def get_image_name(self, level, index):
        """Generate the image name based on level and index"""
        return f"Fused_Image_Level_{level}_{index}.tif"
    
    def trace_image_path(self, base_image_index):
        """
        Trace a single base image's path through all fusion levels
        
        Args:
            base_image_index: The index of the base image in level 0
            
        Returns:
            list: List of tuples (level, index) showing which fused images this base image contributed to
        """
        path = [(0, base_image_index)]
        current_index = base_image_index
        
        for level in range(1, self.max_level + 1):
            # In each level up, the image contributes to index // 2
            current_index = current_index // 2
            path.append((level, current_index))
            
        return path
    
    def print_image_paths(self):
        """Print the fusion path for each base image"""
        for base_idx in range(self.num_base_images):
            path = self.trace_image_path(base_idx)
            print(f"\nBase image {base_idx} contributes to:")
            for level, idx in path:
                print(f"  Level {level}: {self.get_image_name(level, idx)}")

    def get_level_pairs(self, level):
        """Get all fusion pairs for a given level"""
        num_pairs = self.num_base_images // (2 ** (level + 1))
        pairs = []
        for i in range(num_pairs):
            source_idx1 = i * 2
            source_idx2 = i * 2 + 1
            target_idx = i
            pairs.append((
                self.get_image_name(level, source_idx1),
                self.get_image_name(level, source_idx2),
                self.get_image_name(level + 1, target_idx)
            ))
        return pairs

# Example usage
tracker = FusionTracker(16)  # Starting with 16 base images

# Print fusion pairs for each level
print("Fusion pairs by level:")
for level in range(7):  # Up to level 7
    print(f"\nLevel {level} fusion pairs:")
    pairs = tracker.get_level_pairs(level)
    for source1, source2, target in pairs:
        print(f"  {source1} + {source2} -> {target}")

print("\nTracking individual base image paths:")
tracker.print_image_paths()

# To get specific base image path
base_image_0_path = tracker.trace_image_path(0)
print("\nDetailed path for base image 0:")
for level, idx in base_image_0_path:
    print(f"Level {level}, Index {idx}: {tracker.get_image_name(level, idx)}")

Fusion pairs by level:

Level 0 fusion pairs:
  Fused_Image_Level_0_0.tif + Fused_Image_Level_0_1.tif -> Fused_Image_Level_1_0.tif
  Fused_Image_Level_0_2.tif + Fused_Image_Level_0_3.tif -> Fused_Image_Level_1_1.tif
  Fused_Image_Level_0_4.tif + Fused_Image_Level_0_5.tif -> Fused_Image_Level_1_2.tif
  Fused_Image_Level_0_6.tif + Fused_Image_Level_0_7.tif -> Fused_Image_Level_1_3.tif
  Fused_Image_Level_0_8.tif + Fused_Image_Level_0_9.tif -> Fused_Image_Level_1_4.tif
  Fused_Image_Level_0_10.tif + Fused_Image_Level_0_11.tif -> Fused_Image_Level_1_5.tif
  Fused_Image_Level_0_12.tif + Fused_Image_Level_0_13.tif -> Fused_Image_Level_1_6.tif
  Fused_Image_Level_0_14.tif + Fused_Image_Level_0_15.tif -> Fused_Image_Level_1_7.tif

Level 1 fusion pairs:
  Fused_Image_Level_1_0.tif + Fused_Image_Level_1_1.tif -> Fused_Image_Level_2_0.tif
  Fused_Image_Level_1_2.tif + Fused_Image_Level_1_3.tif -> Fused_Image_Level_2_1.tif
  Fused_Image_Level_1_4.tif + Fused_Image_Level_1_5.tif -> Fused_Image_Leve

In [128]:
def trace_base_image_path(base_image_index, max_level=7):
    """
    Trace a single base image's path through the fusion process
    
    Args:
        base_image_index: Index of the starting image in level 0
        max_level: Maximum fusion level to trace to (default 7)
    """
    current_index = base_image_index
    
    # Start with the base image
    print(f"../FusedDataset/FusedImages_Level_0/Fused_Image_Level_0_{base_image_index}.tif")
    
    # Trace through each level
    for level in range(1, max_level + 1):
        current_index = current_index // 2
        print(f"../FusedDataset/FusedImages_Level_{level}/Fused_Image_Level_{level}_{current_index}.tif")

# Example usage for base image 0
print("Path for base image 0:")
trace_base_image_path(0)

print("\nPath for base image 1:")
trace_base_image_path(1)

print("\nPath for base image 2:")
trace_base_image_path(2)

print("\nPath for base image 68:")
trace_base_image_path(68)

Path for base image 0:
../FusedDataset/FusedImages_Level_0/Fused_Image_Level_0_0.tif
../FusedDataset/FusedImages_Level_1/Fused_Image_Level_1_0.tif
../FusedDataset/FusedImages_Level_2/Fused_Image_Level_2_0.tif
../FusedDataset/FusedImages_Level_3/Fused_Image_Level_3_0.tif
../FusedDataset/FusedImages_Level_4/Fused_Image_Level_4_0.tif
../FusedDataset/FusedImages_Level_5/Fused_Image_Level_5_0.tif
../FusedDataset/FusedImages_Level_6/Fused_Image_Level_6_0.tif
../FusedDataset/FusedImages_Level_7/Fused_Image_Level_7_0.tif

Path for base image 1:
../FusedDataset/FusedImages_Level_0/Fused_Image_Level_0_1.tif
../FusedDataset/FusedImages_Level_1/Fused_Image_Level_1_0.tif
../FusedDataset/FusedImages_Level_2/Fused_Image_Level_2_0.tif
../FusedDataset/FusedImages_Level_3/Fused_Image_Level_3_0.tif
../FusedDataset/FusedImages_Level_4/Fused_Image_Level_4_0.tif
../FusedDataset/FusedImages_Level_5/Fused_Image_Level_5_0.tif
../FusedDataset/FusedImages_Level_6/Fused_Image_Level_6_0.tif
../FusedDataset/FusedIm

In [131]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os

class FusionImageDataset(Dataset):
    def __init__(self, base_indices, transform=None):
        self.transform = transform
        self.image_groups = []
        
        for base_idx in base_indices:
            paths = []
            names = []
            current_idx = base_idx
            
            # Level 0
            name = f"Fused_Image_Level_0_{base_idx}.tif"
            path = f"../FusedDataset/FusedImages_Level_0/{name}"
            paths.append(path)
            names.append(name)
            
            # Levels 1-7
            for level in range(1, 8):
                current_idx = current_idx // 2
                name = f"Fused_Image_Level_{level}_{current_idx}.tif"
                path = f"../FusedDataset/FusedImages_Level_{level}/{name}"
                paths.append(path)
                names.append(name)
                
            self.image_groups.append((paths, names))
    
    def __len__(self):
        return len(self.image_groups)
    
    def __getitem__(self, idx):
        paths, names = self.image_groups[idx]
        images = []
        
        for path in paths:
            img = Image.open(path)
            if self.transform:
                img = self.transform(img)
            images.append(img)
            
        return torch.stack(images), names


import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = FusionImageDataset(
    base_indices=[0, 1, 68],
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Example of loading and printing image names
for images, names in dataloader:
    print("\nImage group names:")
    for i, name in enumerate(names):
        print(f"Level {i}: {name}")


Image group names:
Level 0: ('Fused_Image_Level_0_0.tif',)
Level 1: ('Fused_Image_Level_1_0.tif',)
Level 2: ('Fused_Image_Level_2_0.tif',)
Level 3: ('Fused_Image_Level_3_0.tif',)
Level 4: ('Fused_Image_Level_4_0.tif',)
Level 5: ('Fused_Image_Level_5_0.tif',)
Level 6: ('Fused_Image_Level_6_0.tif',)
Level 7: ('Fused_Image_Level_7_0.tif',)

Image group names:
Level 0: ('Fused_Image_Level_0_1.tif',)
Level 1: ('Fused_Image_Level_1_0.tif',)
Level 2: ('Fused_Image_Level_2_0.tif',)
Level 3: ('Fused_Image_Level_3_0.tif',)
Level 4: ('Fused_Image_Level_4_0.tif',)
Level 5: ('Fused_Image_Level_5_0.tif',)
Level 6: ('Fused_Image_Level_6_0.tif',)
Level 7: ('Fused_Image_Level_7_0.tif',)

Image group names:
Level 0: ('Fused_Image_Level_0_68.tif',)
Level 1: ('Fused_Image_Level_1_34.tif',)
Level 2: ('Fused_Image_Level_2_17.tif',)
Level 3: ('Fused_Image_Level_3_8.tif',)
Level 4: ('Fused_Image_Level_4_4.tif',)
Level 5: ('Fused_Image_Level_5_2.tif',)
Level 6: ('Fused_Image_Level_6_1.tif',)
Level 7: ('Fused_

In [132]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os
import torchvision.transforms as transforms

class FusionImageDataset(Dataset):
    def __init__(self, base_indices, transform=None):
        self.transform = transform if transform else transforms.ToTensor()
        self.image_groups = []
        
        for base_idx in base_indices:
            paths = []
            names = []
            current_idx = base_idx
            
            # Level 0
            name = f"Fused_Image_Level_0_{base_idx}.tif"
            path = f"../FusedDataset/FusedImages_Level_0/{name}"
            paths.append(path)
            names.append(name)
            
            # Levels 1-7
            for level in range(1, 8):
                current_idx = current_idx // 2
                name = f"Fused_Image_Level_{level}_{current_idx}.tif"
                path = f"../FusedDataset/FusedImages_Level_{level}/{name}"
                paths.append(path)
                names.append(name)
                
            self.image_groups.append((paths, names))
    
    def __len__(self):
        return len(self.image_groups)
    
    def __getitem__(self, idx):
        paths, names = self.image_groups[idx]
        images = []
        
        for path in paths:
            img = Image.open(path)
            img = self.transform(img)
            images.append(img)
            
        return torch.stack(images), names

# Example usage
if __name__ == "__main__":    
    dataset = FusionImageDataset(base_indices=[0, 1, 2, 68])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    # Example loading images and printing their properties
    for images, names in dataloader:
        print("\nLoaded image group:")
        for i, (img, name) in enumerate(zip(images[0], names)):  # images[0] to remove batch dimension
            print(f"Level {i}: {name}")
            print(f"Image shape: {img.shape}, Value range: [{img.min():.2f}, {img.max():.2f}]")
            # If you want to do something with each image:
            # process_image(img)  # img is a tensor of shape [channels, height, width]


Loaded image group:
Level 0: ('Fused_Image_Level_0_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.22, 0.95]
Level 1: ('Fused_Image_Level_1_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.27, 0.95]
Level 2: ('Fused_Image_Level_2_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.30, 0.94]
Level 3: ('Fused_Image_Level_3_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.33, 0.84]
Level 4: ('Fused_Image_Level_4_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.34, 0.84]
Level 5: ('Fused_Image_Level_5_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.35, 0.88]
Level 6: ('Fused_Image_Level_6_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.35, 0.84]
Level 7: ('Fused_Image_Level_7_0.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.31, 0.86]

Loaded image group:
Level 0: ('Fused_Image_Level_0_1.tif',)
Image shape: torch.Size([1, 512, 512]), Value range: [0.21, 0.96]
Leve