In [31]:
import re
import subprocess
import io
import gzip
import sys

import numpy as np
import pandas as pd

## Parsers

def get_seq_df_input_symbols(input_df, seq_df, mane=False):
    """
    Update gene names (HUGO Symbols) of O3D built sequence with names in input file.
    Do it only for entries in the sequence df with available transcript information
    and use transcript ID to get gene name.
    """

    # Split sequence df by entries with available transcript info (Reference_info 0 and 1) and not available ones (-1)
    seq_df_tr_missing = seq_df[seq_df["Reference_info"] == -1].reset_index(drop=True)
    seq_df_tr_available = seq_df[seq_df["Reference_info"] != -1].reset_index(drop=True)

    # Use names from input
    df_mapping = input_df[["Hugo_Symbol", "Feature"]].rename(columns={"Hugo_Symbol" : "Gene", "Feature" : "Ens_Transcr_ID"})
    seq_df_tr_available = seq_df_tr_available.drop(columns=["Gene"]).drop_duplicates().merge(df_mapping, how="left", on="Ens_Transcr_ID")

    # If the same gene is associated to multiple structures, keep the first one obtained from Uniprot (descending, Reference_info 1) or keep the MANE (ascending, Reference_info 0)
    # TO DO: Use the one reviewed (UniProtKB reviewed (Swiss-Prot)), if multiple Uniprot ones are present. The info must be added during the build step
    order_ascending = [True, mane]
    seq_df_tr_available = seq_df_tr_available.sort_values(by=["Gene", "Reference_info"], ascending=order_ascending).drop_duplicates(subset="Gene")

    # If the same genes is associated to multiple structures, keep the one not obtained by Backtranseq (Reference_info 1 or 0)
    seq_df = pd.concat([seq_df_tr_missing, seq_df_tr_available]).sort_values(by=["Gene", "Reference_info"], ascending=[True, False])

    return seq_df.drop_duplicates(subset="Gene").reset_index(drop=True)


def get_hgvsp_mut(df_row):
    """
    Parse mutation entries to get HGVSp_Short format.
    """
    
    amino_acids = df_row["Amino_acids"]
    
    if pd.isna(amino_acids):
        return np.nan
    
    amino_acids = amino_acids.split("/")
    if len(amino_acids) > 1:
        return f"p.{amino_acids[0]}{df_row['Protein_position']}{amino_acids[1]}"
    
    return np.nan
                
                
def filter_transcripts(df, seq_df):
    """
    Filter VEP output by Oncodrive3D transcripts. For genes with NA 
    transcripts in the sequence dataframe, keep canonical ones.
    """
    
    if "CANONICAL" in df.columns and "Feature" in df.columns:
        
        # Genes without available transcript info in O3D built datasets
        df_tr_missing = df[df["Hugo_Symbol"].isin(seq_df.loc[seq_df["Reference_info"] == -1, "Gene"])]
        df_tr_missing = df_tr_missing[df_tr_missing["CANONICAL"] == "YES"]
        
        # Genes with transcript info
        df_tr_available = df[df["Feature"].isin(seq_df.loc[seq_df["Reference_info"] != -1, "Ens_Transcr_ID"])]
        
        return pd.concat((df_tr_available, df_tr_missing))
    
    else:
        print("Failed to filter input by O3D transcripts. Please provide as input the output of VEP with canonical and transcripts information: Exiting..")
        sys.exit(1)


def parse_vep_output(df,
                     seq_df=None, 
                     use_o3d_transcripts=False, 
                     use_input_symbols=False, 
                     mane=False):
    """
    Parse the dataframe in case it is the direct output of VEP without any 
    processing. Rename the columns to match the fields name of a MAF file, 
    and select the canonical transcripts if multiple ones are present.
    """

    df.rename(columns={"SYMBOL": "Hugo_Symbol",
                       "Consequence": "Variant_Classification"}, inplace=True)
            
    # Adapt HUGO_Symbol in seq_df to input file
    if seq_df is not None and use_input_symbols:
        print("Adapting Oncodrive3D HUGO Symbols of built datasets to input file..")
        seq_df = get_seq_df_input_symbols(df, seq_df, mane)
    
    # Transcripts filtering
    if use_o3d_transcripts and seq_df is not None:
        print("Filtering input by Oncodrive3D built transcripts..")
        df = filter_transcripts(df, seq_df)
    elif "CANONICAL" in df.columns:
        df = df[df["CANONICAL"] == "YES"]
            
    # Get HGVSp
    if "HGVSp_Short" not in df.columns and "Amino_acids" in df.columns and "Protein_position" in df.columns:
        df["HGVSp_Short"] = df.apply(get_hgvsp_mut, axis=1)
        
    return df, seq_df


