In [None]:
!pip install osfclient --quiet
!pip install plotnine --quiet
!pip install git+https://github.com/jspsych/eyetracking-utils.git --quiet

: 

In [None]:
import os
from google.colab import userdata
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import argparse

import osfclient

import et_util.dataset_utils as dataset_utils

In [None]:
os.environ['OSF_TOKEN'] = userdata.get('osftoken')
os.environ['OSF_USERNAME'] = userdata.get('osfusername')

In [None]:
!osf -p uf2sh fetch single_eye_tfrecords.tar.gz

In [None]:
!mkdir single_eye_tfrecords
!tar -xf single_eye_tfrecords.tar.gz -C single_eye_tfrecords

In [None]:

# Path to directory containing TFRecord files
TFRECORDS_PATH = 'single_eye_tfrecords'

# Path to save CSV output
OUTPUT_CSV = 'eye_image_quality_metrics.csv'

# Maximum number of subjects to process (None for all)
MAX_SUBJECTS = 10  # Change to None to process all subjects

# OSF token for uploading results (None to skip upload)
OSF_TOKEN = None  # Set to your token if you want to upload to OSF

# OSF project ID for uploading results
OSF_PROJECT_ID = None  # Set to your project ID if you want to upload

In [None]:
def parse(element):
    data_structure = {
        'landmarks': tf.io.FixedLenFeature([], tf.string),
        'img_width': tf.io.FixedLenFeature([], tf.int64),
        'img_height': tf.io.FixedLenFeature([], tf.int64),
        'x': tf.io.FixedLenFeature([], tf.float32),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'eye_img': tf.io.FixedLenFeature([], tf.string),
        'phase': tf.io.FixedLenFeature([], tf.int64),
        'subject_id': tf.io.FixedLenFeature([], tf.int64),
    }

    content = tf.io.parse_single_example(element, data_structure)

    #landmarks = content['landmarks']
    raw_image = content['eye_img']
    width = content['img_width']
    height = content['img_height']
    phase = content['phase']
    depth = 3
    coords = [content['x'], content['y']]
    subject_id = content['subject_id']

    image = tf.io.parse_tensor(raw_image, out_type=tf.uint8)

    return image, phase, coords, subject_id

In [None]:
test_data, _, _ = dataset_utils.process_tfr_to_tfds(
    'single_eye_tfrecords/',
    parse,
    train_split=1.0,
    val_split=0.0,
    test_split=0.0,
    random_seed=12604,
    group_function=lambda img, phase, coords, subject_id: subject_id
)

