In [None]:
import os
import subprocess
import gzip
import shutil
from typing import Optional, Tuple
import urllib.request
import logging
import glob
import time
import signal
from datetime import datetime

class SalmonIndexBuilder:
    def __init__(self, work_dir: str, gencode_version: str = "v38", genome: str = "human"):
        self.work_dir = work_dir
        self.gencode_version = gencode_version
        self.genome = genome.lower()
        
        os.makedirs(work_dir, exist_ok=True)
        
        # Set up logging with timestamps and file output
        log_file = os.path.join(work_dir, f"salmon_index_build_{datetime.now():%Y%m%d_%H%M%S}.log")
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)
        
        self.fasta_gz = os.path.join(work_dir, f"gencode.{gencode_version}.transcripts.fa.gz")
        self.fasta = self.fasta_gz.replace('.gz', '')
        self.index_dir = os.path.join(work_dir, f"salmon_index_{gencode_version}")

    def check_disk_space(self, required_gb: int = 50) -> None:
        """Check if there's enough disk space available."""
        stats = os.statvfs(self.work_dir)
        available_gb = (stats.f_bavail * stats.f_frsize) / (1024**3)
        
        if available_gb < required_gb:
            raise RuntimeError(
                f"Insufficient disk space. Available: {available_gb:.1f}GB, Required: {required_gb}GB"
            )

    def download_transcriptome(self) -> None:
        if os.path.exists(self.fasta_gz):
            self.logger.info(f"Transcriptome file already exists: {self.fasta_gz}")
            return

        base_url = "https://ftp.ebi.ac.uk/pub/databases/gencode"
        species = "Gencode_human" if self.genome == "human" else "Gencode_mouse"
        filename = f"gencode.{self.gencode_version}.transcripts.fa.gz"
        url = f"{base_url}/{species}/release_{self.gencode_version.replace('v', '')}/{filename}"

        self.logger.info(f"Downloading transcriptome from: {url}")
        try:
            self.check_disk_space()
            urllib.request.urlretrieve(url, self.fasta_gz)
            self.logger.info("Download completed successfully")
            
            file_size = os.path.getsize(self.fasta_gz)
            self.logger.info(f"Downloaded file size: {file_size/1024/1024:.2f} MB")
            
            if file_size < 1000000:  # Less than 1MB
                raise ValueError(f"Downloaded file is too small ({file_size} bytes)")
                
        except Exception as e:
            if os.path.exists(self.fasta_gz):
                os.remove(self.fasta_gz)
            self.logger.error(f"Failed to download transcriptome: {str(e)}")
            raise

    def decompress_fasta(self) -> None:
        if os.path.exists(self.fasta):
            self.logger.info(f"Decompressed FASTA already exists: {self.fasta}")
            return

        self.logger.info("Decompressing FASTA file...")
        try:
            self.check_disk_space()
            with gzip.open(self.fasta_gz, 'rb') as f_in:
                with open(self.fasta, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            
            compressed_size = os.path.getsize(self.fasta_gz)
            decompressed_size = os.path.getsize(self.fasta)
            self.logger.info(f"Compressed size: {compressed_size/1024/1024:.2f} MB")
            self.logger.info(f"Decompressed size: {decompressed_size/1024/1024:.2f} MB")
            
            if not decompressed_size > compressed_size:
                raise ValueError("Decompressed file is smaller than compressed file")
                
            self.logger.info("Decompression completed successfully")
        except Exception as e:
            if os.path.exists(self.fasta):
                os.remove(self.fasta)
            self.logger.error(f"Failed to decompress FASTA: {str(e)}")
            raise

    def build_salmon_index(self, threads: int = 16, kmer_length: int = 31, timeout: int = 7200) -> None:
        if os.path.exists(self.index_dir):
            self.logger.info(f"Removing existing index directory: {self.index_dir}")
            shutil.rmtree(self.index_dir)
        
        os.makedirs(self.index_dir, exist_ok=True)

        # Add --keepDuplicates flag to handle short sequences better
        cmd = [
            "salmon", "index",
            "-t", self.fasta,
            "-i", self.index_dir,
            "-p", str(threads),
            "--gencode",
            "-k", str(kmer_length),
            "--keepDuplicates"  # Added this flag
        ]

        self.logger.info("Building Salmon index...")
        self.logger.info(f"Command: {' '.join(cmd)}")
        
        try:
            self.check_disk_space()
            start_time = time.time()
            
            # Create a file to collect warnings
            warning_file = os.path.join(self.work_dir, "salmon_index_warnings.log")
            
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                bufsize=1
            )
            
            def is_warning(line: str) -> bool:
                return (
                    "warning" in line.lower() or
                    "had length less than equal to" in line or
                    "Version Server Response" in line
                )

            def is_progress(line: str) -> bool:
                return any(x in line for x in [
                    "Round", "Pass", "junctions count", "Hash table", 
                    "Reallocating", "Threads", "Vertex length"
                ])

            with open(warning_file, 'w') as warn_f:
                while True:
                    if time.time() - start_time > timeout:
                        process.kill()
                        raise TimeoutError(f"Index building timed out after {timeout} seconds")
                    
                    # Check if process has ended
                    return_code = process.poll()
                    if return_code is not None:
                        remaining_stdout, remaining_stderr = process.communicate()
                        if remaining_stdout:
                            self.logger.info(remaining_stdout.strip())
                        if remaining_stderr:
                            for line in remaining_stderr.splitlines():
                                if is_warning(line):
                                    warn_f.write(f"{line}\n")
                                elif not is_progress(line):
                                    self.logger.error(line.strip())
                        break
                        
                    # Read output
                    stdout_line = process.stdout.readline()
                    stderr_line = process.stderr.readline()
                    
                    if stdout_line:
                        if is_progress(stdout_line):
                            self.logger.info(stdout_line.strip())
                    
                    if stderr_line:
                        if is_warning(stderr_line):
                            warn_f.write(f"{stderr_line}\n")
                        elif is_progress(stderr_line):
                            self.logger.info(stderr_line.strip())
                        else:
                            self.logger.error(stderr_line.strip())
                    
                    time.sleep(0.1)
            
            build_time = time.time() - start_time
            
            if return_code != 0:
                raise subprocess.CalledProcessError(
                    return_code, 
                    cmd, 
                    output=None,
                    stderr="Salmon index building failed"
                )
            
            # Count warnings
            with open(warning_file) as f:
                warning_count = sum(1 for _ in f)
            
            self.logger.info(f"Salmon index built successfully in {build_time:.1f} seconds")
            self.logger.info(f"Found {warning_count} warnings (saved to {warning_file})")
            
        except Exception as e:
            self.logger.error(f"Error during index building: {str(e)}")
            if os.path.exists(self.index_dir):
                shutil.rmtree(self.index_dir)
            raise
        finally:
            if 'process' in locals() and process.poll() is None:
                process.kill()

        # Validate immediately after building
        if not self.validate_index():
            raise RuntimeError("Index validation failed after building")

    def validate_index(self) -> bool:
        """Validate the created Salmon index with updated file checks."""
        self.logger.info("Validating Salmon index...")
        
        if not os.path.exists(self.index_dir):
            self.logger.error("Index directory does not exist")
            return False

        required_files = [
            'complete_ref_lens.bin',
            'ctable.bin',
            'ctg_offsets.bin',
            'duplicate_clusters.tsv',
            'info.json',
            'mphf.bin',
            'pos.bin',
            'pre_indexing.log',
            'rank.bin',
            'refseq.bin',
            'seq.bin',
            'versionInfo.json'
        ]
        
        existing_files = os.listdir(self.index_dir)
        self.logger.info(f"Files in index directory: {existing_files}")
        
        missing_files = [f for f in required_files 
                        if not os.path.exists(os.path.join(self.index_dir, f))]
        
        if missing_files:
            self.logger.error(f"Missing index files: {missing_files}")
            return False
        
        # Check file sizes
        total_size = sum(os.path.getsize(os.path.join(self.index_dir, f)) for f in existing_files)
        self.logger.info(f"Total index size: {total_size/1024/1024:.2f} MB")
        
        if total_size < 1000000:  # Less than 1MB
            self.logger.error("Index files are suspiciously small")
            return False
            
        self.logger.info("Index validation successful")
        return True

    def cleanup(self) -> None:
        """Clean up intermediate files."""
        self.logger.info("Cleaning up intermediate files...")
        
        if os.path.exists(self.fasta):
            os.remove(self.fasta)
            self.logger.info(f"Removed decompressed FASTA: {self.fasta}")

if __name__ == "__main__":
    # Define parameters
    work_dir = "/beegfs/scratch/ric.broccoli/kubacki.michal/SRF_Snords/salmon_index"
    gencode_version = "v38"
    genome = "human"
    threads = 16
    
    builder = SalmonIndexBuilder(work_dir, gencode_version, genome)
    
    try:
        # Verify Salmon installation first
        try:
            result = subprocess.run(
                ["salmon", "--version"],
                capture_output=True,
                text=True,
                check=True
            )
            builder.logger.info(f"Salmon version: {result.stdout.strip()}")
        except subprocess.CalledProcessError:
            raise RuntimeError("Salmon is not properly installed or not in PATH")
        
        # Execute pipeline
        builder.download_transcriptome()
        builder.decompress_fasta()
        builder.build_salmon_index(threads=threads, timeout=7200)  # 2 hour timeout
        
        print(f"\nSalmon index created successfully at: {builder.index_dir}")
        
    except Exception as e:
        print(f"\nError creating Salmon index: {str(e)}")
        raise
    finally:
        builder.cleanup()