In [1]:
!cp -r /kaggle/input/mhafyolo/pytorch/default/1/MHAF-YOLO-main /kaggle/working/

import os
from pathlib import Path

current_dir = Path.cwd()
print("this_dir:", current_dir)

target_dir = Path("/kaggle/working/MHAF-YOLO-main") 
os.chdir(target_dir)  
model_path = "/kaggle/input/mhaf-m-public/other/default/1/36_data_96_663.pt"
CONFIDENCE_THRESHOLD = 0.3
TRACK_CONFIDENCE_THRESHOLD = 0.47 # NEW: For finalized 3D tracks before NMS

NMS_IOU_THRESHOLD = 0.2 # This will be used for 3D NMS
CONCENTRATION = 1 # Percentage of slices to process
BATCH_SIZE = 32 # Batch size for YOLO inference

LINK_IOU_THRESHOLD_2D = 0.3      # For associating 2D boxes across slices
LINK_MAX_MISSED_SLICES = 1      # Max original slices a track can be missed and still be linked
LINK_MIN_SLICES_FOR_3D = 3       # Min number of slices a track must span to be a 3D object
LINK_CONFIDENCE_BOOST_FACTOR = 0.2 # Confidence boost per slice beyond LINK_MIN_SLICES_FOR_3D


import numpy as np
import pandas as pd
from PIL import Image
import torch
import cv2
from tqdm.notebook import tqdm
from ultralytics import YOLOv10 # Assuming this is correctly installed/available
import time
from torch.utils.data import Dataset, DataLoader

np.random.seed(42)
torch.manual_seed(42)

data_path = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/"
test_dir = os.path.join(data_path, "test")
submission_path = "/kaggle/working/submission.csv"

device = 'cuda' if torch.cuda.is_available() else 'cpu' # Simpler device check
if device == 'cuda':
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False # Usually False for speed
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# AMP scaler for inference (no scaling needed for inference with autocast)
# scaler = torch.cuda.amp.GradScaler(enabled=False) # Not typically used for inference like this

class GPUProfiler:
    def __init__(self, name):
        self.name = name
        self.start_time = None
    def __enter__(self):
        if torch.cuda.is_available(): torch.cuda.synchronize()
        self.start_time = time.time()
    def __exit__(self, *args):
        if torch.cuda.is_available(): torch.cuda.synchronize()
        elapsed = time.time() - self.start_time
        # print(f"[PROFILE] {self.name}: {elapsed:.3f}s")


# ------------- START: 2D-to-3D Linking and 3D NMS Code -------------
def xywh_to_xyxy(x_center, y_center, w, h):
    x_min = x_center - w / 2
    y_min = y_center - h / 2
    x_max = x_center + w / 2
    y_max = y_center + h / 2
    return x_min, y_min, x_max, y_max
    # return x_center, y_center, w, h

def calculate_2d_iou(boxA_xyxy, boxB_xyxy):
    xA = max(boxA_xyxy[0], boxB_xyxy[0])
    yA = max(boxA_xyxy[1], boxB_xyxy[1])
    xB = min(boxA_xyxy[2], boxB_xyxy[2])
    yB = min(boxA_xyxy[3], boxB_xyxy[3])

    interArea = max(0, xB - xA) * max(0, yB - yA)
    if interArea == 0:
        return 0.0

    boxAArea = (boxA_xyxy[2] - boxA_xyxy[0]) * (boxA_xyxy[3] - boxA_xyxy[1])
    boxBArea = (boxB_xyxy[2] - boxB_xyxy[0]) * (boxB_xyxy[3] - boxB_xyxy[1])
    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

