<a href="https://colab.research.google.com/github/gileshall/ML-Biology-Notebooks/blob/main/Taxonomy_with_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cat or Dog ... DNA?

Kaggle first announced their [dog-vs-cat](https://www.kaggle.com/c/dogs-vs-cats) competition in 2013, asking participants to develop a computer vision model that would classify images of cats versus dogs.  Since its debut nine years ago, there are now thousands of solutions available, some which achieve near perfect accuracy.

The first build of the human genome was published in 2003, and since then, there are now thousands of published genomes for different species from around the world, including dogs and cats.  If we can train a model to distinguish between the image of a dog versus a cat, can we also train a model to classify between dog DNA and cat DNA? Or, between any number of species?

In [None]:
!pip install -q -U transformers biopython datasets pyfaidx tqdm colored wandb joblib


In [None]:
#@title Training Parameters

#@markdown Comma seperated list of organisms to classify
organisms = 'cat, dog' #@param {type:"string"}
organisms = sorted(set(map(str.strip, organisms.split(','))))

#@markdown How many training examples to sample from each genome?
training_samples_per_genome =  10000#@param {type:"integer"}

#@markdown How many evaluation examples to sample from each genome?
eval_samples_per_genome =  1000#@param {type:"integer"}

#@markdown Exclude low-complexity regions from the training set?
#@markdown - none: do not exclude any region
#@markdown - masked: exclude soft-masked regions (train exclusively on upper case bases)
#@markdown - unmasked: include ONLY soft-masked regions (train exclusively on lower case bases)
exclude_region = 'unmasked' #@param ['masked', 'unmasked', 'none']

#@markdown DNABERT kmer length?
klen = 6#@param [3, 4, 5, 6]

#@markdown Do you have a weights and biases account?
#@markdown - Yes: you will be prompted for your wandb API key before training starts.
#@markdown - No: wandb will automatically post your results anonymous
wandb_account_flag = False #@param {type:"boolean"}
wandb_anon = 'must' if not wandb_account_flag else 'never'

import io
import re
import os
import gzip
import json
import random
from pprint import pprint
from datetime import datetime
from uuid import uuid4
from multiprocessing import cpu_count

import wandb
import datasets
import numpy as np
from colored import fg, bg, attr, set_tty_aware
import pandas as pd
import requests
from joblib import Memory
from tqdm import tqdm
from Bio import Entrez
from pyfaidx import Fasta
from transformers import BertForSequenceClassification, BertTokenizer
from transformers import TrainingArguments, Trainer
from transformers.utils.logging import set_verbosity_error, set_verbosity_info
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

memory = Memory('cache', verbose=False)

project_name = "taxonomy_with_transformers"
total_samples = (training_samples_per_genome + eval_samples_per_genome) * len(organisms)
run_name = f'{str.join("_", organisms)}_{total_samples // 1000}K_exclude_{exclude_region}'
path_output = f'{run_name}/output'
path_dataset = f'{run_name}/dataset'
path_genomes = f'genomes'

Entrez.email = f'{str(uuid4())}@example.com'

os.environ['FORCE_COLOR'] = str(3)
GRN = fg('green')
RED = fg('red')
BLU = fg('blue')
RST = attr('reset')
N_PROC = cpu_count()

def load_model(klen=6, id2label=None):
    set_verbosity_error()
    msg = f'{GRN}Loading DNABERT (klen={klen}, id2label={id2label}){RST}'
    print(msg)
    assert klen >= 3 and klen <= 6
    model_path = f"armheb/DNA_bert_{klen}"
    label2id = dict(zip(id2label.values(), id2label.keys()))
    model = BertForSequenceClassification.from_pretrained(
        model_path, num_labels=len(id2label),
        id2label=id2label, label2id=label2id
    )
    tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False)
    return (model, tokenizer)

