## Task 1: generate simulated data with deletions

We are going to generate sequences with motif: ATATTCA and GTACTGC.

They will either appear together with a nulcleotid distance of 10-fold from each other: 10, 20, 30, .... nucleotides apart. (where the deletion don't count), but there can be deletions inside.

Or only one of them will appear.

And all the other part of the sequences will be randomly generated with deletions.

And then we will add "-" tokens for the deletion. Should we have deletion in the motif too? 

In [None]:
# this cell generates data with 5 bases as distance between motif A and B, and could have multiple motifs in one sequence(1, 2, 3, 4, 5)
# for now there should not be any deletions in the motifs
import random
from dataclasses import dataclass
from statistics import mode
from typing import List, Optional, Tuple

from torch import mode

DNA = ["A", "C", "G", "T"]

MOTIF_A = "ATATTCA"
MOTIF_B = "GTACTGC"


def rand_dna(n: int, rng: random.Random) -> str:
    return "".join(rng.choice(DNA) for _ in range(n))


@dataclass
class Sequence:
    seq: str
    label: str               # "both", "A_only", "B_only", "no_motif"
    deletions: int    # numbers of deletions in the sequence
    pos_a: Optional[int]     # start index or None
    pos_b: Optional[int]     # start index or None
    gap: Optional[int]       # nt between motifs if both else None

import math

# this function can be used to augment the random background with deletions by adding deletions
def augmentation(seq: str, frac: float, rng: random.Random) -> str:
    if frac <= 0:
        return seq
    n = len(seq)
    k = random.randint(0, max(1, math.floor(frac * n)))
    positions = list(range(n))
    rng.shuffle(positions)
    replace_pos = positions[:k]
    s = list(seq)
    # insert deletion token at positions p
    for p in sorted(replace_pos):
        s.insert(p, '-')
    return "".join(s)


def make_example(
    length: int,
    number_of_motifs: int,     # e.g. 1, 2, 3, 4, 5
    gap: int,               # have a fixed gap between motifs: 5
    deletions: float,
    if_deletions: bool,
    rng: random.Random
) -> Sequence:
    motif_list = [MOTIF_A, MOTIF_B, MOTIF_A, MOTIF_B, MOTIF_A][:number_of_motifs]

    if length < len(MOTIF_A) * number_of_motifs + 1:
        raise ValueError("Sequence length too short for motifs.")

    total_motif_len = len(MOTIF_A) + gap + len(MOTIF_B)
    if number_of_motifs > 1:
        fixed_gap_count = 0
        for i in range(len(motif_list) - 1):
            if motif_list[i] == MOTIF_A and motif_list[i + 1] == MOTIF_B:
                fixed_gap_count += 1

        # count random B->A gaps
        ba_count = 0
        for i in range(len(motif_list) - 1):
            if motif_list[i] == MOTIF_B and motif_list[i + 1] == MOTIF_A:
                ba_count += 1

        min_len = sum(len(m) for m in motif_list) + fixed_gap_count * gap
        remaining = length - min_len

        # distribute remaining into prefix + B->A gaps + suffix
        cuts = sorted(rng.randint(0, remaining) for _ in range(ba_count + 1))
        alloc = []
        prev = 0
        for c in cuts:
            alloc.append(c - prev)
            prev = c
        alloc.append(remaining - prev)

        prefix_len = alloc[0]
        ba_gap_lens = alloc[1:-1]
        suffix_len = alloc[-1]

        # build prefix/suffix
        if if_deletions:
            prefix = augmentation(rand_dna(prefix_len, rng), deletions, rng)
            suffix = augmentation(rand_dna(suffix_len, rng), deletions, rng)
        else:
            prefix = rand_dna(prefix_len, rng)
            suffix = rand_dna(suffix_len, rng)

        seq = ""
        position_a = []
        position_b = []

        ba_idx = 0

        for i, motif in enumerate(motif_list):
            # record motif start
            motif_start = len(prefix) + len(seq)
            if motif == MOTIF_A:
                position_a.append(motif_start)
            else:
                position_b.append(motif_start)

            seq += motif

            if i < len(motif_list) - 1:
                next_motif = motif_list[i + 1]

                # fixed A->B gap
                if motif == MOTIF_A and next_motif == MOTIF_B:
                    if if_deletions:
                        between = augmentation(rand_dna(gap, rng), deletions, rng)
                    else:
                        between = rand_dna(gap, rng)
                    seq += between

                # random B->A gap
                elif motif == MOTIF_B and next_motif == MOTIF_A:
                    g_len = ba_gap_lens[ba_idx]
                    ba_idx += 1
                    if if_deletions:
                        between = augmentation(rand_dna(g_len, rng), deletions, rng)
                    else:
                        between = rand_dna(g_len, rng)
                    seq += between

        seq = prefix + seq + suffix
        
        
        # list to string
        pos_a = ','.join(str(p) for p in position_a)
        pos_b = ','.join(str(p) for p in position_b)
        deletions = seq.count('-')

        return Sequence(seq=seq, label="both", deletions = deletions, pos_a=pos_a, pos_b=pos_b, gap=gap)

    elif number_of_motifs == 1:
        # equally likely to be A_only or B_only
        mode = rng.choice(["A_only", "B_only"])
        if mode == "A_only":
            total = len(MOTIF_A)
            start = rng.randint(0, length - total)
            if if_deletions:
                prefix = augmentation(rand_dna(start, rng), deletions, rng)
                seq = prefix + MOTIF_A + augmentation(rand_dna(length - start - total, rng), deletions, rng)
            else:
                prefix = rand_dna(start, rng)
                seq = rand_dna(start, rng) + MOTIF_A + rand_dna(length - start - total, rng)
            deletions = seq.count('-')
            return Sequence(seq=seq, label="A_only", deletions = deletions, pos_a=len(prefix), pos_b=None, gap=None)

        elif mode == "B_only":
            total = len(MOTIF_B)
            start = rng.randint(0, length - total)
            if if_deletions:
                prefix = augmentation(rand_dna(start, rng), deletions, rng)
                seq = prefix + MOTIF_B + augmentation(rand_dna(length - start - total, rng), deletions, rng)
            else:
                prefix = rand_dna(start, rng)
                seq = rand_dna(start, rng) + MOTIF_B + rand_dna(length - start - total, rng)
            deletions = seq.count('-')
            return Sequence(seq=seq, label="B_only", deletions = deletions, pos_a=None, pos_b=len(prefix), gap=None)

    elif number_of_motifs == 0:
        if if_deletions:
            seq = augmentation(rand_dna(length, rng), deletions, rng)
        else:
            seq = rand_dna(length, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="no_motif", deletions = deletions, pos_a=None, pos_b=None, gap=None)
    
    else:
        raise ValueError("mode must be: 'both', 'A_only', or 'B_only'")


def generate_dataset(
    n: int, # numbers of seuqneces we want to generate
    length: int = 150, # lenght of each sequence, should all be the same
    gap: int = 5, # list of possible gaps between motif A and motif B: 10, 20, 30, ..., 100
    deletions: float = 0.2, # percentage of deletions to augment each sequence with
    no_del_seqs: float = 0.25, # percentage of sequences without any deletions
    seed: int = 727 # random seed
) -> List[Sequence]:
    

    rng = random.Random(seed)

    data = []
    for _ in range(n):
        r = rng.random()
        # random rumber between 0 and 5 to decide the number of motifs
        if r < 0.05:
            mode = 0
        elif r < 0.15:
            mode = 1
        elif r < 0.25:
            mode = 2
        elif r < 0.5:
            mode = 3
        elif r < 0.75:
            mode = 4
        else:
            mode = 5
        number_of_motifs = mode
        if rng.random() < no_del_seqs:
            if_deletions = False
        else:
            if_deletions = True

        ex = make_example(length=length, gap=5, deletions = deletions, if_deletions = if_deletions, number_of_motifs=number_of_motifs, rng=rng)
        data.append(ex)

    return data


n=10000
length=150
deletions = 0.2
no_del_seqs = 0.05
gaps = 5

dataset = generate_dataset(n=n, length=length, gap = gaps, deletions=deletions, no_del_seqs=no_del_seqs, seed=284)

# output the sequences in FASTA format
def write_fasta(dataset, filepath):
    with open(filepath, "w") as f:
        for i, ex in enumerate(dataset, start=1):
            header = (
                f">seq{i:04d}"
                f"|label={ex.label}"
                f"|posAmotif={ex.pos_a}"
                f"|posBmotif={ex.pos_b}"
                f"|gaplength={ex.gap}"
                f"|deletions={ex.deletions}"
            )
            f.write(header + "\n")
            f.write(ex.seq + "\n")
            

write_fasta(dataset, f"simulated_sequences_distance=5/new_augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

#write_fasta(dataset, f"test_data_distance=5/augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

In [1]:
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple

DNA = ["A", "C", "G", "T"]

MOTIF_A = "ATATTCA"
MOTIF_B = "GTACTGC"


def rand_dna(n: int, rng: random.Random) -> str:
    return "".join(rng.choice(DNA) for _ in range(n))


@dataclass
class Sequence:
    seq: str
    label: str               # "both", "A_only", "B_only", "no_motif"
    deletions: int    # numbers of deletions in the sequence
    pos_a: Optional[int]     # start index or None
    pos_b: Optional[int]     # start index or None
    gap: Optional[int]       # nt between motifs if both else None

import math

# this function can be used to augment the random background with deletions by adding deletions
def augmentation(seq: str, frac: float, rng: random.Random) -> str:
    if frac <= 0:
        return seq
    n = len(seq)
    k = random.randint(1, max(1, math.floor(frac * n)))
    positions = list(range(n))
    rng.shuffle(positions)
    replace_pos = positions[:k]
    s = list(seq)
    # insert deletion token at positions p
    for p in sorted(replace_pos):
        s.insert(p, '-')
    return "".join(s)

def make_example(
    length: int,
    mode: str,                     # "both" | "A_only" | "B_only" | "no_motif" 
    gaps: List[int],               # e.g. [10,20,30,...]
    deletions: float,
    if_deletions: bool,
    rng: random.Random
) -> Sequence:
    if length < max(len(MOTIF_A), len(MOTIF_B)) + 1:
        raise ValueError("Sequence length too short for motifs.")
    
    # we don't allow gap not divisible by 10
    invalid = [g for g in gaps if g % 10 != 0]
    if invalid:
        raise ValueError(
            f"Invalid gaps detected: {invalid}. "
            "All gaps must be divisible by 10 (e.g. 10, 20, 30, ...)."
        )
    
    if mode == "both":
        gap = rng.choice(gaps)
        total_motif_len = len(MOTIF_A) + gap + len(MOTIF_B)
        if total_motif_len > length:
            raise ValueError(
                f"Length {length} too short for both motifs with gap {gap} "
                f"(need >= {total_motif_len})."
            )
        
        start = rng.randint(0, length - total_motif_len)
        prefix_len = start
        suffix_len = length - prefix_len - total_motif_len

        # generate random background with deletions, the deletion in the gap will be added 
        if if_deletions:
            prefix = augmentation(rand_dna(prefix_len, rng), deletions, rng)
            between = augmentation(rand_dna(gap, rng), deletions, rng)
            suffix = augmentation(rand_dna(suffix_len, rng), deletions, rng)
        else:
            prefix = rand_dna(prefix_len, rng)
            between = rand_dna(gap, rng)
            suffix = rand_dna(suffix_len, rng)
        

        seq = prefix + MOTIF_A + between + MOTIF_B + suffix
        pos_a = len(prefix)
        pos_b = len(prefix) + len(MOTIF_A) + len(between)
        deletions = seq.count('-')

        return Sequence(seq=seq, label="both", deletions = deletions, pos_a=pos_a, pos_b=pos_b, gap=gap)

    elif mode == "A_only":
        total = len(MOTIF_A)
        start = rng.randint(0, length - total)
        if if_deletions:
            prefix = augmentation(rand_dna(start, rng), deletions, rng)
            seq = prefix + MOTIF_A + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            prefix = rand_dna(start, rng)
            seq = rand_dna(start, rng) + MOTIF_A + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="A_only", deletions = deletions, pos_a=len(prefix), pos_b=None, gap=None)

    elif mode == "B_only":
        total = len(MOTIF_B)
        start = rng.randint(0, length - total)
        if if_deletions:
            prefix = augmentation(rand_dna(start, rng), deletions, rng)
            seq = prefix + MOTIF_B + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            prefix = rand_dna(start, rng)
            seq = rand_dna(start, rng) + MOTIF_B + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="B_only", deletions = deletions, pos_a=None, pos_b=len(prefix), gap=None)

    elif mode == "no_motif":
        if if_deletions:
            seq = augmentation(rand_dna(length, rng), deletions, rng)
        else:
            seq = rand_dna(length, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="no_motif", deletions = deletions, pos_a=None, pos_b=None, gap=None)
    
    else:
        raise ValueError("mode must be: 'both', 'A_only', or 'B_only'")


