# **Land Type Classification using Sentinel-2 Satellite Images**

This project aims to automatically classify different types of land (such as
farms, bodies of water, urban areas, roads, and forests) from satellite imagery provided by the Sentinel-2 mission, which is an initiative run by the European Space Agency.

The outcome of the project will be a smart model that can look at satellite images and automatically label the type of land, which can save time, reduce manual effort, and provide insights at a national or global scale.




## **1- Problem Definition & Business Understanding**






Land classification from satellite images is a challenging yet crucial task. With the increasing availability of satellite data, especially multispectral imagery from sources like Sentinel-2, it's now possible to automatically identify what type of land is present in a given area (e.g., urban, agricultural, water, desert, etc.).

Manual land classification is time-consuming, expensive, and prone to human error. Automating this process using AI models offers a scalable, efficient, and more accurate alternative.

**Business Goal**: To develop a deep learning solution that classifies satellite images into meaningful land types, reducing manual effort and enabling timely, informed decisions.

## **2- Data Collection & Description**

For this project, we are using the EuroSAT dataset, which is based on imagery from the Sentinel-2 satellite provided by the European Space Agency (ESA). The dataset is publicly available on platforms like Kaggle and consists of pre-labeled images representing different land cover classes.

The EuroSAT dataset contains:

**Image size**: 64×64 pixels


**Number of classes: 10 land types**:
*   Annual Crop
*   Forest
*   Herbaceous Vegetation
*   Highway
*   Industrial
*   Pasture
*   Permanent Crop
*   Residential
*   River
*   Sea/Lake


In [None]:
from google.colab import files
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d apollo2506/eurosat-dataset
!unzip eurosat-dataset.zip -d eurosat_data

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: eurosat_data/EuroSATallBands/River/River_1990.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1991.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1992.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1993.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1994.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1995.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1996.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1997.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1998.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_1999.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_2.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_20.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_200.tif  
  inflating: eurosat_data/EuroSATallBands/River/River_2000.tif  
  inflating: eurosat_data/EuroS

In [None]:
!pip install rasterio

Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m76.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3


In [None]:
# Standard Library
import os
from glob import glob

# Data Processing
import numpy as np
import pandas as pd

# Image Processing
from PIL import Image
import rasterio  # For geospatial raster data
import cv2  # OpenCV for computer vision
import albumentations as A  # Image augmentations
from albumentations.augmentations import transforms
from albumentations.core.composition import OneOf

# Data Visualization
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Machine Learning
from sklearn.model_selection import train_test_split  # Data splitting

# Utilities
from tqdm import tqdm  # Progress bars

In [None]:
# Path to EuroSAT dataset
DATA_DIR = "/content/eurosat_data/EuroSATallBands"

In [None]:
# List and sort class names (each class is a folder)
classes = sorted(os.listdir(DATA_DIR))
print("Classes:", classes)

Classes: ['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake', 'label_map.json', 'test.csv', 'train.csv', 'validation.csv']


In [None]:
# Get all image paths with .tif extension inside class folders
image_paths = glob(os.path.join(DATA_DIR, "*/*.tif"))


# Create a DataFrame with image paths and their class labels
df = pd.DataFrame({
    "path": image_paths,
    "label": [os.path.basename(os.path.dirname(p)) for p in image_paths]
})

In [None]:
def random_images(df, n=1):
    """
    Returns n random images per class from the dataframe.
    """
    return df.groupby("label").sample(n=n)

### Sample Image's Spectral Bands


*   Each Sentinel-2 image contains 13 spectral bands capturing different wavelengths, including visible, near-infrared (NIR), and shortwave infrared (SWIR) regions. This rich spectral information enables the model to effectively distinguish between various land types based on their unique spectral signatures.




In [None]:
band_names = [
    'B01 - Coastal Aerosol',
    'B02 - Blue',
    'B03 - Green',
    'B04 - Red',
    'B05 - Red Edge 1',
    'B06 - Red Edge 2',
    'B07 - Red Edge 3',
    'B08 - NIR',
    'B8A - Narrow NIR',
    'B09 - Water Vapor',
    'B10 - SWIR - Cirrus',
    'B11 - SWIR 1',
    'B12 - SWIR 2'
]