@memory.cache
def esearch(block_size=1000, limit=None, **kw):
    req = Entrez.esearch(**kw, retmax=1)
    res = Entrez.read(req)
    total = res['Count'] = int(res['Count'])
    if limit:
        total = min(total, limit)
    block_size = min(block_size, total)
    while len(res['IdList']) < total:
        pos = len(res['IdList'])
        assert len(res['IdList']) == len(set(res['IdList']))
        req = Entrez.esearch(**kw, retstart=pos, retmax=block_size)
        res['IdList'] += Entrez.read(req)['IdList']
    return res

@memory.cache
def esummary(**kw):
    req = Entrez.esummary(**kw)
    res = Entrez.read(req)
    return res

def search_for_genome(query=None):
    re_acc = re.compile('^(G..)_(\d\d\d)(\d\d\d)(\d\d\d).*$')
    url_base = 'https://ftp.ncbi.nlm.nih.gov/genomes/all/'

    idlist = None
    for query_type in ('Organism', 'Accession'):
        term = f'{query}[{query_type}]'
        res = esearch(db='assembly', term=term)
        if res['Count'] > 0:
            idlist = res['IdList']
            break
    if idlist is None:
        return None
    asm_list = []
    for asm_id in idlist:
        ret = esummary(db='assembly', id=asm_id)
        asm_info = {'id': asm_id, 'query': query}
        summary = ret['DocumentSummarySet']['DocumentSummary'][0]
        sub_date = summary.get('SubmissionDate')
        asm_info['submission_date'] = datetime.strptime(sub_date, "%Y/%m/%d %H:%M")
        asm_info['accession'] = summary.get('AssemblyAccession')
        asm_info['name'] = summary.get('AssemblyName')
        asm_info['status'] = summary.get('AssemblyStatus')
        asm_info['organism'] = summary.get('Organism')
        asm_info['taxid'] = summary.get('SpeciesTaxid')
        asm_info['partial'] = (str(summary.get('PartialGenomeRepresentation', 'true').lower()) == 'true')
        if asm_info['partial']:
            continue
        if asm_info['status'] != 'Chromosome':
            continue
        build_key = f'{asm_info["accession"]}_{asm_info["name"]}'
        asm_info['build_key'] = build_key
        ftp_url = summary.get('FtpPath_GenBank')
        url = ftp_url.replace('ftp://', 'https://')
        url = '/'.join([url, f'{build_key}_genomic.fna.gz'])
        asm_info['url'] = url
        asm_list.append(asm_info)
    asm_list = sorted(asm_list, key=lambda it: it['submission_date'])
    asm_pick = asm_list[-1]
    return asm_pick

def find_all_genomes(genomes, path_genomes=None):
    os.makedirs(path_genomes, exist_ok=True)
    msg = f'{GRN}Searching for genomes:{RST}'
    print(msg)

    ret = {}
    for (label_id, name) in enumerate(genomes):
        info = search_for_genome(name)
        if info is None:
            msg = f"{RED}Genome not found for '{name}'{RST}"
            raise KeyError(msg)
        msg = '  - {query} ({organism}), Assembly accession {accession}'.format(**info)
        print(msg)
        ret[name] = {'info': info, 'label_id': label_id, 'label_name': str(info['organism'])}
    return ret

def load_genome(info=None, path_genomes=None):
    os.makedirs(path_genomes, exist_ok=True)
    genome_fn = '{build_key}_genome.fa'.format(**info)
    genome_fn = os.path.join(path_genomes, genome_fn)
    if not os.path.exists(genome_fn):
        msg = "  - Downloading '{query}' genome ({organism})".format(**info)
        print(msg)
        url = info['url']
        with requests.get(url, stream=True) as resp:
            fa_str = gzip.decompress(resp.content)
        with open(genome_fn, 'wb') as fh:
            fh.write(fa_str)
    msg = "  - Loading '{query}' genome ({organism})".format(**info)
    print(msg)
    return Fasta(genome_fn)