def generate_dataset(
    n: int, # numbers of seuqneces we want to generate
    length: int = 150, # lenght of each sequence, should all be the same
    gaps: List[int] = None, # list of possible gaps between motif A and motif B: 10, 20, 30, ..., 100
    deletions: float = 0.2, # percentage of deletions to augment each sequence with
    no_del_seqs: float = 0.25, # percentage of sequences without any deletions
    p_both: float = 0.8, # how many sequences contain both motifs in prbobability
    p_a_only: float = 0.1,  # how many sequences contain only motif A in probability
    p_b_only: float = 0.1,  # how many sequences contain only motif B in probability
    p_no_motif: float = 0.0, # how many sequences contain no motif in probability
    seed: int = 727 # random seed
) -> List[Sequence]:
    
    if gaps is None:
        gaps = list(range(10, 101, 10))  # 10,20,...,100

    if abs((p_both + p_a_only + p_b_only + p_no_motif) - 1.0) > 1e-9:
        raise ValueError("Probabilities must sum to 1 (p_both + p_a_only + p_b_only).")

    rng = random.Random(seed)

    data = []
    for _ in range(n):
        r = rng.random()
        # r is in range [0,1), so we can use it to randomly select mode, given the probabilities
        if r < p_both:
            mode = "both"
        elif r < p_both + p_a_only:
            mode = "A_only"
        elif r < p_both + p_a_only + p_b_only:
            mode = "B_only"
        else:
            mode = "no_motif"

        if rng.random() < no_del_seqs:
            if_deletions = False
        else:
            if_deletions = True

        ex = make_example(length=length, mode=mode, gaps=gaps, deletions = deletions, if_deletions = if_deletions, rng=rng)
        data.append(ex)

    return data


