In [2]:
from ultralytics import YOLO, RTDETR
import os
import pdb
import cv2
from PIL import Image
import numpy as np
import pandas as pd
import pickle
import sys
from ensemble_boxes import *
import torch # Added
import torch.nn as nn # Added
import albumentations as A # Added
from albumentations.pytorch.transforms import ToTensorV2 # Added
from timm import create_model # Added

# --- Configuration ---

os.environ["ALBUMENTATIONS_DISABLE_VERSION_CHECK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
AUGMENT = True
RUN = 19
RUN2 = 19
loc = f'runs/detect/train{RUN}'
loc2 = f'runs/detect/train{RUN2}'
SZ = (928, 448)
SZ2 = (928, 448) # 928

THRESH = 0.5
IOU_THRESH = 0.2
CONF = 0.5 # 0.5
YOLO_IOU_THR = 0.35
EXT = 'jpg'

# Model paths
# Using the same model for both paths as per original code
model_path = "data/yolo_weights/best.pt"
model_path2 = "data/yolo_weights/best.pt"

# Directories
base_dir = "data"
test_dir = os.path.join(base_dir, "sample_jpg_images") # .jpg images
tif_dir = os.path.join(base_dir, "test_images/Images") # .tif images (assuming same dir based on original code)
crops_dir = os.path.join(base_dir, "cropped_images")
annotated_dir = os.path.join(base_dir, "annotated_images") # Directory for output plots
data_pkl_path = os.path.join(base_dir, "data_pkl_path.pkl")#'/kaggle/working/data_with_bboxes.pkl' # Save bboxes here
submission_csv_path = os.path.join(base_dir, "sample_inference_output.csv") # Output submission file path
classifier_model_path = os.path.join(base_dir, "classifier_weights/efficientvit_b0.r224_in1k_3_patch_model_fold_0.pth")
ndvi_csv_path = os.path.join(base_dir, "test_ndvi.csv")

# Create directories
os.makedirs(crops_dir, exist_ok=True)
os.makedirs(annotated_dir, exist_ok=True)

# --- Helper Functions (from original code) ---
def get_areas(boxes):
    areas = []
    for b in boxes:
        arr = abs(b[2] - b[0]) * abs(b[3] - b[1])
        areas.append(arr)
    return areas

def get_results(res, threshold=0.5):
    res     = res.cpu().numpy()
    cls     = res.boxes.cls
    conf    = res.boxes.conf
    boxes   = res.boxes.xyxyn # Keep normalized coordinates for WBF
    mask    = conf > threshold
    conf = conf[mask]
    cls = cls[mask]
    boxes = boxes[mask]
    return cls, boxes, conf

# WBF function (using weighted_boxes_fusion from original code)
def run_wbf(bboxes, confs,labels, image_size_w, image_size_h, iou_thr=0.50, skip_box_thr=0.0001, weights=None):
    # boxes need to be normalized for WBF
    boxes =  [bbox for bbox in bboxes]
    scores = [conf for conf in confs]
    labels_list = [lab for lab in labels] # Assuming labels are already in correct format
    boxes, scores, labels = weighted_boxes_fusion(boxes, scores, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
    # Convert boxes back to absolute coordinates
    boxes[:, 0] = boxes[:, 0] * image_size_w
    boxes[:, 1] = boxes[:, 1] * image_size_h
    boxes[:, 2] = boxes[:, 2] * image_size_w
    boxes[:, 3] = boxes[:, 3] * image_size_h
    return boxes, scores, labels


# --- Stage 1: Object Detection and Cropping ---
print("[INFO] - Starting Stage 1: Object Detection and Cropping...")
imgs = sorted([os.path.join(test_dir, f) for f in os.listdir(test_dir) if f.endswith(f'_pre_disaster.{EXT}')])
# tifs = [os.path.join(tif_dir, f) for f in os.listdir(tif_dir) if f.endswith(f'_pre_disaster.{EXT}')] # Not used directly later

model = YOLO(model_path)
model2 = YOLO(model_path2) # Using the same model weights again

data = {} # Dictionary to store crop info and bboxes

for im_pre_path in imgs:
    print(f". Processing {os.path.basename(im_pre_path)}", end="", flush=True)
    im_post_path = im_pre_path.replace('_pre_disaster', '_post_disaster')
    image_id = os.path.basename(im_pre_path).replace(f'_pre_disaster.{EXT}', '')

    # Check if post image exists
    if not os.path.exists(im_post_path):
        print(f"\nWarning: Post image not found for {image_id}, skipping.")
        continue

    data[image_id] = {'pre_crops': [], 'post_crops': [], 'bboxes': []} # Initialize structure

    try:
        # Read images AFTER checking existence
        img_pre = cv2.imread(im_pre_path)
        img_post = cv2.imread(im_post_path)

        if img_pre is None or img_post is None:
             print(f"\nWarning: Could not read pre or post image for {image_id}, skipping.")
             continue

        # Run inference on pre-disaster image
        results_pre = model(im_pre_path, imgsz=SZ, augment=AUGMENT, conf=CONF, iou=YOLO_IOU_THR, verbose=False)
        results_pre2 = model2(im_pre_path, imgsz=SZ2, augment=AUGMENT, conf=CONF, iou=YOLO_IOU_THR, verbose=False)

        if not results_pre or not results_pre2:
             print(f"\nWarning: No results from inference for {image_id}, skipping.")
             continue

        res_pre = results_pre[0]
        res_pre2 = results_pre2[0]
        orig_h, orig_w = res_pre.orig_shape # Get original dimensions

        ccls_pre, bbxs_pre_norm, cnfs_pre = get_results(res_pre, threshold=THRESH)
        ccls_pre2, bbxs_pre2_norm, cnfs_pre2 = get_results(res_pre2, threshold=THRESH)

        # Run Weighted Boxes Fusion (WBF) - Pass original dimensions
        boxes_abs, scores_pre, labels_pre = run_wbf(
                                         [bbxs_pre_norm, bbxs_pre2_norm],
                                         [cnfs_pre, cnfs_pre2],
                                         [ccls_pre, ccls_pre2], # Pass class labels
                                         orig_w, orig_h, # Pass original W and H
                                         iou_thr=IOU_THRESH
                                         )

        # Crop and save patches based on WBF results
        for i in range(len(boxes_abs)):
            x1, y1, x2, y2 = map(int, boxes_abs[i]) # Ensure integer coordinates

            # Clamp coordinates to image boundaries to avoid errors
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(orig_w, x2)
            y2 = min(orig_h, y2)

            # Ensure valid box dimensions
            if x1 >= x2 or y1 >= y2:
                #print(f"\nWarning: Invalid box dimensions for {image_id} box {i}, skipping crop.")
                continue

            crop_pre = img_pre[y1:y2, x1:x2]
            crop_post = img_post[y1:y2, x1:x2]

            # Basic check for empty crops
            if crop_pre.size == 0 or crop_post.size == 0:
                 #print(f"\nWarning: Empty crop generated for {image_id} box {i}, skipping crop.")
                 continue

            # Save crops
            pre_crop_fname = f"{image_id}_X_{i}_pre.jpg"
            post_crop_fname = f"{image_id}_X_{i}_post.jpg"
            cv2.imwrite(os.path.join(crops_dir, pre_crop_fname), crop_pre)
            cv2.imwrite(os.path.join(crops_dir, post_crop_fname), crop_post)

            # Store filenames and the corresponding absolute bounding box
            data[image_id]['pre_crops'].append(pre_crop_fname)
            data[image_id]['post_crops'].append(post_crop_fname)
            data[image_id]['bboxes'].append(boxes_abs[i]) # Store the absolute bbox coordinates

    except Exception as e:
        print(f"\nError processing {image_id}: {e}")
        # Clean up potentially partial data for this ID
        if image_id in data:
            del data[image_id]

print("\n[INFO] - Stage 1 finished. Saving crop data...")
with open(data_pkl_path, 'wb') as f:
    pickle.dump(data, f)
print(f"[INFO] - Crop data saved to {data_pkl_path}")

# --- Classifier Definition ---
CLASSIFIER_SZ = 128
CLASSIFIER_MODEL = 'efficientvit_b0.r224_in1k'

class DisasterClassifier(nn.Module):
    def __init__(self):
        super(DisasterClassifier, self).__init__()
        # Load base model without classifier head
        self.features = create_model(CLASSIFIER_MODEL, pretrained=False, num_classes=0) # Load weights later
        # Get feature dimension dynamically (safer)
        dummy_input = torch.randn(1, 3, CLASSIFIER_SZ, CLASSIFIER_SZ)
        dummy_output = self.features(dummy_input)
        feature_dim = dummy_output.shape[-1]

        # Siamese distance calculation and final layer including NDVI
        # Output dim of features * 2 (pre, post) + 1 (NDVI)? No, distance is 1 dim.
        # Distance (1) + NDVI (1) = 2 inputs to final FC layer
        self.fc = nn.Linear(1 + 1, 1) # 1 for distance, 1 for NDVI

    def euclidean_distance(self, x1, x2):
        # Ensure features are flattened if necessary or use appropriate distance for feature maps
        # Assuming features are vectors [batch_size, feature_dim]
        return torch.sqrt(torch.sum((x1 - x2)**2, dim=1, keepdim=True))

    def forward(self, pre_image, post_image, ndvi_val):
        pre_features = self.features(pre_image)
        post_features = self.features(post_image)

        # Calculate Euclidean distance between feature vectors
        distance = self.euclidean_distance(pre_features, post_features)

        # Concatenate distance and NDVI value
        # Ensure ndvi_val has the correct shape [batch_size, 1]
        combined_input = torch.cat([distance, ndvi_val], dim=1)

        # Final classification layer
        out = torch.sigmoid(self.fc(combined_input))
        return out

# --- Stage 2: Classification and Plotting ---
print("\n[INFO] - Starting Stage 2: Classification and Plotting...")

# Load NDVI data
try:
    test_df = pd.read_csv(ndvi_csv_path)
    test_df.set_index('id', inplace=True) # Set index for easier lookup
    print(f"[INFO] - Loaded NDVI data from {ndvi_csv_path}")
except Exception as e:
    print(f"Error loading NDVI data: {e}")
    sys.exit(1)


# Setup device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] - Using device: {device}")

classifier_model = DisasterClassifier()
try:
    classifier_model.load_state_dict(torch.load(classifier_model_path, map_location=device))
    print(f"[INFO] - Loaded classifier weights from {classifier_model_path}")
except Exception as e:
    print(f"Error loading classifier model weights: {e}")
    # Attempting to load pretrained weights for the feature extractor part if main load failed
    try:
        print("[INFO] - Attempting to load TIMM pretrained weights for feature extractor...")
        classifier_model.features = create_model(CLASSIFIER_MODEL, pretrained=True, num_classes=0)
        print("[INFO] - Loaded TIMM pretrained weights. Final layer remains untrained.")
    except Exception as e_timm:
        print(f"[ERROR] - Failed to load TIMM pretrained weights: {e_timm}. Cannot proceed.")
        sys.exit(1)


classifier_model = classifier_model.to(device)
classifier_model.eval()


# Define transformations
valid_transform = A.Compose([
    A.Resize(CLASSIFIER_SZ, CLASSIFIER_SZ),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], additional_targets={'image2': 'image'})
# Load the data with bboxes
print(f"[INFO] - Loading data from {data_pkl_path}")
try:
    with open(data_pkl_path, "rb") as f:
        data = pickle.load(f)
except FileNotFoundError:
    print(f"Error: Data file not found at {data_pkl_path}. Please run Stage 1 first.")
    sys.exit(1)

submission_ids = []
submission_damage = []
total_destroyed_count = 0

# Plotting configuration
destroyed_color = (0, 0, 255)  # Red in BGR
nondamaged_color = (0, 255, 0) # Green in BGR
alpha = 0.4  # Transparency factor for fill
outline_thickness = 2

# Process each image_id from the detected data
for image_id in data.keys():
    print(f"[INFO] - Processing & Plotting: {image_id}", flush=True)

    # Retrieve NDVI value for the current image_id
    try:
        ndvi_mean = test_df.loc[image_id, 'NDVI_mean']
        ndvi_tensor = torch.tensor([[ndvi_mean]], dtype=torch.float32).to(device) # Shape [1, 1]
    except KeyError:
        print(f"Warning: NDVI value not found for {image_id}. Using 0. Skipping classification?")
        # Decide how to handle missing NDVI: skip image, use default, etc.
        # For now, let's skip classification for this image_id if NDVI is missing.
        # We still need dummy entries for the submission file for this ID.
        print(f"Skipping classification and plotting for {image_id} due to missing NDVI.")
        submission_ids.extend([f"{image_id}_X_no_damage", f"{image_id}_X_minor_damage", f"{image_id}_X_major_damage", f"{image_id}_X_destroyed"])
        submission_damage.extend([0, 0, 0, 0]) # Append zeros if skipping
        continue # Move to the next image_id

    # Per-image results storage
    results_for_image = [] # List to store (bbox, is_destroyed) tuples
    damage_counts = {0: 0, 3: 0} # Store counts for no_damage (0) and destroyed (3)

    # Classify each crop associated with this image_id
    num_crops = len(data[image_id]['pre_crops'])
    if num_crops == 0:
        print(f"Warning: No crops found for {image_id}, adding zeros to submission.")
    else:
        for i in range(num_crops):
            pre_crop_fname = data[image_id]['pre_crops'][i]
            post_crop_fname = data[image_id]['post_crops'][i]
            bbox = data[image_id]['bboxes'][i] # Retrieve the corresponding bbox

            pre_crop_path = os.path.join(crops_dir, pre_crop_fname)
            post_crop_path = os.path.join(crops_dir, post_crop_fname)

            try:
                pre_image = cv2.imread(pre_crop_path)
                post_image = cv2.imread(post_crop_path)

                if pre_image is None or post_image is None:
                    print(f"Warning: Could not read crop {pre_crop_fname} or {post_crop_fname}. Skipping patch.")
                    # How to handle this? Maybe count as non-damaged? Or skip entirely?
                    # Skipping for now, it won't be plotted or counted.
                    continue

                # Preprocess images
                pre_image_rgb = cv2.cvtColor(pre_image, cv2.COLOR_BGR2RGB)
                post_image_rgb = cv2.cvtColor(post_image, cv2.COLOR_BGR2RGB)

                transformed = valid_transform(image=pre_image_rgb, image2=post_image_rgb)
                pre_tensor = transformed['image'].unsqueeze(0).to(device)
                post_tensor = transformed['image2'].unsqueeze(0).to(device)


                # Run classification inference
                with torch.no_grad():
                    output = classifier_model(pre_tensor, post_tensor, ndvi_tensor)
                
                # Check if destroyed (using 0.5 threshold as per original code)
                is_destroyed = output.item() > 0.5

                # Store result along with the bounding box
                results_for_image.append((bbox, is_destroyed))

                # Update counts
                if is_destroyed:
                    damage_counts[3] += 1
                    total_destroyed_count += 1
                else:
                    damage_counts[0] += 1

            except Exception as e:
                print(f"Error processing crop pair {pre_crop_fname}/{post_crop_fname}: {e}")
                continue # Skip this problematic crop pair


    # --- Plotting for the current image_id ---
    # Load the original post-disaster image
    post_img_path = os.path.join(test_dir, f"{image_id}_post_disaster.{EXT}")
    post_img = cv2.imread(post_img_path)

    if post_img is None:
        print(f"Warning: Could not read post image {post_img_path} for annotation.")
    elif not results_for_image:
         print(f"Info: No valid classification results for {image_id}, skipping annotation.")
    else:
        overlay = post_img.copy() # Create a copy for drawing transparent fills

        # Draw filled rectangles on the overlay
        for bbox, is_destroyed in results_for_image:
            x1, y1, x2, y2 = map(int, bbox) # Ensure integer coordinates
            color = destroyed_color if is_destroyed else nondamaged_color
            cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) # -1 for filled

        # Blend the overlay with the original image
        cv2.addWeighted(overlay, alpha, post_img, 1 - alpha, 0, post_img)

        # Draw outlines on the blended image (optional, for better visibility)
        for bbox, is_destroyed in results_for_image:
            x1, y1, x2, y2 = map(int, bbox)
            color = destroyed_color if is_destroyed else nondamaged_color
            cv2.rectangle(post_img, (x1, y1), (x2, y2), color, outline_thickness)

        # Save the annotated image
        annotated_img_path = os.path.join(annotated_dir, f"{image_id}_post_disaster_annotated.png")
        cv2.imwrite(annotated_img_path, post_img)
        # print(f"Saved annotated image to {annotated_img_path}") # Can uncomment for more verbose output

    # --- Prepare submission data for this image_id ---
    # Append counts based on the classification results for this image
    submission_ids.append(f"{image_id}_X_no_damage")
    submission_damage.append(damage_counts[0])
    submission_ids.append(f"{image_id}_X_minor_damage")
    submission_damage.append(0) # Explicitly 0 based on problem description
    submission_ids.append(f"{image_id}_X_major_damage")
    submission_damage.append(0) # Explicitly 0 based on problem description
    submission_ids.append(f"{image_id}_X_destroyed")
    submission_damage.append(damage_counts[3])


# --- Final Submission CSV ---
print("\n[INFO] - Generating final submission file...")
submission_df = pd.DataFrame({'id': submission_ids, 'damage': submission_damage})
submission_df.to_csv(submission_csv_path, index=False)
print(f"[INFO] - Submission file saved to {submission_csv_path}")
print(f"[INFO] - Total 'destroyed' count across all images: {total_destroyed_count}")
print("[INFO] - Script finished.")

[INFO] - Starting Stage 1: Object Detection and Cropping...
. Processing malawi-cyclone_00000034_pre_disaster.jpg. Processing malawi-cyclone_00000066_pre_disaster.jpg. Processing malawi-cyclone_00000212_pre_disaster.jpg
[INFO] - Stage 1 finished. Saving crop data...
[INFO] - Crop data saved to data\data_pkl_path.pkl

[INFO] - Starting Stage 2: Classification and Plotting...
[INFO] - Loaded NDVI data from data\test_ndvi.csv
[INFO] - Using device: cpu
[INFO] - Loaded classifier weights from data\classifier_weights/efficientvit_b0.r224_in1k_3_patch_model_fold_0.pth
[INFO] - Loading data from data\data_pkl_path.pkl
[INFO] - Processing & Plotting: malawi-cyclone_00000034
[INFO] - Processing & Plotting: malawi-cyclone_00000066
[INFO] - Processing & Plotting: malawi-cyclone_00000212

[INFO] - Generating final submission file...
[INFO] - Submission file saved to data\sample_inference_output.csv
[INFO] - Total 'destroyed' count across all images: 2
[INFO] - Script finished.
