In [1]:
import numpy as np
from astropy.io import fits
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
#import matplotlib
#matplotlib.use('Agg') # Comment out if want to show matplotlib plots
from astropy.visualization import ZScaleInterval
from scipy import ndimage
import os
import glob

In [2]:
def findBoundingBox(filePath):
    with fits.open(filePath) as hdul:
        # Assuming the image data is in the primary HDU
        img_data = hdul[0].data
    
        # Replace NaNs and infinite values with zeros
        img_data = np.nan_to_num(img_data, nan=0.0, posinf=0.0, neginf=0.0)
    
        # # Remove any NaN or infinite values
        # valid_data = flat_data[np.isfinite(flat_data)]
    
        # Use ZScale to determine a good range for display and analysis
        zscale = ZScaleInterval()
        z1, z2 = zscale.get_limits(img_data)
    
        # Clip the data to the ZScale range
        clipped_data = np.clip(img_data, z1, z2)
    
        normalized_data = (clipped_data - z1) / (z2 - z1)

    # Reshape for k-means
    pixels = normalized_data.reshape((-1, 1))
    
    # Perform k-means clustering
    kmeans = KMeans(n_clusters=10, random_state=42)
    kmeans.fit(pixels)

    # Get cluster centers and labels
    centers = kmeans.cluster_centers_.flatten()
    labels = kmeans.labels_
    
    # Reshape labels back to image shape
    labeled_image = labels.reshape(img_data.shape)
    
    # Find the darkest cluster (lowest center value)
    darkest_cluster = np.argmax(centers)
    
    # Create mask for the darkest cluster
    dark_mask = labeled_image == darkest_cluster
    
    # Label connected regions in the dark mask
    labeled_mask, num_features = ndimage.label(dark_mask)
    
    # Find the largest connected region
    if num_features > 0:
        region_sizes = ndimage.sum(dark_mask, labeled_mask, range(1, num_features + 1))
        largest_region_label = np.argmax(region_sizes) + 1
        largest_dark_region = labeled_mask == largest_region_label
    else:
        print("num features less 0")
        largest_dark_region = np.zeros_like(dark_mask)
    
    rows, cols = np.where(largest_dark_region)
    
    # Find bounding box of the largest dark region
    box_noise = 50
    
    if np.any(largest_dark_region):
        top = max(0, np.min(rows) - box_noise)
        bottom = min(img_data.shape[0] - 1, np.max(rows) + box_noise)
        left = max(0, np.min(cols) - box_noise)
        right = min(img_data.shape[1] - 1, np.max(cols) + box_noise)
        bbox = (int(top), int(bottom), int(left), int(right))
    else:
        print("Empty box")
        top, bottom, left, right = 0, 0, 0, 0
        bbox = (int(top), int(bottom), int(left), int(right))

    return img_data, bbox

In [3]:
def getImageData(filePath):
    with fits.open(filePath) as hdul:
        # Assuming the image data is in the primary HDU
        img_data = hdul[0].data
    
        # Replace NaNs and infinite values with zeros
        img_data = np.nan_to_num(img_data, nan=0.0, posinf=0.0, neginf=0.0)
    return img_data

In [4]:
def generateImage(img_data, bbox):
    # Draw bounding box on original image
    top, bottom, left, right = bbox
    rect = plt.Rectangle((left, top), right - left, bottom - top,
                         fill=False, edgecolor='red', linewidth=2)
    return rect

