In [3]:
import os
import pandas as pd
import shutil
from PIL import Image, ImageEnhance, ImageFilter
import torchvision.transforms as T
import matplotlib.pyplot as plt
import random


In [4]:
#check the number and size of picture

root_dir = "/Users/jqyang/Documents/comp9517/Aerial_Landscapes"
category_stats = {}

for category in os.listdir(root_dir):
    category_path = os.path.join(root_dir, category)
    if not os.path.isdir(category_path):
        continue  

    image_files = os.listdir(category_path)
    all_256 = True
    wrong_size_images = []

    for img_name in image_files:
        img_path = os.path.join(category_path, img_name)

        try:
            img = Image.open(img_path)
            if img.size != (256, 256):
                all_256 = False
                wrong_size_images.append((img_name, img.size))
        except Exception as e:
            print(f"ERROR: {img_path}, {e}")

    category_stats[category] = {
        'total': len(image_files),
        'all_256': all_256,
        'wrong_images': wrong_size_images
    }

for category, info in category_stats.items():
    print(f"\nCategory: {category}")
    print(f"Total images: {info['total']}")
    print(f"All 256x256: {info['all_256']}")
    if not info['all_256']:
        print("Wrong size images:")
        for name, size in info['wrong_images']:
            print(f"  {name}: {size}")


Category: Agriculture
Total images: 800
All 256x256: True

Category: Forest
Total images: 800
All 256x256: True

Category: River
Total images: 800
All 256x256: True

Category: City
Total images: 800
All 256x256: True

Category: Highway
Total images: 800
All 256x256: True

Category: Railway
Total images: 800
All 256x256: True

Category: Lake
Total images: 800
All 256x256: True

Category: Residential
Total images: 800
All 256x256: True

Category: Airport
Total images: 800
All 256x256: True

Category: Beach
Total images: 800
All 256x256: True

Category: Port
Total images: 800
All 256x256: True

Category: Mountain
Total images: 800
All 256x256: True

Category: Grassland
Total images: 800
All 256x256: True

Category: Desert
Total images: 800
All 256x256: True

Category: Parking
Total images: 800
All 256x256: True


In [5]:
#Dividing trainset and testset (8:2)


src_root = "/Users/jqyang/Documents/comp9517/Aerial_Landscapes" 
dst_root = "/Users/jqyang/Documents/comp9517/split_dataset"  

for split in ["train", "test"]:
    split_path = os.path.join(dst_root, split)
    os.makedirs(split_path, exist_ok = True)

random.seed(42) 

for category in os.listdir(src_root):
    category_path = os.path.join(src_root, category)
    if not os.path.isdir(category_path):
        continue

    image_files = os.listdir(category_path)
    random.shuffle(image_files)  #Chaotic order
#Calculate the dividing point
    split_idx = int(len(image_files) * 0.8)
    train_files = image_files[:split_idx]
    test_files = image_files[split_idx:]

    for split, file_list in [("train", train_files), ("test", test_files)]:
        dst_category_path = os.path.join(dst_root, split, category)
        os.makedirs(dst_category_path, exist_ok=True)

        for file in file_list:
            src_path = os.path.join(category_path, file)
            dst_path = os.path.join(dst_category_path, file)
            shutil.copy(src_path, dst_path)  # Copy the image and store

In [6]:
#generate .csv 
split_root = "/Users/jqyang/Documents/comp9517/split_dataset"
csv_records = []

for split in ['train', 'test']:
    split_path = os.path.join(split_root, split)
    for category in os.listdir(split_path):
        category_path = os.path.join(split_path, category)
        if not os.path.isdir(category_path):
            continue
        for fname in os.listdir(category_path):
            relative_path = os.path.join(split, category, fname)
            csv_records.append({
                'image_path': relative_path,
                'label': category
            })

df = pd.DataFrame(csv_records)
df_train = df[df['image_path'].str.startswith('train')]
df_test = df[df['image_path'].str.startswith('test')]

df_train.to_csv("train.csv", index=False)
df_test.to_csv("test.csv", index=False)

In [None]:
#check the distribution of transet and testset

import seaborn as sns
import matplotlib.pyplot as plt

df_train = pd.read_csv("train.csv")
df_test = pd.read_csv("test.csv")

plt.figure(figsize=(10,4))
sns.countplot(data=df_train, x='label', order=sorted(df_train['label'].unique()))
plt.title("Trainset Category Distribution")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

plt.figure(figsize=(10,4))
sns.countplot(data=df_test, x='label', order=sorted(df_test['label'].unique()))
plt.title("Testset Category Distribution")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
Image enhancement

