In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# this cell can be modified if your upload location differs

In [2]:
%cd /content/drive/MyDrive/ELEC576_Team4_Final_Project/OurModel

/content/drive/MyDrive/ELEC576_Team4_Final_Project/OurModel


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.utils import save_image
from PIL import Image
import os
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

class RGBThermalDataset(Dataset):
    def __init__(self, rgb_dir, thermal_dir, transform=None):
        """
        Args:
            rgb_dir (str): Path to RGB images
            thermal_dir (str): Path to thermal images
            transform (callable, optional): Optional transform to be applied on images
        """
        self.rgb_dir = rgb_dir
        self.thermal_dir = thermal_dir
        self.transform = transform
        self.image_files = os.listdir(rgb_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]  # "<name>_color.jpg"
        base_name = img_name.replace('_color.jpg', '')  # "<name>"
        rgb_path = os.path.join(self.rgb_dir, img_name)

        thermal_name = base_name.replace('_color.png', '.jpg')
        thermal_path = os.path.join(self.thermal_dir, thermal_name)

        rgb_image = Image.open(rgb_path)
        thermal_image = Image.open(thermal_path)

        # Apply transformations if they are provided
        if self.transform:
            rgb_image = self.transform(rgb_image)
            thermal_image = self.transform(thermal_image)

        return rgb_image, thermal_image

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features[:23]
        self.features = nn.Sequential(*list(vgg.children()))
        self.features.eval()

        # Freeze parameters
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        # Convert single channel to 3 channels for VGG
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        if y.size(1) == 1:
            y = y.repeat(1, 3, 1, 1)

        x_features = self.features(x)
        y_features = self.features(y)
        return nn.functional.mse_loss(x_features, y_features)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x += residual
        x = self.relu(x)
        return x

class RGBToThermalNet(nn.Module):
    def __init__(self):
        super(RGBToThermalNet, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            ResidualBlock(64),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            ResidualBlock(128),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            ResidualBlock(256),
        )

        # Decoder with skip connections
        self.decoder_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder_res1 = ResidualBlock(128)
        self.decoder_conv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_res2 = ResidualBlock(64)
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.encoder[:5](x)  # First stage output
        x2 = self.encoder[5:10](x1)  # Second stage output
        x3 = self.encoder[10:](x2)  # Final encoder output

        # Decoder with skip connections
        x = self.decoder_conv1(x3)

        # Ensure x2 is the same size as x before adding
        if x.size() != x2.size():
            x2 = nn.functional.interpolate(x2, size=x.shape[2:], mode='bilinear', align_corners=False)

        x = x + x2  # Skip connection 1
        x = self.decoder_res1(x)

        x = self.decoder_conv2(x)

        # Ensure x1 is the same size as x before adding
        if x.size() != x1.size():
            x1 = nn.functional.interpolate(x1, size=x.shape[2:], mode='bilinear', align_corners=False)

        x = x + x1  # Skip connection 2
        x = self.decoder_res2(x)

        x = self.final_conv(x)
        return x