n=10000
length=150
deletions = 0.2
no_del_seqs = 0.1
gaps = list(range(10, 101, 10))

dataset = generate_dataset(n=n, length=length, gaps = gaps, deletions=deletions, no_del_seqs=no_del_seqs, seed=831)

# output the sequences in FASTA format
def write_fasta(dataset, filepath):
    with open(filepath, "w") as f:
        for i, ex in enumerate(dataset, start=1):
            header = (
                f">seq{i:04d}"
                f"|label={ex.label}"
                f"|posAmotif={ex.pos_a}"
                f"|posBmotif={ex.pos_b}"
                f"|gaplength={ex.gap}"
                f"|deletions={ex.deletions}"
            )
            f.write(header + "\n")
            f.write(ex.seq + "\n")
            

write_fasta(dataset, f"simulated_sequences/augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

#write_fasta(dataset, f"test_data/augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

## generate data for the baseline model without deletion


In [1]:
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple

DNA = ["A", "C", "G", "T"]

MOTIF_A = "ATATTCA"
MOTIF_B = "GTACTGC"


def rand_dna(n: int, rng: random.Random) -> str:
    return "".join(rng.choice(DNA) for _ in range(n))


@dataclass
class Sequence:
    seq: str
    label: str               # "both", "A_only", "B_only", "no_motif"
    deletions: int    # numbers of deletions in the sequence
    pos_a: Optional[int]     # start index or None
    pos_b: Optional[int]     # start index or None
    gap: Optional[int]       # nt between motifs if both else None

import math

# this function can be used to augment the random background with deletions by adding deletions
def augmentation(seq: str, frac: float, rng: random.Random) -> str:
    if frac <= 0:
        return seq
    n = len(seq)
    k = max(1, math.floor(frac * n))
    positions = list(range(n))
    rng.shuffle(positions)
    replace_pos = positions[:k]
    s = list(seq)
    # insert deletion token at positions p
    for p in sorted(replace_pos):
        s.insert(p, '-')
    return "".join(s)

def make_example(
    length: int,
    mode: str,                     # "both" | "A_only" | "B_only" | "no_motif" 
    gaps: List[int],               # e.g. [10,20,30,...]
    deletions: float,
    if_deletions: bool,
    rng: random.Random
) -> Sequence:
    if length < max(len(MOTIF_A), len(MOTIF_B)) + 1:
        raise ValueError("Sequence length too short for motifs.")
    
    # we don't allow gap not divisible by 10
    invalid = [g for g in gaps if g % 10 != 0]
    if invalid:
        raise ValueError(
            f"Invalid gaps detected: {invalid}. "
            "All gaps must be divisible by 10 (e.g. 10, 20, 30, ...)."
        )
    
    if mode == "both":
        gap = rng.choice(gaps)
        total_motif_len = len(MOTIF_A) + gap + len(MOTIF_B)
        if total_motif_len > length:
            raise ValueError(
                f"Length {length} too short for both motifs with gap {gap} "
                f"(need >= {total_motif_len})."
            )

        if if_deletions:
            deletions_in_gap = rng.randint(1, int(gap * deletions))  # up to 10% deletions in the gap
        else:
            deletions_in_gap = 0
        
        start = rng.randint(0, length - total_motif_len)
        prefix_len = start
        suffix_len = length - prefix_len - total_motif_len

        # generate random background with deletions, the deletion in the gap will be added 
        if if_deletions:
            prefix = augmentation(rand_dna(prefix_len, rng), deletions, rng)
            between = augmentation(rand_dna(gap, rng), deletions, rng)
            suffix = augmentation(rand_dna(suffix_len, rng), deletions, rng)
        else:
            prefix = rand_dna(prefix_len, rng)
            between = rand_dna(gap, rng)
            suffix = rand_dna(suffix_len, rng)
        

        seq = prefix + MOTIF_A + between + MOTIF_B + suffix
        pos_a = len(prefix)
        pos_b = len(prefix) + len(MOTIF_A) + len(between)
        deletions = seq.count('-')

        return Sequence(seq=seq, label="both", deletions = deletions, pos_a=pos_a, pos_b=pos_b, gap=gap)

    elif mode == "A_only":
        total = len(MOTIF_A)
        start = rng.randint(0, length - total)
        if if_deletions:
            prefix = augmentation(rand_dna(start, rng), deletions, rng)
            seq = prefix + MOTIF_A + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            prefix = rand_dna(start, rng)
            seq = rand_dna(start, rng) + MOTIF_A + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="A_only", deletions = deletions, pos_a=len(prefix), pos_b=None, gap=None)

    elif mode == "B_only":
        total = len(MOTIF_B)
        start = rng.randint(0, length - total)
        if if_deletions:
            prefix = augmentation(rand_dna(start, rng), deletions, rng)
            seq = prefix + MOTIF_B + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        else:
            prefix = rand_dna(start, rng)
            seq = rand_dna(start, rng) + MOTIF_B + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="B_only", deletions = deletions, pos_a=None, pos_b=len(prefix), gap=None)

    elif mode == "no_motif":
        if if_deletions:
            seq = augmentation(rand_dna(length, rng), deletions, rng)
        else:
            seq = rand_dna(length, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="no_motif", deletions = deletions, pos_a=None, pos_b=None, gap=None)
    
    else:
        raise ValueError("mode must be: 'both', 'A_only', or 'B_only'")