In [None]:
# Select one random image from the dataset using our utility function
sample_df = random_images(df, n=1)
sample_path = sample_df['path'].iloc[0]

# Open the image using rasterio to inspect its band structure
with rasterio.open(sample_path) as src:
    # Print total number of spectral bands in the image
    print(f"Number of Bands: {src.count}")

    # Loop through each band and display its data type and shape
    print("Data For Each Band:")
    for i in range(1, src.count + 1):
        print(f"  Band {i}: dtype={src.dtypes[i - 1]}, shape={src.read(i).shape}")

Number of Bands: 13
Data For Each Band:
  Band 1: dtype=uint16, shape=(64, 64)
  Band 2: dtype=uint16, shape=(64, 64)
  Band 3: dtype=uint16, shape=(64, 64)
  Band 4: dtype=uint16, shape=(64, 64)
  Band 5: dtype=uint16, shape=(64, 64)
  Band 6: dtype=uint16, shape=(64, 64)
  Band 7: dtype=uint16, shape=(64, 64)
  Band 8: dtype=uint16, shape=(64, 64)
  Band 9: dtype=uint16, shape=(64, 64)
  Band 10: dtype=uint16, shape=(64, 64)
  Band 11: dtype=uint16, shape=(64, 64)
  Band 12: dtype=uint16, shape=(64, 64)
  Band 13: dtype=uint16, shape=(64, 64)


### Class Distribution


*   The dataset is fairly balanced, with most classes around 9–11%.

*   SeaLake has the highest share (~13%), which might slightly bias the model.

*    Pasture is the least represented (~7.25%), so it may need augmentation to
     avoid underfitting.



In [None]:
# Count the number of images per class
class_counts = df['label'].value_counts().reset_index()
class_counts.columns = ['Class', 'Count']

# Pie Chart
fig_pie = px.pie(class_counts,
                 names='Class',
                 values='Count',
                 title='Distribution of Images per Class',
                 hole=0.3)
fig_pie.show()

### Top Spectral Bands per Land Type
There is a clear distinction in band importance between land and water categories. This indicates that selecting specific bands—like SWIR and NIR for vegetation, and Coastal/Blue for water—can enhance model accuracy and reduce unnecessary data dimensions.

In [None]:
def band_stats(df, n_samples=20):
    band_stats = []
    sampled_df = random_images(df,n_samples)

    for _, row in sampled_df.iterrows():
        with rasterio.open(row['path']) as src:
            band_means = {}
            for i in range(1, src.count + 1):
                band_name = band_names[i - 1]
                band_data = src.read(i)
                band_means[band_name] = band_data.mean()

            band_stats.append({
                'label': row['label'],
                **band_means
            })

    return pd.DataFrame(band_stats)

In [None]:
band_df = band_stats(df)
grouped = band_df.groupby("label").mean().T

top_bands_named = grouped.apply(lambda x: x.sort_values(ascending=False).index[:6])
top_bands_named

label,AnnualCrop,Forest,HerbaceousVegetation,Highway,Industrial,Pasture,PermanentCrop,Residential,River,SeaLake
0,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B12 - SWIR 2,B01 - Coastal Aerosol
1,B07 - Red Edge 3,B07 - Red Edge 3,B08 - NIR,B07 - Red Edge 3,B07 - Red Edge 3,B07 - Red Edge 3,B07 - Red Edge 3,B07 - Red Edge 3,B07 - Red Edge 3,B02 - Blue
2,B08 - NIR,B08 - NIR,B07 - Red Edge 3,B08 - NIR,B08 - NIR,B08 - NIR,B08 - NIR,B08 - NIR,B08 - NIR,B03 - Green
3,B06 - Red Edge 2,B06 - Red Edge 2,B10 - SWIR - Cirrus,B06 - Red Edge 2,B10 - SWIR - Cirrus,B06 - Red Edge 2,B10 - SWIR - Cirrus,B06 - Red Edge 2,B06 - Red Edge 2,B04 - Red
4,B10 - SWIR - Cirrus,B10 - SWIR - Cirrus,B06 - Red Edge 2,B10 - SWIR - Cirrus,B06 - Red Edge 2,B10 - SWIR - Cirrus,B06 - Red Edge 2,B10 - SWIR - Cirrus,B10 - SWIR - Cirrus,B05 - Red Edge 1
5,B05 - Red Edge 1,B01 - Coastal Aerosol,B11 - SWIR 1,B01 - Coastal Aerosol,B01 - Coastal Aerosol,B01 - Coastal Aerosol,B11 - SWIR 1,B01 - Coastal Aerosol,B01 - Coastal Aerosol,B06 - Red Edge 2