def calculate_3d_iou(boxA_3d, boxB_3d):
    # box_3d: (x_min, y_min, z_min, x_max, y_max, z_max)
    xA = max(boxA_3d[0], boxB_3d[0])
    yA = max(boxA_3d[1], boxB_3d[1])
    zA = max(boxA_3d[2], boxB_3d[2])
    xB = min(boxA_3d[3], boxB_3d[3])
    yB = min(boxA_3d[4], boxB_3d[4])
    zB = min(boxA_3d[5], boxB_3d[5])

    interVolume = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)
    if interVolume == 0:
        return 0.0

    boxAVolume = (boxA_3d[3] - boxA_3d[0]) * \
                 (boxA_3d[4] - boxA_3d[1]) * \
                 (boxA_3d[5] - boxA_3d[2])
    boxBVolume = (boxB_3d[3] - boxB_3d[0]) * \
                 (boxB_3d[4] - boxB_3d[1]) * \
                 (boxB_3d[5] - boxB_3d[2])
    
    iou = interVolume / float(boxAVolume + boxBVolume - interVolume)
    return iou

class Track3D:
    _next_id = 0 # Class variable for unique track IDs
    def __init__(self, initial_detection_xywhc, z_index, class_id):
        # initial_detection_xywhc: (x_center, y_center, w, h, confidence)
        self.id = Track3D._next_id
        Track3D._next_id += 1
        
        self.class_id = class_id
        
        x_min, y_min, x_max, y_max = xywh_to_xyxy(*initial_detection_xywhc[:4])
        self.boxes_2d_slices = {z_index: (x_min, y_min, x_max, y_max, initial_detection_xywhc[4])} # Store xyxy, conf
        
        self.last_seen_z = z_index
        self.confidences = [initial_detection_xywhc[4]]

    def add_detection(self, detection_xywhc, z_index):
        # detection_xywhc: (x_center, y_center, w, h, confidence)
        x_min, y_min, x_max, y_max = xywh_to_xyxy(*detection_xywhc[:4])
        self.boxes_2d_slices[z_index] = (x_min, y_min, x_max, y_max, detection_xywhc[4])
        self.last_seen_z = z_index
        self.confidences.append(detection_xywhc[4])

    def get_last_box_xyxy(self):
        if not self.boxes_2d_slices: return None
        return self.boxes_2d_slices[self.last_seen_z][:4] # (x_min, y_min, x_max, y_max)

    def finalize(self, min_slices_for_3d_object, confidence_boost_factor):
        num_slices_seen = len(self.boxes_2d_slices)
        if num_slices_seen < min_slices_for_3d_object:
            return None

        all_x_mins = [box[0] for box in self.boxes_2d_slices.values()]
        all_y_mins = [box[1] for box in self.boxes_2d_slices.values()]
        all_x_maxs = [box[2] for box in self.boxes_2d_slices.values()]
        all_y_maxs = [box[3] for box in self.boxes_2d_slices.values()]
        
        final_x_min = np.mean(all_x_mins)
        final_y_min = np.mean(all_y_mins)
        final_x_max = np.mean(all_x_maxs)
        final_y_max = np.mean(all_y_maxs)

        z_coords = sorted(self.boxes_2d_slices.keys())
        final_z_min = z_coords[0]
        final_z_max = z_coords[-1]

        avg_2d_conf = np.mean(self.confidences)
        
        boost = confidence_boost_factor * max(0, num_slices_seen - min_slices_for_3d_object)
        final_confidence = min(1.0, avg_2d_conf + boost)

        return (final_x_min, final_y_min, final_z_min,
                final_x_max, final_y_max, final_z_max,
                self.class_id, final_confidence)

