In [None]:
import os
import boto3
import shutil
import random
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from sklearn.model_selection import train_test_split

In [None]:
DATASET_PATH = "basketball-dataset"
LOCAL_DATASET_PATH = "basketball-dataset/"

TRAIN_RATIO = 0.80
VAL_RATIO = 0.15
TEST_RATIO = 0.05

AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
AWS_S3_BUCKET = os.environ.get("AWS_S3_BUCKET")
AWS_S3_ENDPOINT = os.environ.get("AWS_S3_ENDPOINT")

In [None]:
session = boto3.Session(
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
client = session.client("s3", endpoint_url=AWS_S3_ENDPOINT)

In [None]:
def count_files(bucket_name, folder):
    paginator = client.get_paginator('list_objects_v2')
    operation_parameters = {'Bucket': bucket_name, 'Prefix': folder}
    page_iterator = paginator.paginate(**operation_parameters)
    total_files = 0
    for page in page_iterator:
        if 'Contents' in page:
            total_files += len(page['Contents'])
    return total_files

def download_file(bucket_name, key, local_file_path):
    try:
        os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
        client.download_file(bucket_name, key, local_file_path)
    except Exception as e:
        print(f"Error downloading {key}: {e}")

def download_folder(bucket_name, folder, local_dir=None):
    if local_dir is None:
        local_dir = os.getcwd()

    os.makedirs(local_dir, exist_ok=True)

    # Count the total number of files
    total_files = count_files(bucket_name, folder)

    paginator = client.get_paginator('list_objects_v2')
    operation_parameters = {'Bucket': bucket_name, 'Prefix': folder}
    page_iterator = paginator.paginate(**operation_parameters)

    # Use ThreadPoolExecutor to download files concurrently
    with ThreadPoolExecutor(max_workers=10) as executor:
        with tqdm(total=total_files, desc="Downloading files", unit="file") as pbar:
            for page in page_iterator:
                if 'Contents' in page:
                    for obj in page['Contents']:
                        key = obj['Key']
                        relative_path = os.path.relpath(key, folder)
                        local_file_path = os.path.join(local_dir, relative_path)

                        # Submit download task to the executor
                        executor.submit(download_file, bucket_name, key, local_file_path)
                        pbar.update(1)

In [None]:
def split_dataset(image_dir, label_dir, output_dir, train_ratio, val_ratio, test_ratio):
    # Create output directories
    train_image_dir = os.path.join(output_dir, 'images', 'train')
    train_label_dir = os.path.join(output_dir, 'labels', 'train')
    val_image_dir = os.path.join(output_dir, 'images', 'val')
    val_label_dir = os.path.join(output_dir, 'labels', 'val')
    test_image_dir = os.path.join(output_dir, 'images', 'test')
    test_label_dir = os.path.join(output_dir, 'labels', 'test')

    os.makedirs(train_image_dir, exist_ok=True)
    os.makedirs(train_label_dir, exist_ok=True)
    os.makedirs(val_image_dir, exist_ok=True)
    os.makedirs(val_label_dir, exist_ok=True)
    os.makedirs(test_image_dir, exist_ok=True)
    os.makedirs(test_label_dir, exist_ok=True)

    # Get list of image files
    image_files = [f for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
    random.shuffle(image_files)

    # Split dataset
    train_files, test_files = train_test_split(image_files, test_size=test_ratio, random_state=42)
    train_files, val_files = train_test_split(train_files, test_size=val_ratio/(train_ratio+val_ratio), random_state=42)

    # Function to move files
    def move_files(files, image_dest, label_dest):
        for file in files:
            base_name = os.path.splitext(file)[0]
            image_src = os.path.join(image_dir, file)
            label_src = os.path.join(label_dir, base_name + '.txt')

            shutil.move(image_src, os.path.join(image_dest, file))
            shutil.move(label_src, os.path.join(label_dest, base_name + '.txt'))

    # Move files to respective directories
    move_files(train_files, train_image_dir, train_label_dir)
    move_files(val_files, val_image_dir, val_label_dir)
    move_files(test_files, test_image_dir, test_label_dir)

In [None]:
if os.path.exists(LOCAL_DATASET_PATH):
    shutil.rmtree(LOCAL_DATASET_PATH)
download_folder(AWS_S3_BUCKET, DATASET_PATH, LOCAL_DATASET_PATH)

In [None]:
split_dataset(LOCAL_DATASET_PATH + 'images', LOCAL_DATASET_PATH + 'labels', LOCAL_DATASET_PATH, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)