In [None]:
import boto3
import gzip
import os
import awswrangler as wr
from tqdm import tqdm
from botocore.exceptions import ClientError

class ProgressBar(tqdm):
    """Custom progress bar for S3 operations."""

    def update_to(self, bytes_transferred):
        self.update(bytes_transferred - self.n)


def download_file_with_progress(s3_client, bucket_name, key, local_file):
    """
    Downloads a file from S3 with a progress bar.

    Args:
        s3_client: The boto3 S3 client.
        bucket_name (str): The name of the S3 bucket.
        key (str): The key of the file in the S3 bucket.
        local_file (str): The local file path to save the downloaded file.

    Returns:
        None
    """
    file_size = s3_client.head_object(Bucket=bucket_name, Key=key)['ContentLength']
    with ProgressBar(total=file_size, unit="B", unit_scale=True, desc="Downloading") as progress_bar:
        s3_client.download_file(
            Bucket=bucket_name,
            Key=key,
            Filename=local_file,
            Callback=progress_bar.update_to
        )


def download_extract_save_to_athena(counter, bucket_name, gz_key, s3_output_prefix, glue_database, glue_table, aws_region="us-east-1"):
    """
    Downloads a .gz file from S3, extracts its contents, and saves it to S3 as a dataset using awswrangler.

    Args:
        bucket_name (str): The name of the S3 bucket.
        gz_key (str): The key of the .gz file in the S3 bucket.
        s3_output_prefix (str): The S3 prefix where the dataset will be saved.
        glue_database (str): The Glue database name.
        glue_table (str): The Glue table name.
        aws_region (str): The AWS region of the S3 bucket.

    Returns:
        None
    """
    s3_client = boto3.client("s3", region_name=aws_region)
    extracted_key = gz_key.replace(".gz", "")
    local_extracted_file = extracted_key.split('/')[-1]
    local_gz_file = f'{local_extracted_file}.gz'

    try:
        # Step 1: Download the .gz file from S3
        print(f"Downloading {gz_key} from S3 bucket {bucket_name}...")
        download_file_with_progress(s3_client, bucket_name, gz_key, local_gz_file)
        print(f"Downloaded {gz_key} to {local_gz_file}")

        # Step 2: Extract the .gz file
        print(f"Extracting {local_gz_file}...")
        with gzip.open(local_gz_file, 'rt') as gz_file:
            gz_file_size = os.path.getsize(local_gz_file)
            with tqdm(total=gz_file_size, unit='B', unit_scale=True, desc=f'Extracting {local_gz_file}') as progress_bar:
                with open(local_extracted_file, 'w') as extracted_file:
                    while chunk := gz_file.read(8192):
                        extracted_file.write(chunk)
                        progress_bar.update(len(chunk))
        print(f'Extracted content saved to {local_extracted_file}')

        # Step 3: Save the extracted file to S3 as a dataset and register it in Athena
        print(f"Saving {local_extracted_file} to S3 as a dataset and registering it in Athena...")
        s3_output_path = f"s3://{bucket_name}/{s3_output_prefix}"
    
        # Read the JSONL file into a Pandas DataFrame
        df = wr.s3.read_json(local_extracted_file, lines=True)
        mode = 'append'
        if counter == 0:
            mode = 'overwrite'
    
        # Save the DataFrame to S3 as a dataset
        wr.s3.to_json(
            df=df,
            path=s3_output_path,
            dataset=True,
            database=glue_database,
            table=glue_table,
            mode=mode
        )
        print(f"Dataset saved to {s3_output_path} and registered in Athena (Glue database: {glue_database}, table: {glue_table}).")

    except ClientError as e:
        print(f"Error: {e}")
#     finally:
#         # Clean up local temporary files
#         if os.path.exists(local_gz_file):
#             os.remove(local_gz_file)
#         if os.path.exists(local_extracted_file):
#             os.remove(local_extracted_file)


def extract_all_files(bucket_name, prefix, s3_output_prefix, glue_database, glue_table, aws_region="us-east-1", min_index=0, max_index=100000):
    """
    Processes all .gz files in an S3 bucket with a given prefix.

    Args:
        bucket_name (str): The name of the S3 bucket.
        prefix (str): The prefix of the .gz files in the S3 bucket.
        s3_output_prefix (str): The S3 prefix where the dataset will be saved.
        glue_database (str): The Glue database name.
        glue_table (str): The Glue table name.
        aws_region (str): The AWS region of the S3 bucket.

    Returns:
        None
    """
    s3_client = boto3.client("s3", region_name=aws_region)

    try:
        # List all .gz files in the bucket with the given prefix
        print(f"Listing .gz files in bucket {bucket_name} with prefix {prefix}...")
        response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
        if "Contents" not in response:
            print("No files found.")
            return

        counter = 0
        for obj in response["Contents"]:
            if counter >= min_index and counter < max_index:
                gz_key = obj["Key"]
                if gz_key.endswith(".gz"):
                    print(f"Processing file: {gz_key}")
                    download_extract_save_to_athena(counter, bucket_name, gz_key, s3_output_prefix, glue_database, glue_table, aws_region)
            counter += 1

    except ClientError as e:
        print(f"Error listing files in S3: {e}")


# Example usage
bucket_name = "steve-sagemaker-data-bucket"
prefix = "00_raw/papers/"  # Prefix for the .gz files
s3_output_prefix = "01_extracted/papers/"  # S3 prefix for the output dataset
glue_database = "s2"  # Glue database name
glue_table = "papers"  # Glue table name
aws_region = "eu-west-2"  # AWS region of the bucket

extract_all_files(bucket_name, prefix, s3_output_prefix, glue_database, glue_table, aws_region, 0, 1)