In [None]:
import os
import logging
import rasterio
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime
from pathlib import Path
from sklearn.model_selection import train_test_split

2024-11-15 18:19:23.096527: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-15 18:19:23.096563: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-15 18:19:23.096600: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### Logging

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

### Load Dataset

In [None]:
def load_geotiff(file_path):
    """Load and preprocess a GeoTIFF file using rasterio instead of geotiff package"""
    try:
        with rasterio.open(file_path) as src:
            # Read all bands
            array = src.read()
            
            # Get nodata value
            nodata = src.nodata or 32767  # Use 32767 as default nodata value 
            
            # Convert to float32 for processing
            array = array.astype(np.float32)
            
            # Handle nodata values
            array[array == nodata] = 0
            
            # Transpose from (bands, height, width) to (height, width, bands)
            array = np.transpose(array, (1, 2, 0))
            
            # Split into image and mask
            image = array[:, :, :3]
            mask = array[:, :, 3:4]  # Keep the mask as (H,W,1)
            
            # Normalize image to [0,1] range
            image = image / 255.0
            image = np.clip(image, 0, 1)
            
            # Binarize mask
            mask = (mask > 0).astype(np.float32)
            
            # Check if mask has any positive pixels
            if not np.any(mask):
                return None, None
            
            return image, mask
            
    except Exception as e:
        logger.error(f"Error processing {file_path}: {str(e)}")
        return None, None
            
    except Exception as e:
        logger.error(f"Error processing {file_path}: {str(e)}")
        return None

def prepare_dataset(data_dir, image_size=(128, 128)):
    """Prepare dataset for training"""
    images = []
    masks = []
    file_paths = []
    data_dir = Path(data_dir)
    logger.info(f"Looking for TIF files in: {data_dir}")
    
    tif_files = list(data_dir.glob('**/*.[Tt][Ii][Ff]')) + \
                list(data_dir.glob('**/*.[Tt][Ii][Ff][Ff]'))
    
    logger.info(f"Found {len(tif_files)} TIF files")
    
    for i, file_path in enumerate(tif_files):
        if i % 100 == 0:
            logger.info(f"Processing file {i+1}/{len(tif_files)}")
            
        img, mask = load_geotiff(str(file_path))
        if img is not None and mask is not None:
            # Resize image and mask
            img = tf.image.resize(img, image_size)
            mask = tf.image.resize(mask, image_size, method='nearest')
            
            # Ensure float32 dtype
            img = tf.cast(img, tf.float32)
            mask = tf.cast(mask, tf.float32)
            
            images.append(img)
            masks.append(mask)
            file_paths.append(file_path)
    
    if not images:
        raise ValueError(f"No valid images found in {data_dir}")
    
    images_array = np.array(images)
    masks_array = np.array(masks)
    
    logger.info(f"Final dataset shape: {images_array.shape}, masks shape: {masks_array.shape}")
    logger.info(f"Number of images with masks: {len(images)}")
    
    return images_array, masks_array, file_paths

def visualize_samples(images, file_paths, title, num_samples=16, save_path='/app/plots/'):
    """Visualize a grid of sample images"""
    plt.close('all')
    
    rows = int(np.sqrt(num_samples))
    cols = int(np.ceil(num_samples / rows))

    fig = plt.figure(figsize=(20, 20))
    
    indices = np.random.choice(len(images), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        plt.subplot(rows, cols, i + 1)
        
        img = images[idx]
        filename = Path(file_paths[idx]).name
        
        mean_val = np.mean(img)
        std_val = np.std(img)
        min_val = np.min(img)
        max_val = np.max(img)
        
        # Display the image
        plt.imshow(img)
        plt.title(f'Title: {title}\n' +
                f'File: {filename}\n' + 
                 f'Range: [{min_val:.3f}, {max_val:.3f}]')
        plt.axis('off')
    
    plt.tight_layout()
    
    plt.show()

In [None]:
IMAGE_SIZE = (128, 128)
logger.info("Loading dataset...")
X, y, file_paths = prepare_dataset('/app/data', IMAGE_SIZE)

### Train/Test Split

In [None]:
BATCH_SIZE = 32
EPOCHS = 100

X_train, X_val, y_train, y_val = train_test_split(
        X, y, 
        test_size=0.2, 
        random_state=42
    )
logger.info(f"Training set: {X_train.shape}, {y_train.shape}")
logger.info(f"Validation set: {X_val.shape}, {y_val.shape}")