In [1]:
# system level
import os
from os import path
import sys


# deep learning
from scipy.stats import pearsonr, spearmanr
import numpy as np
import torch
from torch import nn
from torchvision import models,transforms
import torch.optim as optim
import wandb
from sklearn.model_selection import GroupKFold

# data 
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# local
from nerf_qa.DISTS_pytorch.DISTS_pt import DISTS, prepare_image

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
DATA_DIR = "/home/ccl/Datasets/NeRF-QA"
REF_DIR = path.join(DATA_DIR, "Reference")
SYN_DIR = path.join(DATA_DIR, "NeRF-QA_videos")
SCORE_FILE = path.join(DATA_DIR, "NeRF_VQA_MOS.csv")

In [3]:
import torch
import torch.nn as nn

class VQAModel(nn.Module):
    def __init__(self, dists_mean, dists_std):
        super(VQAModel, self).__init__()
        self.dists_model = DISTS()
        # Initialize mean and std as trainable parameters
        self.dists_mean = nn.Parameter(torch.tensor([dists_mean], dtype=torch.float32))
        self.dists_std = nn.Parameter(torch.tensor([-dists_std], dtype=torch.float32))

    def compute_dists_with_batches(self, ref_batches, dist_batches):
        all_scores = []  # Collect scores from all batches as tensors

        for ref_batch, dist_batch in zip(ref_batches, dist_batches):
            ref_images = ref_batch[0].to(device)  # Assuming ref_batch[0] is the tensor of images
            dist_images = dist_batch[0].to(device)  # Assuming dist_batch[0] is the tensor of images
            scores = self.dists_model(ref_images, dist_images, require_grad=True, batch_average=False)  # Returns a tensor of scores
            
            # Collect scores tensors
            all_scores.append(scores)

        # Concatenate all score tensors into a single tensor
        all_scores_tensor = torch.cat(all_scores, dim=0)

        # Compute the average score across all batches
        average_score = torch.mean(all_scores_tensor) if all_scores_tensor.numel() > 0 else torch.tensor(0.0).to(device)

        return average_score
        
    def forward(self, ref_batches, dist_batches):
        raw_scores = self.compute_dists_with_batches(ref_batches, dist_batches)
        
        # Normalize raw scores using the trainable mean and std
        normalized_scores = (raw_scores - self.dists_mean) / self.dists_std
        return normalized_scores


In [4]:
# Read the CSV file
scores_df = pd.read_csv(SCORE_FILE)

loss_fn = nn.MSELoss()


# Initialize a new run with wandb
wandb.init(project='nerf-qa', config={
    "seed": 42,
    "resize": True,
    "epochs": 200,
    "batch_size": 4,
    "forward_batch_size": 64,
    "optimizer": {
        "type": "adam",
        "lr": 1e-5,
        "eps": 1e-8,
        "beta1": 0.9,
        "beta2": 0.999,
    },
})
config = wandb.config

# Number of splits for GroupKFold
unique_references_count = scores_df['reference_filename'].nunique()
num_folds = min(unique_references_count, 8)

# Example function to load a video and process it frame by frame
def load_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        # Convert frame to RGB (from BGR) and then to tensor
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
        frame = transforms.ToPILImage()(frame)
        frame = prepare_image(frame, resize=config.resize).squeeze(0)
        frames.append(frame)
    cap.release()
    return torch.stack(frames)

# Batch creation function
def create_batches(frames, forward_batch_size):
    # Create a dataset and dataloader for efficient batching
    dataset = TensorDataset(frames)
    dataloader = DataLoader(dataset, batch_size=forward_batch_size, shuffle=False)
    return dataloader



# Initialize GroupKFold
gkf = GroupKFold(n_splits=num_folds)

# Extract reference filenames as groups for GroupKFold
groups = scores_df['reference_filename'].values

global_step = 0
plccs = []
srccs = []
rsmes = []

