In [26]:
from tqdm import tqdm
import h5py
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.utils as vutils
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16
from piqa import SSIM, LPIPS
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

# Set device

In [27]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths to your dataset folders (adjust to your actual directory)
BASE_PATH = ""  # Replace with your actual path

# TRAIN_OBS_DIR = os.path.join(BASE_PATH, "observation_train")
# TRAIN_GT_DIR = os.path.join(BASE_PATH, "ground_truth_train")

TRAIN_OBS_DIR = os.path.join(BASE_PATH, "observation_test")
TRAIN_GT_DIR = os.path.join(BASE_PATH, "ground_truth_test")

# Custom Dataset class

In [28]:
# Custom Dataset class for LoDoPaB-CT
class LoDoPaBDataset(Dataset):
    def __init__(self, obs_dir, gt_dir, num_files=None):
        self.obs_files = sorted(os.listdir(obs_dir))
        if num_files:
            self.obs_files = self.obs_files[:num_files]
        self.obs_dir = obs_dir
        self.gt_dir = gt_dir

        print("Loading dataset...")
        self.obs_data = []
        self.gt_data = []
        for file_name in self.obs_files:
            obs_file = os.path.join(obs_dir, file_name)
            gt_file = os.path.join(gt_dir, file_name.replace("observation", "ground_truth"))
            with h5py.File(obs_file, 'r') as f_obs, h5py.File(gt_file, 'r') as f_gt:
                obs = f_obs['data'][:].astype(np.float16)
                gt = f_gt['data'][:].astype(np.float16)
                self.obs_data.append(obs)
                self.gt_data.append(gt)
        
        self.obs_data = np.concatenate(self.obs_data, axis=0)
        self.gt_data = np.concatenate(self.gt_data, axis=0)
        
        self.obs_data = (self.obs_data - np.min(self.obs_data)) / (np.max(self.obs_data) - np.min(self.obs_data)) * 2 - 1
        self.gt_data = (self.gt_data - np.min(self.gt_data)) / (np.max(self.gt_data) - np.min(self.gt_data)) * 2 - 1
        print(f"Dataset loaded: {self.obs_data.shape[0]} samples")

    def __len__(self):
        return len(self.obs_data)

    def __getitem__(self, idx):
        obs = torch.FloatTensor(self.obs_data[idx]).unsqueeze(0)  # [1, 1000, 513]
        gt = torch.FloatTensor(self.gt_data[idx]).unsqueeze(0)    # [1, 362, 362]
        return obs, gt

# Cache clearing

In [29]:
torch.cuda.empty_cache()

In [30]:
dataset = LoDoPaBDataset(TRAIN_OBS_DIR, TRAIN_GT_DIR)
dataloader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=0)

print(f"Dataset size: {len(dataset)} samples")
print(f"Number of batches: {len(dataloader)}")

# Test dataloader
# print("Testing dataloader...")
# for i, (obs, gt) in enumerate(dataloader):
#     print(f"Batch {i+1}/{len(dataloader)} - Obs shape: {obs.shape}, GT shape: {gt.shape}")
#     break
# print("Dataloader test complete")

Loading dataset...
Dataset loaded: 3553 samples
Dataset size: 3553 samples
Number of batches: 297


# Custom modules

In [31]:
class MultiKernelDepthwiseConv(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
    
    def forward(self, x):
        return self.conv(x)

In [32]:
class ConvolutionalMultiFocalAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
    
    def forward(self, x):
        return torch.sigmoid(self.conv(x)) * x

In [33]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, img_size=(256, 256)):
        super().__init__()
        self.img_size = img_size
        
        # First two convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Global average pooling (to reduce the feature map size)
        self.global_pooling = nn.AdaptiveAvgPool2d(1)  # Outputs a 1x1 feature map per channel
        
        # Fully connected layer after the global pooling
        self.fc = nn.Linear(128, 1)  # The output size is the number of channels after pooling
        
    def forward(self, x):
        x = self.conv_layers(x)
        
        # Apply global average pooling
        x = self.global_pooling(x)
        
        # Flatten the output of the global pooling (this will be 2D now)
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, channels)
        
        x = self.fc(x)  # Apply the fully connected layer
        return torch.sigmoid(x)

In [34]:
class UltraLightUNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, channels=[16, 32, 64]):
        super().__init__()
        
        self.initial_conv = nn.Sequential(
            nn.Conv2d(in_channels, channels[0], kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True)
        )
        
        self.encoder1 = self._make_encoder_block(channels[0], channels[0])
        self.encoder2 = self._make_encoder_block(channels[0], channels[1], downsample=True)
        self.encoder3 = self._make_encoder_block(channels[1], channels[2], downsample=True)
        
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(channels[2], channels[1], kernel_size=3, padding=1),
            nn.BatchNorm2d(channels[1]),
            nn.ReLU(inplace=True)
        )
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(channels[1], channels[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True)
        )
        
        self.final_conv = nn.Conv2d(channels[0], out_channels, kernel_size=3, padding=1)

    def _make_encoder_block(self, in_ch, out_ch, downsample=False):
        layers = []
        if downsample:
            layers.append(nn.MaxPool2d(2))
        layers.extend([
            nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, groups=in_ch),  # Depthwise Conv
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        ])
        return nn.Sequential(*layers)

    def forward(self, x):
        original_size = (x.size(2), x.size(3))  # Save input size
        x = self.initial_conv(x)  # [8, 16, 256, 256]
        e1 = self.encoder1(x)     # [8, 16, 256, 256]
        e2 = self.encoder2(e1)    # [8, 32, 128, 128]
        e3 = self.encoder3(e2)    # [8, 64, 64, 64]
        
        d2 = self.up2(e3)         # [8, 32, 128, 128]
        e2_resized = F.interpolate(e2, size=(d2.size(2), d2.size(3)), mode='bilinear', align_corners=True)
        d2 = d2 + e2_resized      # [8, 32, 128, 128]
        
        d1 = self.up1(d2)         # [8, 16, 256, 256]
        e1_resized = F.interpolate(e1, size=(d1.size(2), d1.size(3)), mode='bilinear', align_corners=True)
        d1 = d1 + e1_resized      # [8, 16, 256, 256]
        
        output = self.final_conv(d1)  # [8, 1, 256, 256]
        output = F.interpolate(output, size=original_size, mode='bilinear', align_corners=True)  # [8, 1, 513, 513]
        return output