def generate_dataset(
    n: int, # numbers of seuqneces we want to generate
    length: int = 150, # lenght of each sequence, should all be the same
    gaps: List[int] = None, # list of possible gaps between motif A and motif B: 10, 20, 30, ..., 100
    deletions: float = 0, # percentage of deletions to augment each sequence with
    no_del_seqs: float = 1, # percentage of sequences without any deletions
    p_both: float = 0.8, # how many sequences contain both motifs in prbobability
    p_a_only: float = 0.1,  # how many sequences contain only motif A in probability
    p_b_only: float = 0.1,  # how many sequences contain only motif B in probability
    p_no_motif: float = 0.0, # how many sequences contain no motif in probability
    seed: int = 727 # random seed
) -> List[Sequence]:
    
    if gaps is None:
        gaps = list(range(10, 101, 10))  # 10,20,...,100

    if abs((p_both + p_a_only + p_b_only + p_no_motif) - 1.0) > 1e-9:
        raise ValueError("Probabilities must sum to 1 (p_both + p_a_only + p_b_only).")

    rng = random.Random(seed)

    data = []
    for _ in range(n):
        r = rng.random()
        # r is in range [0,1), so we can use it to randomly select mode, given the probabilities
        if r < p_both:
            mode = "both"
        elif r < p_both + p_a_only:
            mode = "A_only"
        elif r < p_both + p_a_only + p_b_only:
            mode = "B_only"
        else:
            mode = "no_motif"

        if rng.random() < no_del_seqs:
            if_deletions = False
        else:
            if_deletions = True

        ex = make_example(length=length, mode=mode, gaps=gaps, deletions = deletions, if_deletions = if_deletions, rng=rng)
        data.append(ex)

    return data


