# Using Pytorch to train a neural network classifier on a dataset of enhancers in Drosophila

In [None]:
# Install the required packages
pip install genomic-benchmarks
pip3 uninstall torch torchvision torchaudio
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

In [58]:
#We will download the Drosophila enhancers dataset from https://github.com/
#ML-Bioinfo-CEITEC/genomic_benchmarks/tree/main included in the publication by 
#Grešová, Katarína, et al. "Genomic benchmarks: a collection of datasets for 
# genomic sequence classification." BMC Genomic Data 24.1 (2023): 25.

#Importing the required libraries
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchtext.data.utils import get_tokenizer
import torch.nn.functional as F

from genomic_benchmarks.dataset_getters.pytorch_datasets import DrosophilaEnhancersStark
from genomic_benchmarks.models.torch import CNN
from genomic_benchmarks.dataset_getters.utils import coll_factory, LetterTokenizer, build_vocab
from genomic_benchmarks.data_check import info

In [74]:
#We will create a PyTorch dataset object
train_dset = DrosophilaEnhancersStark(split='train')
#We will take a look at the dataset
info('drosophila_enhancers_stark', 0)

Dataset `drosophila_enhancers_stark` has 2 classes: negative, positive.

The length of genomic intervals ranges from 236 to 3237, with average 2118.1238067688746 and median 2142.0.

Totally 6914 sequences have been found, 5184 for training and 1730 for testing.


Unnamed: 0,train,test
negative,2592,865
positive,2592,865


Now, we will create tokenizer for the dataset, to classify the data into numerical categories and feed it to neural network. We will use padding to make all the sequences the same length later.

In [14]:
tokenizer = get_tokenizer(LetterTokenizer())
vocabulary = build_vocab(train_dset, tokenizer, use_padding=False)

print("vocab len:" ,vocabulary.__len__())
print(vocabulary.get_stoi())

vocab len: 7
{'<eos>': 6, 'G': 4, 'A': 2, 'C': 3, 'T': 5, '<bos>': 1, '<unk>': 0}


## Dataloader and batch preparation

In [15]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print('Using {} device'.format(device))

Using mps device


In [75]:
#We will use the collate_fn function to combine individual samples and include a 
#padding to make all the sequences equal to the maximum length
collate = coll_factory(vocabulary, tokenizer, device, pad_to_length = 3237)
#Parallelizing the data loading
train_loader = DataLoader(train_dset, 
                          batch_size=32, 
                          shuffle=True, 
                          collate_fn=collate)

## Model

We will initialize our model. From the padding in the previous step, we know that all inputs are 3237 characters long, and the number of classes is 2.

In [79]:
model = CNN(
    number_of_classes=2,
    vocab_size=vocabulary.__len__(),
    embedding_dim=100,
    input_len=3237,
    device=device
).to(device)

## Training

In [77]:
model.fit(train_loader, epochs=5)

Epoch 0


  x = torch.tensor(pad(x), dtype=torch.long)


Train metrics: 
 Accuracy: 49.9%, Avg loss: 0.695347 

Epoch 1
Train metrics: 
 Accuracy: 50.1%, Avg loss: 0.695992 

Epoch 2
Train metrics: 
 Accuracy: 50.0%, Avg loss: 0.693070 

Epoch 3
Train metrics: 
 Accuracy: 50.0%, Avg loss: 0.693147 

Epoch 4
Train metrics: 
 Accuracy: 50.0%, Avg loss: 0.693147 



## Testing

In [82]:
test_dset = DrosophilaEnhancersStark('test', version=0)
test_loader = DataLoader(test_dset, batch_size=32, shuffle=False, collate_fn=collate)

predictions = []
for x,y in test_loader:
  output = torch.round(model(x))
  for prediction in output.tolist():
    predictions.append(prediction[0])

We will now get the f1 score to determine the model's accuracy.

In [85]:
from sklearn.metrics import f1_score
from genomic_benchmarks.data_check.info import labels_in_order

labels = labels_in_order(dset_name='drosophila_enhancers_stark')
f1_score(labels, predictions)

0.5951492537313433

We can see that the model made the correct prediction 59% of the time.