def parse_mutations(maf):
    """
    Parse HGVSp_Short in maf.
    """
    
    # Ensure the required 'HGVSp_Short' column is present and not empty
    if 'HGVSp_Short' not in maf.columns or maf['HGVSp_Short'].isnull().all():
        print("Missing or empty 'HGVSp_Short' column in input MAF data.")
        sys.exit(1)

    # Parse the position, wild type, and mutation type from 'HGVSp_Short'
    maf.dropna(subset="HGVSp_Short", inplace=True)
    maf['Pos'] = maf['HGVSp_Short'].apply(lambda x: re.sub(r"\D", "", x)).astype(np.int32)
    maf['WT'] = maf['HGVSp_Short'].apply(lambda x: re.findall(r"\D", x)[2])
    maf['Mut'] = maf['HGVSp_Short'].apply(lambda x: re.findall(r"\D", x)[3])

    # Parse cols
    columns_to_keep = ['Hugo_Symbol', 'Pos', 'WT', 'Mut', 'Tumor_Sample_Barcode', 'Feature', 'Transcript_ID']
    columns_to_keep = [col for col in columns_to_keep if col in maf.columns]
    maf = maf[columns_to_keep]
    maf = maf.rename(columns={'Hugo_Symbol' : 'Gene', 'Feature': 'Transcript_ID'})

    return maf.sort_values(by=['Gene', 'Pos']).reset_index(drop=True)


def add_transcript_info(maf, seq_df):
    """
    Add transcript status information.
    """
    
    if 'Transcript_ID' not in maf.columns:
        maf['Transcript_ID'] = np.nan
    maf = maf.merge(seq_df[[col for col in ['Gene', 'Ens_Transcr_ID', 'Refseq_prot'] if col in seq_df.columns]].drop_duplicates(), 
                    on='Gene', how='left').rename(columns={"Ens_Transcr_ID" : "O3D_transcript_ID"})

    # Vectorized conditions for setting Transcript_status
    conditions = [
        maf['Transcript_ID'].isna(),
        maf['O3D_transcript_ID'].isna(),
        maf['Transcript_ID'] != maf['O3D_transcript_ID'],
        maf['Transcript_ID'] == maf['O3D_transcript_ID']
    ]
    choices = ['Input_missing', 'O3D_missing', 'Mismatch', 'Match']
    maf['Transcript_status'] = np.select(conditions, choices, default=np.nan)

    # Log transcript report
    transcript_report = maf['Transcript_status'].value_counts().reset_index(name='Count')
    transcript_report = ", ".join([f"{status} = {count}" for status, count in transcript_report.to_numpy()])
    print(f"Transcript status of {len(maf)} mutations: {transcript_report}")

    return maf


def read_input(input_path):
    """
    Read input file optimizing memory usage.
    """

    cols_to_read = ["Variant_Classification",
                    "Tumor_Sample_Barcode",
                    "Feature", 
                    "Transcript_ID",
                    "Consequence", 
                    "SYMBOL", 
                    "Hugo_Symbol",
                    "CANONICAL", 
                    "HGVSp_Short",
                    "Amino_acids", 
                    "Protein_position"]
    
    header = pd.read_table(input_path, nrows=0)
    cols_to_read = [col for col in cols_to_read if col in header.columns]
    dtype_mapping = {col : "object" for col in cols_to_read}
    dtype = {key: dtype_mapping[key] for key in cols_to_read if key in dtype_mapping}
    
    return pd.read_table(input_path, usecols=cols_to_read, dtype=dtype)