class Visualizer:
    def __init__(self, output_dir='output'):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.train_losses = []
        self.val_losses = []

    def save_images(self, epoch, rgb_img, thermal_img, output_img):
        # Create a figure with three subplots
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Use tensor_to_image() to convert tensors to correct format
        rgb_img = tensor_to_image(rgb_img)
        thermal_img = tensor_to_image(thermal_img)
        output_img = tensor_to_image(output_img)

        axes[0].imshow(rgb_img)
        axes[0].set_title('Input RGB')
        axes[1].imshow(thermal_img, cmap='inferno')
        axes[1].set_title('Ground Truth Thermal')
        axes[2].imshow(output_img, cmap='inferno')
        axes[2].set_title('Generated Thermal')

        plt.savefig(os.path.join(self.output_dir, f'comparison_epoch_{epoch}.png'))
        plt.close()

    def update_loss_plot(self, train_loss, val_loss):
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)

        plt.figure(figsize=(10, 5))
        plt.plot(self.train_losses, label='Training Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(os.path.join(self.output_dir, 'loss_plot.png'))
        plt.close()

def train_model(model, train_loader, val_loader, num_epochs=50, device='cuda'):
    # Initialize losses
    mse_criterion = nn.MSELoss()
    perceptual_criterion = PerceptualLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

    # Initialize visualizer
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    visualizer = Visualizer(output_dir=f'output_{timestamp}')

    model = model.to(device)
    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        for i, (rgb_imgs, thermal_imgs) in enumerate(train_loader):
            rgb_imgs, thermal_imgs = rgb_imgs.to(device), thermal_imgs.to(device)

            optimizer.zero_grad()
            outputs = model(rgb_imgs)

            # Combine MSE and perceptual loss
            mse_loss = mse_criterion(outputs, thermal_imgs)
            perceptual_loss = perceptual_criterion(outputs, thermal_imgs)
            loss = mse_loss + 0.1 * perceptual_loss

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            '''
            # Save example images periodically
            if i == 0:  # Save first batch of each epoch
                visualizer.save_images(epoch,
                                    rgb_imgs[0],
                                    thermal_imgs[0],
                                    outputs[0])
            '''
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for rgb_imgs, thermal_imgs in val_loader:
                rgb_imgs, thermal_imgs = rgb_imgs.to(device), thermal_imgs.to(device)
                outputs = model(rgb_imgs)
                mse_loss = mse_criterion(outputs, thermal_imgs)
                perceptual_loss = perceptual_criterion(outputs, thermal_imgs)
                val_loss += (mse_loss + 0.1 * perceptual_loss).item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        scheduler.step(val_loss)

        visualizer.update_loss_plot(train_loss, val_loss)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }, 'best_version_our_model.pth')

def test_model(model_path, test_loader, device='cuda'):
    """
    Test a trained model on a test dataset and compute metrics

    Args:
        model_path (str): Path to the saved model checkpoint
        test_loader (DataLoader): DataLoader for test dataset
        device (str): Device to run the model on
    """
    # Load the trained model
    model = RGBToThermalNet()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    # Initialize metrics
    mse_criterion = nn.MSELoss()
    total_mse = 0
    total_psnr = 0
    total_ssim = 0
    total_samples = 0

    test_output_dir = 'test_results'
    os.makedirs(test_output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (rgb_imgs, thermal_imgs) in enumerate(test_loader):
            rgb_imgs, thermal_imgs = rgb_imgs.to(device), thermal_imgs.to(device)
            outputs = model(rgb_imgs)

            # Calculate MSE
            mse = mse_criterion(outputs, thermal_imgs).item()
            total_mse += mse * rgb_imgs.size(0)

            # Calculate PSNR and SSIM for each image in batch
            for j in range(rgb_imgs.size(0)):
                pred = outputs[j].cpu().numpy().squeeze()
                target = thermal_imgs[j].cpu().numpy().squeeze()

                pred = (pred - pred.min()) / (pred.max() - pred.min())
                target = (target - target.min()) / (target.max() - target.min())
                target = 0.2989 * target[0, :, :] + 0.5870 * target[1, :, :] + 0.1140 * target[2, :, :]

                # Calculate metrics
                curr_psnr = psnr(target, pred, data_range=1.0)
                curr_ssim = ssim(target, pred, data_range=1.0)

                total_psnr += curr_psnr
                total_ssim += curr_ssim

                if i < 100:
                    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

                    # Convert images for visualization
                    rgb_img = rgb_imgs[j].cpu().numpy().transpose(1, 2, 0)

                    axes[0].imshow(rgb_img)
                    axes[0].set_title('Input RGB')
                    axes[1].imshow(target, cmap='inferno')
                    axes[1].set_title('Ground Truth Thermal')
                    axes[2].imshow(pred, cmap='inferno')
                    axes[2].set_title(f'Generated Thermal\nPSNR: {curr_psnr:.2f}, SSIM: {curr_ssim:.2f}')

                    plt.savefig(os.path.join(test_output_dir, f'test_sample_{i}_{j}.png'))
                    plt.close()

            total_samples += rgb_imgs.size(0)

            if (i + 1) % 10 == 0:
                print(f'Processed {i+1}/{len(test_loader)} batches')

    # Calculate average metrics
    avg_mse = total_mse / total_samples
    avg_psnr = total_psnr / total_samples
    avg_ssim = total_ssim / total_samples

    # Save metrics to file
    metrics = {
        'MSE': avg_mse,
        'PSNR': avg_psnr,
        'SSIM': avg_ssim
    }

    with open(os.path.join(test_output_dir, 'test_metrics.txt'), 'w') as f:
        f.write('Test Metrics:\n')
        f.write(f'Average MSE: {avg_mse:.4f}\n')
        f.write(f'Average PSNR: {avg_psnr:.2f} dB\n')
        f.write(f'Average SSIM: {avg_ssim:.4f}\n')

    print('\nTest Results:')
    print(f'Average MSE: {avg_mse:.4f}')
    print(f'Average PSNR: {avg_psnr:.2f} dB')
    print(f'Average SSIM: {avg_ssim:.4f}')

    return metrics
def tensor_to_image(tensor):
    """
    Convert a PyTorch tensor to a numpy array suitable for image display

    Args:
    tensor (torch.Tensor): Input tensor, expected shapes:
        - (C, H, W) for color images
        - (H, W) for grayscale images

    Returns:
    numpy.ndarray: Image array ready for plt.imshow()
    """
    # If tensor is on GPU, move to CPU
    if tensor.is_cuda:
        tensor = tensor.cpu()

    # Normalize tensor if values are not in [0, 1]
    if tensor.min() < 0 or tensor.max() > 1:
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())

    # Handle different tensor shapes
    if tensor.dim() == 3:
        # Check for (C, H, W) vs (H, W, C)
        if tensor.size(0) in [1, 3, 4]:  # Channels first
            tensor = tensor.permute(1, 2, 0)

        # Squeeze single-channel images
        if tensor.size(2) == 1:
            tensor = tensor.squeeze(2)

    # Convert to numpy
    tensor.detach().numpy()

    return tensor