# **3- Exploratory Data Analysis (EDA)**

###  Spectral Signature Visualization per Land Type
To better understand the distinct spectral characteristics of various land cover types, we visualized a sample image from each class along with its average spectral signature across all bands. This helps identify how different land types (e.g., urban, forest, sea lake) reflect or absorb energy across the spectrum.

We used the natural RGB bands (Bands 4, 3, 2 – Red, Green, Blue) to render the images, and plotted their corresponding spectral signatures using the mean reflectance values across all Sentinel-2 bands.


In [None]:
def get_rgb_image(data):
    """
    Create a natural RGB image using Bands 4, 3, 2 (Red, Green, Blue).
    """
    rgb = np.dstack((data[3], data[2], data[1]))  # Bands are 0-indexed
    rgb = rgb.astype(np.float32)

    # Normalize between 0 and 1 using fixed scaling factor
    rgb /= 2750
    rgb = np.clip(rgb, 0, 1)

    return rgb

In [None]:
# Group the extracted band statistics by label
grouped = band_df.groupby('label')

# Iterate through each land type to display a sample image and its spectral signature
for landcover, group in grouped:
    image_path = df[df['label'] == landcover].sample(1)['path'].values[0]
    mean_signature = group[band_names].mean().values

    # Make The Image RGB
    with rasterio.open(image_path) as src:
        data = src.read()
        rgb_img = get_rgb_image(data)

    # Making Supplots
    fig = make_subplots(rows=1, cols=2, subplot_titles=(f'{landcover}' ,'Spectral Signature'))

    fig.add_trace(go.Image(z=(rgb_img * 255).astype(np.uint8)),row=1, col=1)

    fig.add_trace(go.Scatter(
            x=band_names,
            y=mean_signature,
            mode='lines+markers',
            name=landcover,
            line=dict(width=2)),
        row=1, col=2
    )

    fig.update_layout(
        title=f'Spectral Signature Analysis - {landcover}',
        height=500,
        hovermode='x unified',
        showlegend=False
    )

    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=1, col=1)

    fig.show()


### Image Pixel Analysis
#### Key Patterns
1. **Water Identification**
   - Ultra-low values in NIR/SWIR bands
   - B08 < 500 and B12 < 300 = Strong water signal

2. **Vegetation Health**
   - Forests show highest NIR (B08: ~2500)
   - Crops show strong red-edge (B05: 1500-1700)

3. **Urban Detection**
   - Highest SWIR1 values (B11: 1500+)
   - Consistent high reflectance across all bands

In [None]:
def pixel_data(df, num_samples=5, pixels_per_band=500):
    result=[]

    for label in df['label'].unique():
        class_data = df[df['label'] == label]
        samples = class_data.sample(num_samples)

        for i, row in samples.iterrows():
            with rasterio.open(row['path']) as img:
                bands = img.read()

                for band_num in range(bands.shape[0]):
                    pixels = bands[band_num].flatten()
                    sampled_pixels = np.random.choice(pixels, size=pixels_per_band)
                    band_name = band_names[band_num]

                    for val in sampled_pixels:
                        result.append({'label': label, 'band': band_name, 'value': float(val)})

    return pd.DataFrame(result)

In [None]:
df_pixels = pixel_data(df)

In [None]:
heatmap_data = df_pixels.groupby(['label', 'band'])['value'].mean().unstack()

plt.figure(figsize=(12, 6))
sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap='viridis')
plt.title('Mean Pixel Values by Band and Land Type')
plt.show()