In [5]:
def displayLabeledImages(band1path, band2path, band3path, band4path):
    img_data3, bbox_band3 = findBoundingBox(band3path)
    img_data4, bbox_band4 = findBoundingBox(band4path)
    total_bbox = (min(bbox_band3[0], bbox_band4[0]), max(bbox_band3[1], bbox_band4[1]), min(bbox_band3[2], bbox_band4[2]), max(bbox_band3[3], bbox_band4[3]))
    img_data1 = getImageData(band1path)
    img_data2 = getImageData(band2path)
    
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(12, 6))
    zscale = ZScaleInterval()
    
    # Band 3 Image
    z1, z2 = zscale.get_limits(img_data3)
    im1 = ax1.imshow(img_data3, cmap='viridis', vmin=z1, vmax=z2)
    ax1.set_title('Band 3 Image')
    #plt.colorbar(im1, ax=ax1, label='Pixel Value')
    rect1 = generateImage(img_data3, bbox_band3)
    ax1.add_patch(rect1)
    ax1.set_xticks([])
    ax1.set_yticks([])
    
    # Band 4 Image
    z1, z2 = zscale.get_limits(img_data4)
    im2 = ax2.imshow(img_data4, cmap='viridis', vmin=z1, vmax=z2)
    ax2.set_title('Band 4 Image')
    #plt.colorbar(im2, ax=ax2, label='Pixel Value')
    rect2 = generateImage(img_data4, bbox_band4)
    ax2.add_patch(rect2)
    ax2.set_xticks([])
    ax2.set_yticks([])
    
    # Band 1 Image
    z1, z2 = zscale.get_limits(img_data1)
    im3 = ax3.imshow(img_data1, cmap='viridis', vmin=z1, vmax=z2)
    ax3.set_title('Band 1 Image')
    #plt.colorbar(im3, ax=ax3, label='Pixel Value')
    rect3 = generateImage(img_data1, total_bbox)
    ax3.add_patch(rect3)
    ax3.set_xticks([])
    ax3.set_yticks([])

    # Band 2 Image
    z1, z2 = zscale.get_limits(img_data2)
    im4 = ax4.imshow(img_data2, cmap='viridis', vmin=z1, vmax=z2)
    ax4.set_title('Band 2 Image')
    #plt.colorbar(im4, ax=ax4, label='Pixel Value')
    rect4 = generateImage(img_data2, total_bbox)
    ax4.add_patch(rect4)
    ax4.set_xticks([])
    ax4.set_yticks([])
    
    plt.tight_layout()
    plt.show()

In [6]:
#process_wise_dataset("300images-part1", "labeled_images")