# GAN Losses

In [35]:
# GAN Losses
class GANLosses:
    def __init__(self, device):
        self.device = device
        self.adv_loss = nn.BCELoss()
        self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.pixel_loss = nn.L1Loss()
        self.ssim = SSIM().to(device)
        self.lpips = LPIPS().to(device)
        
    def adversarial_loss(self, pred, is_real):
        target = torch.ones_like(pred) if is_real else torch.zeros_like(pred)
        return self.adv_loss(pred, target)
    
    def perceptual_loss(self, generated, target):
        gen_rgb = generated.repeat(1, 3, 1, 1)
        target_rgb = target.repeat(1, 3, 1, 1)
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
        gen_rgb = (gen_rgb - mean) / std
        target_rgb = (target_rgb - mean) / std
        gen_features = self.vgg(gen_rgb)
        target_features = self.vgg(target_rgb)
        return F.mse_loss(gen_features, target_features)
    
    def compute_metrics(self, generated, target):
        with torch.no_grad():
            gen = (generated + 1) / 2
            tgt = (target + 1) / 2
            psnr = 10 * torch.log10(1 / F.mse_loss(gen, tgt))
            ssim_val = self.ssim(gen, tgt)
            lpips_val = self.lpips(gen, tgt)
            return {'PSNR': psnr.item(), 'SSIM': ssim_val.item(), 'LPIPS': lpips_val.item()}


In [36]:
def train_gan(generator, discriminator, loss_fn, optimizer_G, optimizer_D, dataloader, epochs, device):
    for epoch in range(epochs):
        start_time = time.time()  # Start timer for epoch

        for batch in dataloader:
            real_images = batch[0].to(device)  # Ensure tensor, not list

            # Generate fake images
            fake_images = generator(torch.randn_like(real_images))

            # --- Discriminator Update ---
            optimizer_D.zero_grad()
            real_loss = loss_fn.adversarial_loss(discriminator(real_images), True)
            fake_loss = loss_fn.adversarial_loss(discriminator(fake_images.detach()), False)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()

            # --- Generator Update ---
            optimizer_G.zero_grad()
            g_loss = loss_fn.adversarial_loss(discriminator(fake_images), True) + loss_fn.pixel_loss(fake_images, real_images)
            g_loss.backward()
            optimizer_G.step()

        epoch_time = time.time() - start_time  # Calculate epoch duration

        # Print epoch stats with time
        print(
            f"Epoch [{epoch+1}/{epochs}], "
            f"D Loss: {d_loss.item():.4f}, "
            f"G Loss: {g_loss.item():.4f}, "
            f"Time: {epoch_time:.2f}s"
        )

# Training Setup

In [37]:
# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UltraLightUNetGenerator().to(device)
discriminator = Discriminator().to(device)
loss_fn = GANLosses(device)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))

In [None]:
train_gan(generator, discriminator, loss_fn, optimizer_G, optimizer_D, dataloader, 5, device)

Epoch [1/5], D Loss: 0.5250, G Loss: 1.3356, Time: 87.81s
Epoch [2/5], D Loss: 0.5125, G Loss: 1.2838, Time: 88.41s


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UltraLightUNetGenerator().to(device)
sample_input = torch.randn(8, 1, 513, 513).to(device)
output = generator(sample_input)
print(f"Output shape: {output.shape}")  # Should be [8, 1, 513, 513]

Output shape: torch.Size([8, 1, 513, 513])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UltraLightUNetGenerator().to(device)

In [None]:
sample_input = torch.randn(8, 1, 512, 512).to(device)
output = generator(sample_input)
print(f"Output shape: {output.shape}")

Output shape: torch.Size([8, 1, 512, 512])


In [None]:
generator = UltraLightUNetGenerator().to(device)
summary(generator, input_size=(1, 1000, 508))  # Channels, Height, Width

discriminator = Discriminator().to(device)
summary(discriminator, input_size=(1, 362, 362))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 500, 254]             800
       BatchNorm2d-2         [-1, 16, 500, 254]              32
              ReLU-3         [-1, 16, 500, 254]               0
            Conv2d-4         [-1, 16, 500, 254]             160
            Conv2d-5         [-1, 16, 500, 254]           2,320
       BatchNorm2d-6         [-1, 16, 500, 254]              32
              ReLU-7         [-1, 16, 500, 254]               0
         MaxPool2d-8         [-1, 16, 250, 127]               0
            Conv2d-9         [-1, 16, 250, 127]             160
           Conv2d-10         [-1, 32, 250, 127]           4,640
      BatchNorm2d-11         [-1, 32, 250, 127]              64
             ReLU-12         [-1, 32, 250, 127]               0
        MaxPool2d-13          [-1, 32, 125, 63]               0
           Conv2d-14          [-1, 32, 