NameError: name 'sns' is not defined

<Figure size 1200x600 with 0 Axes>

#  **4- Preprocessing and Feature Engineering**

### Selecting Top 6 Bands per Class
In this step, we:
1. Prioritize important bands for NDVI/NDBI calculations (Red, NIR, SWIR 1, SWIR 2).
2. Complete the band set to exactly **6 bands** by selecting the top mean bands per class.
3. Save new images containing only these 6 bands for each class.
4. Store paths and labels in a new DataFrame.

In [None]:
OUTPUT_DIR = '/content/processed_images'

# Priority bands for NDVI/NDBI, always included
mandatory_bands = ['B04 - Red', 'B08 - NIR', 'B11 - SWIR 1']
label_to_band_indices = {}

for label in top_bands_named.columns:
    top_bands = top_bands_named[label]
    # Start with mandatory NDVI/NDBI bands
    final_bands = mandatory_bands.copy()

    # Add top bands until we reach 6 total
    for b in top_bands:
        if b not in final_bands:
            final_bands.append(b)
        if len(final_bands) == 6:
            break

    # Convert band names to their corresponding indices
    band_idxs = sorted([band_names.index(band) for band in final_bands])
    label_to_band_indices[label] = band_idxs

# Save new processed images with only 6 bands
os.makedirs(OUTPUT_DIR, exist_ok=True)
new_paths = []

print('Saving new processed 6-band images')

for i, row in tqdm(df.iterrows(), total=len(df)):
    label = row['label']
    path = row['path']
    band_idxs = label_to_band_indices[label]

    with rasterio.open(path) as src:
        # Read only the selected bands
        bands = [src.read(i + 1) for i in band_idxs]
        stacked = np.stack(bands)

        # Create class-specific subfolder
        class_dir = os.path.join(OUTPUT_DIR, label)
        os.makedirs(class_dir, exist_ok=True)

        filename = os.path.basename(path)
        new_path = os.path.join(class_dir, filename)

        # Update profile to match 6-band format
        profile = src.profile
        profile.update(count=6)

        with rasterio.open(new_path, 'w', **profile) as dst:
            for i in range(6):
                dst.write(stacked[i], i + 1)

        new_paths.append((new_path, label))

# Final DataFrame
df_processed = pd.DataFrame(new_paths, columns=['path', 'label'])

print('all images saved with 6 bands')
df_processed.head()

Saving new processed 6-band images


100%|██████████| 27597/27597 [03:39<00:00, 125.76it/s]

all images saved with 6 bands





Unnamed: 0,path,label
0,/content/processed_images/Highway/Highway_1017...,Highway
1,/content/processed_images/Highway/Highway_861.tif,Highway
2,/content/processed_images/Highway/Highway_31.tif,Highway
3,/content/processed_images/Highway/Highway_708.tif,Highway
4,/content/processed_images/Highway/Highway_2083...,Highway


### Min-Max Scaling
normalize each band in the processed 6-band images individually using **Min-Max Scaling**. This scales each band’s pixel values to a 0–1 range, which helps improve model performance and ensures consistency when computing indices like NDVI or NDBI.

In [None]:
SCALED_DIR = '/content/processed_images_scaled'
os.makedirs(SCALED_DIR, exist_ok=True)

scaled_paths = []

print('Saving scaled 6-band images')

for i, row in tqdm(df_processed.iterrows(), total=len(df_processed)):
    image_path = row['path']
    label = row['label']

    with rasterio.open(image_path) as src:
        data = src.read().astype(np.float32)

        # Min-Max scaling per band
        for b in range(data.shape[0]):
            band = data[b]
            min_val = band.min()
            max_val = band.max()
            if max_val > min_val:
                data[b] = (band - min_val) / (max_val - min_val)
            else:
                data[b] = band

        # Save scaled image
        class_dir = os.path.join(SCALED_DIR, label)
        os.makedirs(class_dir, exist_ok=True)

        filename = os.path.basename(image_path)
        scaled_path = os.path.join(class_dir, filename)

        profile = src.profile
        profile.update(dtype=rasterio.float32)

        with rasterio.open(scaled_path, 'w', **profile) as dst:
            dst.write(data)

        scaled_paths.append((scaled_path, label))

