# Setup

## Import necessary libraries

In [None]:
import importlib
import sys
import os

# Append the parent directory to the path to import the necessary modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import the utilities
from utils import selfutil
from utils import parseutil 

# Now reload the modules to ensure they are up-to-date
importlib.reload(selfutil)
importlib.reload(parseutil)

# Import the functions needed from utils
from utils.selfutil import get_vocabulary
from utils.parseutil import process_spreadsheet

# Other regular imports
import torch
from tqdm import tqdm
import requests
from bs4 import BeautifulSoup
import aiohttp
import asyncio
import shutil
from joblib import Parallel, delayed
import warnings
import time


# Setup device as a global constant
devstr = "cuda:1"  # "cpu" 
gpu = False if (devstr == 'cpu') else True
DEVICE = 'cpu' if (devstr == 'cpu') else (torch.device(devstr if torch.cuda.is_available() else 'cpu') if devstr else torch.cuda.current_device())
print(DEVICE)

# Download Files

In [None]:
# Function to get links from pages and download all XLS files
async def download_datagov_xls(start_page=1, end_page=2, data_dir='../data/train_big/', max_size_mb=2):
    # Check if the directory exists
    if not os.path.exists(data_dir):
        print(f"Directory '{data_dir}' does not exist. Exiting.")
        return

    # Initialize an empty list to accumulate links
    all_links = []
    
    # Loop through each page in the specified range with a progress bar
    for page_number in tqdm(range(start_page, end_page + 1), desc="Getting Links"):
        base_url = f"https://catalog.data.gov/dataset/?res_format=EXCEL&_res_format_limit=0&_bureauCode_limit=0&page={page_number}"
        # Send a request to the URL and parse the HTML content
        soup = BeautifulSoup(requests.get(base_url).content, 'html.parser')
        # Add the found links to the accumulated list
        all_links.extend([link['href'] for link in soup.find_all('a', href=True) if '.xls' in link['href'].lower()])
    
    print(f"Total XLS links found: {len(all_links)}")
    
    # Limit concurrency to 5 simultaneous downloads
    sem = asyncio.Semaphore(5)

    async with aiohttp.ClientSession() as session:
        # Create a tqdm progress bar for downloading
        with tqdm(total=len(all_links), desc="Downloading Files") as pbar:
            # Loop through each link to download
            for url in all_links:
                async with sem:
                    try:
                        # Check file size using HEAD request before downloading
                        try:
                            response = requests.head(url, timeout=1, allow_redirects=True)
                            if response.status_code == 200 and 'Content-Length' in response.headers:
                                file_size_mb = int(response.headers['Content-Length']) / (1024 * 1024)  # Convert to MB
                                if file_size_mb > max_size_mb:
                                    # Skip downloading files larger than max_size_mb
                                    pbar.update(1)
                                    continue
                            elif response.status_code == 403:
                                # If access to headers is restricted, proceed to download anyway
                                pass
                            else:
                                # Skip if unable to get a valid response for size
                                pbar.update(1)
                                continue
                        except requests.exceptions.RequestException as e:
                            # Handle any exceptions from the HEAD request
                            pbar.update(1)
                            continue

                        # Combine directory and filename
                        filename = os.path.join(data_dir, url.split('/')[-1])

                        # Make a request to download the file
                        async with session.get(url, timeout=1, allow_redirects=True) as response:
                            # If the response is successful, write the file
                            if response.status == 200:
                                with open(filename, 'wb') as f:
                                    f.write(await response.read())
                            # Update progress bar regardless of success or failure
                            pbar.update(1)
                    
                    # Handle timeout errors
                    except asyncio.TimeoutError:
                        pbar.update(1)
                    
                    # Handle any other errors
                    except Exception:
                        pbar.update(1)


In [None]:
# Example usage
start_page = 1
end_page = 1
data_dir = '../data/train/'

# Run the combined function
await download_datagov_xls(start_page, end_page, data_dir)

# Validate files in Directory

In [None]:
# Set the directory containing the spreadsheets
data_dir = '../data/train/'

# Get the list of file paths
spreadsheet_vocab,file_paths = get_vocabulary(data_dir)

# Print info
print(f'\n\nVocabulary size: {len(spreadsheet_vocab._word2idx)}')
print(f'Files Processed: {len(file_paths)}')

