## Task 1: generate simulated data with deletions

We are going to generate sequences with motif: ATATTCA and GTACTGC.

They will either appear together with a nulcleotid distance of 10-fold from each other: 10, 20, 30, .... nucleotides apart. (where the deletion don't count), but there can be deletions inside.

Or only one of them will appear.

And all the other part of the sequences will be randomly generated with deletions.

And then we will add "-" tokens for the deletion. Should we have deletion in the motif too? 

In [None]:
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple

DNA = ["A", "C", "G", "T"]

MOTIF_A = "ATATTCA"
MOTIF_B = "GTACTGC"


def rand_dna(n: int, rng: random.Random) -> str:
    return "".join(rng.choice(DNA) for _ in range(n))


@dataclass
class Sequence:
    seq: str
    label: str               # "both", "A_only", "B_only", "no_motif"
    deletions: int    # numbers of deletions in the sequence
    pos_a: Optional[int]     # start index or None
    pos_b: Optional[int]     # start index or None
    gap: Optional[int]       # nt between motifs if both else None

import math

# this function can be used to augment the random background with deletions
def augmentation(seq: str, frac: float, rng: random.Random) -> str:
    if frac <= 0:
        return seq
    n = len(seq)
    k = max(1, math.floor(frac * n))
    positions = list(range(n))
    rng.shuffle(positions)
    replace_pos = positions[:k]
    s = list(seq)
    for p in replace_pos:
        s[p] = '-'
    return "".join(s)

def make_example(
    length: int,
    mode: str,                     # "both" | "A_only" | "B_only" | "no_motif" 
    gaps: List[int],               # e.g. [10,20,30,...]
    deletions: float,
    if_deletions: bool,
    rng: random.Random
) -> Sequence:
    if length < max(len(MOTIF_A), len(MOTIF_B)) + 1:
        raise ValueError("Sequence length too short for motifs.")
    
    # we don't allow gap not divisible by 10
    invalid = [g for g in gaps if g % 10 != 0]
    if invalid:
        raise ValueError(
            f"Invalid gaps detected: {invalid}. "
            "All gaps must be divisible by 10 (e.g. 10, 20, 30, ...)."
        )
    
    if mode == "both":
        gap = rng.choice(gaps)
        total_motif_len = len(MOTIF_A) + gap + len(MOTIF_B)
        if total_motif_len > length:
            raise ValueError(
                f"Length {length} too short for both motifs with gap {gap} "
                f"(need >= {total_motif_len})."
            )

        if if_deletions:
            deletions_in_gap = rng.randint(1, int(gap * deletions))  # up to 10% deletions in the gap
        else:
            deletions_in_gap = 0
        
        start = rng.randint(0, length - total_motif_len - deletions_in_gap)
        prefix_len = start
        suffix_len = length - prefix_len - total_motif_len - deletions_in_gap

        # generate random background with deletions, the deletion in the gap will be added laer
        if if_deletions:
            prefix = augmentation(rand_dna(prefix_len, rng), deletions, rng)
            between = rand_dna(gap, rng)
            suffix = augmentation(rand_dna(suffix_len, rng), deletions, rng)
        else:
            prefix = rand_dna(prefix_len, rng)
            between = rand_dna(gap, rng)
            suffix = rand_dna(suffix_len, rng)
        
        # motifs region -> by adding them, so the nucleotide length remains the same
        if deletions_in_gap > 0:
            indices = list(range(gap))
            rng.shuffle(indices)
            del_indices = set(indices[:deletions_in_gap])
            between = "".join(
                "-" + c if i in del_indices else c
                for i, c in enumerate(between)
            )

        seq = prefix + MOTIF_A + between + MOTIF_B + suffix
        pos_a = prefix_len
        pos_b = prefix_len + len(MOTIF_A) + gap + deletions_in_gap
        deletions = seq.count('-')

        return Sequence(seq=seq, label="both", deletions = deletions, pos_a=pos_a, pos_b=pos_b, gap=gap)

    elif mode == "A_only":
        total = len(MOTIF_A)
        start = rng.randint(0, length - total)
        if if_deletions:
            seq = augmentation(rand_dna(start, rng), deletions, rng) + MOTIF_A + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            seq = rand_dna(start, rng) + MOTIF_A + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="A_only", deletions = deletions, pos_a=start, pos_b=None, gap=None)

    elif mode == "B_only":
        total = len(MOTIF_B)
        start = rng.randint(0, length - total)
        if if_deletions:
            seq = augmentation(rand_dna(start, rng), deletions, rng) + MOTIF_B + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            seq = rand_dna(start, rng) + MOTIF_B + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="B_only", deletions = deletions, pos_a=None, pos_b=start, gap=None)

    elif mode == "no_motif":
        if if_deletions:
            seq = augmentation(rand_dna(length, rng), deletions, rng)
        else:
            seq = rand_dna(length, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="no_motif", deletions = deletions, pos_a=None, pos_b=None, gap=None)
    
    else:
        raise ValueError("mode must be: 'both', 'A_only', or 'B_only'")


def generate_dataset(
    n: int, # numbers of seuqneces we want to generate
    length: int = 120, # lenght of each sequence, should all be the same
    gaps: List[int] = None, # list of possible gaps between motif A and motif B: 10, 20, 30, ..., 100
    deletions: float = 0.1, # percentage of deletions to augment each sequence with
    no_del_seqs: float = 0.25, # percentage of sequences without any deletions
    p_both: float = 0.4, # how many sequences contain both motifs in prbobability
    p_a_only: float = 0.25,  # how many sequences contain only motif A in probability
    p_b_only: float = 0.25,  # how many sequences contain only motif B in probability
    p_no_motif: float = 0.1, # how many sequences contain no motif in probability
    seed: int = 727 # random seed
) -> List[Sequence]:
    
    if gaps is None:
        gaps = list(range(10, 101, 10))  # 10,20,...,100

    if abs((p_both + p_a_only + p_b_only + p_no_motif) - 1.0) > 1e-9:
        raise ValueError("Probabilities must sum to 1 (p_both + p_a_only + p_b_only).")

    rng = random.Random(seed)

    data = []
    for _ in range(n):
        r = rng.random()
        # r is in range [0,1), so we can use it to randomly select mode, given the probabilities
        if r < p_both:
            mode = "both"
        elif r < p_both + p_a_only:
            mode = "A_only"
        elif r < p_both + p_a_only + p_b_only:
            mode = "B_only"
        else:
            mode = "no_motif"

        if rng.random() < no_del_seqs:
            if_deletions = False
        else:
            if_deletions = True

        ex = make_example(length=length, mode=mode, gaps=gaps, deletions = deletions, if_deletions = if_deletions, rng=rng)
        data.append(ex)

    return data


n=1000
length=150
deletions = 0.1
no_del_seqs = 0.25
gaps = list(range(10, 101, 10))

dataset = generate_dataset(n=n, length=length, gaps = gaps, deletions=deletions, no_del_seqs=no_del_seqs, seed=)

# output the sequences in FASTA format
def write_fasta(dataset, filepath):
    with open(filepath, "w") as f:
        for i, ex in enumerate(dataset, start=1):
            header = (
                f">seq{i:04d}"
                f"|label={ex.label}"
                f"|posAmotif={ex.pos_a}"
                f"|posBmotif={ex.pos_b}"
                f"|gaplength={ex.gap}"
                f"|deletions={ex.deletions}"
            )
            f.write(header + "\n")
            f.write(ex.seq + "\n")
            

write_fasta(dataset, f"test_data/augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")