In [32]:
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import maximum_filter


class KPExtractor:
    def __init__(self, config):
        self.threshold_type = config.get('threshold_type', 'nms')
        self.threshold_value = config.get('threshold_value', 0.015)
        self.max_keypoints = config.get('max_keypoints', None)  # Default to None if not provided

    def extract_keypoints(self, file_path):
        with h5py.File(file_path, 'r') as f:
            # Ensure dataset exists and handle missing keys gracefully
            if 'superpoint_heatmap' not in f:
                raise KeyError(f"'superpoint_heatmap' dataset not found in {file_path}")
            
            heatmap = f['superpoint_heatmap'][()]
            
            nms = (heatmap == maximum_filter(heatmap, size=3)) & (heatmap > self.threshold_value)
            keypoints = np.argwhere(nms)  # (row, col)
            scores = heatmap[nms]  # Extract scores

            if self.max_keypoints is not None and len(keypoints) > self.max_keypoints:
                idx = np.argsort(scores)[::-1][:self.max_keypoints]
                keypoints = keypoints[idx]

            return keypoints

class GTGenerator:
    def __init__(self, config):
        self.source_dir = config['source_dir']
        self.output_dir = config['output_dir']
        self.file_format = config.get('file_format', 'hdf5')
        self.output_format = config.get('output_format', 'numpy')
        self.extractor = KPExtractor(config.get('extractor_config', {}))

    def get_hdf5_files(self):
        files = []
        for root, _, filenames in os.walk(self.source_dir):
            files.extend([os.path.join(root, f) for f in filenames if f.endswith(self.file_format)])
        return files

    def save_keypoints(self, keypoints, file_path):
        # Modify output file name to remove .hdf5 extension
        rel_path = os.path.relpath(file_path, self.source_dir)
        output_file = os.path.join(self.output_dir, rel_path).replace('.hdf5', '.npy')
        output_dir = os.path.dirname(output_file)
        os.makedirs(output_dir, exist_ok=True)
        np.save(output_file, keypoints)

    def run(self):
        files = self.get_hdf5_files()
        for file_path in files:
            try:
                keypoints = self.extractor.extract_keypoints(file_path)
                self.save_keypoints(keypoints, file_path)
            except KeyError as e:
                print(f"Error processing {file_path}: {e}")

In [33]:
config = {
    'source_dir': "/media/egoedeke/a9a96a4d-b323-489e-a833-13f4ade040c8/glue-factory/outputs/results/superpoint_gt",
    'output_dir': "data/superpoint_gt",
    'file_format': 'hdf5',
    'output_format': 'numpy',
    'extractor_config': {
        'threshold_type': 'nms',
        'threshold_value': 0.015,
        'max_keypoints': -1, #1000,
        'binary_heatmap': True
    }
}

generator = GTGenerator(config)
generator.run()