# Final scaled dataframe
final_df = pd.DataFrame(scaled_paths, columns=['path', 'label'])

print('All scaled 6-band images saved')
final_df.head()


Saving scaled 6-band images


100%|██████████| 27597/27597 [03:05<00:00, 148.55it/s]


All scaled 6-band images saved


Unnamed: 0,path,label
0,/content/processed_images_scaled/Highway/Highw...,Highway
1,/content/processed_images_scaled/Highway/Highw...,Highway
2,/content/processed_images_scaled/Highway/Highw...,Highway
3,/content/processed_images_scaled/Highway/Highw...,Highway
4,/content/processed_images_scaled/Highway/Highw...,Highway


### NDVI
*  we are creating new images that help us understand how much vegetation (like trees, grass, and crops) is in each satellite image.

*  These new images are called **NDVI images**, It’s a special way of highlighting green areas from satellite pictures using light that plants reflect.


In [None]:
NDVI_dir='/content/ndvi_scaled_images'

def ndvi_images(df):

    os.makedirs(NDVI_dir, exist_ok=True)
    ndvi_paths = []

    print('Generating NDVI images')

    for i, row in tqdm(df.iterrows(), total=len(df)):
        image_path = row['path']
        label = row['label']

        with rasterio.open(image_path) as src:
            data = src.read().astype(np.float32)

            # NDVI = (NIR - Red) / (NIR + Red)
            red = data[0]
            nir = data[1]

            ndvi = (nir - red) / (nir + red + 1e-6)
            ndvi = np.clip(ndvi, -1, 1)

            # Save NDVI image
            ndvi_path = os.path.join(NDVI_dir, os.path.basename(image_path))
            profile = src.profile
            profile.update(count=1, dtype='float32')

            with rasterio.open(ndvi_path, 'w', **profile) as dst:
                dst.write(ndvi, 1)

            ndvi_paths.append(ndvi_path)

    df['ndvi_path'] = ndvi_paths
    print('NDVI images generated')

    return df

In [None]:
    final_df = ndvi_images(final_df)
    final_df.head()

Generating NDVI images


100%|██████████| 27597/27597 [02:48<00:00, 163.30it/s]

NDVI images generated





Unnamed: 0,path,label,ndvi_path
0,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_1017.tif
1,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_861.tif
2,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_31.tif
3,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_708.tif
4,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_2083.tif


### NDBI
*   we are creating new images that help us understand how much built-up or urban area (like buildings, roads, and cities) is in each satellite image.

*   These new images are called **NDBI images**, It’s a special way of highlighting man-made structures in satellite images using light reflected by buildings and surfaces.

In [None]:
NDBI_dir = '/content/ndbi_scaled_images'

def ndbi_images(df, nir_band_index=1, swir_band_index=2):
    os.makedirs(NDBI_dir, exist_ok=True)
    ndbi_paths = []

    print('Generating NDBI images')

    for i, row in tqdm(df.iterrows(), total=len(df)):
        image_path = row['path']
        label = row['label']

        with rasterio.open(image_path) as src:
            data = src.read().astype(np.float32)

            nir = data[1]
            swir = data[2]

            ndbi = (swir - nir) / (swir + nir + 1e-6)
            ndbi = np.clip(ndbi, -1, 1)

            ndbi_path = os.path.join(NDBI_dir, os.path.basename(image_path))
            profile = src.profile
            profile.update(count=1, dtype='float32')

            with rasterio.open(ndbi_path, 'w', **profile) as dst:
                dst.write(ndbi, 1)

            ndbi_paths.append(ndbi_path)

    df['ndbi_path'] = ndbi_paths
    print('NDBI images generated')
    return df


In [None]:
    ndbi_images(final_df)
    final_df.head()

Generating NDBI images


100%|██████████| 27597/27597 [02:43<00:00, 168.95it/s]


NDBI images generated