In [12]:

source_dir = "/Users/jqyang/Documents/comp9517/split_dataset/train"
output_root = "/Users/jqyang/Documents/comp9517/augmented_dataset"
# Augmented Type 
augment_types = ["original", "flip", "rotate", "crop", "brightness", "blur"]

for aug in augment_types:
    for category in os.listdir(source_dir):
        os.makedirs(os.path.join(output_root, aug, category), exist_ok=True)

# Define augmented operations
def augment_image(img):
    width, height = img.size
    transforms = {}
    
    # 1. Horizontal flip
    transforms["flip"] = img.transpose(Image.FLIP_LEFT_RIGHT)
    
    # 2. Random rotation (-25 to +25 degrees)
    transforms["rotate"] = img.rotate(angle=T.RandomRotation.get_params([-25, 25]))
    
    # 3. Random cropping 
    crop_transform = T.RandomResizedCrop((height, width), scale=(0.5, 1.0))
    transforms["crop"] = crop_transform(img)
    
    # 4. Brightness adjustment
    enhancer = ImageEnhance.Brightness(img)
    transforms["brightness"] = enhancer.enhance(1.8) 
    
    # 5. Blur processing
    transforms["blur"] = img.filter(ImageFilter.GaussianBlur(radius=2))

    return transforms


for category in os.listdir(source_dir):
    category_path = os.path.join(source_dir, category)
    for fname in os.listdir(category_path):
        fpath = os.path.join(category_path, fname)

        try:
            img = Image.open(fpath).convert("RGB")

            # save original images
            save_path = os.path.join(output_root, "original", category, fname)
            img.save(save_path)

            # save the enhanced images
            augmented = augment_image(img)
            for aug_type, aug_img in augmented.items():
                base, ext = os.path.splitext(fname)
                aug_name = f"{base}_{aug_type}{ext}"
                aug_save_path = os.path.join(output_root, aug_type, category, aug_name)
                aug_img.save(aug_save_path)

        except Exception as e:
            print(f"ERROR: {fpath}, {e}")

In [None]:
#augmented image vs original image

root_dir = "/Users/jqyang/Documents/comp9517/augmented_dataset"
aug_types = os.listdir(root_dir)
aug_types = [a for a in aug_types if a != "original"] #exclude original category

plt.figure(figsize=(10, 5 * len(aug_types)))

for i, aug_type in enumerate(aug_types):
    aug_type_path = os.path.join(root_dir, aug_type)
    categories = os.listdir(aug_type_path)
    # Randomly select a category
    category = random.choice(categories)
    
    category_aug_path = os.path.join(aug_type_path, category)
    sample_imgs = os.listdir(category_aug_path)
    if not sample_imgs:
        continue

    # Randomly pick an augmented image
    sample_aug_name = random.choice(sample_imgs)
    sample_base, ext = os.path.splitext(sample_aug_name)
    sample_base = sample_base.replace(f"_{aug_type}", "")  

    original_path = os.path.join(root_dir, "original", category, f"{sample_base}{ext}")
    augmented_path = os.path.join(category_aug_path, sample_aug_name)

    try:
        img_original = Image.open(original_path)
        img_augmented = Image.open(augmented_path)

        # original
        plt.subplot(len(aug_types), 2, 2 * i + 1)
        plt.imshow(img_original)
        plt.axis('off')
        plt.title(f"{aug_type} | {category} - Original")

        # augmented
        plt.subplot(len(aug_types), 2, 2 * i + 2)
        plt.imshow(img_augmented)
        plt.axis('off')
        plt.title(f"{aug_type} | {category} - Augmented")

    except Exception as e:
        print(f"ERROR {aug_type}/{category}/{sample_aug_name}, {e}")

plt.tight_layout()
plt.show()

In [15]:
#generate .csv

aug_root = "/Users/jqyang/Documents/comp9517/augmented_dataset"
augment_types = os.listdir(aug_root)  
records = []

for aug_type in augment_types:
    aug_path = os.path.join(aug_root, aug_type)
    if not os.path.isdir(aug_path):
        continue
    for category in os.listdir(aug_path):
        category_path = os.path.join(aug_path, category)
        if not os.path.isdir(category_path):
            continue
        for fname in os.listdir(category_path):
            rel_path = os.path.join(aug_type, category, fname) 
            records.append({
                "image_path": rel_path,
                "label": category,
                "augmentation": aug_type
            })

df = pd.DataFrame(records)
df.to_csv("augmented_train.csv", index=False)