def sample_genome(genome, samples_per_genome=None, sample_len=None, exclude_region=False, sequence_prefix=('CM',)):
    # NB: set math stipulates equality, not subset.  In other words, no matter
    # the exclusion, the example must contain ALL bases
    if exclude_region == 'masked':
        msg_exclude = 'excluding masked regions'
        ok_nucs = set('AGTC')
        f_func = lambda it: set(str(seq)) == ok_nucs
    elif exclude_region == 'unmasked':
        msg_exclude = 'excluding unmasked regions'
        ok_nucs = set('agtc')
        f_func = lambda it: set(str(seq)) == ok_nucs
    elif exclude_region == 'none':
        msg_exclude = ''
        ok_nucs = set('AGTC')
        f_func = lambda it: set(str(seq).upper()) == ok_nucs
    else:
        raise ValueError(exclude_region)
    chr_filter = lambda nm: any([nm.startswith(pre) for pre in sequence_prefix])
    (seq_names, seq_lens) = zip(*[(rec.name, len(rec)) for rec in genome if chr_filter(rec.name)])

    seq_total_len = sum(seq_lens)
    sample_total_len = samples_per_genome * sample_len
    genome_sample_percent = sample_total_len / seq_total_len * 100
    msg = f'  - Sampling {sample_total_len / 1e6:.02f} megabases from {seq_total_len / 1e6:.02f} megabase genome ({genome_sample_percent:.02f}%) {msg_exclude}'
    print(msg)
    if seq_total_len == 0:
        raise ValueError("Empty genome")
    seq_weights = [slen / seq_total_len for slen in seq_lens]
    seq_map = {seq_name: str(genome[seq_name]) for seq_name in seq_names}
    for cnt in tqdm(range(samples_per_genome), total=samples_per_genome):
        while True:
            seq_name = random.choices(seq_names, seq_weights)[0]
            seq = seq_map[seq_name]
            if sample_len > (len(seq) - sample_len):
                continue
            start = random.randint(0, len(seq) - sample_len)
            end = start + sample_len
            seq = seq[start:end]
            if f_func(seq):
                name = f'{seq_name}:{start}-{end}'
                break
        yield {'seq': str(seq), 'name': name}

def kmerize(seq=None, klen=6):
    kmers = [seq[i:i + klen] for i in range(len(seq) - klen + 1)]
    return kmers

def tokenize(it, klen=None, kmerize=None, tokenizer=None):
    seq = it['seqs'].upper()
    vocab = tokenizer.vocab
    max_length = tokenizer.model_max_length
    kmers = kmerize(seq=seq, klen=klen)
    inp = ['[CLS]'] + kmers + ['[SEP]']
    toks = list(map(vocab.__getitem__, inp))
    toks += [0] * (max_length - len(toks))
    ret = {
        "input_ids": np.array(toks, dtype=int),
        "attention_mask": np.array(toks != 0, dtype=int),
        "token_type_ids": np.zeros_like(toks, dtype=int),
    }
    return ret

def build_dataset(
        all_genomes=None,
        training_samples_per_genome=None,
        eval_samples_per_genome=None,
        sample_len=None,
        klen=None,
        tokenizer=None,
        kmerize=None,
        exclude_region=False,
        path_genomes=None
    ):

    total_samples_per_genome = training_samples_per_genome + eval_samples_per_genome
    sample_len = sample_len or (klen + tokenizer.model_max_length - 3)

    ds_list = []
    print(f'\n{GRN}Building training dataset:{RST}')
    for key in all_genomes:
        label_id = all_genomes[key]['label_id']
        label_name = all_genomes[key]['label_name']
        genome_info = all_genomes[key]['info']
        genome = load_genome(genome_info, path_genomes=path_genomes)
        # msg = "  - Randomly sampling '{query}' genome ({organism})".format(**genome_info)
        # print(msg)
        samples = sample_genome(
            genome=genome,
            sample_len=sample_len,
            samples_per_genome=total_samples_per_genome,
            exclude_region=exclude_region
        )
        samples = list(samples)
        names = [smp['name'] for smp in samples]
        seqs = [smp['seq'] for smp in samples]
        labels = [label_id] * len(seqs)

        dataset = datasets.Dataset.from_dict(dict(seqs=seqs, names=names, labels=labels))
        dataset = dataset.train_test_split(train_size=training_samples_per_genome)
        ds_list.append(dataset)
        print()

    print(f'{GRN}Tokenizing, splitting and shuffling datasets{RST}')
    ds_dict = {}
    for key in ('train', 'test'):
        print(f'  - split={key}')
        _ds_list = [ds[key] for ds in ds_list]
        dataset = datasets.concatenate_datasets(_ds_list).shuffle()
        dataset = dataset.map(
            tokenize,
            fn_kwargs={'kmerize': kmerize, 'klen': klen, 'tokenizer': tokenizer},
            new_fingerprint=f'{klen}_tokenizer',
            num_proc=N_PROC,
        )
        ds_dict[key] = dataset
    dataset = datasets.DatasetDict(ds_dict)
    return dataset