In [None]:
from joblib import Parallel, delayed
import os
import warnings
import time
import json
from tqdm import tqdm
from utils.parseutil import process_spreadsheet

def validate_dir_parallel(directory, vocab, max_rows=10, max_cols=10, max_size_mb=2, max_processing_time=10):
    """
    Validates the contents of a given directory by ensuring all files have supported extensions (.xls, .xlsx, .csv) and are processable.
    Unsupported files, files larger than 2MB, or those that fail to process are deleted. Deletes files if a specific type of warning occurs.
    Utilizes parallel processing to speed up the validation.

    Args:
        directory (str): The path to the directory to be validated.
        vocab: The vocabulary object for encoding tokens.
        max_rows (int, optional): The maximum number of rows to process. Defaults to 10.
        max_cols (int, optional): The maximum number of columns to process. Defaults to 10.
        max_size_mb (int, optional): The maximum file size in MB to process. Defaults to 2MB.
        max_processing_time (int, optional): The maximum processing time in seconds. Defaults to 10 seconds.

    Returns:
        None
    """
    # List supported file extensions
    supported_extensions = ['.xls', '.xlsx', '.csv']

    # Check if directory exists
    if not os.path.exists(directory):
        return

    # Cache file to store validated file paths
    cache_file_path = os.path.join(directory, "cache.json")
    
    # Load existing cache data or initialize an empty list
    if os.path.exists(cache_file_path):
        with open(cache_file_path, "r") as cache_file:
            validated_files = set(json.load(cache_file))
    else:
        validated_files = set()

    # Gather all files in the directory
    file_list = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    total_files = len(file_list)

    # Filter out files that are already validated
    files_to_validate = [file for file in file_list if file not in validated_files]
    deleted_files = 0

    # Function to validate a single file
    def validate_file(file_path):
        nonlocal deleted_files

        # Skip cache.json file
        if os.path.basename(file_path) == "cache.json":
            return None

        # Get the file extension
        file_extension = os.path.splitext(file_path)[1].lower()

        # Delete file if it does not have a supported extension
        if file_extension not in supported_extensions:
            os.remove(file_path)
            deleted_files += 1
            return None

        # Check the file size and delete if greater than max_size_mb
        file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
        if file_size_mb > max_size_mb:
            os.remove(file_path)
            deleted_files += 1
            return None

        # Process the spreadsheet and delete if None is returned or takes too long
        try:
            # Start timing the processing
            start_time = time.time()

            # Suppress the warning and catch it as an exception if it occurs
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")  # Trigger all warnings

                # Process the spreadsheet
                x_tok, y_tok = process_spreadsheet(file_path, vocab=vocab, max_rows=max_rows, max_cols=max_cols)

                # Check if processing was successful
                if x_tok is None or y_tok is None:
                    raise ValueError("Processing returned None")

                # If warnings were captured, identify and delete the file
                for warning in w:
                    if issubclass(warning.category, UserWarning):
                        os.remove(file_path)
                        deleted_files += 1
                        return None

            # Measure the processing time
            processing_time = time.time() - start_time

            # If processing time exceeds max_processing_time, delete the file
            if processing_time > max_processing_time:
                os.remove(file_path)
                deleted_files += 1
                return None

            # Mark file as validated if all checks pass
            return file_path

        except Exception:
            # Handle any other exceptions that occur by deleting the file
            os.remove(file_path)
            deleted_files += 1
            return None

    # Run validation in parallel using joblib
    results = Parallel(n_jobs=int(os.cpu_count() / 2))(
        delayed(validate_file)(file_path) for file_path in tqdm(files_to_validate, desc="Validating Files")
    )

    # Filter out None results (which indicate deleted files) and add successfully validated files to the cache
    validated_files.update([result for result in results if result is not None])

    # Save the updated cache data back to cache.json
    with open(cache_file_path, "w") as cache_file:
        json.dump(list(validated_files), cache_file)

    remaining_files = total_files - deleted_files
    print(f"Total files: {total_files}, Deleted files: {deleted_files}, Remaining files: {remaining_files}")


In [None]:
# Example usage
directory = '../data/train/'  # Make sure this path exists and contains the files

# Assuming spreadsheet_vocab is already defined in your environment
vocab = spreadsheet_vocab

# Run the parallel validation function
validate_dir_parallel(directory, vocab)