def combine_2d_to_3d(all_slice_detections_xywhc, # {z: [(xc,yc,w,h,conf), ...]}
                     iou_threshold_2d_link, 
                     max_missed_slices,
                     min_slices_for_3d_object,
                     confidence_boost_factor,
                     nms_threshold_3d,
                    track_confidence_threshold):
    Track3D._next_id = 0 # Reset track ID counter for each tomogram
    active_tracks = []
    completed_tracks = []

    sorted_z_indices = sorted(all_slice_detections_xywhc.keys())

    for z_idx_current_slice in sorted_z_indices:
        detections_in_current_slice = all_slice_detections_xywhc.get(z_idx_current_slice, [])
        # Assuming detections_in_current_slice = [(xc,yc,w,h, class_id_from_yolo, conf), ...]
        # For this problem, class_id is fixed (e.g., 0)
        
        # Convert current slice detections (xc,yc,w,h,conf) to (xyxy, class_id, conf, original_det_xywhc)
        current_slice_processed_dets = []
        for det_xywhc_conf in detections_in_current_slice:
            # det_xywhc_conf is (xc, yc, w, h, cls_id, conf) where cls_id is the one from YOLO.
            # We'll use a fixed class_id for tracking if it's single-class problem.
            xc, yc, w, h, yolo_cls_id, conf = det_xywhc_conf # Assuming this structure
            det_xyxy = xywh_to_xyxy(xc, yc, w, h)
            # For single class motor detection, track_class_id will be 0.
            # If YOLO gives multiple classes, yolo_cls_id would be used.
            track_class_id = 0 # Use fixed class for motors.
            current_slice_processed_dets.append(
                (det_xyxy, track_class_id, conf, (xc,yc,w,h,conf)) # Pass (xc,yc,w,h,conf) for Track3D
            )

        matched_to_track_this_slice = [False] * len(current_slice_processed_dets)
        
        for track_idx in range(len(active_tracks) -1, -1, -1): # Iterate backwards for safe removal
            track = active_tracks[track_idx]
            
            # Terminate old tracks before matching
            if z_idx_current_slice - track.last_seen_z > max_missed_slices:
                completed_tracks.append(active_tracks.pop(track_idx))
                continue

            best_match_local_idx = -1 # Index within current_slice_processed_dets
            max_iou = 0.0
            track_last_box_xyxy = track.get_last_box_xyxy()
            if track_last_box_xyxy is None: continue

            for det_idx, (det_xyxy, det_cls, _, _) in enumerate(current_slice_processed_dets):
                if matched_to_track_this_slice[det_idx]: continue
                
                if det_cls == track.class_id: # Class matching
                    iou = calculate_2d_iou(track_last_box_xyxy, det_xyxy)
                    if iou > iou_threshold_2d_link and iou > max_iou:
                        max_iou = iou
                        best_match_local_idx = det_idx
            
            if best_match_local_idx != -1:
                original_det_xywhc_conf_for_track = current_slice_processed_dets[best_match_local_idx][3]
                track.add_detection(original_det_xywhc_conf_for_track, z_idx_current_slice)
                matched_to_track_this_slice[best_match_local_idx] = True
        
        for det_idx, (det_xyxy, det_cls, det_conf, original_det_xywhc_conf) in enumerate(current_slice_processed_dets):
            if not matched_to_track_this_slice[det_idx]:
                # original_det_xywhc_conf is (xc,yc,w,h,conf)
                active_tracks.append(Track3D(original_det_xywhc_conf, z_idx_current_slice, det_cls))
                
    completed_tracks.extend(active_tracks)

    pre_nms_3d_boxes = []
    for track in completed_tracks:
        final_box_3d = track.finalize(min_slices_for_3d_object, confidence_boost_factor)
        if final_box_3d:
            if final_box_3d[7] >= track_confidence_threshold :
                pre_nms_3d_boxes.append(final_box_3d)

    if not pre_nms_3d_boxes: return []

    pre_nms_3d_boxes = sorted(pre_nms_3d_boxes, key=lambda x: x[7], reverse=True) # Sort by confidence (idx 7)
    
    final_preds_3d = []
    while pre_nms_3d_boxes:
        current_box = pre_nms_3d_boxes.pop(0)
        final_preds_3d.append(current_box)
        
        remaining_boxes_after_nms = []
        for box_to_compare in pre_nms_3d_boxes:
            if current_box[6] != box_to_compare[6]: # Different class, keep (NMS is class-specific)
                 remaining_boxes_after_nms.append(box_to_compare)
                 continue
            iou_3d = calculate_3d_iou(current_box[:6], box_to_compare[:6])
            if iou_3d < nms_threshold_3d:
                remaining_boxes_after_nms.append(box_to_compare)
        pre_nms_3d_boxes = remaining_boxes_after_nms
        
    return final_preds_3d
# ------------- END: 2D-to-3D Linking and 3D NMS Code -------------


class TomogramSliceDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = cv2.imread(str(p)) # Path object to string
        if img is None:
            print(f"Warning: Failed to read image {p}. Using a placeholder.")
            # Placeholder: a black image of a common size, e.g., 960x960x3
            # You might need to know the expected dimensions from other images.
            # For now, let's assume this won't happen or use a fixed size.
            # This should ideally raise an error or be handled based on competition rules.
            return np.zeros((960, 960, 3), dtype=np.uint8), -1 # Invalid slice_num

        # img = cv2.medianBlur(img, ksize=5)
        
        try:
            slice_num = int(Path(p).stem.split('_')[1])
        except (IndexError, ValueError):
            print(f"Warning: Could not parse slice number from {Path(p).stem}. Using -1.")
            slice_num = -1 # Indicates an error or unexpected format
            
        return img, slice_num

    @staticmethod
    def collate_fn(batch):
        imgs, slice_nums_batch = zip(*batch)
        # Filter out any images that failed to load properly (where slice_num might be -1)
        valid_batch = [(img, sn) for img, sn in zip(imgs, slice_nums_batch) if sn != -1]
        if not valid_batch:
            return [], [] # Return empty lists if all images in batch failed
        
        valid_imgs, valid_slice_nums = zip(*valid_batch)
        return list(valid_imgs), list(valid_slice_nums)


