In [None]:
import re
import matplotlib.pyplot as plt
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from collections import Counter


def read_fasta(fasta_file):
    records = list(SeqIO.parse(fasta_file, "fasta"))
    return records


def global_alignment(sequence1, sequence2, match_score=1, mismatch_penalty=-1, gap_penalty=-2):
    m = len(sequence1)
    n = len(sequence2)

    alignment_matrix = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        alignment_matrix[i][0] = i * gap_penalty
    for j in range(1, n + 1):
        alignment_matrix[0][j] = j * gap_penalty

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            match = alignment_matrix[i - 1][j - 1] + (match_score if sequence1[i - 1] == sequence2[j - 1] else mismatch_penalty)
            delete = alignment_matrix[i - 1][j] + gap_penalty
            insert = alignment_matrix[i][j - 1] + gap_penalty
            alignment_matrix[i][j] = max(match, delete, insert)

    aligned_sequence1 = ""
    aligned_sequence2 = ""
    i, j = m, n
    while i > 0 and j > 0:
        if alignment_matrix[i][j] == alignment_matrix[i - 1][j - 1] + (match_score if sequence1[i - 1] == sequence2[j - 1] else mismatch_penalty):
            aligned_sequence1 = sequence1[i - 1] + aligned_sequence1
            aligned_sequence2 = sequence2[j - 1] + aligned_sequence2
            i -= 1
            j -= 1
        elif alignment_matrix[i][j] == alignment_matrix[i - 1][j] + gap_penalty:
            aligned_sequence1 = sequence1[i - 1] + aligned_sequence1
            aligned_sequence2 = '-' + aligned_sequence2
            i -= 1
        else:
            aligned_sequence1 = '-' + aligned_sequence1
            aligned_sequence2 = sequence2[j - 1] + aligned_sequence2
            j -= 1

    while i > 0:
        aligned_sequence1 = sequence1[i - 1] + aligned_sequence1
        aligned_sequence2 = '-' + aligned_sequence2
        i -= 1

    while j > 0:
        aligned_sequence1 = '-' + aligned_sequence1
        aligned_sequence2 = sequence2[j - 1] + aligned_sequence2
        j -= 1

    return aligned_sequence1, aligned_sequence2


# def hirschberg(sequence1, sequence2, match_score=1, mismatch_penalty=-1, gap_penalty=-2):
    
#     def needleman_wunsch(seq1, seq2):
#         m, n = len(seq1), len(seq2)
#         matrix = [[0] * (n + 1) for _ in range(2)]
        
#         for j in range(1, n + 1):
#             matrix[0][j] = j * gap_penalty
        
#         for i in range(1, m + 1):
#             matrix[1][0] = i * gap_penalty
#             for j in range(1, n + 1):
#                 match = matrix[0][j - 1] + (match_score if seq1[i - 1] == seq2[j - 1] else mismatch_penalty)
#                 delete = matrix[0][j] + gap_penalty
#                 insert = matrix[1][j - 1] + gap_penalty
#                 matrix[1][j] = max(match, delete, insert)
#             matrix[0], matrix[1] = matrix[1], [0] * (n + 1)
        
#         return matrix[0]
    
#     def traceback(seq1, seq2, matrix):
#         i, j = len(seq1), len(seq2)
#         aligned_seq1, aligned_seq2 = "", ""
        
#         while i > 0 or j > 0:
#             if i > 0 and j > 0 and matrix[i][j] == matrix[i - 1][j - 1] + (match_score if seq1[i - 1] == seq2[j - 1] else mismatch_penalty):
#                 aligned_seq1 = seq1[i - 1] + aligned_seq1
#                 aligned_seq2 = seq2[j - 1] + aligned_seq2
#                 i -= 1
#                 j -= 1
#             elif i > 0 and matrix[i][j] == matrix[i - 1][j] + gap_penalty:
#                 aligned_seq1 = seq1[i - 1] + aligned_seq1
#                 aligned_seq2 = "-" + aligned_seq2
#                 i -= 1
#             else:
#                 aligned_seq1 = "-" + aligned_seq1
#                 aligned_seq2 = seq2[j - 1] + aligned_seq2
#                 j -= 1
        
#         return aligned_seq1, aligned_seq2
    
#     m, n = len(sequence1), len(sequence2)
#     if m == 0:
#         return "-" * n, sequence2
#     elif n == 0:
#         return sequence1, "-" * m
    