# Group K-Fold Cross-Validation
for fold, (train_idx, val_idx) in enumerate(gkf.split(scores_df, groups=groups), 1):
    print(f"Fold {fold}/{num_folds}")
    
    # Split the data into training and validation sets
    train_df = scores_df.iloc[train_idx]
    val_df = scores_df.iloc[val_idx]

    # Shuffle train_df with random seed
    train_df = train_df.sample(frac=1, random_state=config.seed).reset_index(drop=True)

    # Compute the mean and standard deviation for MOS and DISTS columns
    mos_mean = train_df['MOS'].mean()
    mos_std = train_df['MOS'].std()
    dists_mean = train_df['DISTS'].mean()
    dists_std = train_df['DISTS'].std()
    print(f"Fold Stats: MOS ({mos_mean}, {mos_std}) DISTS ({dists_mean}, {dists_std})")
    
    # Reset model and optimizer for each fold (if you want to start fresh for each fold)
    model = VQAModel(dists_mean=dists_mean, dists_std=dists_std).to(device)
    betas = (config.optimizer['beta1'], config.optimizer['beta2'])
    optimizer = optim.Adam(model.parameters(), lr=config.optimizer['lr'], betas=betas, eps=config.optimizer['eps'])
    
    train_size = train_df.shape[0]
    val_size = val_df.shape[0]

    # Training loop
    for epoch in range(wandb.config.epochs):
        model.train()  # Set model to training mode
        total_loss = 0
        batch_loss = 0
        optimizer.zero_grad()  # Initialize gradients to zero at the start of each epoch

        for index, (i, row) in enumerate(train_df.iterrows(), 1):  # Start index from 1 for easier modulus operation
            # Load frames
            dist_video_path = path.join(SYN_DIR, row['distorted_filename'])
            ref_video_path = path.join(REF_DIR, row['reference_filename'])
            ref = load_video_frames(ref_video_path)
            dist = load_video_frames(dist_video_path)
            ref = create_batches(ref, config.forward_batch_size)
            dist = create_batches(dist, config.forward_batch_size)
            
            # Compute score
            predicted_score = model(ref, dist)
            
            # Normalize scores
            target_score_normalized = torch.tensor((row['MOS'] - mos_mean) / mos_std, device=device, dtype=torch.float32)
            
            # Compute loss
            loss = loss_fn(predicted_score, target_score_normalized)
            
            # Accumulate gradients
            loss.backward()
            total_loss += loss.item()
            batch_loss += loss.item()
            print("normalized scores:", target_score_normalized.item(), predicted_score.item())
            
            if index % config.batch_size == 0 or index == train_size:

                # Scale gradients
                accumulation_steps = ((index-1) % config.batch_size) + 1
                global_step += accumulation_steps
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad /= accumulation_steps
                
                # Update parameters every batch_size steps or on the last iteration
                optimizer.step()
                optimizer.zero_grad()  # Zero the gradients after updating
                average_batch_loss = batch_loss / config.batch_size
                wandb.log({
                    "Train Metrics Dict/batch_loss": average_batch_loss,
                    "Train Metrics Dict/rmse": np.sqrt(average_batch_loss),
                    }, step=global_step)
                batch_loss = 0
        
        # Validation step
        model.eval()  # Set model to evaluation mode
        with torch.no_grad():
            eval_loss = 0
            all_rmse = []
            all_target_scores = []  # List to store all target scores
            all_predicted_scores = []  # List to store all predicted scores

            for index, row in val_df.iterrows():
                # Load frames
                dist_video_path = path.join(SYN_DIR, row['distorted_filename'])
                ref_video_path = path.join(REF_DIR, row['reference_filename'])
                ref = load_video_frames(ref_video_path)
                dist = load_video_frames(dist_video_path)
                ref = create_batches(ref, config.forward_batch_size)
                dist = create_batches(dist, config.forward_batch_size)
                
                # Compute score
                predicted_score = model(ref, dist)
                all_predicted_scores.append(float(predicted_score.item()))
                
                # Normalize scores
                target_score_normalized = torch.tensor((row['MOS'] - mos_mean) / mos_std, device=device, dtype=torch.float32)
                all_target_scores.append(float(target_score_normalized.item()))
                
                # Compute loss
                loss = loss_fn(predicted_score, target_score_normalized)
                eval_loss += loss.item()
                all_rmse.append(float(np.sqrt(loss.item())))

            
            # Convert lists to arrays for correlation computation
            all_target_scores = np.array(all_target_scores)
            all_predicted_scores = np.array(all_predicted_scores)
            
            # Compute PLCC and SRCC
            plcc = pearsonr(all_target_scores, all_predicted_scores)[0]
            srcc = spearmanr(all_target_scores, all_predicted_scores)[0]
            
            # Average loss over validation set
            eval_loss /= len(val_df)
            rsme = np.mean(all_rmse)

            if epoch == wandb.config.epochs-1:
                # last epoch
                plccs.append(float(plcc))
                srccs.append(float(srcc))
                rsmes.append(float(rsme))

            # Log to wandb
            wandb.log({
                "Eval Metrics Dict/batch_loss": eval_loss,
                "Eval Metrics Dict/rmse": rsme,
                "Eval Metrics Dict/rmse_hist": wandb.Histogram(np.array(all_rmse)),
                "Eval Metrics Dict/plcc": plcc,
                "Eval Metrics Dict/srcc": srcc,
            }, step=global_step)

            
        # Logging the average loss
        average_loss = total_loss / len(scores_df)
        print(f"Epoch {epoch+1}, Average Loss: {average_loss}")
        wandb.log({ "Train Metrics Dict/total_loss": average_batch_loss }, step=global_step)

