<a href="https://www.kaggle.com/code/kkm121121/train-clearsight?scriptVersionId=299311530" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
!pip install stable-baselines3[extra] ultralytics gymnasium opencv-python-headless

In [None]:
"""
ClearSight-RL: Oracle AI Training Script (GPU Optimized)

This script trains an AI agent to clean up bad-weather photos (like fog or low light) so that computer vision systems can see objects better. 
It works by comparing a foggy image to a clear "answer key" image. The agent tries out different photo filters (like brightness, contrast, and sharpening). 
If its edits help a pre-trained YOLO object detector find the exact same objects as the clear image, the agent gets a reward. 
We are training it for a very long time (150,000 steps) so it learns to be brave and try complex filter combinations instead of giving up early.
"""

import os
import cv2
import numpy as np
import scipy.stats
import torch
import time
import datetime
import gymnasium as gym
from gymnasium import spaces
from ultralytics import YOLO
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback, CallbackList
from stable_baselines3.common.vec_env import SubprocVecEnv

device = "cuda" if torch.cuda.is_available() else "cpu"

class ETACallback(BaseCallback):
    """
    Custom Kaggle-safe callback to print live ETA and Progress.
    """
    def __init__(self, total_timesteps, print_freq=2048, verbose=0):
        super().__init__(verbose)
        self.total_timesteps = total_timesteps
        self.print_freq = print_freq
        self.start_time = None

    def _on_training_start(self) -> None:
        self.start_time = time.time()

    def _on_step(self) -> bool:
        if self.num_timesteps % self.print_freq == 0:
            elapsed_time = time.time() - self.start_time
            fps = self.num_timesteps / elapsed_time
            remaining_steps = self.total_timesteps - self.num_timesteps
            eta_seconds = remaining_steps / fps
            
            elapsed_str = str(datetime.timedelta(seconds=int(elapsed_time)))
            eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
            
            percentage = (self.num_timesteps / self.total_timesteps) * 100
            print(f"[{percentage:.1f}%] Steps: {self.num_timesteps}/{self.total_timesteps} | "
                  f"Speed: {int(fps)} it/s | Elapsed: {elapsed_str} | ETA: {eta_str}")
        return True

def calculate_iou(box1, box2):
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - intersection
    return intersection / union if union > 0 else 0