In [7]:
def createLabeledImages(band1path, band2path, band3path, band4path):
    # Ensure the labeled_images directory exists
    os.makedirs("labeled_images", exist_ok=True)

    # Get bounding boxes for band 3 and 4
    img_data3, bbox_band3 = findBoundingBox(band3path)
    img_data4, bbox_band4 = findBoundingBox(band4path)
    
    # Calculate the total bounding box that encompasses both band 3 and band 4 bounding boxes
    total_bbox = (
        min(bbox_band3[0], bbox_band4[0]),
        max(bbox_band3[1], bbox_band4[1]),
        min(bbox_band3[2], bbox_band4[2]),
        max(bbox_band3[3], bbox_band4[3])
    )
    
    # Load data for bands 1 and 2
    img_data1 = getImageData(band1path)
    img_data2 = getImageData(band2path)
    
    # Process each band
    for img_data, band_path in zip([img_data1, img_data2], [band1path, band2path]):
        fig, ax = plt.subplots()
        ax.imshow(img_data, origin='lower', cmap='viridis')  # Display without brightness scaling

        # Add the bounding box as a rectangle
        rect = patches.Rectangle(
            (total_bbox[0], total_bbox[2]),  # Lower-left corner
            total_bbox[1] - total_bbox[0],  # Width
            total_bbox[3] - total_bbox[2],  # Height
            linewidth=2,
            edgecolor='red',
            facecolor='none'
        )
        ax.add_patch(rect)
        ax.set_xticks([])
        ax.set_yticks([])

        # Save the labeled image
        new_band = os.path.basename(band_path)
        new_path = f"labeled_images/{new_band[:-5]}_LABELED.png"
        plt.savefig(new_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

In [8]:
def process_wise_dataset(input_folder, output_folder):   
    # Get all unique object identifiers (the part after -w1, -w2, -w3, or -w4)
    object_identifiers = set()
    for file in os.listdir(input_folder):
        for band in range(1, 5):
            if f"-w{band}-" in file:
                identifier = file.split(f"-w{band}-", 1)[1].rsplit('.', 1)[0]
                object_identifiers.add(identifier)
                break
    
    # Process each object
    for identifier in object_identifiers:
        # Find paths for all 4 bands
        band_paths = []
        for band in range(1, 5):
            pattern = f"*-w{band}-{identifier}.fits"
            matching_files = glob.glob(os.path.join(input_folder, pattern))
            if matching_files:
                band_paths.append(matching_files[0])
            else:
                print(f"Warning: Band {band} not found for object with identifier {identifier}")
                break
        
        # If we found all 4 bands, process the object
        if len(band_paths) == 4:
            #print(f"Processing object: {indeitifer}")
            displayLabeledImages(*band_paths)
        else:
            print(f"Skipping object {identifier} due to missing bands")

In [9]:
#TESTING
band1 = "0000p605_ac51-w1-int-3_ra358.4953600000003_dec60.37638300000005_asec600.000.fits"
band2 = "0000p605_ac51-w2-int-3_ra358.4953600000003_dec60.37638300000005_asec600.000.fits"
band3 = "0000p605_ac51-w3-int-3_ra358.4953600000003_dec60.37638300000005_asec600.000.fits"
band4 = "0000p605_ac51-w4-int-3_ra358.4953600000003_dec60.37638300000005_asec600.000.fits"
img_data3, bbox_band3 = findBoundingBox(band3)
img_data4, bbox_band4 = findBoundingBox(band4)
# Calculate the total bounding box that encompasses both band 3 and band 4 bounding boxes
total_bbox = (
                    min(bbox_band3[0], bbox_band4[0]),
                    max(bbox_band3[1], bbox_band4[1]),
                    min(bbox_band3[2], bbox_band4[2]),
                    max(bbox_band3[3], bbox_band4[3])
)
print(total_bbox)

FileNotFoundError: [Errno 2] No such file or directory: '0000p605_ac51-w3-int-3_ra358.4953600000003_dec60.37638300000005_asec600.000.fits'

In [10]:
import os
from astropy.io import fits
import csv

# Define your function to process the FITS data (example placeholder)
def process_fits_data(data):
    # Example: Apply some operation to the data
    return data * 2  # Modify this function as needed

def process_fits_in_subfolders(root_folder, output_folder):

    # Loop through each sub-folder in the root folder
    for subfolder in os.listdir(root_folder):
        subfolder_path = os.path.join(root_folder, subfolder)
        files = [""]
        # Check if it is a directory
        if os.path.isdir(subfolder_path):
            files = []
            i = 0
            # Loop through each FITS file in the subfolder
            for filename in os.listdir(subfolder_path):
                if filename.endswith('.fits'):
                    filepath = os.path.join(subfolder_path, filename)
                    files.append(filepath)
                    i+=1
            if len(files) == 4:
                # Get bounding boxes for band 3 and 4
                band3 = ""
                band4 = ""
                for i in range(len(files)):
                    if ('-w3-' in files[i]):
                        band3 = files[i]
                    if ('-w4-' in files[i]):
                        band4 = files[i]
                img_data3, bbox_band3 = findBoundingBox(band3)
                img_data4, bbox_band4 = findBoundingBox(band4)
                # Calculate the total bounding box that encompasses both band 3 and band 4 bounding boxes
                total_bbox = (
                    min(bbox_band3[0], bbox_band4[0]),
                    max(bbox_band3[1], bbox_band4[1]),
                    min(bbox_band3[2], bbox_band4[2]),
                    max(bbox_band3[3], bbox_band4[3])
                )
                csv_filepath = os.path.join(subfolder_path, "bounding_box.csv")
                temp_bbox = [total_bbox[2], total_bbox[3], total_bbox[0], total_bbox[1]]
                # Write the bounding box to the CSV file
                with open(csv_filepath, mode="w", newline="") as csv_file:
                    writer = csv.writer(csv_file)
                    writer.writerow(["xmin", "xmax", "ymin", "ymax"])  # Header row
                    writer.writerow(temp_bbox)
                
            
# Example usage
root_folder = "L3a"  # Replace with your root folder path
output_folder = "test_out"  # Replace with your output folder path
process_fits_in_subfolders(root_folder, output_folder)