In [None]:
!pip install rasterio

In [None]:
import numpy as np
import rasterio
from rasterio.windows import Window
from rasterio.windows import bounds # Correct import
import glob
import os
import pyproj
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import sys
import time

In [None]:
# Path to your Sentinel-2 .SAFE.zip file.
zip_path = '/content/S2B_MSIL2A_20250405T042659_N0511_R133_T46QBM_20250405T063356.SAFE.zip'

# Directory where you want to extract the files.
extract_path = '/content/extracted_data'

# Create the extraction directory if it doesn't already exist.
if not os.path.exists(extract_path):
    os.makedirs(extract_path)

# Open the zip file and extract all its contents.
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Extraction complete. Files are extracted to: {extract_path}")

In [None]:
kiln_locations=pd.read_excel('/content/all_kilns_dec102024.xlsx')

In [None]:
safe_dir = '/content/extracted_data/S2B_MSIL2A_20250405T042659_N0511_R133_T46QBM_20250405T063356.SAFE'

patch_size = 64
stride = 64
SENTINEL_SCALE_FACTOR = 10000.0
random_seed = 42


output_dir = "/content/prepared_datasets_balanced_1_1"
os.makedirs(output_dir, exist_ok=True)
train_output_file = os.path.join(output_dir, "kiln_train_balanced_1_1.npz")
val_output_file = os.path.join(output_dir, "kiln_val_balanced_1_1.npz")
test_output_file = os.path.join(output_dir, "kiln_test_balanced_1_1.npz")

# Step 1: Extract Patches and Metadata
print("--- Starting Patch Extraction & Metadata Collection ---")
start_time = time.time()

try:
    band_paths = [
        glob.glob(os.path.join(safe_dir, 'GRANULE', '*', 'IMG_DATA', 'R10m', '*_B04_10m.jp2'))[0], # Red
        glob.glob(os.path.join(safe_dir, 'GRANULE', '*', 'IMG_DATA', 'R10m', '*_B03_10m.jp2'))[0], # Green
        glob.glob(os.path.join(safe_dir, 'GRANULE', '*', 'IMG_DATA', 'R10m', '*_B02_10m.jp2'))[0], # Blue
    ]
    print("Located band files.")
except IndexError:
    print(f"Error: Could not find one or more band files in {safe_dir} structure.")
    sys.exit(1)
except Exception as e:
    print(f"An error occurred locating band files: {e}")
    sys.exit(1)

srcs = [rasterio.open(path) for path in band_paths]
transform = srcs[0].transform
crs = srcs[0].crs
width = srcs[0].width
height = srcs[0].height

try:
    to_latlon = pyproj.Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
except Exception as e:
    print(f"Error initializing pyproj Transformer: {e}")
    sys.exit(1)

if 'kiln_locations' not in locals() or not isinstance(kiln_locations, pd.DataFrame):
     print("Error: kiln_locations DataFrame not found or not loaded.")
     sys.exit(1)

all_patch_data_list = []
all_patch_metadata_list = []

print(f"Extracting patches...")
# Loop over the image by patch
for i, (row_offset, col_offset) in enumerate(
    [(r, c) for r in range(0, height - patch_size + 1, stride)
            for c in range(0, width - patch_size + 1, stride)]
):
    window = Window(col_off=col_offset, row_off=row_offset, width=patch_size, height=patch_size)
    patch_data = np.zeros((patch_size, patch_size, len(srcs)), dtype=np.float32)
    valid_bands = True
    for band_idx, src in enumerate(srcs):
        band_data = src.read(1, window=window).astype(np.float32)
        if band_data.size == 0 or np.all(band_data == 0):
             valid_bands = False
             break
        scaled_band_data = (band_data / SENTINEL_SCALE_FACTOR) * 255.0
        scaled_band_data = np.clip(scaled_band_data, 0, 255)
        patch_data[..., band_idx] = scaled_band_data

    if not valid_bands or np.all(patch_data == 0):
        continue

    patch_bounds_coords = bounds(window, transform=transform)
    left, bottom, right, top = patch_bounds_coords
    lon_left, lat_bottom = to_latlon.transform(left, bottom)
    lon_right, lat_top = to_latlon.transform(right, top)

    all_patch_data_list.append(patch_data)
    all_patch_metadata_list.append({
        "lon_left": lon_left, "lat_bottom": lat_bottom,
        "lon_right": lon_right, "lat_top": lat_top
    })

    if (i + 1) % 5000 == 0:
         print(f"  Extracted {i + 1} potential patches...")

for src in srcs:
    src.close()

if not all_patch_data_list:
    print("Error: No valid patches were extracted. Exiting.")
    sys.exit(1)

print(f"Finished extraction. Created {len(all_patch_data_list)} valid patches.")
print(f"Patch extraction took: {time.time() - start_time:.2f} seconds.")

#  Step 2: Convert Patch Data to NumPy Array
print("\nConverting patch data to NumPy array...")
start_time = time.time()
X_all = np.array(all_patch_data_list)
del all_patch_data_list # Free memory
print(f"Conversion took: {time.time() - start_time:.2f} seconds.")
print(f"X_all shape: {X_all.shape}")

