In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| default_exp meiosis

In [None]:
from chewc.core import *
import torch

## meiosis
> Simulating Meisois and Recombination for various Crossing actions

In [None]:


def gamma_interference_model(length, rate, shape, device):
    """
    Simulate crossover events using a gamma interference model.
    
    Parameters:
    length (float): Length of the chromosome.
    rate (float): Rate of crossover events.
    shape (float): Shape parameter for the gamma distribution.
    device (torch.device): Device to perform computations on.
    
    Returns:
    torch.Tensor: Positions of crossover events.
    """
    num_crossovers = torch.poisson(torch.tensor([rate * length], device=device))
    intervals = torch.distributions.gamma.Gamma(shape, rate).sample((int(num_crossovers.item()),)).to(device)
    crossover_positions = torch.cumsum(intervals, dim=0)
    crossover_positions = crossover_positions[crossover_positions < length]
    return crossover_positions


def simulate_meiosis(num_chromosomes, map_length, num_individuals, num_crossovers, device):
    """
    This function simulates random crossover events across chromosomes.
    
    Parameters:
    num_chromosomes (int): Number of chromosomes.
    map_length (float): Length of the chromosome map.
    num_individuals (int): Number of individuals.
    num_crossovers (int): Number of crossovers.
    device (torch.device): Device to perform computations on.
    
    Returns:
    torch.Tensor: Tensor of crossover positions for each individual and chromosome. 
                  Shape: (num_individuals, num_chromosomes, num_crossovers)
    """
    return torch.sort(torch.rand((num_individuals, num_chromosomes, num_crossovers), device=device) * map_length, dim=-1)[0]


def simulate_gametes(genetic_map, parent_genomes, device):
    """
    Simulate the formation of gametes for multiple parents given crossover positions, genetic map, and parent genomes.

    Parameters:
    genetic_map (torch.Tensor): Positions of genetic markers on the chromosomes. 
                                 Shape: (num_chromosomes, num_loci)
    parent_genomes (torch.Tensor): Genomes of the parents. 
                                    Shape: (num_individuals, ploidy, num_chromosomes, num_loci)
    device (torch.device): Device to perform computations on.

    Returns:
    torch.Tensor: The resultant gametes. 
                  Shape: (num_individuals, ploidy//2, num_chromosomes, num_loci)
    """
    num_individuals, ploidy, num_chromosomes, num_loci = parent_genomes.shape
    gamete_genomes = torch.zeros((num_individuals, ploidy // 2, num_chromosomes, num_loci), 
                                dtype=parent_genomes.dtype, device=device)
    
    # Simulate crossover positions for all individuals
    crossover_positions = simulate_meiosis(num_chromosomes, genetic_map.max(), num_individuals, 1, device)

    for individual in range(num_individuals):
        for chrom in range(num_chromosomes):
            crossover_mask = torch.zeros(num_loci, dtype=torch.bool, device=device)
            crossover_site = crossover_positions[individual, chrom, 0] 

            # Efficiently find the crossover index using vectorized operations
            index = torch.argmin(torch.abs(genetic_map[chrom] - crossover_site))
            crossover_mask[index:] = ~crossover_mask[index:]

            gamete_genomes[individual, 0, chrom] = torch.where(crossover_mask, 
                                                            parent_genomes[individual, 1, chrom], 
                                                            parent_genomes[individual, 0, chrom])

    return gamete_genomes

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()