In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import os
from tqdm import tqdm
import h5py
import random
from einops import rearrange  # Ensure this module is installed
import glob

In [56]:
# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Hyperparameters & Configuration
config = {
    'batch_size': 8,
    'lr': 1e-4,
    'num_epochs': 50,
    'num_workers': 2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'channels': [16, 32, 64],
    'kernels': [1, 3],
    'expansion_factor': 2,
    'reduction_ratio': 16,
    'save_dir': './checkpoints',
    'data_path': r"D:\LoDoPaB-CT\observation_train\observation_train_001.hdf5",  # Ensure this line ends with a comma or nothing
    'data_percentage': 0.05  # No comma at the end if it's the last element
}


In [5]:
# Create save directory
os.makedirs(config['save_dir'], exist_ok=True)

# Ensure dataset file exists
if not os.path.exists(config['data_path']):
    raise FileNotFoundError(f"Dataset file not found: {config['data_path']}")

## Model Components (Efficient & Compact)
class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super().__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        channels_per_group = num_channels // self.groups
        x = x.view(batch_size, self.groups, channels_per_group, height, width)
        x = torch.transpose(x, 1, 2).contiguous()
        x = x.view(batch_size, -1, height, width)
        return x

class MultiKernelDepthwiseConv(nn.Module):
    def __init__(self, in_channels, kernels=[1, 3]):
        super().__init__()
        self.kernels = kernels
        self.groups = len(kernels)
        
        self.dw_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_channels // self.groups, in_channels // self.groups, kernel_size=k, padding=k//2, 
                          groups=in_channels // self.groups, bias=False),
                nn.BatchNorm2d(in_channels // self.groups),
                nn.ReLU6(inplace=True)
            ) for k in kernels
        ])
        
        self.channel_shuffle = ChannelShuffle(self.groups)
        
    def forward(self, x):
        splits = torch.split(x, x.size(1) // self.groups, dim=1)
        out = [conv(splits[i]) for i, conv in enumerate(self.dw_convs)]
        out = torch.cat(out, dim=1)
        return self.channel_shuffle(out)

class ConvolutionalMultiFocalAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        reduced_channels = max(1, in_channels // reduction_ratio)
        
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, reduced_channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(reduced_channels, in_channels, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=5, padding=2),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        ca = self.channel_attention(x)
        x_ca = x * ca
        max_pool = torch.max(x_ca, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x_ca, dim=1, keepdim=True)
        spatial_pool = torch.cat([max_pool, avg_pool], dim=1)
        sa = self.spatial_attention(spatial_pool)
        return x_ca * sa

## Simplified UltraLightUNet
class UltraLightUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, channels=[16, 32, 64]):
        super().__init__()
        self.encoder1 = nn.Conv2d(in_channels, channels[0], kernel_size=3, padding=1)
        self.encoder2 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
        self.encoder3 = nn.Sequential(nn.MaxPool2d(2), nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1))

        self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.final_conv = nn.Conv2d(channels[0], out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        d2 = self.up2(e3)
        d1 = self.up1(d2)
        return self.final_conv(d1)

In [None]:
# # Create Data Loaders
# train_dataset = LoDoPaBDataset(config['data_path'], mode='train')
# val_dataset = LoDoPaBDataset(config['data_path'], mode='val')

# train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'])
# val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])


In [43]:
class LoDoPaBTrainDataset(Dataset):
    def __init__(self, observations_dir, transform=None):
        self.observation_files = sorted(glob.glob(os.path.join(observations_dir, "*.hdf5")))
        self.transform = transform
        self.sample_indices = []  # (file_idx, sample_idx)
        
        for file_idx, obs_file in enumerate(self.observation_files):
            with h5py.File(obs_file, 'r') as f:
                num_samples = f['data'].shape[0]
                self.sample_indices.extend([(file_idx, i) for i in range(num_samples)])
    
    def __len__(self):
        return len(self.sample_indices)
    
    def __getitem__(self, idx):
        file_idx, sample_idx = self.sample_indices[idx]
        with h5py.File(self.observation_files[file_idx], 'r') as f:
            observation = torch.from_numpy(f['data'][sample_idx]).float().unsqueeze(0)
            if self.transform:
                observation = self.transform(observation)
            return observation  # (1, H, W)

class LoDoPaBValDataset(Dataset):
    def __init__(self, observations_dir, ground_truth_dir, transform=None):
        self.observation_files = sorted(glob.glob(os.path.join(observations_dir, "*.hdf5")))
        self.ground_truth_files = sorted(glob.glob(os.path.join(ground_truth_dir, "*.hdf5")))
        self.transform = transform
        self.sample_indices = []
        
        # Verify matching files
        assert len(self.observation_files) == len(self.ground_truth_files)
        
        for file_idx, (obs_file, gt_file) in enumerate(zip(self.observation_files, self.ground_truth_files)):
            with h5py.File(obs_file, 'r') as f_obs, h5py.File(gt_file, 'r') as f_gt:
                assert f_obs['data'].shape == f_gt['data'].shape
                num_samples = f_obs['data'].shape[0]
                self.sample_indices.extend([(file_idx, i) for i in range(num_samples)])
    
    def __getitem__(self, idx):
        file_idx, sample_idx = self.sample_indices[idx]
        with h5py.File(self.observation_files[file_idx], 'r') as f_obs, \
             h5py.File(self.ground_truth_files[file_idx], 'r') as f_gt:
            
            obs = torch.from_numpy(f_obs['data'][sample_idx]).float().unsqueeze(0)
            gt = torch.from_numpy(f_gt['data'][sample_idx]).float().unsqueeze(0)
            
            if self.transform:
                # Apply same transform to both
                stacked = torch.cat([obs, gt], dim=0)
                stacked = self.transform(stacked)
                obs, gt = stacked[0], stacked[1]
            
            return obs, gt  # (1, H, W), (1, H, W)

In [46]:
train_dataset = LoDoPaBTrainDataset(
    observations_dir="/LoDoPaB-CT/observation_train/"
)

# Validation (supervised)
val_dataset = LoDoPaBValDataset(
    observations_dir="/LoDoPaB-CT/observation_test/",
    ground_truth_dir="/LoDoPaB-CT/ground_truth_validation/"
)

AssertionError: 

In [49]:
import h5py
import os
from tqdm import tqdm  # for progress bar (install with: pip install tqdm)

def check_hdf5_corruption(filepath):
    try:
        with h5py.File(filepath, 'r') as f:
            # Attempt to access a key (e.g., 'data') to verify readability
            _ = f['data'][:]  # Try reading a small chunk if files are large
        return True  # File is OK
    except (OSError, KeyError) as e:
        print(f"\nCorrupted file: {filepath} | Error: {str(e)}")
        return False

def scan_directory_for_corrupt_files(directory):
    corrupt_files = []
    hdf5_files = [f for f in os.listdir(directory) if f.endswith('.hdf5')]
    
    print(f"Checking {len(hdf5_files)} HDF5 files in {directory}...")
    for filename in tqdm(hdf5_files):
        filepath = os.path.join(directory, filename)
        if not check_hdf5_corruption(filepath):
            corrupt_files.append(filename)
    
    if not corrupt_files:
        print("\n All files are valid.")
    else:
        print(f"\n Found {len(corrupt_files)} corrupt files:")
        for f in corrupt_files:
            print(f"  - {f}")

# Usage
scan_directory_for_corrupt_files("/LoDoPaB-CT/observation_test/")

Checking 28 HDF5 files in /LoDoPaB-CT/observation_test/...


100%|██████████████████████████████████████████████████████████████████████████████████| 28/28 [00:10<00:00,  2.55it/s]


✅ All files are valid.





In [52]:
import os
import requests
from tqdm import tqdm  # For progress bar (install with `pip install tqdm`)

# Zenodo record URL (LoDoPaB-CT)
ZENODO_RECORD = "https://zenodo.org/record/3384092"
DOWNLOAD_DIR = "./LoDoPaB-CT"  # Where to save files

# List of files to download (modify based on your needs)
FILES_TO_DOWNLOAD = [
    "ground_truth_test.zip",
    # Add more files as needed
]

def download_from_zenodo(file_name, save_dir):
    """Download a file from Zenodo."""
    download_url = f"{ZENODO_RECORD}/files/{file_name}?download=1"
    save_path = os.path.join(save_dir, file_name)
    
    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Stream download to handle large files
    response = requests.get(download_url, stream=True)
    response.raise_for_status()  # Check for HTTP errors
    
    # Progress bar
    total_size = int(response.headers.get('content-length', 0))
    with open(save_path, 'wb') as f, tqdm(
        desc=file_name,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
            bar.update(len(chunk))
    
    print(f"Downloaded: {save_path}")

# Download all files
for file in FILES_TO_DOWNLOAD:
    download_from_zenodo(file, DOWNLOAD_DIR)

print("All downloads complete!")

ground_truth_test.zip: 100%|██████████████████████████████████████████████████████| 1.47G/1.47G [17:17<00:00, 1.52MB/s]

Downloaded: ./LoDoPaB-CT\ground_truth_test.zip
All downloads complete!





In [33]:
from torch.utils.data import DataLoader

# Create dataset
dataset = LoDoPaBCTDataset(file_path="/LoDoPaB-CT/observation_train/observation_train_000.hdf5")

# Create data loader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
print(f"Training dataset size: {len(dataset)}")



Training dataset size: 128


In [None]:
## Training Setup
model = UltraLightUNet().to(config['device'])
optimizer = AdamW(model.parameters(), lr=config['lr'])
criterion = nn.L1Loss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

## Training Loop
best_val_loss = float('inf')

for epoch in range(config['num_epochs']):
    print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
    model.train()
    train_loss = sum(criterion(model(low.to(config['device'])), full.to(config['device'])).item() 
                     for low, full in train_loader) / len(train_loader)

    model.eval()
    val_loss = sum(criterion(model(low.to(config['device'])), full.to(config['device'])).item() 
                   for low, full in val_loader) / len(val_loader)

    scheduler.step(val_loss)
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), os.path.join(config['save_dir'], 'best_model.pth'))
        print("Saved new best model!")

print("Training complete!")