def build_experiment(
        organisms=None,
        klen=None,
        training_samples_per_genome=None,
        eval_samples_per_genome=None,
        path_dataset=None,
        **kw,
    ):

    all_genomes = find_all_genomes(organisms, path_genomes=path_genomes)
    print()
    id2label = {idx: all_genomes[org]['label_name'] for (idx, org) in enumerate(organisms)}
    (model, tokenizer) = load_model(klen=klen, id2label=id2label)
    if os.path.isdir(path_dataset):
        msg = f'\n{GRN}Loading dataset from {path_dataset}{RST}'
        print(msg)
        dataset = datasets.load_from_disk(path_dataset)
    else:
        dataset = build_dataset(
            all_genomes=all_genomes,
            training_samples_per_genome=training_samples_per_genome,
            eval_samples_per_genome=eval_samples_per_genome,
            klen=klen,
            tokenizer=tokenizer,
            **kw
        )
        dataset.save_to_disk(path_dataset)

    print(f'\n{BLU}{dataset}{RST}\n')
    return (model, tokenizer, dataset)


# Train the classification task

In [None]:
epochs = 4
warmup_steps = 500
train_batch_size = 16
eval_batch_size = 4 * train_batch_size

(model, tokenizer, dataset) = build_experiment(
    organisms=organisms,
    klen=klen,
    training_samples_per_genome=training_samples_per_genome,
    eval_samples_per_genome=eval_samples_per_genome,
    exclude_region=exclude_region,
    kmerize=kmerize,
    path_dataset=path_dataset,
    path_genomes=path_genomes
)
label_names = [
    model.config.id2label[idx] for idx in range(len(model.config.id2label))
]

training_args = TrainingArguments(
    output_dir=path_output,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='steps',
    logging_first_step=True,
    logging_steps=100,
    report_to='wandb',
    warmup_steps=warmup_steps,
    num_train_epochs=epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    disable_tqdm=False,
)

def compute_confusion_matrix(truth=None, preds=None, label_names=None):
    imgfn = "confusion_matrix.png"
    cm = confusion_matrix(truth, preds)
    cm = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=label_names)
    cm = cm.plot()
    cm.figure_.savefig(imgfn, bbox_inches='tight', pad_inches=0, dpi=120)
    log = {"eval/confusion_matrix": wandb.Image(imgfn)}
    wandb.log(log)

def compute_classification_report(truth=None, preds=None, label_names=None):
    rpt = classification_report(truth, preds, target_names=label_names, zero_division=0, output_dict=True)
    res = {}
    for top_cat in rpt:
        item = rpt[top_cat]
        if type(item) is dict:
            for bot_cat in item:
                key = f'{top_cat}_{bot_cat}'
                res[key] = item[bot_cat]
        else:
            res[top_cat] = item
    for key in list(res.keys()):
        if 'weight' in key:
            del res[key]
    return res

def compute_metrics(pred):
    truth = pred.label_ids
    preds = pred.predictions.argmax(-1)
    log = dict()
    calls = [compute_confusion_matrix,  compute_classification_report]
    for call in calls:
        ret = call(truth, preds, label_names=label_names)
        if type(ret) is dict:
            log.update(ret)
    return ret

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    compute_metrics=compute_metrics,
)

with wandb.init(project=project_name, name=run_name, anonymous=wandb_anon):
    trainer.evaluate()
    training_output = trainer.train()