# Human vs Pig Classification

The distribution of k-mers in a genome can act as a _signature_ for a given species. In this notebook, we will exploit this fact by training a neural network to distinguish between genome extracts from a human and pig genome. We will use `torchmers` to encode DNA sequences into their k-mer spectra and feed the counts to a [multilayer perceptron](https://en.wikipedia.org/wiki/Multilayer_perceptron).

## Download Data

We will use the first chromosome from a human and pig reference genome as training and test data:

In [1]:
!curl https://ftp.ensembl.org/pub/release-110/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.1.fa.gz | gunzip > human.fasta
!curl https://ftp.ensembl.org/pub/release-110/fasta/sus_scrofa/dna/Sus_scrofa.Sscrofa11.1.dna.primary_assembly.1.fa.gz | gunzip > pig.fasta

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 66.0M  100 66.0M    0     0  2393k      0  0:00:28  0:00:28 --:--:-- 2433k
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 79.7M  100 79.7M    0     0  2357k      0  0:00:34  0:00:34 --:--:-- 2384k


## Dataset Preparation

The training dataset will consists of non-overlapping patches from the DNA sequences of length `patch_length`. Extracts from the human genome will have the label `0` assigned while pig segments will have the label `1`.

10,000 segments from each species's will serve as validation data. We choose a fixed number per species here to offset the different lengths of the two chromosomes.

In [2]:
import sys; sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from Bio import SeqIO
from tqdm import tqdm

from torchmers.modules import KMerFrequencyEncoder, MLP
from torchmers.tokenizers import Tokenizer
from torchmers.utils import k_mer_from_index

In [3]:
class DNADataset(Dataset):
    def __init__(self, fasta_path, patch_length, tokenizer, label):
        self.sequences = ''.join(
            str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')
        )
        self.patch_length = patch_length
        self.tokenizer = tokenizer
        self.label = label

    def __len__(self):
        return len(self.sequences) // self.patch_length

    def __getitem__(self, idx):
        patch = self.sequences[idx * self.patch_length:(idx + 1) * self.patch_length]
        tokens = self.tokenizer.encode(patch)
        return tokens, self.label

In [4]:
tokenizer = Tokenizer.from_name('DNA')

patch_length = 1024
batch_size = 512
device = 'cuda'

labels = {
    'human': 0,
    'pig': 1
}

def load_dataset(species, test_samples=10_000):
    dataset = DNADataset(
        f'{species}.fasta',
        patch_length,
        tokenizer,
        label=labels[species]
    )

    train_split, test_split = random_split(
        dataset,
        [len(dataset) - test_samples, test_samples]
    )

    return train_split, test_split


human_train, human_test = load_dataset('human')
pig_train, pig_test = load_dataset('pig')

train_set = human_train + pig_train
test_set = human_test + pig_test

train_batches = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_batches = DataLoader(test_set, batch_size=batch_size, shuffle=True)


## Define Model

The classification model will consists of two main components: a k-mer frequency encoder layer that efficiently encodes DNA sequences to their frequency spectra and a MLP with a single layer. In fact, since the MLP does not contain any hidden layers in the default configuration, it boils down to a linear model that will be fitted via SGD.

Both the k-mer frequency encoder and MLP class are part of the `torchmers` package.

In [5]:
# Set the length of the k-mers here
k = 7

model = nn.Sequential(
    KMerFrequencyEncoder(k=k, log_counts=True),
    MLP(
        input_dim=4 ** k,
        hidden_dim=256,
        output_dim=1,
        num_layers=1,
        bias=False
    ),
    nn.Flatten(0, 1)
).to(device)

optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

Printing the network definition reveals that it contains only a single linear layer that maps from the k-mer spectrum with $$k^4$$ entries to a single unit, corresponding to the binary classification logit.

In [6]:
model

Sequential(
  (0): KMerFrequencyEncoder()
  (1): MLP(
    (net): Sequential(
      (0): Linear(in_features=16384, out_features=1, bias=False)
    )
  )
  (2): Flatten(start_dim=0, end_dim=1)
)

The model will be trained with SGD for 10 epochs. After each epoch, the loss and classificaton accuracy for the test dataset will be computed.

In [7]:
for epoch in range(10):
    model.train()

    with tqdm(train_batches) as pbar:
        for seqs, labels in pbar:
            seqs = seqs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            logits = model(seqs)
            
            loss = F.binary_cross_entropy_with_logits(logits, labels.float())
            loss.backward()

            accuracy = ((logits > 0) == labels).float().mean().item() * 100

            optimizer.step()

            pbar.set_description(f'TRAIN | EPOCH {epoch} | loss: {loss.item():.4f}, acc: {accuracy:.2f}')
    
    model.eval()
    
    val_loss = 0
    val_corr = 0

    with torch.no_grad():
        for seqs, labels in test_batches:
            seqs = seqs.to(device)
            labels = labels.to(device)

            logits = model(seqs)

            val_loss += F.binary_cross_entropy_with_logits(logits, labels.float()).item()
            val_corr += ((logits > 0) == labels).float().sum().item()
        
        val_loss_avg = val_loss / len(test_batches)
        val_acc = val_corr / len(test_set) * 100
        
    print(f'VALID | EPOCH {epoch} | loss: {val_loss_avg:.4f}, acc: {val_acc:.2f}')


TRAIN | EPOCH 0 | loss: 0.2105, acc: 84.62: 100%|██████████| 960/960 [00:10<00:00, 89.26it/s]


VALID | EPOCH 0 | loss: 0.3538, acc: 81.04


TRAIN | EPOCH 1 | loss: 0.3906, acc: 69.23: 100%|██████████| 960/960 [00:10<00:00, 90.85it/s]


VALID | EPOCH 1 | loss: 0.3542, acc: 81.38


TRAIN | EPOCH 2 | loss: 0.4727, acc: 61.54: 100%|██████████| 960/960 [00:10<00:00, 90.73it/s]


VALID | EPOCH 2 | loss: 0.3613, acc: 81.34


TRAIN | EPOCH 3 | loss: 0.5518, acc: 69.23: 100%|██████████| 960/960 [00:10<00:00, 90.11it/s]


VALID | EPOCH 3 | loss: 0.3547, acc: 81.34


TRAIN | EPOCH 4 | loss: 0.3146, acc: 92.31: 100%|██████████| 960/960 [00:10<00:00, 90.34it/s]


VALID | EPOCH 4 | loss: 0.3526, acc: 81.22


TRAIN | EPOCH 5 | loss: 0.1497, acc: 92.31: 100%|██████████| 960/960 [00:10<00:00, 90.71it/s]


VALID | EPOCH 5 | loss: 0.3546, acc: 81.36


TRAIN | EPOCH 6 | loss: 0.3143, acc: 84.62: 100%|██████████| 960/960 [00:10<00:00, 90.87it/s]


VALID | EPOCH 6 | loss: 0.3535, acc: 81.30


TRAIN | EPOCH 7 | loss: 0.4549, acc: 69.23: 100%|██████████| 960/960 [00:10<00:00, 90.83it/s]


VALID | EPOCH 7 | loss: 0.3650, acc: 81.08


TRAIN | EPOCH 8 | loss: 0.3700, acc: 76.92: 100%|██████████| 960/960 [00:10<00:00, 90.82it/s]


VALID | EPOCH 8 | loss: 0.3662, acc: 80.94


TRAIN | EPOCH 9 | loss: 0.5026, acc: 76.92: 100%|██████████| 960/960 [00:10<00:00, 90.51it/s]


VALID | EPOCH 9 | loss: 0.3560, acc: 81.42


As we can see, the model is capable of classifying between human and pig genome extracts with an accuracy of around 81%. This value could likely be further improved by using larger k-mers, a deeper model or more hyperparameter tuning in general.

## Interpreting the Model

Since our single-layer neural network is essentially just a linear model, we can inspect the only weight matrix to get a hint on what k-mers are most important for the classification decision.

Human has the label `0`, thus the weights with the lowest numerical value correspond to the most important features for that class. For pig, the weights with the highest value indicate the most important features.

In [8]:
weights = model[1].net[0].weight[0]
weight_index_pairs = list(zip(*weights.sort()))

print('Top 10 most important k-mers for human:')

for weight, i in weight_index_pairs[:10]:
    print(f'{-weight.item():.3f} {k_mer_from_index(i.item(), k)}')

print()
print('Top 10 most important k-mers for pig:')

for weight, i in weight_index_pairs[-10:]:
    print(f'{weight.item():.3f} {k_mer_from_index(i.item(), k)}')

Top 10 most important k-mers for human:
0.351 AAAAAAA
0.334 GTAATCC
0.332 GGATTAC
0.298 GTCTCGC
0.290 TCGAACT
0.287 TCAAGCG
0.270 CGCGGTG
0.268 GCGATCC
0.257 ACAGGCG
0.249 GGAGTGC

Top 10 most important k-mers for pig:
0.300 TCCCCCC
0.302 GGTTCGA
0.326 TCGAACC
0.330 CTAGTCG
0.348 CGATCCC
0.352 GGGGGGA
0.370 GGATCGA
0.400 TCGATCC
0.805 GGGGGGG
0.818 CCCCCCC