def parse_maf_input(input_path, seq_df=None, use_o3d_transcripts=False, use_input_symbols=False, mane=False):
    """
    Parsing and process MAF input data.
    """

    # Load, parse from VEP and update seq_df if needed
    print("Reading input mutations file..")
    maf = read_input(input_path)
    print(f"Processing [{len(maf)}] total mutations..")
    maf, seq_df = parse_vep_output(maf, seq_df, use_o3d_transcripts, use_input_symbols, mane)

    # Extract and parse missense mutations
    maf = maf[maf['Variant_Classification'].str.contains('Missense_Mutation|missense_variant')]
    if "Protein_position" in maf.columns:
        maf = maf[~maf['Protein_position'].astype(str).str.contains('-')] # Filter DBS
    print(f"Processing [{len(maf)}] missense mutations..")
    maf = parse_mutations(maf)                       
    
    # Add transcript status from seq_df
    if seq_df is not None:
        maf = add_transcript_info(maf, seq_df)
    
    return maf.reset_index(drop=True), seq_df


In [32]:
o3d_datasets_dir = "/workspace/nobackup/scratch/oncodrive3d/datasets_mane_240506"
seq_df = pd.read_csv(os.path.join(o3d_datasets_dir, "seq_for_mut_prob.tsv"), sep="\t")

In [41]:
import os

cohort = "TCGA_WXS_LUAD"
o3d_datasets_dir = "/workspace/nobackup/scratch/oncodrive3d/datasets_mane_240506"
seq_df = pd.read_csv(os.path.join(o3d_datasets_dir, "seq_for_mut_prob.tsv"), sep="\t")
input_path = f"/workspace/projects/clustering_3d/o3d_analysys/datasets/input/cancer_202404/vep/{cohort}.vep.tsv.gz"
data, seq_df = parse_maf_input(input_path, 
                                seq_df, 
                                use_o3d_transcripts=True,
                                use_input_symbols=True, 
                                mane=True)


Reading input mutations file..
Processing [1358452] total mutations..
Adapting Oncodrive3D HUGO Symbols of built datasets to input file..
Filtering input by Oncodrive3D built transcripts..
Processing [98457] missense mutations..
Transcript status of 98457 mutations: Match = 94233, O3D_missing = 4224


In [43]:
data

Unnamed: 0,Gene,Pos,WT,Mut,Transcript_ID,O3D_transcript_ID,Transcript_status
0,A1BG,9,L,F,ENST00000263100,ENST00000263100,Match
1,A1BG,31,L,M,ENST00000263100,ENST00000263100,Match
2,A1BG,81,F,L,ENST00000263100,ENST00000263100,Match
3,A1CF,85,M,K,ENST00000373997,ENST00000373997,Match
4,A1CF,111,N,T,ENST00000373997,ENST00000373997,Match
...,...,...,...,...,...,...,...
98452,ZZZ3,281,K,R,ENST00000370801,ENST00000370801,Match
98453,ZZZ3,389,R,I,ENST00000370801,ENST00000370801,Match
98454,ZZZ3,562,L,F,ENST00000370801,ENST00000370801,Match
98455,ZZZ3,591,V,I,ENST00000370801,ENST00000370801,Match


In [46]:
data = data[data["Gene"] == "TP53"].reset_index(drop=True)
data

Unnamed: 0,Gene,Pos,WT,Mut,Transcript_ID,O3D_transcript_ID,Transcript_status
0,TP53,57,D,N,ENST00000269305,ENST00000269305,Match
1,TP53,77,P,L,ENST00000269305,ENST00000269305,Match
2,TP53,105,G,D,ENST00000269305,ENST00000269305,Match
3,TP53,105,G,D,ENST00000269305,ENST00000269305,Match
4,TP53,105,G,C,ENST00000269305,ENST00000269305,Match
...,...,...,...,...,...,...,...
166,TP53,337,R,L,ENST00000269305,ENST00000269305,Match
167,TP53,337,R,P,ENST00000269305,ENST00000269305,Match
168,TP53,337,R,C,ENST00000269305,ENST00000269305,Match
169,TP53,337,R,C,ENST00000269305,ENST00000269305,Match


In [55]:
data_top1 = data.drop_duplicates(subset=["Gene", "Pos"], keep="first")
data_top1

