In [None]:
from collections import defaultdict, Counter
from Bio import SeqIO

# Parameters
k = 5       # k-mer length
m = 20      # min diagonal score
g = 3       # max gap to join diagonals
threshold_percent = 0.5
kmer_frequency_threshold = 2  # New parameter for filtering frequent k-mers


## Make a dictionary with all the possible k-mers of the query sequence and their positions in it (Enhanced with frequency filtering option for k-mers)

def kmers_query(query, k, use_frequent_only=False, kmer_frequency_threshold):

    kmer_index = defaultdict(list)
    
    for i in range(len(query) - k + 1): ## extract all possible k-mers    
        kmer = query[i:i + k] ## Extracts the k-mer starting at position i
        kmer_index[kmer].append(i)
    
    ## use the most common/frequent k-mers instead of the full k-mers list to ake the algorithm faster

    if use_frequent_only:
        
        kmer_counts = Counter(query[i:i+k] for i in range(len(query) - k + 1)) ## count k-mer frequencies
        frequent_kmers = {kmer for kmer, count in kmer_counts.items() if count >= kmer_frequency_threshold} ## keep only k-mers that appear at least frequency_threshold times in the query sequence
        
        filtered_index = defaultdict(list)
        for kmer, positions in kmer_index.items():
            if kmer in frequent_kmers:
                filtered_index[kmer] = positions
        
        print(f"Frequency filtering: {len(kmer_index)} k-mers to {len(filtered_index)} k-mers")
        return filtered_index
    
    return kmer_index



## Make a dictionary with the diagonals as keys and the number of k-mer matches on each particular diagonal as values
## When a k-mers from the target sequence matches a k-mer from the query sequence, this match implies a potential alignment between the two sequences
## At this point, the FastA algorithm treats all k-mer matches that fall on the same diagonal as evidence for a single, potentially continuous, ungapped alignment.

def score_diagonals(target, kmer_index, k):

    S = defaultdict(int)
    matches_found = 0
    
    for j in range(len(target) - k + 1):
        kmer = target[j:j + k]
        if kmer in kmer_index: ## checks if the k-mer from the target sequence exists in the query sequence
            for i in kmer_index[kmer]: ## iterates through all the starting positions (i) where that k-mer appeared in the query sequence
                diag = j - i  ## represents the offset between the start position of the k-mer in the target sequence (j) and the start position in the query sequence (i)
                S[diag] += 1 
                matches_found += 1
    
    return S, matches_found





## Filter and join diagonals that are close to each other (within the g gap parameter). With this process a larger, more comprehensive alignment region is built

def filter_and_join_diagonals(S, m, g, verbose=False):

    significant_diagonals = sorted([d for d in S if S[d] >= m]) ## keeps only diagonals with a score greater than or equal to m and sorts them in ascending order
    
    if not significant_diagonals:
        return ('No match')
    
    regions = []
    
    current_region = [significant_diagonals[0]] ## start with the first diagonal in the sorted list
    
    for d in significant_diagonals[1:]: ## iterates through the sorted diagonals, starting from the second one
        if d - current_region[-1] <= g:  ## checks if the diagonal d has a gap smaller than 3 with the last diagonal in current_region
            current_region.append(d)

        else: ## if the gap is bigger than 3, it means that the current diagonal starts a new potential region

            ## score current region and start new one
            region_score = sum(S[x] for x in current_region) ## calculates the total score for the current region by summing the scores of all diagonals in current_region
            regions.append(region_score)
            current_region = [d]
    
    ## add the last region
    region_score = sum(S[x] for x in current_region)
    regions.append(region_score)
    
    return max(regions) if regions else 'No regions found' ## returns the maximum score among all regions found, or a message if no regions were found






# Step 4: Enhanced search function with detailed reporting