In [None]:
class ImageQualityMetrics:
    """Class to calculate various image quality metrics"""
    
    @staticmethod
    def to_numpy(img_tensor):
        """Convert TensorFlow tensor to numpy array"""
        if isinstance(img_tensor, tf.Tensor):
            img = img_tensor.numpy()
        else:
            img = img_tensor
        # Ensure grayscale images have proper dimensions
        if len(img.shape) == 2:
            img = np.expand_dims(img, axis=-1)
        return img
    
    @staticmethod
    def brightness(img):
        """Calculate mean brightness of the image"""
        img = ImageQualityMetrics.to_numpy(img)
        return float(np.mean(img))
    
    @staticmethod
    def contrast(img):
        """Calculate contrast as standard deviation of pixel values"""
        img = ImageQualityMetrics.to_numpy(img)
        return float(np.std(img))
    
    @staticmethod
    def entropy(img):
        """Calculate image entropy (information content)"""
        img = ImageQualityMetrics.to_numpy(img)
        img_flat = img.flatten()
        hist, _ = np.histogram(img_flat, bins=256, range=(0, 255), density=True)
        hist = hist[hist > 0]  # Remove zero counts
        return float(-np.sum(hist * np.log2(hist)))
    
    @staticmethod
    def laplacian_variance(img):
        """Calculate variance of the Laplacian (measure of focus/sharpness)"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 
        else:
            img_gray = img.squeeze()
        return float(cv2.Laplacian(img_gray, cv2.CV_64F).var())
    
    @staticmethod
    def gradient_magnitude(img):
        """Calculate mean gradient magnitude (edge strength)"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        sobelx = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)
        gradient_magnitude = np.sqrt(sobelx**2 + sobely**2)
        return float(np.mean(gradient_magnitude))
    
    @staticmethod
    def blur_detection(img):
        """Just Noticeable Blur (JNB) measure - higher values indicate less blur"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Apply Laplacian filter
        laplacian = cv2.Laplacian(img_gray, cv2.CV_64F)
        
        # Calculate mean and standard deviation of Laplacian
        mean, std = cv2.meanStdDev(laplacian)
        
        # Calculate normalized blur measure (higher value = less blur)
        blur_measure = float(std[0][0]**2)
        
        return blur_measure
    
    @staticmethod
    def noise_estimation(img):
        """Estimate image noise level"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Apply median filter (noise reduction)
        median_filtered = cv2.medianBlur(img_gray.astype(np.uint8), 3)
        
        # Calculate noise as difference between original and filtered
        noise = np.abs(img_gray - median_filtered)
        
        return float(np.mean(noise))
    
    @staticmethod
    def local_contrast(img, region_size=5):
        """Calculate local contrast in eye region"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Assuming eye region is in center - calculate local contrast there
        h, w = img_gray.shape
        center_y, center_x = h // 2, w // 2
        y1 = max(0, center_y - region_size)
        y2 = min(h, center_y + region_size)
        x1 = max(0, center_x - region_size)
        x2 = min(w, center_x + region_size)
        
        region = img_gray[y1:y2, x1:x2]
        return float(np.std(region))
    
    @staticmethod
    def dark_channel_prior(img):
        """Calculate dark channel prior (often used for haze detection)"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Apply min filter to estimate dark channel
        kernel_size = 3
        pad_size = kernel_size // 2
        padded = np.pad(img_gray, pad_size, mode='edge')
        dark_channel = np.zeros_like(img_gray)
        
        for i in range(img_gray.shape[0]):
            for j in range(img_gray.shape[1]):
                patch = padded[i:i+kernel_size, j:j+kernel_size]
                dark_channel[i, j] = np.min(patch)
        
        return float(np.mean(dark_channel))
    
    @staticmethod
    def specular_reflection_count(img, threshold=240):
        """Count potential specular reflections (bright spots)"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Threshold to find bright spots
        bright_spots = (img_gray > threshold).astype(np.uint8)
        
        # Find connected components
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(bright_spots, connectivity=8)
        
        # Count spots above a minimum size (to ignore noise)
        min_area = 5
        num_spots = 0
        for i in range(1, num_labels):  # Skip background (label 0)
            if stats[i, cv2.CC_STAT_AREA] >= min_area:
                num_spots += 1
        
        return num_spots
    
    @staticmethod
    def pupil_detection_score(img):
        """Estimate how well we can detect a pupil (using basic circle detection)"""
        img = ImageQualityMetrics.to_numpy(img)
        if img.shape[-1] == 3:
            img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        else:
            img_gray = img.squeeze()
        
        # Apply Gaussian blur to reduce noise
        blurred = cv2.GaussianBlur(img_gray, (5, 5), 0)
        
        # Use Hough Circle Transform to detect circles
        # The parameters may need to be adjusted based on your images
        try:
            circles = cv2.HoughCircles(
                blurred, 
                cv2.HOUGH_GRADIENT, 
                dp=1, 
                minDist=20, 
                param1=50, 
                param2=30, 
                minRadius=5, 
                maxRadius=25
            )
            
            # Score based on number of circles and strength of detection
            if circles is not None:
                circles = circles[0]
                # Return score based on strength of best circle
                if len(circles) > 0:
                    return float(circles[0][2])  # Use radius as a proxy for confidence
                else:
                    return 0.0
            else:
                return 0.0
        except:
            # If any error occurs, return 0
            return 0.0
    
    @staticmethod
    def compute_all_metrics(img):
        """Compute all image quality metrics and return as dictionary"""
        metrics = {
            'brightness': ImageQualityMetrics.brightness(img),
            'contrast': ImageQualityMetrics.contrast(img),
            'entropy': ImageQualityMetrics.entropy(img),
            'laplacian_variance': ImageQualityMetrics.laplacian_variance(img),
            'gradient_magnitude': ImageQualityMetrics.gradient_magnitude(img),
            'blur_detection': ImageQualityMetrics.blur_detection(img),
            'noise_estimation': ImageQualityMetrics.noise_estimation(img),
            'local_contrast': ImageQualityMetrics.local_contrast(img),
            'dark_channel_prior': ImageQualityMetrics.dark_channel_prior(img),
            'specular_reflection_count': ImageQualityMetrics.specular_reflection_count(img),
            'pupil_detection_score': ImageQualityMetrics.pupil_detection_score(img)
        }
        return metrics

In [None]:
def process_subject_images(dataset, subject_id):
    """Process all images for a given subject and compute metrics"""
    # Filter dataset to only include this subject
    subject_dataset = dataset.filter(lambda img, phase, coords, sid: tf.equal(sid, subject_id))
    
    # Initialize results
    results = []
    
    # Process each image
    for i, (img, phase, coords, _) in enumerate(subject_dataset):
        # Convert to numpy and ensure format is correct
        img_np = img.numpy()
        
        # Compute all metrics
        metrics = ImageQualityMetrics.compute_all_metrics(img_np)
        
        # Add metadata
        metrics.update({
            'subject_id': subject_id.numpy(),
            'image_index': i,
            'phase': phase.numpy(),
            'coord_x': coords[0].numpy(),
            'coord_y': coords[1].numpy()
        })
        
        # Append to results
        results.append(metrics)
    
    return results

def process_all_subjects(dataset, subject_ids, max_subjects=None):
    """Process images for all subjects and save metrics"""
    all_results = []
    
    # Limit number of subjects if specified
    if max_subjects:
        print(f"Limiting processing to {max_subjects} subjects")
        subject_ids = subject_ids[:max_subjects]
    
    # Process each subject
    for subject_id in tqdm(subject_ids, desc="Processing subjects"):
        subject_results = process_subject_images(dataset, tf.constant(subject_id))
        all_results.extend(subject_results)
        
        # Print occasional updates
        if len(all_results) % 1000 == 0:
            print(f"Processed {len(all_results)} images so far...")
    
    # Convert to dataframe
    df = pd.DataFrame(all_results)
    
    return df

def save_and_upload_results(df, filename, osf_token=None, osf_project_id=None):
    """Save results to CSV and upload to OSF if credentials provided"""
    # Save to CSV
    df.to_csv(filename, index=False)
    print(f"Saved results to {filename}")
    
    # Upload to OSF if credentials provided
    if OSF_AVAILABLE and osf_token and osf_project_id:
        try:
            osf = OSF(token=osf_token)
            project = osf.project(osf_project_id)
            
            # Upload file
            with open(filename, 'rb') as f:
                project.storage().create_file(filename, f, update=True)
            
            print(f"Uploaded {filename} to OSF project {osf_project_id}")
        except Exception as e:
            print(f"Error uploading to OSF: {e}")
    else:
        if not OSF_AVAILABLE:
            print("OSF upload skipped: osfclient not installed")
        else:
            print("OSF upload skipped: credentials not provided")
    
    return filename

def visualize_metrics(df, output_file="metric_correlations.png"):
    """Create visualization of metric correlations"""
    plt.figure(figsize=(12, 10))
    correlation_matrix = df.select_dtypes(include=[np.number]).drop(
        columns=['subject_id', 'image_index', 'phase']).corr()
    plt.matshow(correlation_matrix, fignum=1)
    plt.colorbar()
    plt.xticks(range(len(correlation_matrix.columns)), correlation_matrix.columns, rotation=90)
    plt.yticks(range(len(correlation_matrix.columns)), correlation_matrix.columns)
    plt.title('Correlation between Image Quality Metrics')
    plt.tight_layout()
    plt.savefig(output_file)
    print(f"Saved correlation matrix visualization to {output_file}")

In [None]:
# Process TFRecords
dataset = process_tfrecords(TFRECORDS_PATH)

# Get list of all unique subject IDs
subject_ids = get_subject_id_list(dataset)

# Process images for all subjects
image_quality_df = process_all_subjects(dataset, subject_ids, max_subjects=MAX_SUBJECTS)

# Display summary statistics
print("\nSummary statistics of image quality metrics:")
print(image_quality_df.describe())

# Save and upload results
saved_file = save_and_upload_results(
    image_quality_df, 
    filename=OUTPUT_CSV, 
    osf_token=OSF_TOKEN, 
    osf_project_id=OSF_PROJECT_ID
)

# Visualize relationships between metrics
visualize_metrics(image_quality_df)