In [None]:
import torch

def meiosis(haplotypes: torch.Tensor,  recombination_map: torch.Tensor,  crossover_interference: float = 2.6,  p: float = 0):
    """Simulates meiosis for an individual.

    Args:
        haplotypes (torch.Tensor): A tensor of haplotypes of shape (ploidy, n_chromosomes, n_loci_per_chr).
        recombination_map (torch.Tensor): A tensor of recombination map values for each locus of shape (n_chromosomes, n_loci_per_chr).
        crossover_interference (float, optional): The crossover interference parameter. Defaults to 2.6.
        p (float, optional): The proportion of crossovers coming from a non-interfering pathway. Defaults to 0.

    Returns:
        torch.Tensor: A tensor of gametes of shape (ploidy, n_chromosomes, n_loci_per_chr).
    """
    
    n_chromosomes = recombination_map.size(0)
    n_loci_per_chr = recombination_map.size(1)
    ploidy = haplotypes.size(0)
    
    gametes = torch.zeros_like(haplotypes)
    
    for chr in range(n_chromosomes):
        
        # Simulate crossovers 
        crossover_events = simulate_crossovers(recombination_map[chr], crossover_interference, p)
        
        # Recombine haplotypes
        for ploidy_idx in range(ploidy):
            gametes[ploidy_idx, chr] = recombine_haplotype(haplotypes[ploidy_idx, chr], crossover_events)
    
    return gametes

def simulate_crossovers(recombination_map: torch.Tensor, crossover_interference: float, p: float):
    """
    Simulates crossover events along a chromosome based on the recombination map.

    Args:
        recombination_map (torch.Tensor): A tensor of recombination map values for each locus.
        crossover_interference (float): Crossover interference parameter.
        p (float): Proportion of crossovers from non-interfering pathway.

    Returns:
        torch.Tensor: A tensor of crossover events (1 for crossover, 0 for no crossover) for each locus.
    """
    
    # Calculate the probability of crossover at each locus
    crossover_probabilities = 1 - torch.exp(-2 * recombination_map) 

    # Apply crossover interference using a gamma distribution
    crossover_events = torch.zeros_like(recombination_map, dtype=torch.bool)

    # Initialize the last crossover point
    last_crossover = -1

    for i in range(recombination_map.size(0)):
        # Probability of crossover at this locus
        crossover_prob = crossover_probabilities[i]

        # Probability of crossover from the interfering pathway
        interfering_prob = (1 - p) * crossover_prob

        # Probability of crossover from the non-interfering pathway
        non_interfering_prob = p * crossover_prob

        # Simulate crossover from the interfering pathway
        if interfering_prob > 0:
            # Sample from a gamma distribution to determine the distance from the last crossover
            distance_from_last_crossover = torch.distributions.Gamma(
                concentration=crossover_interference,
                rate=1.0
            ).sample()

            # If the sampled distance is less than the current locus, there is a crossover
            if distance_from_last_crossover <= i - last_crossover:
                crossover_events[i] = True
                last_crossover = i

        # Simulate crossover from the non-interfering pathway
        if non_interfering_prob > 0:
            if torch.rand(1) < non_interfering_prob:
                crossover_events[i] = True
                last_crossover = i

    return crossover_events

def recombine_haplotype(haplotype: torch.Tensor, crossover_events: torch.Tensor):
    """
    Recombines a haplotype based on crossover events.

    Args:
        haplotype (torch.Tensor): A tensor of haplotype values.
        crossover_events (torch.Tensor): A tensor of crossover events (1 for crossover, 0 for no crossover).

    Returns:
        torch.Tensor: The recombined haplotype.
    """
    
    # Split the haplotype at crossover points
    haplotype_segments = []
    start_index = 0
    
    for i in range(crossover_events.size(0)):
        if crossover_events[i]:
            haplotype_segments.append(haplotype[start_index:i])
            start_index = i
            
    haplotype_segments.append(haplotype[start_index:])
    
    # Reverse the segments after the first crossover
    for i in range(len(haplotype_segments) // 2):
        haplotype_segments[i + 1], haplotype_segments[-(i + 1)] = (
            haplotype_segments[-(i + 1)],
            haplotype_segments[i + 1],
        )
    
    # Concatenate the segments
    recombined_haplotype = torch.cat(haplotype_segments)
    
    return recombined_haplotype