n=10000
length=150
deletions = 0
no_del_seqs = 1
gaps = list(range(10, 101, 10))

dataset = generate_dataset(n=n, length=length, gaps = gaps, deletions=deletions, no_del_seqs=no_del_seqs, seed=538)

# output the sequences in FASTA format
def write_fasta(dataset, filepath):
    with open(filepath, "w") as f:
        for i, ex in enumerate(dataset, start=1):
            header = (
                f">seq{i:04d}"
                f"|label={ex.label}"
                f"|posAmotif={ex.pos_a}"
                f"|posBmotif={ex.pos_b}"
                f"|gaplength={ex.gap}"
                f"|deletions={ex.deletions}"
            )
            f.write(header + "\n")
            f.write(ex.seq + "\n")
            

write_fasta(dataset, f"simulated_sequences/sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

#write_fasta(dataset, f"test_data/augumented_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

## a new idea way to create data set as:

In training, we can have same sequence multiple times, one has no deletion, 

the other has randomly inserted deletion tokens â€”> the model might learn that they are all biologically the same

(not sure if we should still do it)

In [None]:


import random
from dataclasses import dataclass
from typing import List, Optional, Tuple

DNA = ["A", "C", "G", "T"]

MOTIF_A = "ATATTCA"
MOTIF_B = "GTACTGC"


def rand_dna(n: int, rng: random.Random) -> str:
    return "".join(rng.choice(DNA) for _ in range(n))


@dataclass
class Sequence:
    seq: str
    label: str               # "both", "A_only", "B_only", "no_motif"
    deletions: int    # numbers of deletions in the sequence
    pos_a: Optional[int]     # start index or None
    pos_b: Optional[int]     # start index or None
    gap: Optional[int]       # nt between motifs if both else None

import math

# this function can be used to augment the random background with deletions by adding deletions
def augmentation(seq: str, frac: float, rng: random.Random) -> str:
    if frac <= 0:
        return seq
    n = len(seq)
    k = max(1, math.floor(frac * n))
    positions = list(range(n))
    rng.shuffle(positions)
    replace_pos = positions[:k]
    s = list(seq)
    # insert deletion token at positions p
    for p in sorted(replace_pos):
        s.insert(p, '-')
    return "".join(s)

def make_example(
    length: int,
    mode: str,                     # "both" | "A_only" | "B_only" | "no_motif" 
    gaps: List[int],               # e.g. [10,20,30,...]
    deletions: float,
    rng: random.Random
) -> Sequence:
    if length < max(len(MOTIF_A), len(MOTIF_B)) + 1:
        raise ValueError("Sequence length too short for motifs.")
    
    # we don't allow gap not divisible by 10
    invalid = [g for g in gaps if g % 10 != 0]
    if invalid:
        raise ValueError(
            f"Invalid gaps detected: {invalid}. "
            "All gaps must be divisible by 10 (e.g. 10, 20, 30, ...)."
        )
    
    if mode == "both":
        gap = rng.choice(gaps)
        total_motif_len = len(MOTIF_A) + gap + len(MOTIF_B)
        if total_motif_len > length:
            raise ValueError(
                f"Length {length} too short for both motifs with gap {gap} "
                f"(need >= {total_motif_len})."
            )

        deletions_in_gap = rng.randint(1, int(gap * deletions))  # up to 10% deletions in the gap
        
        start = rng.randint(0, length - total_motif_len)
        prefix_len = start
        suffix_len = length - prefix_len - total_motif_len

        # generate random background with deletions, the deletion in the gap will be added 

        prefix = rand_dna(prefix_len, rng)
        between = rand_dna(gap, rng)
        suffix = rand_dna(suffix_len, rng)

        prefix_del = augmentation(prefix, deletions, rng)
        between_del = augmentation(between, deletions, rng)
        suffix_del = augmentation(suffix, deletions, rng)

        seq = prefix + MOTIF_A + between + MOTIF_B + suffix
        pos_a = len(prefix)
        pos_b = len(prefix) + len(MOTIF_A) + len(between)
        deletions = seq.count('-')
        sequence = Sequence(seq=seq, label="both", deletions = deletions, pos_a=pos_a, pos_b=pos_b, gap=gap)

        seq_del = prefix_del + MOTIF_A + between_del + MOTIF_B + suffix_del
        pos_a_del = len(prefix_del)
        pos_b_del = len(prefix_del) + len(MOTIF_A) + len(between_del)
        deletions = seq_del.count('-')
        sequence_del = Sequence(seq=seq_del, label="both", deletions = deletions_del, pos_a=pos_a_del, pos_b=pos_b_del, gap=gap)

        return [sequence, sequence_del]

    elif mode == "A_only":
        total = len(MOTIF_A)
        start = rng.randint(0, length - total)
        prefix_del = augmentation(rand_dna(start, rng), deletions, rng)
        seq_del = prefix_del + MOTIF_A + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        deletions_del = seq_del.count("-")
        sequence_del = Sequence(seq=seq_del, label="A_only", deletions = deletions_del, pos_a=prefix_del, pos_b=None, gap=None)

        seq = rand_dna(start, rng) + MOTIF_A + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        sequence = Sequence(seq=seq, label="A_only", deletions = deletions, pos_a=start, pos_b=None, gap=None)

        return [sequence, sequence_del]

    elif mode == "B_only":
        total = len(MOTIF_B)
        start = rng.randint(0, length - total)

        prefix_del = augmentation(rand_dna(start, rng), deletions, rng)
        seq_del = prefix_del + MOTIF_B + augmentation(rand_dna(length - start - total, rng), deletions, rng)
        deletions_del = seq_del.count("-")
        sequence_del = Sequence(seq=seq_del, label="B_only", deletions = deletions_del, pos_a=None, pos_b=prefix_del, gap=None)

        seq = rand_dna(start, rng) + MOTIF_B + rand_dna(length - start - total, rng)
        deletions = seq.count('-')
        sequence = Sequence(seq=seq, label="B_only", deletions = deletions, pos_a=None, pos_b=start, gap=None)

        return [sequence, sequence_del]

    elif mode == "no_motif":
        seq = augmentation(rand_dna(length, rng), deletions, rng)
        seq = rand_dna(length, rng)
        deletions = seq.count('-')
        return Sequence(seq=seq, label="no_motif", deletions = deletions, pos_a=None, pos_b=None, gap=None)
    
    else:
        raise ValueError("mode must be: 'both', 'A_only', or 'B_only'")


def generate_dataset(
    n: int, # numbers of seuqneces we want to generate
    length: int = 150, # lenght of each sequence, should all be the same
    gaps: List[int] = None, # list of possible gaps between motif A and motif B: 10, 20, 30, ..., 100
    deletions: float = 0.2, # percentage of deletions to augment each sequence with
    p_both: float = 0.8, # how many sequences contain both motifs in prbobability
    p_a_only: float = 0.1,  # how many sequences contain only motif A in probability
    p_b_only: float = 0.1,  # how many sequences contain only motif B in probability
    p_no_motif: float = 0.0, # how many sequences contain no motif in probability
    seed: int = 727 # random seed
) -> List[Sequence]:
    
    if gaps is None:
        gaps = list(range(10, 101, 10))  # 10,20,...,100

    if abs((p_both + p_a_only + p_b_only + p_no_motif) - 1.0) > 1e-9:
        raise ValueError("Probabilities must sum to 1 (p_both + p_a_only + p_b_only).")

    rng = random.Random(seed)

    data = []
    for _ in range(n):
        r = rng.random()
        # r is in range [0,1), so we can use it to randomly select mode, given the probabilities
        if r < p_both:
            mode = "both"
        elif r < p_both + p_a_only:
            mode = "A_only"
        elif r < p_both + p_a_only + p_b_only:
            mode = "B_only"
        else:
            mode = "no_motif"

        ex = make_example(length=length, mode=mode, gaps=gaps, deletions = deletions, rng=rng)
        data.append(ex)

    return data


n=10000
length=150
deletions = 0.2
gaps = list(range(10, 101, 10))

dataset = generate_dataset(n=n, length=length, gaps = gaps, deletions=deletions, seed=294)

# output the sequences in FASTA format
def write_fasta(dataset, filepath):
    with open(filepath, "w") as f:
        for i, ex in enumerate(dataset, start=1):
            header = (
                f">seq{i:04d}"
                f"|label={ex.label}"
                f"|posAmotif={ex.pos_a}"
                f"|posBmotif={ex.pos_b}"
                f"|gaplength={ex.gap}"
                f"|deletions={ex.deletions}"
            )
            f.write(header + "\n")
            f.write(ex.seq + "\n")
            

write_fasta(dataset, f"simulated_sequences/comparison_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")

#write_fasta(dataset, f"test_data/comparison_sequence_size{n}_length{length}_deletions{deletions}_nodeletionseq{no_del_seqs}.fasta")
