# 210818 Create query seqs

In [1]:
from pathlib import Path
import json
from itertools import count

In [2]:
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import numpy as np
import pandas as pd
import h5py as h5
from tqdm import tqdm

In [3]:
import matplotlib.pyplot as plt
import seaborn as sns

In [4]:
from gambit.kmers import KmerSpec, find_kmers, dense_to_sparse
from gambit.metric import jaccard_sparse_array
from gambit.db import load_database
from gambit.classify import classify
from gambit.signatures import SignatureArray, SignaturesMeta
from gambit.signatures.hdf5 import HDF5Signatures

In [5]:
DATESTR = '210818'
DATESTR_LONG = '2021-08-18'

DBNAME = 'testdb_' + DATESTR

## File paths

In [6]:
outdir = Path('output')
tmpdir = Path('tmp')

In [7]:
infiles = dict(
    params=tmpdir / 'params.json',
    taxa=tmpdir / 'taxa.csv',
    ref_genomes=outdir / 'testdb_210818-genomes.db',
    ref_sigs=outdir / 'testdb_210818-signatures.h5',
    taxon_centers=tmpdir / 'taxon-centers.fasta',
)

## Load data

In [8]:
with open(infiles['params']) as f:
    _params = json.load(f)

MUTATION_COUNTS = _params['n_mutations']

### Taxa

In [9]:
taxa_df = pd.read_csv(infiles['taxa'])

In [10]:
name_to_index = {name: i for i, name in enumerate(taxa_df['name'])}

### Taxon centers

In [11]:
taxon_centers = dict()

for record in SeqIO.parse(str(infiles['taxon_centers']), 'fasta'):
    taxon_centers[record.id] = bytes(record.seq)

### Database

In [12]:
db = load_database(infiles['ref_genomes'], infiles['ref_sigs'])

In [13]:
ref_sigs = db.signatures[db.sig_indices]

In [14]:
kspec = db.signatures.kmerspec

## Funcs

In [15]:
NUC_BYTES = np.fromiter(map(ord, 'ACGT'), dtype=np.uint8)

In [16]:
def mutate_seq(a, n):
    """Randomly mutate sequence array in-place at n positions."""
    for i in np.random.choice(range(len(a)), n, replace=False):
        a[i] = np.random.choice(NUC_BYTES)

In [17]:
def make_mutations(a, m, n):
    """Make m mutated versions of an array a, each with n mutated positions."""
    for i in range(m):
        ma = bytearray(a)
        mutate_seq(ma, n)
        yield ma

In [18]:
def find_kmers_multi(kspec, arrs):
    out = None
    for arr in arrs:
        out = find_kmers(kspec, arr, dense_out=out, sparse=False)
    return dense_to_sparse(out)

In [19]:
def split_seq(a, n, minlen):
    total = len(a)
    rem = total - n * minlen
    assert rem > 0
    
    rem_split = np.sort(np.random.choice(rem + 1, n - 1))
    rem_split = [0, *rem_split, rem]
    lengths = np.diff(rem_split) + minlen
    assert len(lengths) == n
    assert all(lengths >= minlen)
    assert sum(lengths) == total
    
    parts = []
    start = 0
    for l in lengths:
        parts.append(a[start:start+l])
        start += l
        
    assert start == total
    return parts

## Make query seqs

### Standard

In [20]:
N_PARTS_RANGE = (5, 11)
MIN_PART_LEN = 100

In [21]:
query_seqs = []
query_sigs = []
_rows = []


np.random.seed(123)

for taxon in tqdm(taxa_df.itertuples()):
    if pd.isnull(taxon.threshold):
        continue

    center = taxon_centers[taxon.name]
    
    while True:
        seq = bytearray(center)
        mutate_seq(seq, MUTATION_COUNTS[taxon.level])
        
        n_parts = np.random.randint(*N_PARTS_RANGE)
        parts = split_seq(seq, n_parts, MIN_PART_LEN)
        
        sig = find_kmers_multi(kspec, parts)
        dists = jaccard_sparse_array(sig, ref_sigs, distance=True)
        result = classify(db.genomes, dists, strict=True)
        
        if result.warnings or result.error or not result.success:
            continue
            
        # Closest is unique
        ds = np.sort(dists)
        if ds[0] == ds[1]:
            continue
        
        if taxon.name == 'root':
            if result.predicted_taxon is None:
                break
        else:
            if result.predicted_taxon is not None and result.predicted_taxon.name == taxon.name:
                break

    query_seqs.append(parts)
    query_sigs.append(sig)
    
    if taxon.name == 'root':
        _rows.append((
            'unclassifiable',
            None,
            None,
            result.closest_match.genome.description,
        ))
    else:
        _rows.append((
            taxon.name,
            taxon.name,
            result.primary_match.genome.description,
            result.closest_match.genome.description,
        ))
    
    