#  Step 3: Create Metadata DataFrame & Assign Labels
print("\nCreating metadata DataFrame and assigning labels...")
start_time = time.time()
df_all_metadata = pd.DataFrame(all_patch_metadata_list)
df_all_metadata['patch_index'] = df_all_metadata.index

y_all = np.zeros(len(df_all_metadata), dtype=np.int32)

kiln_patch_indices = set()
print(f"Matching {len(kiln_locations)} kiln locations to patches...")
for _, kiln_row in kiln_locations.iterrows():
    kiln_lon = kiln_row["Longitude"]
    kiln_lat = kiln_row["Latitude"]
    contained_in = df_all_metadata[
        (df_all_metadata['lon_left'] <= kiln_lon) & (df_all_metadata['lon_right'] >= kiln_lon) &
        (df_all_metadata['lat_bottom'] <= kiln_lat) & (df_all_metadata['lat_top'] >= kiln_lat)
    ]
    kiln_patch_indices.update(contained_in['patch_index'].tolist())

if kiln_patch_indices:
    y_all[list(kiln_patch_indices)] = 1

print(f"Label assignment took: {time.time() - start_time:.2f} seconds.")
print(f"y_all shape: {y_all.shape}")
unique_all, counts_all = np.unique(y_all, return_counts=True)
print(f"Overall class distribution: {dict(zip(unique_all, counts_all))}")

#  Step 4: Balance the Entire Dataset (1:1 Undersampling)
print("\nPerforming 1:1 undersampling on the full dataset...")
start_time = time.time()

indices_class_0 = np.where(y_all == 0)[0]
indices_class_1 = np.where(y_all == 1)[0]

n_class_0_all = len(indices_class_0)
n_class_1_all = len(indices_class_1)

if n_class_1_all == 0:
     print("Error: No kiln samples found in the dataset. Cannot balance.")
     sys.exit(1)

n_majority_keep = n_class_1_all

if n_class_0_all < n_majority_keep:
    print(f"Warning: Fewer non-kiln samples ({n_class_0_all}) than kiln samples ({n_class_1_all}). "
          f"Using all non-kiln samples and sampling kiln samples down.")
    n_majority_keep = n_class_0_all
    selected_minority_indices = np.random.choice(indices_class_1, size=n_majority_keep, replace=False, seed=random_seed)
    selected_majority_indices = indices_class_0
else:
    print(f"Keeping all {n_class_1_all} kiln samples.")
    print(f"Randomly selecting {n_majority_keep} non-kiln samples for a 1:1 balance.")
    np.random.seed(random_seed)
    selected_majority_indices = np.random.choice(indices_class_0, size=n_majority_keep, replace=False)
    selected_minority_indices = indices_class_1

balanced_indices = np.concatenate([selected_minority_indices, selected_majority_indices])

X_balanced = X_all[balanced_indices]
y_balanced = y_all[balanced_indices]

X_balanced, y_balanced = shuffle(X_balanced, y_balanced, random_state=random_seed)

print(f"Balancing took: {time.time() - start_time:.2f} seconds.")
print(f"\nTotal BALANCED dataset shape: X={X_balanced.shape}, y={y_balanced.shape}")
unique_balanced, counts_balanced = np.unique(y_balanced, return_counts=True)
print(f"Total BALANCED dataset class distribution: {dict(zip(unique_balanced, counts_balanced))}")

del X_all
del y_all

#  Step 5: Split Balanced Data into Train/Validation/Test Sets
print("\nSplitting BALANCED data into Train (70%), Validation (15%), Test (15%)...")

# First split: Train (70%) and Temp (30%)
X_train, X_temp, y_train, y_temp = train_test_split(
    X_balanced, y_balanced,
    test_size=0.30,
    random_state=random_seed,
    stratify=y_balanced
)


X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp,
    test_size=0.50,
    random_state=random_seed,
    stratify=y_temp
)

print("\nFinal Dataset Shapes:")
print(f"  Train: X={X_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, y={y_test.shape}")

# Check distribution in splits (should be roughly 50/50)
unique_train, counts_train = np.unique(y_train, return_counts=True)
print(f"  Train distribution: {dict(zip(unique_train, counts_train))}")
unique_val, counts_val = np.unique(y_val, return_counts=True)
print(f"  Val distribution:   {dict(zip(unique_val, counts_val))}")
unique_test, counts_test = np.unique(y_test, return_counts=True)
print(f"  Test distribution:  {dict(zip(unique_test, counts_test))}")

#  Step 6: Save the Final Datasets
print("\n--- Saving Final Datasets ---")
try:
    print(f"Saving BALANCED training set to {train_output_file}...")
    np.savez_compressed(train_output_file, X_train=X_train, y_train=y_train)

    print(f"Saving BALANCED validation set to {val_output_file}...")
    np.savez_compressed(val_output_file, X_val=X_val, y_val=y_val)

    print(f"Saving BALANCED test set to {test_output_file}...")
    np.savez_compressed(test_output_file, X_test=X_test, y_test=y_test)

    print("\nAll dataset splits saved successfully.")
except Exception as e:
    print(f"Error saving dataset files: {e}")

print("\nPreprocessing and dataset creation complete.")