Unnamed: 0,path,label,ndvi_path,ndbi_path
0,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_1017.tif,/content/ndbi_scaled_images/Highway_1017.tif
1,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_861.tif,/content/ndbi_scaled_images/Highway_861.tif
2,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_31.tif,/content/ndbi_scaled_images/Highway_31.tif
3,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_708.tif,/content/ndbi_scaled_images/Highway_708.tif
4,/content/processed_images_scaled/Highway/Highw...,Highway,/content/ndvi_scaled_images/Highway_2083.tif,/content/ndbi_scaled_images/Highway_2083.tif


### Splitting the Dataset: Training, Validation, and Testing
we divided our satellite image dataset into three smaller sets:

*   **Training Set**: This is the main set used to teach our model. It learns patterns from these images.

*   **Validation Set**: While the model is learning, we use this set to check how well it's doing and adjust settings if needed.

*   **Test Set**: After the model finishes learning, we test it on this final set to see how well it performs on completely new data.

each of these sets has a balanced mix of all the different land types (like forests, rivers, etc.).

In [None]:
train_df, temp_df = train_test_split(
    final_df,
    test_size=0.3,
    random_state=42,
    stratify=final_df['label']
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
    stratify=temp_df['label']
)

print(f'Train size: {len(train_df)}')
print(f'Validation size: {len(val_df)}')
print(f'Test size: {len(test_df)}')

Train size: 19317
Validation size: 4140
Test size: 4140


### Image Augmentation
*   Enhances model training by creating diverse image variations.
*   Class-Balanced: Prevents model bias      








In [None]:
AUG_DIR = '/content/augmented_images'

# Augmentation pipeline
augment = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=20, p=0.3)
])

def apply_augmentation(train_df, max_aug_ratio=0.5):
    os.makedirs(AUG_DIR, exist_ok=True)
    new_paths = []
    new_labels = []

    # Get mean number of images per class
    class_counts = train_df['label'].value_counts()
    mean_count = int(class_counts.mean())

    print(f'Mean class size: {mean_count}')

    for label in class_counts.index:
        class_df = train_df[train_df['label'] == label]
        current_count = len(class_df)

        # Decide how many images to augment
        if current_count < mean_count:
            target = min(mean_count - current_count, int(current_count * max_aug_ratio))
        else:
            target = int(current_count * 0.1)  # apply light augmentation to big classes

        print(f'\n Augmenting class {label} with {target} new images')

        class_folder = os.path.join(AUG_DIR, label)
        os.makedirs(class_folder, exist_ok=True)

        for i in range(target):
            # Randomly select one image from the class
            row = class_df.sample(1).iloc[0]
            image_path = row['path']

            with rasterio.open(image_path) as src:
                image = src.read().astype(np.uint16)
                profile = src.profile

            # Apply augmentation on each band separately
            augmented = []
            for band in image:
                aug = augment(image=band)
                augmented.append(aug['image'])

            augmented = np.stack(augmented)

            # Save augmented image
            aug_filename = f'{label}_aug_{i}.tif'
            aug_path = os.path.join(class_folder, aug_filename)

            profile.update(dtype='uint16')

            with rasterio.open(aug_path, 'w', **profile) as dst:
                dst.write(augmented)

            new_paths.append(aug_path)
            new_labels.append(label)

    # Append to original train_df
    aug_df = pd.DataFrame({
        'path': new_paths,
        'label': new_labels
    })

    train_df_augmented = pd.concat([train_df, aug_df], ignore_index=True)
    print(f'\n Augmentation done. Final train size: {len(train_df_augmented)}')

    return train_df_augmented



ShiftScaleRotate is a special case of Affine transform. Please use Affine transform instead.



In [None]:
train_df = apply_augmentation(train_df)

Mean class size: 1931

 Augmenting class SeaLake with 251 new images

 Augmenting class AnnualCrop with 210 new images

 Augmenting class Forest with 210 new images

 Augmenting class Residential with 210 new images

 Augmenting class HerbaceousVegetation with 210 new images

 Augmenting class Highway with 181 new images

 Augmenting class Industrial with 181 new images

 Augmenting class River with 181 new images

 Augmenting class PermanentCrop with 181 new images

 Augmenting class Pasture with 531 new images

 Augmentation done. Final train size: 21663
