In [None]:
from google.colab import files
import zipfile
import os
import gzip
import shutil

def process_protein_zips(input_zip, output_faa='combined_proteins.faa', debug=False):

    """
    Unpacks a ZIP archive containing multiple nested ZIP files of protein sequences (in .faa or .faa.gz format),
    recursively extracts all .faa/.faa.gz files, combines them into a single output FASTA file, and reports issues.

    This function is useful for bulk-processing protein sequence datasets distributed in complex, nested ZIP archives
    (e.g., multiple species/proteins in separate compressed files). It tracks which ZIP files produce valid protein
    sequences and flags any ZIPs that fail to contain .faa files.

    Workflow:
    1. Extract the main ZIP archive into a directory.
    2. Recursively extract all nested ZIP files found within the main directory.
    3. Locate all .faa and .faa.gz files and map them back to their parent ZIPs (for reporting).
    4. Warn the user about any ZIP files that do not produce .faa files, along with their contents.
    5. Concatenate all .faa and .faa.gz files into a single combined FASTA file.
    6. Return the path to the combined FASTA file and the total number of protein sequences.

    Parameters:
    - input_zip (str): Path to the input ZIP archive containing nested ZIPs with .faa/.faa.gz files.
    - output_faa (str): Filename for the combined output FASTA file. Default is 'combined_proteins.faa'.
    - debug (bool): If True, prints debug information about file discovery, extraction, and sequence counts.

    Returns:
    - output_path (str): Path to the final combined FASTA file.
    - seq_count (int): Total number of protein sequences in the combined file.

    Example:
    >>> process_protein_zips("my_protein_dataset.zip", output_faa="all_proteins.faa", debug=True)
    """

    initial_dir = os.getcwd()

    # 1. Extract the main ZIP file
    base_dir = os.path.splitext(input_zip)[0]
    if os.path.exists(base_dir):
        shutil.rmtree(base_dir)
    with zipfile.ZipFile(input_zip, 'r') as zip_ref:
        zip_ref.extractall(base_dir)

    os.chdir(base_dir)

    # Store all original zip names for verification
    original_zips = set()
    for root, _, files in os.walk('.'):
        for file in files:
            if file.endswith('.zip'):
                original_zips.add(os.path.join(root, file))

    if debug:
        print(f"[DEBUG] Found {len(original_zips)} initial ZIP files")

    # 2. Recursively extract ALL nested ZIPs
    def extract_all_zips(directory='.'):
        for root, _, files in os.walk(directory):
            for file in files:
                if file.endswith('.zip'):
                    zip_path = os.path.join(root, file)
                    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                        extract_dir = zip_path[:-4]
                        zip_ref.extractall(extract_dir)
                    os.remove(zip_path)
                    if debug:
                        print(f"[DEBUG] Extracted nested ZIP: {zip_path}")
                    extract_all_zips(extract_dir)

    extract_all_zips()

    # 3. Find all .faa/.faa.gz files and track their origins
    faa_files = []
    zip_to_faa = {}  # Track which zips produced which faa files

    for root, _, files in os.walk('.'):
        for file in files:
            if file.endswith(('.faa', '.faa.gz')):
                full_path = os.path.join(root, file)
                faa_files.append(full_path)

                # Try to determine which zip this came from
                parent_dir = root
                while parent_dir != '.':
                    if parent_dir + '.zip' in original_zips:
                        if parent_dir + '.zip' not in zip_to_faa:
                            zip_to_faa[parent_dir + '.zip'] = []
                        zip_to_faa[parent_dir + '.zip'].append(full_path)
                        break
                    parent_dir = os.path.dirname(parent_dir)

    # 4. Identify problematic/missing files
    if len(original_zips) > len(zip_to_faa):
        print("\n⚠️ WARNING: Some ZIP files didn't produce .faa files:")
        for zip_file in original_zips:
            if zip_file not in zip_to_faa:
                print(f" - {os.path.basename(zip_file)}")
                # Try to see what's actually in this zip
                with zipfile.ZipFile(zip_file, 'r') as z:
                    print(f"   Contents: {z.namelist()}")

    # 5. Combine the files
    faa_files.sort()
    if debug:
        print(f"[DEBUG] Found {len(faa_files)} .faa files to combine.")

    with open(output_faa, 'w') as outfile:
        for faa in faa_files:
            try:
                if faa.endswith('.gz'):
                    with gzip.open(faa, 'rt') as infile:
                        content = infile.read()
                else:
                    with open(faa) as infile:
                        content = infile.read()
                outfile.write(content)
            except Exception as e:
                print(f"[ERROR] Could not read {faa}: {e}")

    # 6. Count sequences
    seq_count = 0
    with open(output_faa, 'r') as f:
        for line in f:
            if line.startswith('>'):
                seq_count += 1

    os.chdir(initial_dir)
    output_path = os.path.join(base_dir, output_faa)

    if debug:
        print(f"[DEBUG] Final combined file at: {output_path}")
        print(f"[DEBUG] Total sequences counted: {seq_count}")

    return output_path, seq_count