Unnamed: 0,Gene,Pos,WT,Mut,Transcript_ID,O3D_transcript_ID,Transcript_status
0,TP53,57,D,N,ENST00000269305,ENST00000269305,Match
1,TP53,77,P,L,ENST00000269305,ENST00000269305,Match
2,TP53,105,G,D,ENST00000269305,ENST00000269305,Match
5,TP53,110,R,L,ENST00000269305,ENST00000269305,Match
8,TP53,120,K,E,ENST00000269305,ENST00000269305,Match
...,...,...,...,...,...,...,...
157,TP53,286,E,G,ENST00000269305,ENST00000269305,Match
162,TP53,331,Q,H,ENST00000269305,ENST00000269305,Match
163,TP53,334,G,V,ENST00000269305,ENST00000269305,Match
166,TP53,337,R,L,ENST00000269305,ENST00000269305,Match


In [54]:
data_top3 = data.groupby(["Gene", "Pos"], as_index=False).head(1)
data_top3

Unnamed: 0,Gene,Pos,WT,Mut,Transcript_ID,O3D_transcript_ID,Transcript_status
0,TP53,57,D,N,ENST00000269305,ENST00000269305,Match
1,TP53,77,P,L,ENST00000269305,ENST00000269305,Match
2,TP53,105,G,D,ENST00000269305,ENST00000269305,Match
5,TP53,110,R,L,ENST00000269305,ENST00000269305,Match
8,TP53,120,K,E,ENST00000269305,ENST00000269305,Match
...,...,...,...,...,...,...,...
157,TP53,286,E,G,ENST00000269305,ENST00000269305,Match
162,TP53,331,Q,H,ENST00000269305,ENST00000269305,Match
163,TP53,334,G,V,ENST00000269305,ENST00000269305,Match
166,TP53,337,R,L,ENST00000269305,ENST00000269305,Match


In [45]:
data_top5 = data.groupby(["Gene", "Pos"], as_index=False).head(5)
data_top5

Unnamed: 0,Gene,Pos,WT,Mut,Transcript_ID,O3D_transcript_ID,Transcript_status
0,A1BG,9,L,F,ENST00000263100,ENST00000263100,Match
1,A1BG,31,L,M,ENST00000263100,ENST00000263100,Match
2,A1BG,81,F,L,ENST00000263100,ENST00000263100,Match
3,A1CF,85,M,K,ENST00000373997,ENST00000373997,Match
4,A1CF,111,N,T,ENST00000373997,ENST00000373997,Match
...,...,...,...,...,...,...,...
98452,ZZZ3,281,K,R,ENST00000370801,ENST00000370801,Match
98453,ZZZ3,389,R,I,ENST00000370801,ENST00000370801,Match
98454,ZZZ3,562,L,F,ENST00000370801,ENST00000370801,Match
98455,ZZZ3,591,V,I,ENST00000370801,ENST00000370801,Match


### Mut profile

In [38]:
from itertools import product

import daiquiri
import numpy as np
import json
from tqdm import tqdm


def get_unif_gene_miss_prob(size):
    """
    Get a uniformly distributed gene missense mutation 
    probability vector.
    """
    
    vector = np.ones(size)
    vector[0] = 0
    
    return vector / sum(vector)


def mut_rate_vec_to_dict(mut_rate):
    """
    Convert the vector of mut mut_rate of 96 channels to a dictionary of 192 
    items: the keys are mutations in trinucleotide context (e.g., "ACA>A") 
    and values are the corresponding mut rate (frequency of mut normalized 
    for the nucleotide content).
    """
    
    cb  = dict(zip('ACGT', 'TGCA'))
    mut_rate_dict = {}
    i = 0
    for ref in ['C', 'T']:
        for alt in cb.keys():
            if ref == alt:
                continue
            else:
                for p in product(cb.keys(), repeat=2):
                    mut = f"{p[0]}{ref}{p[1]}>{alt}"
                    cmut = f"{cb[p[1]]}{cb[ref]}{cb[p[0]]}>{cb[alt]}"
                    mut_rate_dict[mut] = mut_rate[i]
                    mut_rate_dict[cmut] = mut_rate[i]
                    i +=1
                    
    return mut_rate_dict


def get_codons(dna_seq):
    """
    Get the list of codons from a DNA sequence.
    """
    
    return [dna_seq[i:i+3] for i in [n*3 for n in range(int(len(dna_seq) / 3))]]


def translate_dna_to_prot(dna_seq, gencode):
    """
    Translate a DNA sequence into amino acid sequence.
    """
    
    return "".join([gencode[codon] for codon in get_codons(dna_seq)])


def codons_trinucleotide_context(lst_contexts):
    
    return list(zip(lst_contexts[::3], lst_contexts[1::3], lst_contexts[2::3]))