@torch.no_grad()
def process_tomogram(tomo_id, model, 
                     link_iou_thresh, link_max_missed, 
                     link_min_slices, link_conf_boost, 
                     final_nms_iou_3d,
                    track_confidence_thresh): # Removed total, idx as they were for tqdm
    
    tomo_path = Path(test_dir) / tomo_id
    files_all = sorted([f for f in tomo_path.glob('*.jpg')]) # Use Path.glob

    num_total_slices = len(files_all)
    if num_total_slices == 0:
        return {'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1}

    num_selected_slices = int(num_total_slices * CONCENTRATION)
    if num_selected_slices == 0: num_selected_slices = 1 # Ensure at least one slice if CONCENTRATION is too low but files exist
    
    if num_selected_slices == 1:
        sel_indices = [num_total_slices // 2] # Middle slice
    else:
        sel_indices = np.linspace(0, num_total_slices - 1, num_selected_slices).round().astype(int)
    
    sel_indices = np.unique(sel_indices)
    paths = [files_all[i] for i in sel_indices]

    if not paths:
        return {'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1}

    ds = TomogramSliceDataset(paths)
    # Determine num_workers based on available CPUs, max 4 as in original
    num_workers = min(4, os.cpu_count() if os.cpu_count() else 1)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=num_workers, pin_memory=True,
                        collate_fn=TomogramSliceDataset.collate_fn)

    # {slice_num: [(xc, yc, w, h, yolo_class_id, confidence), ...]}
    detections_for_3d_linking = {} 

    for imgs_batch, slice_nums_batch in loader:
        if not imgs_batch: continue # Skip if batch is empty (e.g., all images failed to load)

        with GPUProfiler(f"Inference {len(imgs_batch)} slices for {tomo_id}"):
            # YOLOv10 expects a list of numpy arrays or PIL images.
            # imgs_batch should already be in this format from cv2.imread
            with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
                preds = model(imgs_batch, verbose=False) 
        
        for i, result in enumerate(preds):
            actual_z_slice_num = slice_nums_batch[i]
            if actual_z_slice_num == -1: continue # Skip if slice num parsing failed for this image

            if len(result.boxes) > 0:
                for box_idx in range(len(result.boxes)):
                    conf = float(result.boxes.conf[box_idx])
                    if conf >= CONFIDENCE_THRESHOLD:
                        box_xyxy = result.boxes.xyxy[box_idx].cpu().numpy()
                        # yolo_class_id = int(result.boxes.cls[box_idx].cpu()) # If model is multi-class
                        yolo_class_id = 0 # For this single-class problem

                        x1, y1, x2, y2 = box_xyxy
                        xc = (x1 + x2) / 2
                        yc = (y1 + y2) / 2
                        w = abs(x2 - x1)
                        h = abs(y2 - y1)

                        if w == 0 or h == 0: continue

                        if actual_z_slice_num not in detections_for_3d_linking:
                            detections_for_3d_linking[actual_z_slice_num] = []
                        detections_for_3d_linking[actual_z_slice_num].append(
                            (xc, yc, w, h, yolo_class_id, conf)
                        )
    
    final_3d_predictions = combine_2d_to_3d(
        detections_for_3d_linking,
        iou_threshold_2d_link=link_iou_thresh,
        max_missed_slices=link_max_missed,
        min_slices_for_3d_object=link_min_slices,
        confidence_boost_factor=link_conf_boost,
        nms_threshold_3d=final_nms_iou_3d ,
        track_confidence_threshold=track_confidence_thresh
    )

    if not final_3d_predictions:
        return {'tomo_id': tomo_id, 'Motor axis 0': -1,
                'Motor axis 1': -1, 'Motor axis 2': -1}
    
    best_3d_pred = final_3d_predictions[0] # Already sorted by confidence
    
    pred_x_min, pred_y_min, pred_z_min, \
    pred_x_max, pred_y_max, pred_z_max, _, _ = best_3d_pred

    final_z = int(round((pred_z_min + pred_z_max) / 2))
    final_y = int(round((pred_y_min + pred_y_max) / 2))
    final_x = int(round((pred_x_min + pred_x_max) / 2))
        
    return {'tomo_id': tomo_id,
            'Motor axis 0': final_z,
            'Motor axis 1': final_y,
            'Motor axis 2': final_x}


def generate_submission():
    tomos = sorted([d.name for d in Path(test_dir).iterdir() if d.is_dir()])
    
    # It's critical that MHAF-YOLO specific files are accessible if YOLOv10(model_path)
    # relies on them being in the current working directory or a specific structure.
    # If YOLOv10 loads a .pt file and that's self-contained, os.chdir might not be strictly necessary
    # for model loading itself, but could be for other utilities from MHAF-YOLO repo if used.
    # The provided script changes dir, so we assume it's needed.
    # Ensure MHAF-YOLO setup (like !cp and os.chdir) is done *before* this function if model loading depends on it.
    # Re-checking the os.chdir logic at the top of the script.
    
    print(f"Loading model from: {model_path}")
    model = YOLOv10(model_path) # Ensure MHAF-YOLO is in path or CWD if needed
    model.to(device)
    
    if device == 'cuda':
        # model.fuse() # Fuse layers for potential speedup, check if compatible
        if hasattr(model,'fuse'): model.fuse()
        
        # Half precision if supported and beneficial
        # if torch.cuda.get_device_capability(0)[0] >= 7: # Check for Volta or newer for good FP16 support
        #     model.half() # model.model.half() in original
        #     # Note: autocast handles mixed precision, so explicit .half() might not be needed
        #     # or could conflict if autocast is also used. Test this.
        #     # If using model.half(), autocast should often be disabled or used carefully.
        #     # Given autocast is used, explicit .half() might be redundant or counterproductive.
        #     # Let's rely on autocast.
        pass


    results = []
    for tomo_id in tqdm(tomos, desc="Processing Tomograms"):
        if device == 'cuda': torch.cuda.empty_cache()
        res = process_tomogram(
            tomo_id, model,
            LINK_IOU_THRESHOLD_2D,
            LINK_MAX_MISSED_SLICES,
            LINK_MIN_SLICES_FOR_3D,
            LINK_CONFIDENCE_BOOST_FACTOR,
            NMS_IOU_THRESHOLD,
            TRACK_CONFIDENCE_THRESHOLD # This is the 3D NMS threshold from global const
        )
        results.append(res)

    df = pd.DataFrame(results)[['tomo_id','Motor axis 0','Motor axis 1','Motor axis 2']]
    df.to_csv(submission_path, index=False)
    print("\nSubmission file created:")
    print(df.head())
    return df

if __name__ == "__main__":
    start_time = time.time()
    generate_submission()
    end_time = time.time()
    print(f"\nTotal processing time: {end_time - start_time:.2f} seconds")

this_dir: /kaggle/working
Loading model from: /kaggle/input/mhaf-m-public/other/default/1/36_data_96_663.pt


  ckpt = torch.load(file, map_location="cpu")


Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch mod

Processing Tomograms:   0%|          | 0/3 [00:00<?, ?it/s]

Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch mod

  with torch.cuda.amp.autocast(enabled=(device == 'cuda')):


Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock
Switch model to UniRepLKNetBlock

Submission file created:
       tomo_id  Motor axis 0  Motor axis 1  Motor axis 2
0  tomo_003acc            -1            -1            -1
1  tomo_00e047           168           546           602
2  tomo_01a877           146           638           285

Total processing time: 74.85 seconds


In [None]:
# !cp -r /kaggle/input/mhafyolo/pytorch/default/1/MHAF-YOLO-main /kaggle/working/

# model_path = "/kaggle/input/mhaf-m-public/other/default/1/36_data_96_663.pt"
# CONFIDENCE_THRESHOLD = 0.35
# MAX_DETECTIONS_PER_TOMO = 1
# NMS_IOU_THRESHOLD = 0.2
# CONCENTRATION = 0.5
# BATCH_SIZE = 8 

# import os
# from pathlib import Path

# current_dir = Path.cwd()
# print("this_dir:", current_dir)

# target_dir = Path("/kaggle/working/MHAF-YOLO-main") 
# os.chdir(target_dir)  

# import os
# import numpy as np
# import pandas as pd
# from PIL import Image
# import torch
# import cv2
# from tqdm.notebook import tqdm
# from ultralytics import YOLOv10
# import threading
# import time
# from contextlib import nullcontext
# from concurrent.futures import ThreadPoolExecutor

# from pathlib import Path

# from torch.utils.data import Dataset, DataLoader
# from torch.utils.data import DataLoader, TensorDataset



# np.random.seed(42)
# torch.manual_seed(42)


# data_path = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/"
# test_dir = os.path.join(data_path, "test")
# submission_path = "/kaggle/working/submission.csv"

# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# if device.startswith('cuda'):
#     torch.backends.cudnn.benchmark = True
#     torch.backends.cudnn.deterministic = False
#     torch.backends.cuda.matmul.allow_tf32 = True
#     torch.backends.cudnn.allow_tf32 = True

# # AMP scaler for inference (no scaling)
# scaler = torch.cuda.amp.GradScaler(enabled=False)

# class GPUProfiler:
#     def __init__(self, name):
#         self.name = name
#         self.start_time = None
#     def __enter__(self):
#         if torch.cuda.is_available(): torch.cuda.synchronize()
#         self.start_time = time.time()
#     def __exit__(self, *args):
#         if torch.cuda.is_available(): torch.cuda.synchronize()
#         elapsed = time.time() - self.start_time
#         # print(f"[PROFILE] {self.name}: {elapsed:.3f}s")


# def perform_3d_nms(detections, iou_threshold):
#     """
#     Perform 3D Non-Maximum Suppression on detections to merge nearby motors
#     """
#     if not detections:
#         return []
    
#     # Sort by confidence (highest first)
#     detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
    
#     # List to store final detections after NMS
#     final_detections = []
    
#     # Define 3D distance function
#     def distance_3d(d1, d2):
#         return np.sqrt((d1['z'] - d2['z'])**2 + 
#                        (d1['y'] - d2['y'])**2 + 
#                        (d1['x'] - d2['x'])**2)
    
#     # Maximum distance threshold (based on box size and slice gap)
#     box_size = 24  # Same as annotation box size
#     distance_threshold = box_size * iou_threshold
    
#     # Process each detection
#     while detections:
#         # Take the detection with highest confidence
#         best_detection = detections.pop(0)
#         final_detections.append(best_detection)
#         return final_detections
#     #     # Filter out detections that are too close to the best detection
#         # detections = [d for d in detections if distance_3d(d, best_detection) > distance_threshold]
    
#     return final_detections


# class TomogramSliceDataset(Dataset):
#     def __init__(self, paths, target_size=(960,960)):
#         self.paths = paths
#         self.target_size = target_size
#     def __len__(self):
#         return len(self.paths)
#     def __getitem__(self, idx):
#         p = self.paths[idx]
#         img = cv2.imread(p)
#         if img is None:
#             img = np.array(Image.open(p))
#         # Resize to target size to limit GPU memory usage
#         # img = cv2.resize(img, self.target_size, interpolation=cv2.INTER_AREA)
#         # img = torch.from_numpy(img)
#         return idx, img

#     @staticmethod
#     def collate_fn(batch):
#         paths, imgs = zip(*batch)
#         return list(paths), list(imgs)



# @torch.no_grad()
# def process_tomogram(tomo_id, model, total, idx):
#     tomo_path = os.path.join(test_dir, tomo_id)
#     files = sorted([f for f in os.listdir(tomo_path) if f.endswith('.jpg')])
#     sel = np.linspace(0, len(files)-1, int(len(files)*CONCENTRATION)).round().astype(int)
#     files = [files[i] for i in sel]
#     paths = [os.path.join(tomo_path, f) for f in files]
#     slice_nums = [int(f.split('_')[1].split('.')[0]) for f in files]

#     ds = TomogramSliceDataset(paths)
#     loader = DataLoader(ds, batch_size=32, shuffle=False,
#                         num_workers=4, pin_memory=True,
#                         collate_fn=TomogramSliceDataset.collate_fn)

#     all_dets = []
#     # volm = []
#     for batch_zidx, imgs in loader:
#         with GPUProfiler(f"Inference {len(batch_zidx)} slices"):
#             with torch.cuda.amp.autocast():
#                 preds = model(imgs, verbose=False)
        
#         # imgs = np.concatenate(imgs, axis=0)
#         # (B,H, W) = im/gs.shape
#         # volm.append(imgs.reshape(-1, H, W))
#         for z, result in zip(batch_zidx, preds):
#             z = slice_nums.pop(0)
#             if len(result.boxes) > 0 and result.boxes.conf[0] >= CONFIDENCE_THRESHOLD :
#                 x1,y1,x2,y2 = result.boxes.xyxy[0].cpu().numpy()
#                 all_dets.append({
#                             'z': z,
#                             'y': int(round((y1+y2)/2)),
#                             'x': int(round((x1+x2)/2)),
#                             'confidence': float(result.boxes.conf[0])
#                         })
           
#     # if device.startswith('cuda'): torch.cuda.synchronize()
#     final = perform_3d_nms(all_dets, NMS_IOU_THRESHOLD)
#     # final.sort(key=lambda x: x['confidence'], reverse=True)
#     if not final:
#         return {'tomo_id': tomo_id, 'Motor axis 0': -1,
#                 'Motor axis 1': -1, 'Motor axis 2': -1}
        
#     # if len(all_dets)<3:
#     #     return {'tomo_id': tomo_id, 'Motor axis 0': -1,
#     #             'Motor axis 1': -1, 'Motor axis 2': -1}
        
#     best = final[0]
#     return {'tomo_id': tomo_id,
#             'Motor axis 0': best['z'],
#             'Motor axis 1': best['y'],
#             'Motor axis 2': best['x']}


# def generate_submission():
#     tomos = sorted([d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))])
#     model = YOLOv10(model_path)
#     model.to(device)
#     if device.startswith('cuda'):
#         model.fuse()
#         if torch.cuda.get_device_capability(0)[0] >= 7:
#             model.model.half()

#     results = []
#     for idx, tomo in enumerate(tqdm(tomos, desc="Tomo loop"), 1):
#         torch.cuda.empty_cache()
#         res = process_tomogram(tomo, model, len(tomos), idx)
#         results.append(res)

#     df = pd.DataFrame(results)[['tomo_id','Motor axis 0','Motor axis 1','Motor axis 2']]
#     df.to_csv(submission_path, index=False)
#     print(df.head())
#     return df

# if __name__ == "__main__":
#     start = time.time()
#     generate_submission()
#     print(f"Total time: {time.time()-start:.2f}s")