queries_df = pd.DataFrame.from_records(_rows, columns=['name', 'predicted', 'primary', 'closest'])

49it [00:00, 82.28it/s]


In [22]:
queries_df['warnings'] = False

### Inconsistent

In [23]:
target1 = 'A2_B1'
thresh1 = taxa_df.loc[name_to_index[target1], 'threshold']

target2 = 'A2_B2'
thresh2 = taxa_df.loc[name_to_index[target2], 'threshold']

In [24]:
in_target1 = [i for i, g in enumerate(db.genomes) if g.taxon is not None and g.taxon.name.startswith(target1)]
target1_sigs = ref_sigs[in_target1]
len(in_target1)

23

In [25]:
in_target2 = [i for i, g in enumerate(db.genomes) if g.taxon is not None and g.taxon.name.startswith(target2)]
target2_sigs = ref_sigs[in_target2]
len(in_target2)

23

In [26]:
np.random.seed(0)

for i in tqdm(count()):
    seq = bytearray(taxon_centers[target2])
    mutate_seq(seq, MUTATION_COUNTS[3])
    
    n_parts = np.random.randint(*N_PARTS_RANGE)
    parts = split_seq(seq, n_parts, MIN_PART_LEN)
    
    sig = find_kmers_multi(kspec, parts)
    
    d1 = jaccard_sparse_array(sig, target1_sigs, distance=True).min()
    d2 = jaccard_sparse_array(sig, target2_sigs, distance=True).min()
    
    if d1 < thresh1 and d2 < thresh2:
        break

354it [00:01, 282.34it/s]


In [27]:
dists = jaccard_sparse_array(sig, ref_sigs, distance=True)
result = classify(db.genomes, dists, strict=True)

assert result.success
assert result.predicted_taxon.name == 'A2'
assert result.warnings

In [28]:
query_seqs.append(parts)
query_sigs.append(sig)

row = dict(
    name='inconsistent',
    predicted='A2',
    primary=result.primary_match.genome.description,
    closest=result.closest_match.genome.description,
    warnings=True,
)

queries_df = queries_df.append(row, ignore_index=True)

### Primary match not closest

In [29]:
_base = taxon_centers[target1][:2000] + taxon_centers[target2][2000:]

np.random.seed(0)

for i in tqdm(count()):
    seq = bytearray(_base)
    mutate_seq(seq, 40)
    
    n_parts = np.random.randint(*N_PARTS_RANGE)
    parts = split_seq(seq, n_parts, MIN_PART_LEN)
    
    sig = find_kmers_multi(kspec, parts)
    
    d1 = jaccard_sparse_array(sig, target1_sigs, distance=True).min()
    d2 = jaccard_sparse_array(sig, target2_sigs, distance=True).min()
    
    if thresh2 < d2 < d1 < thresh1:
        break

0it [00:00, ?it/s]


In [30]:
dists = jaccard_sparse_array(sig, ref_sigs, distance=True)
result = classify(db.genomes, dists, strict=True)

assert result.success
assert result.predicted_taxon.name == target1
assert result.closest_match.genome.taxon.name.startswith(target2)
assert result.warnings

In [31]:
query_seqs.append(parts)
query_sigs.append(sig)

row = dict(
    name='primary_not_closest',
    predicted=target1,
    primary=result.primary_match.genome.description,
    closest=result.closest_match.genome.description,
    warnings=True,
)

queries_df = queries_df.append(row, ignore_index=True)

## Write output

### Table

In [32]:
queries_df.to_csv(tmpdir / 'query-seqs.csv', index=False)

### Sequences

In [33]:
seqs_dir = tmpdir / 'query-seqs'
seqs_dir.mkdir(exist_ok=True)

In [34]:
for parts, row in zip(query_seqs, queries_df.itertuples()):
    records = [
        SeqRecord(Seq(part), id=f'{row.name}-{i + 1}', description='')
        for i, part in enumerate(parts)
    ]
    
    SeqIO.write(records, str(seqs_dir / (row.name + '.fasta')), 'fasta')

### Signatures

In [35]:
meta = SignaturesMeta(
    name=f'{DBNAME}-queries',
    description=f'Signatures of {DBNAME} query sequences',
    extra=dict(
        date_created=DATESTR_LONG,
        author='Jared Lumpe',
    ),
)

In [36]:
sigarray = SignatureArray(query_sigs, dtype=kspec.coords_dtype)

In [37]:
with h5.File(tmpdir / f'query-signatures.h5', 'w') as f:
    HDF5Signatures.create(f, kspec, sigarray, queries_df['name'], meta)