def display_image(tensor, title=None, cmap=None):
    """
    Display a PyTorch tensor as an image

    Args:
    tensor (torch.Tensor): Input tensor
    title (str, optional): Title for the image
    cmap (str, optional): Colormap to use
    """
    plt.figure(figsize=(10, 6))

    if cmap is None:
        cmap = 'viridis' if tensor.dim() == 2 or (tensor.dim() == 3 and tensor.size(2) == 1) else None

    plt.imshow(tensor_to_image(tensor), cmap=cmap)

    if title:
        plt.title(title)

    plt.axis('off')
    plt.tight_layout()
    plt.show()

In [4]:
transform = transforms.Compose([
        transforms.Resize((288, 384)),
        transforms.ToTensor(),
    ])

In [None]:
# Training
# Create datasets
train_dataset = RGBThermalDataset(
    rgb_dir='../dataset/paired/train/vis',
    thermal_dir='../dataset/paired/train/ir',
    transform=transform
)

val_dataset = RGBThermalDataset(
    rgb_dir='../dataset/paired/train/vis',
    thermal_dir='../dataset/paired/train/ir',
    transform=transform
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=14, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=14, shuffle=False, num_workers=2)

# Initialize model
model = RGBToThermalNet()

# Train model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_model(model, train_loader, val_loader, num_epochs=50, device=device)

In [7]:
# Testing
print("\nStarting testing phase...")

test_dataset = RGBThermalDataset(
    rgb_dir='../dataset/paired/val/vis',
    thermal_dir='../dataset/paired/val/ir',
    transform=transform)

test_loader = DataLoader(test_dataset,
                         batch_size=14,
                         shuffle=False,
                         num_workers=4)

test_metrics = test_model('best_version_our_model_50.pth', test_loader)
print(test_metrics)


Starting testing phase...


  checkpoint = torch.load(model_path)
  return F.mse_loss(input, target, reduction=self.reduction)


Processed 10/19 batches


  return F.mse_loss(input, target, reduction=self.reduction)



Test Results:
Average MSE: 0.0399
Average PSNR: 14.49 dB
Average SSIM: 0.4723
{'MSE': 0.039916919909077676, 'PSNR': 14.490360496065204, 'SSIM': 0.4722972928462446}
