In [None]:
import os
import leafmap
from samgeo import SamGeo, tms_to_geotiff, get_basemaps
import pandas as pd
import rasterio

out_dir = os.path.join(os.path.expanduser("~"), "segment-anything-services/model-weights")
checkpoint = os.path.join(out_dir, "sam_vit_h_4b8939.pth")

sam = SamGeo(
    model_type="vit_h",
    checkpoint=checkpoint,
    sam_kwargs=None,
)

worldstrat_df = pd.read_csv("../data/metadata.csv")

def sample_rows(df):
    sampled_df = df.groupby('IPCC Class').apply(lambda x: x.sample(n=5))
    sampled_df.reset_index(drop=True, inplace=True)
    return sampled_df


def add_image_paths(df):
    # Define the path template
    path_template = f"../data/HR_Landcover/{{scene_id}}/{{scene_id}}_rgb.png"

    # Add HR_path column
    df['HR_path'] = df['ID'].apply(lambda x: path_template.format(scene_id=x))
    
    path_template = f"../data/LR_Landcover/{{scene_id}}/L2A/"
    # Add LR_path column. this contains multiple revisits so it's to the directory instead of a file
    df['LR_path'] = df['ID'].apply(lambda x: path_template.format(scene_id=x))

    return df

In [None]:
df = sample_rows(worldstrat_df)
df = df.rename(columns={"Unnamed: 0": "ID"})
df = add_image_paths(df)
df['ID'] = df['ID'].str.replace(' ', '_')

In [None]:
def convert_tif_dtype(file_path):
    # Open the image file
    with rasterio.open(file_path) as src:
        image = src.read()  # 3D array: (bands, height, width)
        meta = src.meta

    # If the image data type is float32, convert it to uint8
    if image.dtype == np.float32:
        # Scale float32 array to 0-255 and convert to uint8
        image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)

        # Update the metadata
        meta.update(dtype=rasterio.uint8)

        # Construct the output file path
        file_dir = os.path.dirname(file_path)
        file_base = os.path.basename(file_path)
        file_name, file_ext = os.path.splitext(file_base)
        out_path = os.path.join(file_dir, f"{file_name}_uint8{file_ext}")

        # Write the updated image to a new file
        with rasterio.open(out_path, 'w', **meta) as dst:
            dst.write(image)

    return file_path

import os

In [None]:
# Specify your root directory
root_dir = '../data/LR_Landcover'

# Traverse directory tree
for dir_name, subdir_list, file_list in os.walk(root_dir):
    for file_name in file_list:
        # Check if the file is a tiff file
        if file_name.endswith('.tiff'):
            # Construct the full file path
            file_path = os.path.join(dir_name, file_name)
            # Convert the tiff file if needed
            convert_tif_dtype(file_path)

In [None]:
def generate_masks(df, sam, out_dir, revisit_num=1):
    # Create output directories for HR and LR masks
    hr_out_dir = os.path.join(out_dir, 'HR')
    lr_out_dir = os.path.join(out_dir, 'LR')

    os.makedirs(hr_out_dir, exist_ok=True)
    os.makedirs(lr_out_dir, exist_ok=True)

    # Add new columns for the mask paths
    df['HR_mask_path'] = None
    df['LR_mask_path'] = None

    for idx, row in df.iterrows():
        # Get image paths
        hr_image_path = row['HR_path']
        lr_image_path = f"{row['LR_path']}/{row['ID']}-{revisit_num}-L2A_data.tiff"
        
        # Generate unique IDs for masks, based on image ID
        hr_image_id = os.path.splitext(os.path.basename(hr_image_path))[0]
        lr_image_id = os.path.splitext(os.path.basename(lr_image_path))[0]
        hr_mask_path = os.path.join(hr_out_dir, f"{hr_image_id}_mask.tif")
        lr_mask_path = os.path.join(lr_out_dir, f"{lr_image_id}_mask.tif")
        
        # Generate masks
        sam.generate(hr_image_path, hr_mask_path, batch=True, foreground=True, 
                     erosion_kernel=(3, 3), mask_multiplier=255, unique=True)
        sam.generate(lr_image_path, lr_mask_path, batch=True, foreground=True, 
                     erosion_kernel=(3, 3), mask_multiplier=255, unique=True)
        
        # Save mask paths to the DataFrame
        df.loc[idx, 'HR_mask_path'] = hr_mask_path
        df.loc[idx, 'LR_mask_path'] = lr_mask_path

    # Save the updated DataFrame
    df.to_csv('sampled_dataframe_with_zeroshot_masks.csv', index=False)
    return df


import rasterio

rasterio.open("../data/LR_Landcover/UNHCR-MMRs035448/L2A/UNHCR-MMRs035448-5-L2A_data.tiff").read()

df.LR_path

generate_masks_zero_shot(df, sam, out_dir="sample_1")

In [None]:
image = "../data/HR_Landcover/Landcover-1000781/Landcover-1000781_rgb.png"
mask = "mask.tif"
sam.generate(
    image, mask, batch=True, foreground=True, erosion_kernel=(3, 3), mask_multiplier=255, unique = True
)

In [None]:
m = leafmap.image_comparison(
    image,
    "mask.tif",
    label1="Satellite Image",
    label2="Image Segmentation",
    show_labels=True,
)