# TODO: doc function

def get_miss_mut_prob(dna_seq, 
                      dna_tricontext, 
                      mut_rate_dict, 
                      mutability=False, 
                      get_probability=True, 
                      mut_start_codon=False):
    """
    Generate a list including the probabilities that the 
    codons can mutate resulting into a missense mutations.
    
    Arguments
    ---------
    dna_seq: str
        Sequence of DNA
    mut_rate_dict: dict
        Mutation rate probability as values and the 96 possible
        trinucleotide contexts as keys
    gencode: dict
        Nucleotide as values and codons as keys
        
    Returns
    -------
    missense_prob_vec: list
        List of probabilities (one for each codon or prot res) 
        of a missense mutation  
    """

    # Initialize
    gencode = {
        'ATA':'I', 'ATC':'I', 'ATT':'I', 'ATG':'M',
        'ACA':'T', 'ACC':'T', 'ACG':'T', 'ACT':'T',
        'AAC':'N', 'AAT':'N', 'AAA':'K', 'AAG':'K',
        'AGC':'S', 'AGT':'S', 'AGA':'R', 'AGG':'R',
        'CTA':'L', 'CTC':'L', 'CTG':'L', 'CTT':'L',
        'CCA':'P', 'CCC':'P', 'CCG':'P', 'CCT':'P',
        'CAC':'H', 'CAT':'H', 'CAA':'Q', 'CAG':'Q',
        'CGA':'R', 'CGC':'R', 'CGG':'R', 'CGT':'R',
        'GTA':'V', 'GTC':'V', 'GTG':'V', 'GTT':'V',
        'GCA':'A', 'GCC':'A', 'GCG':'A', 'GCT':'A',
        'GAC':'D', 'GAT':'D', 'GAA':'E', 'GAG':'E',
        'GGA':'G', 'GGC':'G', 'GGG':'G', 'GGT':'G',
        'TCA':'S', 'TCC':'S', 'TCG':'S', 'TCT':'S',
        'TTC':'F', 'TTT':'F', 'TTA':'L', 'TTG':'L',
        'TAC':'Y', 'TAT':'Y', 'TAA':'_', 'TAG':'_',
        'TGC':'C', 'TGT':'C', 'TGA':'_', 'TGG':'W'}

    # Get all codons of the seq
    #logger.debug("Getting codons of seq..")
    codons = get_codons(dna_seq)
    missense_prob_vec = []
    
    # Get the trinucleotide context as list of tuples of 3 elements corresponding to each codon   
    #logger.debug("Getting tri context seq..")                               
    tricontext = codons_trinucleotide_context(dna_tricontext.split(","))
    
    # Iterate through codons and get prob of missense based on context
    for c in range(len(codons)):
        missense_prob = 0
        codon = codons[c]
        aa = gencode[codon]         
        trinucl0, trinucl1, trinucl2  = tricontext[c]

        # Iterate through the possible contexts of a missense mut
        for i, trinucl in enumerate([trinucl0, trinucl1, trinucl2]):
            ref = trinucl[1]
            aa = gencode[codon]

            # Iterate through the possible alt 
            for alt in "ACGT":
                if alt != ref:         
                    alt_codon = [n for n in codon]
                    alt_codon[i] = alt
                    alt_codon = "".join(alt_codon)
                    alt_aa = gencode[alt_codon]  
                    # If there is a missense mut, get prob from context and sum it
                    if alt_aa != aa and alt_aa != "_":
                        if not mutability:
                            mut = f"{trinucl}>{alt}"    # query using only the trinucleotide change
                            mut_prob = mut_rate_dict[mut] if mut in mut_rate_dict else 0                                                                 
                            missense_prob += mut_prob

                        else:
                            cdna_pos = (c * 3) + i  # compute the cDNA position of the residue
                            if cdna_pos in mut_rate_dict:
                                missense_prob += mut_rate_dict[cdna_pos].get(alt, 0)
                            else:
                                missense_prob += 0

        missense_prob_vec.append(missense_prob)

    # Assign 0 prob to the first residue
    if mut_start_codon == False:
        missense_prob_vec[0] = 0
        
    # Convert into probabilities
    if get_probability:
        missense_prob_vec = np.array(missense_prob_vec) / sum(missense_prob_vec)
    
    return list(missense_prob_vec)