#     if m == 1 or n == 1:
#         return needleman_wunsch(sequence1, sequence2), sequence2
    
#     mid = n // 2
#     score_l = needleman_wunsch(sequence1, sequence2[:mid])
#     score_r = needleman_wunsch(sequence1[::-1], sequence2[::-1][mid:])[::-1]

   
#     score = [score_l[i] + score_r[i] for i in range(m + 1)]
    
#     max_score_index = score.index(max(score))
#     aligned_seq1_l, aligned_seq2_l = hirschberg(sequence1[:max_score_index], sequence2[:mid])
#     aligned_seq1_r, aligned_seq2_r = hirschberg(sequence1[max_score_index:], sequence2[mid:])
    
#     return aligned_seq1_l + aligned_seq1_r, aligned_seq2_l + aligned_seq2_r


def find_snvs(ref_seq, query_seq):
    snvs = []
    for i in range(len(ref_seq)):
        if ref_seq[i] != query_seq[i]:
            snvs.append((i + 1, ref_seq[i], query_seq[i]))
    return snvs

def find_indels(ref_seq, query_seq):
    indels = []
    i = j = 0
    while i < len(ref_seq) and j < len(query_seq):
        if ref_seq[i] == query_seq[j]:
            i += 1
            j += 1
        elif ref_seq[i] != '-' and query_seq[j] != '-':
            indels.append((i + 1, 'D' if len(ref_seq[i]) > len(query_seq[j]) else 'I', ref_seq[i] if len(ref_seq[i]) > len(query_seq[j]) else query_seq[j]))
            if len(ref_seq[i]) > len(query_seq[j]):
                i += 1
            else:
                j += 1
        elif ref_seq[i] == '-':
            j += 1
        else:
            i += 1
    return indels

def plot_snv_distribution(snvs):
   
    positions, counts = zip(*Counter(snv[1] for snv in snvs).items())
    plt.figure(figsize=(10, 6))
    plt.bar(positions, counts, color='blue')
    plt.xlabel('Position')
    plt.ylabel('Frequency')
    plt.title('SNV Distribution')
    plt.show()


def plot_indel_distribution(indels):
    if not indels:
        print("No indels to plot.")
        return
    
    positions, counts = zip(*Counter(indel[1] for indel in indels).items())
    plt.figure(figsize=(10, 6))
    plt.bar(positions, counts, color='green')
    plt.xlabel('Position')
    plt.ylabel('Frequency')
    plt.title('Indel Distribution')
    plt.show()


def calculate_disease_probability(snvs, indels, ref_fasta, query_fasta):
    
    ref_records = read_fasta(ref_fasta)
    ref_seqs = {rec.id: str(rec.seq) for rec in ref_records}

    query_records = read_fasta(query_fasta)
    query_seqs = {rec.id: str(rec.seq) for rec in query_records}

    disease_associated_positions = set()
    for snv in snvs:
        disease_associated_positions.add(snv[1])
    for indel in indels:
        disease_associated_positions.add(indel[1])

    total_positions = sum(len(ref_seq) for ref_seq in ref_seqs.values())
    disease_associated_count = len(disease_associated_positions)

    prior_probability = disease_associated_count / total_positions
    likelihood = 1 - prior_probability
    posterior_probability = prior_probability / (prior_probability + likelihood * (1 - prior_probability))

    return posterior_probability




def main():
    print('hello')
    ref_fasta = "MSSA476.fasta"
    query_fasta = "mrsastaph.fasta"

    ref_records = read_fasta(ref_fasta)
    query_records = read_fasta(query_fasta)

    ref_sequence = str(ref_records[0].seq)
    query_sequence = str(query_records[0].seq)

    aligned_ref_seq, aligned_query_seq = global_alignment(ref_sequence, query_sequence)

    snvs = find_snvs(aligned_ref_seq, aligned_query_seq)
    indels = find_indels(aligned_ref_seq, aligned_query_seq)

    plot_snv_distribution(snvs)
    plot_indel_distribution(indels)

    difference = calculate_disease_probability(snvs, indels, ref_fasta, query_fasta)
    print(f"\nThe probability difference between the antibiotic resistant vs sensitive: {difference:.2%}")

if __name__ == "__main__":
    main()


hello