def search_fasta(query, fasta_file, use_frequent_kmers=False, verbose=False):

    print(f"Starting FastA search...")
    print(f"Query length: {len(query)}")
    print(f"Parameters: k={k}, m={m}, g={g}, threshold={threshold_percent}")
    
    # Build k-mer index
    kmer_index = kmers_query(query, k, use_frequent_kmers, kmer_frequency_threshold)
    print(f"Built k-mer index with {len(kmer_index)} unique k-mers")
    
    results = []
    required_score = int(len(query) * threshold_percent)
    print(f"Required score threshold: {required_score}")
    
    sequences_processed = 0
    matches_above_threshold = 0
    
    # Process each sequence in FASTA file
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequences_processed += 1
        gene_id = record.id
        sequence = str(record.seq)
        
        # Score diagonals
        S, total_matches = score_diagonals(sequence, kmer_index, k)
        
        # Filter and join diagonals
        if verbose:
            total_score, region_details = filter_and_join_diagonals(S, m, g, verbose=True)
        else:
            total_score = filter_and_join_diagonals(S, m, g)
        
        # Check if score meets threshold
        if total_score >= required_score:
            matches_above_threshold += 1
            result_entry = {
                'gene_id': gene_id,
                'total_score': total_score,
                'sequence_length': len(sequence),
                'total_kmer_matches': total_matches,
                'score_ratio': total_score / len(query) if len(query) > 0 else 0
            }
            
            if verbose:
                result_entry['region_details'] = region_details
            
            results.append(result_entry)
    
    # Sort by total_score (descending)
    results.sort(key=lambda x: x['total_score'], reverse=True)
    
    print(f"\nSearch completed:")
    print(f"Sequences processed: {sequences_processed}")
    print(f"Matches above threshold: {matches_above_threshold}")
    
    return results

# Function to display results nicely
def display_results(results, top_n=10):
    """Display search results in a formatted way."""
    if not results:
        print("No matches found above threshold.")
        return
    
    print(f"\nTop {min(top_n, len(results))} matches:")
    print("-" * 80)
    print(f"{'Rank':<4} {'Gene ID':<20} {'Score':<8} {'Seq Len':<8} {'Score Ratio':<12}")
    print("-" * 80)
    
    for i, result in enumerate(results[:top_n], 1):
        print(f"{i:<4} {result['gene_id']:<20} {result['total_score']:<8} "
              f"{result['sequence_length']:<8} {result['score_ratio']:<12.3f}")
    
    if results:
        print(f"\nHighest matching sequence:")
        best_match = results[0]
        print(f"Gene ID: {best_match['gene_id']}")
        print(f"Total Score: {best_match['total_score']}")
        print(f"Sequence Length: {best_match['sequence_length']}")
        print(f"Score Ratio: {best_match['score_ratio']:.3f}")

# Alternative k-mer filtering approach - by removing very common k-mers
def build_kmer_index_filtered_common(query, k, max_frequency_ratio=0.1):
    """
    Alternative filtering approach: remove k-mers that are too common (likely low complexity).
    max_frequency_ratio: remove k-mers that appear in more than this fraction of all positions.
    """
    kmer_index = defaultdict(list)
    
    # Build complete index first
    for i in range(len(query) - k + 1):
        kmer = query[i:i + k]
        kmer_index[kmer].append(i)
    
    # Filter out overly common k-mers
    max_frequency = int((len(query) - k + 1) * max_frequency_ratio)
    filtered_index = defaultdict(list)
    
    for kmer, positions in kmer_index.items():
        if len(positions) <= max_frequency:
            filtered_index[kmer] = positions
    
    print(f"Filtered common k-mers: {len(kmer_index)} -> {len(filtered_index)} k-mers")
    return filtered_index

# Example usage function
def run_fasta_analysis(query_sequence, fasta_filename):
    """
    Complete FastA analysis with different filtering approaches.
    """
    print("="*60)
    print("FastA Algorithm Analysis")
    print("="*60)
    
    # Standard search
    print("\n1. Standard FastA search (all k-mers):")
    results_standard = search_fasta(query_sequence, fasta_filename, use_frequent_kmers=False)
    display_results(results_standard, top_n=5)
    
    # Search with frequent k-mers only
    print(f"\n2. FastA search (frequent k-mers only, threshold >= {kmer_frequency_threshold}):")
    results_frequent = search_fasta(query_sequence, fasta_filename, use_frequent_kmers=True)
    display_results(results_frequent, top_n=5)
    
    return results_standard, results_frequent

# Performance comparison function
def compare_k_values(query_sequence, fasta_filename, k_values=[3, 5, 7, 9]):
    """
    Compare FastA performance with different k values.
    """
    global k  # Modify global k parameter
    original_k = k
    
    print("\n" + "="*60)
    print("K-value Comparison Analysis")
    print("="*60)
    
    comparison_results = {}
    
    for k_val in k_values:
        k = k_val
        print(f"\nTesting with k = {k_val}:")
        results = search_fasta(query_sequence, fasta_filename, use_frequent_kmers=False)
        comparison_results[k_val] = len(results)
        
        if results:
            print(f"Best match: {results[0]['gene_id']} (score: {results[0]['total_score']})")
        else:
            print("No matches found")
    
    # Restore original k value
    k = original_k
    
    print(f"\nSummary - Number of matches found:")
    for k_val, num_matches in comparison_results.items():
        print(f"k = {k_val}: {num_matches} matches")
    
    return comparison_results