In [None]:
print('Cell 1: Starting imports and device configuration')
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision.models import VGG19_Weights
import cv2
import numpy as np
from PIL import Image
import os
import glob
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.amp import GradScaler, autocast
from tqdm import tqdm

print('Cell 1: Imports completed successfully')
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'CUDA device name: {torch.cuda.get_device_name(0)}')
    print(f'CUDA version: {torch.version.cuda}')
print('Cell 1: Device configuration completed')

Cell 1: Starting imports and device configuration
Cell 1: Imports completed successfully
Using device: cuda
CUDA device name: NVIDIA GeForce RTX 3060
CUDA version: 12.4
Cell 1: Device configuration completed
Cell 1: Imports completed successfully
Using device: cuda
CUDA device name: NVIDIA GeForce RTX 3060
CUDA version: 12.4
Cell 1: Device configuration completed


In [2]:
print('Cell 2: Starting configuration setup')

# Configuration
config = {
    'batch_size': 8,
    'lr': 1e-4,
    'num_epochs': 2,  # INCREASED: Train for more epochs for better convergence
    'device': device,
    'save_path': 'fusion_model_improved.pth',
    'w_ssim': 1.0,
    'w_grad': 1.0,
    'w_perc': 0.1,
    'w_int': 5.0,     # ADDED: Weight for the new L1 intensity loss
    'image_size': (256, 256)
}

print('Configuration set:')
for key, value in config.items():
    print(f'  {key}: {value}')
print('Cell 2: Configuration setup completed')

Cell 2: Starting configuration setup
Configuration set:
  batch_size: 8
  lr: 0.0001
  num_epochs: 2
  device: cuda
  save_path: fusion_model_improved.pth
  w_ssim: 1.0
  w_grad: 1.0
  w_perc: 0.1
  w_int: 5.0
  image_size: (256, 256)
Cell 2: Configuration setup completed


In [3]:
print('Cell 3: Starting LLVIPDataset class definition with Augmentation')
import torchvision.transforms.functional as TF
import random