weighted_score = -1.0 * np.mean(rsmes) + 1.0 * np.mean(plccs) + 1.0 * np.mean(srccs)
# Log to wandb
wandb.log({
    "Eval Metrics Dict/weighted_score_cv_mean": weighted_score,
    "Eval Metrics Dict/rmse_cv_mean": np.mean(rsmes),
    "Eval Metrics Dict/rmse_cv_std": np.std(rsmes),
    "Eval Metrics Dict/rmse_cv_hist": wandb.Histogram(np.array(rsmes)),
    "Eval Metrics Dict/plcc_cv_mean": np.mean(plccs),
    "Eval Metrics Dict/plcc_cv_std": np.std(plccs),
    "Eval Metrics Dict/plcc_cv_hist": wandb.Histogram(np.array(plccs)),
    "Eval Metrics Dict/srcc_cv_mean": np.mean(srccs),
    "Eval Metrics Dict/srcc_cv_std": np.std(srccs),
    "Eval Metrics Dict/srcc_cv_hist": wandb.Histogram(np.array(srccs)),
}, step=global_step)



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mkobejean[0m ([33maizu-nerf[0m). Use [1m`wandb login --relogin`[0m to force relogin


Fold 1/8
Fold Stats: MOS (2.9460675, 1.0108160114069558) DISTS (0.1080630695174138, 0.05652211786027942)


  return F.mse_loss(input, target, reduction=self.reduction)


normalized scores: -0.6653708219528198 -0.09004860371351242
normalized scores: 0.5071471929550171 0.08984322845935822
normalized scores: 0.8732870221138 0.21695443987846375
normalized scores: 0.52683424949646 0.03394448012113571
normalized scores: 1.387524962425232 1.0153530836105347
normalized scores: 0.4404684007167816 0.9531010985374451
normalized scores: -1.353824496269226 -1.8891092538833618
normalized scores: 0.9242359399795532 -0.3183012306690216
normalized scores: -1.558510661125183 -1.6944804191589355
normalized scores: 0.877738893032074 0.8796359896659851
normalized scores: 0.37487781047821045 0.4416847229003906
normalized scores: 1.4264044761657715 1.3944883346557617
normalized scores: -1.4700672626495361 -0.18274003267288208
normalized scores: -0.8767842054367065 -0.19497105479240417
normalized scores: -0.5009492039680481 -0.7245040535926819
normalized scores: 0.6589057445526123 1.318225622177124
normalized scores: 0.3492549657821655 0.53061842918396
normalized scores: 0.77

In [None]:
wandb.log({
    "Eval Metrics Dict/rmse_cv_mean": np.mean(rsmes),
    "Eval Metrics Dict/plcc_cv_mean": np.mean(plccs),
    "Eval Metrics Dict/srcc_cv_mean": np.mean(srccs),
}, step=global_step)

wandb.finish()



0,1
Eval Metrics Dict/batch_loss,▅█▇▁
Eval Metrics Dict/plcc,█▇▃▁
Eval Metrics Dict/plcc_cv_mean,▁
Eval Metrics Dict/plcc_cv_std,▁
Eval Metrics Dict/rmse,▆█▅▁
Eval Metrics Dict/rmse_cv_mean,▁
Eval Metrics Dict/rmse_cv_std,▁
Eval Metrics Dict/srcc,█▁█▆
Eval Metrics Dict/srcc_cv_mean,▁
Eval Metrics Dict/srcc_cv_std,▁

0,1
Eval Metrics Dict/batch_loss,0.63753
Eval Metrics Dict/plcc,0.81039
Eval Metrics Dict/plcc_cv_mean,0.84049
Eval Metrics Dict/plcc_cv_std,0.0301
Eval Metrics Dict/rmse,0.60563
Eval Metrics Dict/rmse_cv_mean,0.74944
Eval Metrics Dict/rmse_cv_std,0.14381
Eval Metrics Dict/srcc,0.8087
Eval Metrics Dict/srcc_cv_mean,0.79696
Eval Metrics Dict/srcc_cv_std,0.01174
