In [2]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import (mean_squared_error, mean_absolute_error,
                             precision_score, recall_score, f1_score, roc_auc_score, jaccard_score)
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Conditional import for U2NET, depending on its location
try:
    from u2net.u2net import U2NET
except ImportError:
    from model.u2net.u2net import U2NET

# Import EfficientNet only once
from efficientnet_pytorch import EfficientNet


In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

sal_model = U2NET(3, 1).to(device)
sal_model.load_state_dict(torch.load('u2net/u2net.pth', map_location=device))
sal_model.eval()


Using device: mps


  sal_model.load_state_dict(torch.load('u2net/u2net.pth', map_location=device))


U2NET(
  (stage1): RSU7(
    (rebnconvin): REBNCONV(
      (conv_s1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (rebnconv1): REBNCONV(
      (conv_s1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv2): REBNCONV(
      (conv_s1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv3): REBNCONV(
      (conv_s1): Conv2d(32, 32, k

In [4]:
def generate_saliency_mask(image_path, model, device):
    """
    Generates a saliency mask for the given image using U^2-Net.

    Args:
        image_path (str): Path to the image file.
        model (torch.nn.Module): Pre-trained saliency detection model.
        device (torch.device): Device to perform computation on.

    Returns:
        np.array: Saliency mask normalized between 0 and 1.
    """
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = model(input_tensor)
        pred = d1[:, 0, :, :]
        pred = pred.cpu().numpy()  # Shape: (1, H, W)
        pred = pred.squeeze(0)      # Now pred has shape (H, W)
        pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)  # Normalize to [0,1]
        pred = np.uint8(pred * 255)
        pred = Image.fromarray(pred).resize(image.size, resample=Image.BILINEAR)
        pred = np.array(pred) / 255.0  # Normalize to [0,1]
    
    return pred


In [5]:
class KonIQ10kDataset(Dataset):
    def __init__(self, images_dir, csv_file, saliency_model, device, transform=None):
        """
        Args:
            images_dir (str): Path to images.
            csv_file (str): Path to the CSV file with global scores.
            saliency_model (torch.nn.Module): Pre-trained saliency detection model.
            device (torch.device): Device to perform computation on.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.images_dir = images_dir
        self.global_scores = pd.read_csv(csv_file)
        self.transform = transform
        self.saliency_model = saliency_model
        self.device = device

    def __len__(self):
        return len(self.global_scores)

    def __getitem__(self, idx):
        # Get image filename and global score
        img_name = self.global_scores.iloc[idx, 0]
        score = self.global_scores.iloc[idx, 1]
        
        # Load image
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        image_np = np.array(image)
        
        # Generate saliency mask
        saliency_mask = generate_saliency_mask(img_path, self.saliency_model, self.device)
        saliency_mask = np.expand_dims(saliency_mask, axis=-1)  # Add channel dimension
        
        # Apply transforms
        if self.transform:
            augmented = self.transform(image=image_np, mask=saliency_mask)
            image = augmented['image']
            saliency_mask = augmented['mask']
        
        # Convert mask to binary (threshold can be adjusted)
        saliency_mask = (saliency_mask > 0.5).float()
        
        return image, torch.tensor(score, dtype=torch.float32), saliency_mask.squeeze(0)


In [6]:

def get_transforms(train=True):
    if train:
        return A.Compose([
            A.Resize(224, 224),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Rotate(limit=15, p=0.3),
            A.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])


In [7]:

images_dir = '../koniq-10k/images/'  # Update with your actual path
csv_file = '../koniq-10k/annotations/koniq-10k.csv'  # Update with your actual path
batch_size = 16
validation_split = 0.2
shuffle_dataset = True
random_seed= 42

# Initialize the dataset
full_dataset = KonIQ10kDataset(images_dir, csv_file, saliency_model=sal_model, device=device, transform=get_transforms(train=True))

# Creating data indices for training and validation splits:
dataset_size = len(full_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = DataLoader(full_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
valid_loader = DataLoader(full_dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=0)


In [8]:
class EfficientNetIQA(nn.Module):
    def __init__(self, efficientnet_version='efficientnet-b0', pretrained=True):
        super(EfficientNetIQA, self).__init__()
        # Load EfficientNet backbone
        self.backbone = EfficientNet.from_pretrained(efficientnet_version) if pretrained else EfficientNet.from_name(efficientnet_version)
        
        # Remove the classification head
        self.backbone._fc = nn.Identity()
        self.backbone._avg_pooling = nn.Identity()
        
        # Global Quality Assessment Head
        self.global_head = nn.Sequential(
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)  # Regression output
        )
        
        # Local Quality Assessment Head
        # We'll add convolutional layers to generate a quality map
        self.local_head = nn.Sequential(
            nn.Conv2d(1280, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 1, kernel_size=1),
            nn.Sigmoid()  # Output between 0 and 1
        )
        
    def forward(self, x):
        # Pass through EfficientNet backbone
        features = self.backbone.extract_features(x)  # Shape: [B, 1280, H, W]
        
        # Global Quality
        # Adaptive pooling to get a fixed-size feature vector
        pooled = F.adaptive_avg_pool2d(features, (1,1)).view(features.size(0), -1)  # Shape: [B, 1280]
        global_quality = self.global_head(pooled).squeeze(1)  # Shape: [B]
        
        # Local Quality
        local_quality_map = self.local_head(features)  # Shape: [B, 1, H, W]
        local_quality_map = F.interpolate(
            local_quality_map,
            size=(x.size(2), x.size(3)),  # Ensure size is a tuple (height, width)
            mode='bilinear',
            align_corners=False
        )  # Resizes to [B, 1, H_in, W_in]
        local_quality_map = local_quality_map.squeeze(1)  # Now shape is [B, H_in, W_in]
        
        return global_quality, local_quality_map



In [9]:


# Initialize the model
model = EfficientNetIQA().to(device)

# Define loss functions
criterion_global = nn.MSELoss()
criterion_local = nn.BCELoss()

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optionally, define a learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


Loaded pretrained weights for efficientnet-b0


In [None]:


# Parameters
num_epochs = 30
alpha = 1.0  # Weight for global loss
beta = 1.0   # Weight for local loss

# Initialize optimizer and other components (Assuming they are already defined)
# optimizer = ...
# criterion_global = ...
# criterion_local = ...

# Define scheduler after optimizer is defined
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

best_val_loss = float('inf')
patience_counter = 0
patience = 10

# Initialize GradScaler for mixed precision
scaler = GradScaler()

# Move model to device before starting training
model = EfficientNetIQA().to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_global_loss = 0.0
    running_local_loss = 0.0
    
    loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    for images, scores, masks in loop:
        # Move data to the appropriate device
        images = images.to(device, non_blocking=True)
        scores = scores.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        
        if masks.dim() == 4 and masks.size(-1) == 1:
            masks = masks.squeeze(-1)

        optimizer.zero_grad()
        
        with autocast():
            # Forward pass
            outputs_global, outputs_local = model(images)
            
            # Compute losses
            loss_global = criterion_global(outputs_global, scores)
            loss_local = criterion_local(outputs_local, masks)
            loss = alpha * loss_global + beta * loss_local
        
        # Backward pass and optimization
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Update running losses
        running_loss += loss.item()
        running_global_loss += loss_global.item()
        running_local_loss += loss_local.item()
        
        # Update progress bar
        loop.set_postfix(loss=loss.item(), global_loss=loss_global.item(), local_loss=loss_local.item())
    
    # Validation after each epoch
    model.eval()
    val_loss = 0.0
    val_global_loss = 0.0
    val_local_loss = 0.0
    with torch.no_grad():
        for images, scores, masks in valid_loader:
            # Move data to the appropriate device
            images = images.to(device, non_blocking=True)
            scores = scores.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            # Forward pass
            outputs_global, outputs_local = model(images)
            
            # Compute losses
            loss_global = criterion_global(outputs_global, scores)
            loss_local = criterion_local(outputs_local, masks)
            loss = alpha * loss_global + beta * loss_local
            
            # Accumulate validation losses
            val_loss += loss.item()
            val_global_loss += loss_global.item()
            val_local_loss += loss_local.item()
    
    # Calculate average losses
    avg_train_loss = running_loss / len(train_loader)
    avg_train_global_loss = running_global_loss / len(train_loader)
    avg_train_local_loss = running_local_loss / len(train_loader)
    avg_val_loss = val_loss / len(valid_loader)
    avg_val_global_loss = val_global_loss / len(valid_loader)
    avg_val_local_loss = val_local_loss / len(valid_loader)
    
    # Print epoch summary
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"Train Loss: {avg_train_loss:.4f} | Global: {avg_train_global_loss:.4f} | Local: {avg_train_local_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f} | Global: {avg_val_global_loss:.4f} | Local: {avg_val_local_loss:.4f}")
    
    # Scheduler step after validation
    scheduler.step(avg_val_loss)
    
    # Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
    
    # Optionally, save model checkpoints
    torch.save(model.state_dict(), f'checkpoint_epoch_{epoch+1}.pth')

# Load the best model after training
model = EfficientNetIQA().to(device)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()


  scaler = GradScaler()


Loaded pretrained weights for efficientnet-b0


  src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
  with autocast():
Epoch [1/30]:  48%|████▊     | 241/504 [16:56<17:44,  4.05s/it, global_loss=0.0274, local_loss=0.706, loss=0.733] 

In [10]:


model.eval()
all_scores = []
all_preds_global = []
all_targets_global = []
all_preds_local = []
all_targets_local = []

with torch.no_grad():
    for images, scores, masks in valid_loader:
        images = images.to(device)
        scores = scores.to(device)
        masks = masks.to(device)
        
        outputs_global, outputs_local = model(images)
        
        # Collect global scores
        all_preds_global.extend(outputs_global.cpu().numpy())
        all_targets_global.extend(scores.cpu().numpy())
        
        # Collect local masks
        all_preds_local.extend(outputs_local.cpu().numpy())
        all_targets_local.extend(masks.cpu().numpy())

# Global Metrics
mse = mean_squared_error(all_targets_global, all_preds_global)
mae = mean_absolute_error(all_targets_global, all_preds_global)
pearson_corr, _ = pearsonr(all_targets_global, all_preds_global)
spearman_corr, _ = spearmanr(all_targets_global, all_preds_global)

print("Global Quality Assessment Metrics:")
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"Pearson Correlation: {pearson_corr:.4f}")
print(f"Spearman Correlation: {spearman_corr:.4f}")

# Local Metrics
# Binarize predictions with a threshold (e.g., 0.5)
threshold = 0.5
all_preds_local_bin = (np.array(all_preds_local) > threshold).astype(int)
all_targets_local_bin = (np.array(all_targets_local) > 0.5).astype(int)

# Flatten the masks for metric computation
all_preds_local_bin_flat = all_preds_local_bin.flatten()
all_targets_local_bin_flat = all_targets_local_bin.flatten()

iou = jaccard_score(all_targets_local_bin_flat, all_preds_local_bin_flat)
precision = precision_score(all_targets_local_bin_flat, all_preds_local_bin_flat)
recall = recall_score(all_targets_local_bin_flat, all_preds_local_bin_flat)
f1 = f1_score(all_targets_local_bin_flat, all_preds_local_bin_flat)
auc = roc_auc_score(all_targets_local_bin_flat, np.array(all_preds_local))

print("\nLocal Quality Assessment Metrics:")
print(f"IoU: {iou:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")


  src = F.upsample(src,size=tar.shape[2:],mode='bilinear')


Global Quality Assessment Metrics:
MSE: 0.0132
MAE: 0.0784
Pearson Correlation: 0.0445
Spearman Correlation: 0.0836


ValueError: Found array with dim 3. None expected <= 2.

In [None]:


def visualize_quality_maps(model, dataset, device, num_samples=5):
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    with torch.no_grad():
        for idx in indices:
            image, score, mask = dataset[idx]
            input_image = image.unsqueeze(0).to(device)
            pred_global, pred_local = model(input_image)
            pred_global = pred_global.item()
            pred_local = pred_local.squeeze().cpu().numpy()
            
            # Threshold the local quality map
            pred_local_bin = (pred_local > 0.5).astype(int)
            
            # Original image
            img = image.permute(1, 2, 0).cpu().numpy()
            img = np.clip(img * np.array([0.229, 0.224, 0.225]) + 
                          np.array([0.485, 0.456, 0.406]), 0, 1)
            
            # Ground truth mask
            gt_mask = mask.cpu().numpy()
            
            # Plotting
            fig, axs = plt.subplots(1, 3, figsize=(15,5))
            axs[0].imshow(img)
            axs[0].set_title(f"Original Image\nGlobal Score: {score:.2f}")
            axs[0].axis('off')
            
            axs[1].imshow(img)
            axs[1].imshow(gt_mask, alpha=0.5, cmap='jet')
            axs[1].set_title("Ground Truth Quality Map")
            axs[1].axis('off')
            
            axs[2].imshow(img)
            axs[2].imshow(pred_local, alpha=0.5, cmap='jet')
            axs[2].set_title(f"Predicted Quality Map\nGlobal Score: {pred_global:.2f}")
            axs[2].axis('off')
            
            plt.show()


In [None]:
def visualize_heatmap(model, dataset, device, num_samples=5):
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    with torch.no_grad():
        for idx in indices:
            image, score, mask = dataset[idx]
            input_image = image.unsqueeze(0).to(device)
            pred_global, pred_local = model(input_image)
            pred_global = pred_global.item()
            pred_local = pred_local.squeeze().cpu().numpy()
            
            # Normalize the quality map for better visualization
            pred_local_norm = (pred_local - pred_local.min()) / (pred_local.max() - pred_local.min() + 1e-8)
            
            # Original image
            img = image.permute(1, 2, 0).cpu().numpy()
            img = np.clip(img * np.array([0.229, 0.224, 0.225]) + 
                          np.array([0.485, 0.456, 0.406]), 0, 1)
            
            # Plotting
            fig, axs = plt.subplots(1, 2, figsize=(10,5))
            axs[0].imshow(img)
            axs[0].set_title(f"Original Image\nGlobal Score: {score:.2f}")
            axs[0].axis('off')
            
            axs[1].imshow(img)
            axs[1].imshow(pred_local_norm, alpha=0.6, cmap='jet')
            axs[1].set_title(f"Predicted Quality Heatmap\nGlobal Score: {pred_global:.2f}")
            axs[1].axis('off')
            
            plt.show()


In [None]:
# Visualize Quality Maps
visualize_quality_maps(model, full_dataset, device, num_samples=3)

# Visualize Heatmaps
visualize_heatmap(model, full_dataset, device, num_samples=3)