class ClearSightOracleEnv(gym.Env):
    def __init__(self, clear_dir=None, hazy_dir=None, max_steps=5):
        super().__init__()
        self.clear_dir = clear_dir
        self.hazy_dir = hazy_dir
        self.image_pairs = []
        
        if self.clear_dir and self.hazy_dir and os.path.exists(self.clear_dir) and os.path.exists(self.hazy_dir):
            clear_files = [f for f in os.listdir(self.clear_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            clear_map = {os.path.splitext(f)[0]: f for f in clear_files}
            hazy_files = [f for f in os.listdir(self.hazy_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            for hazy_f in hazy_files:
                base_name = os.path.splitext(hazy_f)[0]
                if base_name in clear_map:
                    self.image_pairs.append({
                        'clear': os.path.join(self.clear_dir, clear_map[base_name]), 
                        'hazy': os.path.join(self.hazy_dir, hazy_f)
                    })
                elif base_name.split('_')[0] in clear_map:
                    self.image_pairs.append({
                        'clear': os.path.join(self.clear_dir, clear_map[base_name.split('_')[0]]), 
                        'hazy': os.path.join(self.hazy_dir, hazy_f)
                    })

        if len(self.image_pairs) == 0:
            print("Warning: No paired images found. Generating synthetic dataset.")
            self._generate_synthetic_fallback()
            
        print(f"Loaded {len(self.image_pairs)} paired images.")
        self.yolo = None
        
        self.max_steps = max_steps
        self.action_space = spaces.Discrete(6)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)
        self.current_image = None
        self.oracle_boxes = []
        self.current_step = 0
        self.base_oracle_score = 0.0
        self.current_oracle_score = 0.0

    def _generate_synthetic_fallback(self):
        os.makedirs('/tmp/clearsight_dummy/clear', exist_ok=True)
        os.makedirs('/tmp/clearsight_dummy/hazy', exist_ok=True)
        img = np.ones((480, 640, 3), dtype=np.uint8) * 200
        cv2.circle(img, (320, 240), 50, (0, 0, 255), -1) 
        clear_path = '/tmp/clearsight_dummy/clear/fallback_001.jpg'
        cv2.imwrite(clear_path, img)
        hazy_img = cv2.addWeighted(img, 0.4, np.ones_like(img)*255, 0.6, 0)
        hazy_path = '/tmp/clearsight_dummy/hazy/fallback_001.jpg'
        cv2.imwrite(hazy_path, hazy_img)
        self.image_pairs.append({'clear': clear_path, 'hazy': hazy_path})

    def _extract_features(self, img):
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        mean_brightness = np.mean(gray)
        contrast = np.std(gray)
        hist = np.histogram(gray, bins=256, range=(0, 256))[0]
        hist = hist / (hist.sum() + 1e-5)
        entropy = scipy.stats.entropy(hist + 1e-5)
        blur = cv2.Laplacian(gray, cv2.CV_64F).var()
        return np.array([mean_brightness, contrast, entropy, blur], dtype=np.float32)

    def _calculate_oracle_reward(self, img):
        if getattr(self, 'yolo', None) is None:
            self.yolo = YOLO('yolov8n.pt')
            self.yolo.to(device)
            
        results = self.yolo(img, verbose=False, device=device)[0]
        curr_boxes = [b for b in results.boxes if b.conf[0].item() > 0.25]
        reward = 0.0
        matched_oracle_idx = set()
        
        for c_box in curr_boxes:
            c_conf = c_box.conf[0].item()
            c_cls = int(c_box.cls[0].item())
            c_xyxy = c_box.xyxy[0].cpu().numpy()
            best_iou = 0
            best_idx = -1
            
            for i, o_box in enumerate(self.oracle_boxes):
                if i in matched_oracle_idx: continue
                if int(o_box.cls[0].item()) != c_cls: continue
                o_xyxy = o_box.xyxy[0].cpu().numpy()
                iou = calculate_iou(c_xyxy, o_xyxy)
                if iou > best_iou:
                    best_iou = iou
                    best_idx = i
                    
            if best_iou > 0.4:
                reward += c_conf
                matched_oracle_idx.add(best_idx)
            else:
                reward -= (c_conf * 1.5) 
        return reward

    def _apply_action(self, img, action):
        if action == 1: return cv2.convertScaleAbs(img, alpha=1.1, beta=15)
        elif action == 2: return cv2.convertScaleAbs(img, alpha=0.9, beta=-15)
        elif action == 3:
            lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
            l, a, b = cv2.split(lab)
            cl = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(l)
            return cv2.cvtColor(cv2.merge((cl, a, b)), cv2.COLOR_LAB2BGR)
        elif action == 4:
            kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
            return cv2.filter2D(img, -1, kernel)
        elif action == 5:
            return cv2.bilateralFilter(img, 9, 75, 75)
        return img

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = 0
        
        if getattr(self, 'yolo', None) is None:
            self.yolo = YOLO('yolov8n.pt')
            self.yolo.to(device)
            
        if self.image_pairs:
            idx = self.np_random.integers(0, len(self.image_pairs))
            clear_img = cv2.imread(self.image_pairs[idx]['clear'])
            self.current_image = cv2.imread(self.image_pairs[idx]['hazy'])
            oracle_results = self.yolo(clear_img, verbose=False, device=device)[0]
            self.oracle_boxes = [b for b in oracle_results.boxes if b.conf[0].item() > 0.3]
        else:
            self.current_image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
            self.oracle_boxes = []
            
        self.base_oracle_score = self._calculate_oracle_reward(self.current_image)
        self.current_oracle_score = self.base_oracle_score
        return self._extract_features(self.current_image), {}

    def step(self, action):
        self.current_step += 1
        terminated = False
        truncated = False
        reward = 0.0
        
        if action == 0:
            terminated = True
            if self.current_oracle_score > self.base_oracle_score * 1.2:
                reward += 3.0 
        else:
            self.current_image = self._apply_action(self.current_image, action)
            new_oracle_score = self._calculate_oracle_reward(self.current_image)
            delta_score = new_oracle_score - self.current_oracle_score
            reward += delta_score
            self.current_oracle_score = new_oracle_score
            reward -= 0.1 

        if self.current_step >= self.max_steps:
            truncated = True
            
        info = {"step": self.current_step, "oracle_score": self.current_oracle_score}
        return self._extract_features(self.current_image), reward, terminated, truncated, info

if __name__ == "__main__":
    CLEAR_DIR = '/kaggle/input/datasets/brunobelloni/outdoor-training-set-ots-reside/clear'
    HAZY_DIR = '/kaggle/input/datasets/brunobelloni/outdoor-training-set-ots-reside/hazy'
    
    print(f"Hardware detected: {device.upper()}")
    print("Initializing multi-processing environment...")
    
    def make_env():
        return ClearSightOracleEnv(clear_dir=CLEAR_DIR, hazy_dir=HAZY_DIR, max_steps=5)
    
    vec_env = make_vec_env(make_env, n_envs=8, vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "spawn"})
    eval_env = make_vec_env(make_env, n_envs=1, vec_env_cls=SubprocVecEnv, vec_env_kwargs={"start_method": "spawn"})
    eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/', log_path='./logs/', eval_freq=2000, deterministic=True, render=False)
    
    total_steps = 150000
    eta_callback = ETACallback(total_timesteps=total_steps, print_freq=2048)
    callback_list = CallbackList([eta_callback, eval_callback])
    
    print("Configuring PPO Agent...")
    model = PPO("MlpPolicy", vec_env, verbose=1, learning_rate=3e-4, n_steps=1024, batch_size=256, n_epochs=10, gamma=0.99, device=device)
    
    print("Starting Training Phase...")
    print("NOTE: The AI is collecting its first batch of images. The custom ETA tracker will print updates shortly...")
    start_time = time.time()
    model.learn(total_timesteps=total_steps, callback=callback_list)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    formatted_time = str(datetime.timedelta(seconds=int(elapsed_time)))
    
    print(f"Training Complete in {formatted_time}. Saving model...")
    model.save("clearsight_agent_ots_oracle")