# Dataset class for LLVIP dataset
class LLVIPDataset(Dataset):
    """Dataset for LLVIP visible (RGB) and infrared (grayscale) images."""
    def __init__(self, root_dir, train=True): # Removed transform argument
        self.train = train
        mode = 'train' if train else 'test'
        self.vis_path = os.path.join(root_dir, 'visible', mode)
        self.ir_path = os.path.join(root_dir, 'infrared', mode)
        
        self.vis_images = sorted([f for f in os.listdir(self.vis_path) 
                                 if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif'))])
        self.ir_images = sorted([f for f in os.listdir(self.ir_path) 
                                 if f.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif'))])
        
        self.pairs = []
        for vis_img in self.vis_images:
            if vis_img in self.ir_images:
                self.pairs.append(vis_img)
        
        print(f"Found {len(self.pairs)} image pairs in {mode} set")

        # Define transforms internally
        self.transform_vis = transforms.Compose([
            transforms.Resize(config['image_size']),
            transforms.ToTensor(),
        ])
        self.transform_ir = transforms.Compose([
            transforms.Resize(config['image_size']),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        ])

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        img_name = self.pairs[idx]
        
        vis_img = Image.open(os.path.join(self.vis_path, img_name)).convert('RGB')
        ir_img = Image.open(os.path.join(self.ir_path, img_name)).convert('L')
        
        # Apply initial transforms
        vis_tensor = self.transform_vis(vis_img)
        ir_tensor = self.transform_ir(ir_img)

        # --- ADDED: Paired Data Augmentation for training set ---
        if self.train:
            # Random horizontal flip
            if random.random() > 0.5:
                vis_tensor = TF.hflip(vis_tensor)
                ir_tensor = TF.hflip(ir_tensor)

            # Random rotation
            if random.random() > 0.5:
                angle = random.uniform(-10, 10)
                vis_tensor = TF.rotate(vis_tensor, angle)
                ir_tensor = TF.rotate(ir_tensor, angle)
            
        return vis_tensor, ir_tensor, img_name

print('Cell 3: LLVIPDataset class definition completed')

Cell 3: Starting LLVIPDataset class definition with Augmentation
Cell 3: LLVIPDataset class definition completed


In [4]:
print('Cell 4: Starting simplified collate function setup')

# The transforms are now inside the Dataset class. This function just batches the tensors.
def collate_fn(batch):
    vis_tensors, ir_tensors, names = zip(*batch)
    
    vis_batch = torch.stack(vis_tensors)
    ir_batch = torch.stack(ir_tensors)
    
    return vis_batch, ir_batch, names

print('Collate function simplified for batching tensors')
print('Cell 4: Simplified collate function setup completed')

Cell 4: Starting simplified collate function setup
Collate function simplified for batching tensors
Cell 4: Simplified collate function setup completed


In [5]:
print('Cell 5: This cell is currently empty - placeholder for future code')

Cell 5: This cell is currently empty - placeholder for future code


In [6]:
print('Cell 6: Starting loss components definition')
# Loss components (from the original codebase)
class SSIM(nn.Module):
    def __init__(self, window_size=11, C1=0.01**2, C2=0.03**2):
        super().__init__()
        self.window_size = window_size
        self.C1 = C1
        self.C2 = C2
        gauss = cv2.getGaussianKernel(window_size, window_size/6)
        gauss = gauss @ gauss.T
        w = torch.from_numpy(gauss.astype(np.float32))[None, None]
        self.register_buffer('window', w)
        print(f'SSIM initialized with window_size={window_size}')

    def _filt(self, x):
        pad = self.window_size//2
        # CORRECTED: Cast window to match input tensor's dtype
        window = self.window.to(device=x.device, dtype=x.dtype)
        return F.conv2d(x, window, padding=pad, groups=x.size(1))

    def forward(self, x, y):
        # x,y: (B,1,H,W) in [0,1]
        mu_x = self._filt(x)
        mu_y = self._filt(y)
        mu_x2, mu_y2, mu_xy = mu_x*mu_x, mu_y*mu_y, mu_x*mu_y
        sigma_x2 = self._filt(x*x) - mu_x2
        sigma_y2 = self._filt(y*y) - mu_y2
        sigma_xy = self._filt(x*y) - mu_xy
        ssim = ((2*mu_xy + self.C1)*(2*sigma_xy + self.C2)) / ((mu_x2 + mu_y2 + self.C1)*(sigma_x2 + sigma_y2 + self.C2) + 1e-8)
        return ssim.mean()

class GradientLoss(nn.Module):
    def __init__(self):
        super().__init__()
        kx = np.array([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=np.float32)
        ky = np.array([[1,2,1],[0,0,0],[-1,-2,-1]], dtype=np.float32)
        self.register_buffer('kx', torch.from_numpy(kx)[None, None])
        self.register_buffer('ky', torch.from_numpy(ky)[None, None])
        print('GradientLoss initialized')

    def forward(self, fused, vis, ir):
        def grad(img):
            # Handle both single-channel and multi-channel images
            if img.size(1) > 1:
                img_gray = img.mean(dim=1, keepdim=True)
            else:
                img_gray = img
            
            # CORRECTED: Cast kernels to match input tensor's dtype
            kx_ = self.kx.to(device=img_gray.device, dtype=img_gray.dtype)
            ky_ = self.ky.to(device=img_gray.device, dtype=img_gray.dtype)

            gx = F.conv2d(img_gray, kx_, padding=1)
            gy = F.conv2d(img_gray, ky_, padding=1)
            return torch.sqrt(gx*gx + gy*gy + 1e-8)
            
        gF = grad(fused)
        gV = grad(vis)
        gI = grad(ir)
        gT = torch.max(gV, gI)
        return F.l1_loss(gF, gT)

class VGGPerceptual(nn.Module):
    def __init__(self, device):
        super().__init__()
        vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features
        self.slice1 = nn.Sequential(*[vgg[i] for i in range(4)])   # relu1_2
        self.slice2 = nn.Sequential(*[vgg[i] for i in range(4,9)]) # relu2_2
        for p in self.parameters():
            p.requires_grad = False
        self.to(device)
        print(f'VGGPerceptual initialized on device: {device}')

    def forward(self, x):
        # Ensure input is 3-channel for VGG
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        elif x.size(1) > 3:
            x = x[:, :3, :, :]  # Take first 3 channels if more
            
        mean = torch.tensor([0.485, 0.456, 0.406], device=x.device)[None,:,None,None]
        std  = torch.tensor([0.229, 0.224, 0.225], device=x.device)[None,:,None,None]
        x = (x - mean) / std
        f1 = self.slice1(x)
        f2 = self.slice2(f1)
        return f1, f2

def perceptual_loss(vgg: VGGPerceptual, fused, vis, ir):
    fF1, fF2 = vgg(fused)
    vF1, vF2 = vgg(vis)
    iF1, iF2 = vgg(ir)
    return 0.5*(F.l1_loss(fF1, vF1) + F.l1_loss(fF1, iF1) + F.l1_loss(fF2, vF2) + F.l1_loss(fF2, iF2))

class FusionLoss(nn.Module):
    def __init__(self, device, w_ssim=1.0, w_grad=1.0, w_perc=0.1, w_int=5.0): # Added w_int
        super().__init__()
        self.ssim = SSIM()
        self.grad = GradientLoss()
        self.vgg = VGGPerceptual(device)
        self.w_ssim = w_ssim
        self.w_grad = w_grad
        self.w_perc = w_perc
        self.w_int = w_int  # Store new weight
        print(f'FusionLoss initialized with weights: SSIM={w_ssim}, Grad={w_grad}, Perc={w_perc}, Int={w_int}')
        
    def forward(self, fused, vis, ir):
        # --- ADDED: L1 Intensity Loss ---
        # Encourages the fused image to retain the brightest pixels from either input
        intensity_target = torch.max(vis, ir)
        l_intensity = F.l1_loss(fused, intensity_target)

        # --- Grayscale conversions for other losses ---
        if fused.size(1) > 1:
            fused_gray = fused.mean(dim=1, keepdim=True)
        else:
            fused_gray = fused
        if vis.size(1) > 1:
            vis_gray = vis.mean(dim=1, keepdim=True)
        else:
            vis_gray = vis
        if ir.size(1) > 1:
            ir_gray = ir.mean(dim=1, keepdim=True)
        else:
            ir_gray = ir
            
        # --- Original loss calculations ---
        l_ssim = 0.5*(1.0 - self.ssim(fused_gray, vis_gray)) + 0.5*(1.0 - self.ssim(fused_gray, ir_gray))
        l_grad = self.grad(fused, vis, ir)
        l_perc = perceptual_loss(self.vgg, fused, vis, ir)
        
        # --- MODIFIED: Combined total loss ---
        total = self.w_int * l_intensity + self.w_ssim*l_ssim + self.w_grad*l_grad + self.w_perc*l_perc
        
        # --- MODIFIED: Return dictionary with new loss component ---
        return total, {"intensity": l_intensity.item(), "ssim": l_ssim.item(), "grad": l_grad.item(), "perc": l_perc.item()}

print('Cell 6: Loss components definition completed')

Cell 6: Starting loss components definition
Cell 6: Loss components definition completed


In [7]:
print('Cell 7: Starting UNetFusion model definition')
# Simplified U-Net based Fusion Model
class UNetFusion(nn.Module):
    def __init__(self):
        super().__init__()
        print('Initializing UNetFusion model...')
        
        # Encoder (pretrained ResNet backbone)
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.encoder1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1
        )
        self.encoder2 = resnet.layer2
        self.encoder3 = resnet.layer3
        self.encoder4 = resnet.layer4
        
        # Decoder with skip connections
        self.decoder4 = nn.Sequential(
            nn.Conv2d(2048, 1024, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder3 = nn.Sequential(
            nn.Conv2d(2048, 512, 3, padding=1),  # 1024 from decoder + 1024 from encoder3
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder2 = nn.Sequential(
            nn.Conv2d(1024, 256, 3, padding=1),  # 512 from decoder + 512 from encoder2
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.decoder1 = nn.Sequential(
            nn.Conv2d(512, 128, 3, padding=1),  # 256 from decoder + 256 from encoder1
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        
        # Final convolution - CORRECTED
        self.final_conv = nn.Sequential(
            # Add the missing Upsample layer to go from 128x128 to 256x256
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1), # Input channels from decoder1 is 128
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, 1),
            nn.Sigmoid()
        )
        print('UNetFusion model initialized successfully')

    def forward(self, vis, ir):
        # Process both inputs through the encoder
        vis1 = self.encoder1(vis)
        vis2 = self.encoder2(vis1)
        vis3 = self.encoder3(vis2)
        vis4 = self.encoder4(vis3)
        
        ir1 = self.encoder1(ir)
        ir2 = self.encoder2(ir1)
        ir3 = self.encoder3(ir2)
        ir4 = self.encoder4(ir3)
        
        # Fuse features at each level (simple addition)
        fused4 = vis4 + ir4
        fused3 = vis3 + ir3
        fused2 = vis2 + ir2
        fused1 = vis1 + ir1
        
        # Decode with skip connections
        d4 = self.decoder4(fused4)
        d3 = self.decoder3(torch.cat([d4, fused3], dim=1))
        d2 = self.decoder2(torch.cat([d3, fused2], dim=1))
        d1 = self.decoder1(torch.cat([d2, fused1], dim=1))
        
        # Final output
        return self.final_conv(d1)

print('Cell 7: UNetFusion model definition completed')

Cell 7: Starting UNetFusion model definition
Cell 7: UNetFusion model definition completed


In [8]:
print('Cell 8: Starting train_model function definition')
def train_model():
    print('Initializing training components...')
    import time
    # Initialize model, dataset, and loss
    model = UNetFusion().to(config['device'])
    full_dataset = LLVIPDataset('LLVIP', train=True)

    # Split dataset into train and validation (80% train, 20% val)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
                                 num_workers=4, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,
                               num_workers=4, collate_fn=collate_fn)

    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    scaler = GradScaler(device='cuda')
    criterion = FusionLoss(config['device'])

    print(f'Full dataset size: {len(full_dataset)} samples')
    print(f'Training set size: {len(train_dataset)} samples')
    print(f'Validation set size: {len(val_dataset)} samples')
    print(f'Batch size: {config["batch_size"]}')
    print(f'Number of training batches per epoch: {len(train_dataloader)}')
    print(f'Number of validation batches per epoch: {len(val_dataloader)}')
    print(f'Total epochs: {config["num_epochs"]}')
    print('Starting training loop...')

    # Track best validation loss for model saving
    best_val_loss = float('inf')
    training_history = []

    # Training loop with real-time updates
    epoch_bar = tqdm(range(config['num_epochs']), desc='🎯 Training Progress', unit='epoch',
                     bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
                     ncols=100, miniters=1, mininterval=0.1)

    for epoch in epoch_bar:
        epoch_start_time = time.time()
        tqdm.write(f'\n🚀 Epoch {epoch+1}/{config["num_epochs"]} - Starting...')

        model.train()
        total_loss = 0
        ssim_loss = 0
        grad_loss = 0
        perc_loss = 0
        batch_count = 0

        # Training phase with enhanced progress bar
        train_bar = tqdm(train_dataloader,
                        desc=f'📚 Training Epoch {epoch+1}',
                        leave=False,
                        unit='batch',
                        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] {postfix}',
                        ncols=120, miniters=1, mininterval=0.05)

        train_start_time = time.time()
        for batch_idx, (vis, ir, _) in enumerate(train_bar):
            batch_start_time = time.time()
            vis = vis.to(config['device'])
            ir = ir.to(config['device'])

            optimizer.zero_grad()

            with autocast(device_type='cuda'):
                fused = model(vis, ir)
                loss, loss_components = criterion(fused, vis, ir)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            batch_time = time.time() - batch_start_time
            total_loss += loss.item()
            ssim_loss += loss_components['ssim']
            grad_loss += loss_components['grad']
            perc_loss += loss_components['perc']
            batch_count += 1

            # Update progress bar with detailed metrics - force refresh
            current_loss = total_loss / batch_count
            current_ssim = ssim_loss / batch_count
            current_grad = grad_loss / batch_count
            current_perc = perc_loss / batch_count

            train_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{current_loss:.4f}',
                'ssim': f'{current_ssim:.4f}',
                'grad': f'{current_grad:.4f}',
                'perc': f'{current_perc:.4f}',
                'batch_time': f'{batch_time:.3f}s'
            }, refresh=True)

            # Force refresh every few batches
            if batch_idx % 5 == 0:
                train_bar.refresh()

        train_time = time.time() - train_start_time
        avg_train_loss = total_loss / len(train_dataloader)
        avg_train_ssim = ssim_loss / len(train_dataloader)
        avg_train_grad = grad_loss / len(train_dataloader)
        avg_train_perc = perc_loss / len(train_dataloader)

        # Validation phase with enhanced progress bar
        model.eval()
        val_total_loss = 0
        val_ssim_loss = 0
        val_grad_loss = 0
        val_perc_loss = 0
        val_batch_count = 0

        val_bar = tqdm(val_dataloader,
                      desc=f'✅ Validating Epoch {epoch+1}',
                      leave=False,
                      unit='batch',
                      bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}] {postfix}',
                      ncols=120, miniters=1, mininterval=0.05)

        val_start_time = time.time()
        with torch.no_grad():
            for batch_idx, (vis, ir, _) in enumerate(val_bar):
                batch_start_time = time.time()
                vis = vis.to(config['device'])
                ir = ir.to(config['device'])

                with autocast(device_type='cuda'):
                    fused = model(vis, ir)
                    loss, loss_components = criterion(fused, vis, ir)

                batch_time = time.time() - batch_start_time
                val_total_loss += loss.item()
                val_ssim_loss += loss_components['ssim']
                val_grad_loss += loss_components['grad']
                val_perc_loss += loss_components['perc']
                val_batch_count += 1

                # Update validation progress bar - force refresh
                current_val_loss = val_total_loss / val_batch_count
                val_bar.set_postfix({
                    'val_loss': f'{loss.item():.4f}',
                    'avg_val_loss': f'{current_val_loss:.4f}',
                    'batch_time': f'{batch_time:.3f}s'
                }, refresh=True)

                # Force refresh every few batches
                if batch_idx % 3 == 0:
                    val_bar.refresh()

        val_time = time.time() - val_start_time
        avg_val_loss = val_total_loss / len(val_dataloader)
        avg_val_ssim = val_ssim_loss / len(val_dataloader)
        avg_val_grad = val_grad_loss / len(val_dataloader)
        avg_val_perc = val_perc_loss / len(val_dataloader)

        epoch_time = time.time() - epoch_start_time

        # Store epoch metrics
        epoch_metrics = {
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'train_ssim': avg_train_ssim,
            'train_grad': avg_train_grad,
            'train_perc': avg_train_perc,
            'val_loss': avg_val_loss,
            'val_ssim': avg_val_ssim,
            'val_grad': avg_val_grad,
            'val_perc': avg_val_perc,
            'epoch_time': epoch_time,
            'train_time': train_time,
            'val_time': val_time
        }
        training_history.append(epoch_metrics)

        # Update epoch progress bar with final metrics
        epoch_bar.set_postfix({
            'train_loss': f'{avg_train_loss:.4f}',
            'val_loss': f'{avg_val_loss:.4f}',
            'time': f'{epoch_time:.1f}s'
        }, refresh=True)
        epoch_bar.refresh()

        # Print comprehensive epoch summary
        tqdm.write(f'📊 Epoch {epoch+1}/{config["num_epochs"]} Summary (⏱️ {epoch_time:.1f}s total):')
        tqdm.write(f'  🎓 Training (⏱️ {train_time:.1f}s):')
        tqdm.write(f'    Loss: {avg_train_loss:.4f} | SSIM: {avg_train_ssim:.4f} | Grad: {avg_train_grad:.4f} | Perc: {avg_train_perc:.4f}')
        tqdm.write(f'  ✅ Validation (⏱️ {val_time:.1f}s):')
        tqdm.write(f'    Loss: {avg_val_loss:.4f} | SSIM: {avg_val_ssim:.4f} | Grad: {avg_val_grad:.4f} | Perc: {avg_val_perc:.4f}')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': config,
                'epoch': epoch + 1,
                'best_val_loss': best_val_loss,
                'training_history': training_history
            }, config['save_path'])
            tqdm.write(f'💾 Best model saved! (Val Loss: {best_val_loss:.4f})')

    # Final summary
    tqdm.write(f'\n🎉 Training completed!')
    tqdm.write(f'🏆 Best validation loss: {best_val_loss:.4f}')
    tqdm.write(f'💾 Final model saved to {config["save_path"]}')

    # Print training history summary
    tqdm.write(f'\n📈 Training History Summary:')
    tqdm.write(f'{"Epoch":<5} {"Train Loss":<12} {"Val Loss":<12} {"Train SSIM":<12} {"Val SSIM":<12} {"Time":<8}')
    tqdm.write('-' * 80)
    for metrics in training_history[-5:]:  # Show last 5 epochs
        tqdm.write(f'{metrics["epoch"]:<5} {metrics["train_loss"]:<12.4f} {metrics["val_loss"]:<12.4f} '
                  f'{metrics["train_ssim"]:<12.4f} {metrics["val_ssim"]:<12.4f} {metrics["epoch_time"]:<8.1f}s')

