<a href="https://colab.research.google.com/github/dgizdevans/master/blob/main/ai_project/data_sorter_for_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
from google.colab import auth
from google.cloud import storage

In [None]:
# Authenticate and initialize the Google Cloud client
auth.authenticate_user()
project_id = "ai-group-project"  # Replace with your Google Cloud project ID
client = storage.Client(project=project_id)
bucket_name = "ai-group-project-data"  # Replace with your bucket name
bucket = client.get_bucket(bucket_name)

In [None]:
# Define source and target paths
labeled_data_img_path = "data/labeled_data/images/Train"  # Path in the bucket where labeled data (images) is stored
labeled_data_labels_path = "data/labeled_data/labels/Train"  # Path in the bucket where labeled data (labels) is stored
source_data_yaml_path = "data/labeled_data/data.yaml"
target_path = "datasets/model"  # Path in the bucket for train/val/test datasets

In [None]:
# Define split ratios
split_ratios = {'train': 0.6, 'val': 0.2, 'test': 0.2}

In [None]:
# Get the list of images and corresponding labels
blobs_images = list(bucket.list_blobs(prefix=labeled_data_img_path))
blobs_labels = list(bucket.list_blobs(prefix=labeled_data_labels_path))

In [None]:
# Ensure that only image and label pairs are processed
images = [blob.name for blob in blobs_images if blob.name.endswith(('.jpg', '.png'))]
labels = [blob.name for blob in blobs_labels if blob.name.endswith('.txt')]

In [None]:
# Map images to their corresponding labels
base_image_names = {os.path.basename(img).split('.')[0]: img for img in images}
base_label_names = {os.path.basename(lbl).split('.')[0]: lbl for lbl in labels}

In [None]:
# Filter out images that do not have corresponding labels
paired_images_labels = [
    (base_image_names[key], base_label_names[key])
    for key in base_image_names if key in base_label_names
]

In [None]:
# Shuffle and split the data
random.shuffle(paired_images_labels)
total_count = len(paired_images_labels)

splits = {
    'train': paired_images_labels[:int(split_ratios['train'] * total_count)],
    'val': paired_images_labels[int(split_ratios['train'] * total_count):int((split_ratios['train'] + split_ratios['val']) * total_count)],
    'test': paired_images_labels[int((split_ratios['train'] + split_ratios['val']) * total_count):]
}

In [None]:
# Helper function to copy files in GCS
def copy_blob(bucket, source_blob_name, destination_blob_name):
    source_blob = bucket.blob(source_blob_name)
    destination_blob = bucket.blob(destination_blob_name)
    destination_blob.rewrite(source_blob)

In [None]:
# Initialize statistics
stats = {'train': {'images': 0, 'labels': 0},
         'val': {'images': 0, 'labels': 0},
         'test': {'images': 0, 'labels': 0}}

In [None]:
# Distribute the files
for split, data in splits.items():
    for image_path, label_path in data:
        # Copy images
        split_image_path = os.path.join(target_path, split, 'images', os.path.basename(image_path))
        copy_blob(bucket, image_path, split_image_path)
        stats[split]['images'] += 1
        # Copy labels
        split_label_path = os.path.join(target_path, split, 'labels', os.path.basename(label_path))
        copy_blob(bucket, label_path, split_label_path)
        stats[split]['labels'] += 1

In [None]:
# Copy data.yaml file
target_data_yaml_path = os.path.join(target_path, "data.yaml")
copy_blob(bucket, source_data_yaml_path, target_data_yaml_path)

In [None]:
# Display statistics
print("\nData distribution statistics:")
for split in stats:
    print(f"{split.capitalize()} set:")
    print(f"  Images: {stats[split]['images']}")
    print(f"  Labels: {stats[split]['labels']}")


Data distribution statistics:
Train set:
  Images: 488
  Labels: 488
Val set:
  Images: 163
  Labels: 163
Test set:
  Images: 163
  Labels: 163


In [None]:
print("\nData has been successfully distributed into train, val, and test sets.")


Data has been successfully distributed into train, val, and test sets.