def get_miss_mut_prob_dict(mut_rate_dict, seq_df, mutability=False, mutability_config=None):
    """
    Given a dictionary of mut rate in 96 contexts (mut profile) and a 
    dataframe including Uniprot ID, HUGO symbol and DNA sequences, 
    get a dictionary with UniprotID-Fragment as keys and corresponding 
    vectors of missense mutation probabilities as values.
    """

    miss_prob_dict = {}

    if mutability:
        # TODO if the execution time of this step is too long we could
        # parallelize all these loops so that each gene is done in parallel

        # Process any Protein/fragment in the sequence df
        for _, row in seq_df.iterrows():
            # Mutabilities
            mutability_dict = Mutabilities(row.Uniprot_ID, row.Chr, row.Exons_coord, len(row.Seq_dna), row.Reverse_strand, mutability_config).mutabilities_by_pos
            miss_prob_dict[f"{row.Uniprot_ID}-F{row.F}"] = get_miss_mut_prob(row.Seq_dna, row.Tri_context, mutability_dict, mutability=True)

    else:
        # Process any Protein/fragment in the sequence df
        for _, row in seq_df.iterrows():
            miss_prob_dict[f"{row.Uniprot_ID}-F{row.F}"] = get_miss_mut_prob(row.Seq_dna, row.Tri_context, mut_rate_dict)
    
    return miss_prob_dict


def get_cmap(cmap_path, uniprot_id, af_f, cmap_prob_thr=0.5):
    
    cmap_complete_path = f"{cmap_path}/{uniprot_id}-F{af_f}.npy"
    cmap = np.load(cmap_complete_path) 
    cmap = cmap > cmap_prob_thr
    cmap = cmap.astype(int)
    
    return cmap


def get_vol_miss_prob(gene, cmap_path, miss_prob_dict, seq_df):
    
    seq_df_gene = seq_df[seq_df["Gene"] == gene]
    uniprot_id = seq_df_gene['Uniprot_ID'].values[0]
    af_f = seq_df_gene['F'].values[0]
    
    cmap = get_cmap(cmap_path, uniprot_id, af_f)
    gene_miss_prob = np.array(miss_prob_dict[f"{uniprot_id}-F{af_f}"])
    vol_missense_mut_prob = np.dot(cmap, gene_miss_prob)
    
    return vol_missense_mut_prob
    

cmap_path = "/workspace/nobackup/scratch/oncodrive3d/datasets_mane_240506/prob_cmaps"
mut_profile_path = f"/workspace/projects/clustering_3d/o3d_analysys/datasets/input/cancer_202404/mut_profile/{cohort}.sig.json"
mut_profile = json.load(open(mut_profile_path, encoding="utf-8"))
miss_prob_dict = get_miss_mut_prob_dict(mut_rate_dict=mut_profile, seq_df=seq_df)

vol_miss_prob = get_vol_miss_prob("TP53", cmap_path, miss_prob_dict, seq_df)

### Simulations

__TODO:__
* Try once with the same simulation (allowing multiple mutations in the same position)
* Try other iterations by also tweaking the simulations such that we will not have more than n mutations in the same position

In [None]:
import numpy as np

def simulate_unique_mutations(n_mutations, p, size, seed=None):
    """
    Simulate mutations such that no residue is mutated more than once per sample.
    
    Parameters:
    - n_mutations: number of mutations per sample
    - p: probability distribution over residues (must sum to 1)
    - size: number of samples to generate
    - seed: optional random seed
    
    Returns:
    - samples: array of shape (size, n_mutations) with selected residue indices
    """
    rng = np.random.default_rng(seed)
    n_residues = len(p)
    
    if n_mutations > n_residues:
        raise ValueError("n_mutations cannot exceed the number of residues (len(p)) when sampling without replacement.")
    
    samples = np.empty((size, n_mutations), dtype=int)
    for i in range(size):
        samples[i] = rng.choice(n_residues, size=n_mutations, replace=False, p=p)
    
    return samples

In [40]:
def simulate_mutations(n_mutations, p, size, seed=None):
    """
    Simulate the mutations given the mutation rate of a cohort.
    """

    rng = np.random.default_rng(seed=seed)
    samples = rng.multinomial(n_mutations, p, size=size)
    
    return samples

In [None]:
simulate_mutations()