print('Cell 8: train_model function definition completed')

Cell 8: Starting train_model function definition
Cell 8: train_model function definition completed


In [9]:
print('Cell 10: Starting infer_model function definition')
# Inference function
def infer_model():
    print('Loading model for inference...')
    # Load model
    checkpoint = torch.load(config['save_path'])
    model = UNetFusion().to(config['device'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print('Model loaded successfully')

    # Load test dataset
    test_dataset = LLVIPDataset('LLVIP', train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

    print(f'Test dataset size: {len(test_dataset)} samples')

    # Create output directory
    os.makedirs('results', exist_ok=True)
    print('Results directory created')

    # Process test images with progress bar
    print('Starting inference on test images...')
    inference_bar = tqdm(test_loader,
                        desc='🔍 Running Inference',
                        unit='image',
                        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
                        ncols=100, miniters=1, mininterval=0.1)

    with torch.no_grad():
        for i, (vis, ir, names) in enumerate(inference_bar):
            vis = vis.to(config['device'])
            ir = ir.to(config['device'])

            fused = model(vis, ir)

            # Convert to numpy for visualization
            vis_img = vis.squeeze().permute(1, 2, 0).cpu().numpy()
            ir_img = ir.squeeze().mean(0).cpu().numpy()  # Convert to grayscale
            fused_img = fused.squeeze().permute(1, 2, 0).cpu().numpy()

            # Plot results
            plt.figure(figsize=(15, 5))
            plt.subplot(131)
            plt.imshow(vis_img)
            plt.title('Visible (RGB)')
            plt.axis('off')

            plt.subplot(132)
            plt.imshow(ir_img, cmap='gray')
            plt.title('Infrared')
            plt.axis('off')

            plt.subplot(133)
            plt.imshow(fused_img)
            plt.title('Fused (RGB)')
            plt.axis('off')

            plt.savefig(f'results/{names[0]}_fused.png', bbox_inches='tight', dpi=150)
            plt.close()

            # Update progress bar with current image name
            inference_bar.set_postfix({'image': names[0]}, refresh=True)

            if i >= 4:  # Save only first 5 examples
                break

    tqdm.write('Inference completed. Results saved in results/ directory')

print('Cell 10: infer_model function definition completed')

Cell 10: Starting infer_model function definition
Cell 10: infer_model function definition completed


In [10]:
print('Cell 11: Starting main execution')
if __name__ == '__main__':
    print('Beginning training phase...')
    train_model()
    print('Training phase completed.')
    print('Beginning inference phase...')
    infer_model()
    print('Inference phase completed.')
    print('All tasks completed successfully!')
print('Cell 11: Main execution completed')

Cell 11: Starting main execution
Beginning training phase...
Initializing training components...
Initializing UNetFusion model...
UNetFusion model initialized successfully
UNetFusion model initialized successfully
Found 12025 image pairs in train set
SSIM initialized with window_size=11
GradientLoss initialized
Found 12025 image pairs in train set
SSIM initialized with window_size=11
GradientLoss initialized
VGGPerceptual initialized on device: cuda
FusionLoss initialized with weights: SSIM=1.0, Grad=1.0, Perc=0.1, Int=5.0
Full dataset size: 12025 samples
Training set size: 9620 samples
Validation set size: 2405 samples
Batch size: 8
Number of training batches per epoch: 1203
Number of validation batches per epoch: 301
Total epochs: 2
Starting training loop...
VGGPerceptual initialized on device: cuda
FusionLoss initialized with weights: SSIM=1.0, Grad=1.0, Perc=0.1, Int=5.0
Full dataset size: 12025 samples
Training set size: 9620 samples
Validation set size: 2405 samples
Batch size: 8

🎯 Training Progress:   0%|                                                | 0/2 [00:00<?, ?epoch/s]


🚀 Epoch 1/2 - Starting...


🎯 Training Progress:   0%|                                                | 0/2 [04:45<?, ?epoch/s][00:00<?, ?batch/s] [A

📊 Epoch 1/2 Summary (⏱️ 285.3s total):
  🎓 Training (⏱️ 255.0s):
    Loss: 2.2509 | SSIM: 0.7201 | Grad: 0.3399 | Perc: 1.1716
  ✅ Validation (⏱️ 30.3s):
    Loss: 2.2569 | SSIM: 0.7222 | Grad: 0.3464 | Perc: 1.1819


🎯 Training Progress:  50%|███████████████████▌                   | 1/2 [04:46<04:46, 286.04s/epoch]

💾 Best model saved! (Val Loss: 2.2569)

🚀 Epoch 2/2 - Starting...


🎯 Training Progress: 100%|███████████████████████████████████████| 2/2 [09:30<00:00, 285.08s/epoch][00:00<?, ?batch/s] [A
  checkpoint = torch.load(config['save_path'])

  checkpoint = torch.load(config['save_path'])


📊 Epoch 2/2 Summary (⏱️ 284.1s total):
  🎓 Training (⏱️ 254.1s):
    Loss: 2.2516 | SSIM: 0.7202 | Grad: 0.3398 | Perc: 1.1715
  ✅ Validation (⏱️ 30.0s):
    Loss: 2.2587 | SSIM: 0.7226 | Grad: 0.3466 | Perc: 1.1824

🎉 Training completed!
🏆 Best validation loss: 2.2569
💾 Final model saved to fusion_model_improved.pth

📈 Training History Summary:
Epoch Train Loss   Val Loss     Train SSIM   Val SSIM     Time    
--------------------------------------------------------------------------------
1     2.2509       2.2569       0.7201       0.7222       285.3   s
2     2.2516       2.2587       0.7202       0.7226       284.1   s
Training phase completed.
Beginning inference phase...
Loading model for inference...
Initializing UNetFusion model...
UNetFusion model initialized successfully
Model loaded successfully
Found 3463 image pairs in test set
Test dataset size: 3463 samples
Results directory created
Starting inference on test images...
UNetFusion model initialized successfully
Model loa

🔍 Running Inference:   0%|                                     | 4/3463 [00:01<23:40,  2.43image/s]

Inference completed. Results saved in results/ directory
Inference phase completed.
All tasks completed successfully!
Cell 11: Main execution completed



