# PYTORCH CNN Classifier

To run this notebook on an another benchmark, use

```
papermill utils/torch_cnn_character.ipynb torch_cnn_character_experiments/[DATASET NAME].ipynb -p DATASET [DATASET NAME]
```

In [3]:
DATASET = 'demo_mouse_enhancers'
VERSION = 0
BATCH_SIZE = 32
EPOCHS = 5

In [4]:
print(DATASET, VERSION, BATCH_SIZE, EPOCHS)

demo_mouse_enhancers 0 32 5


## Config

In [5]:
import torch
from torch.utils.data import DataLoader

from genomic_benchmarks.dataset_getters.pytorch_datasets import get_dataset
from glp.models import CNN
from glp.tokenizers.tokenizers import CharacterTokenizer
from glp.tokenizers.utils import build_vocab, coll_factory, check_config, check_seq_lengths

In [6]:
config = {
    "dataset": DATASET,
    "dataset_version": VERSION,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "use_padding": True,
    "force_download": False,
    "run_on_gpu": True,
    "number_of_classes": 2,
    "embedding_dim": 100,
}
check_config(config)

## Choose the dataset

In [7]:
train_dset = get_dataset(config["dataset"], 'train')

## Tokenizer and vocab

In [8]:
tokenizer = CharacterTokenizer()
tokenizer.train(train_dset)
vocabulary = build_vocab(train_dset, tokenizer, use_padding=config["use_padding"])

print("vocab len:" ,vocabulary.__len__())
print(vocabulary.get_stoi())

vocab len: 9
{'N': 6, 'A': 2, 'C': 4, '<eos>': 7, 'G': 3, '<pad>': 8, 'T': 5, '<bos>': 1, '<unk>': 0}


## Dataloader and batch preparation

In [9]:
# Run on GPU or CPU
device = 'cuda' if config["run_on_gpu"] and torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

max_tok_len = check_seq_lengths(dataset=train_dset, tokenizer=tokenizer)

# Data Loader
collate = coll_factory(vocabulary, tokenizer, device, pad_to_length = max_tok_len)

train_loader = DataLoader(train_dset, batch_size=config["batch_size"], shuffle=True, collate_fn=collate)

Using cpu device
max_tok_len  4709
not all sequences are of the same length


In [10]:
tokenizer(train_dset[1][0])

['<bos>',
 'A',
 'G',
 'C',
 'C',
 'C',
 'T',
 'C',
 'A',
 'A',
 'G',
 'G',
 'A',
 'G',
 'A',
 'C',
 'A',
 'C',
 'A',
 'G',
 'C',
 'A',
 'T',
 'A',
 'C',
 'T',
 'G',
 'T',
 'A',
 'C',
 'T',
 'T',
 'T',
 'C',
 'C',
 'A',
 'C',
 'A',
 'T',
 'A',
 'A',
 'T',
 'T',
 'G',
 'C',
 'T',
 'T',
 'T',
 'T',
 'C',
 'A',
 'G',
 'G',
 'T',
 'C',
 'T',
 'G',
 'C',
 'C',
 'T',
 'C',
 'C',
 'C',
 'A',
 'C',
 'A',
 'C',
 'A',
 'A',
 'T',
 'C',
 'C',
 'T',
 'T',
 'C',
 'A',
 'G',
 'T',
 'G',
 'A',
 'A',
 'T',
 'A',
 'C',
 'C',
 'T',
 'A',
 'A',
 'G',
 'T',
 'C',
 'A',
 'G',
 'C',
 'A',
 'A',
 'A',
 'A',
 'C',
 'A',
 'A',
 'C',
 'A',
 'G',
 'G',
 'A',
 'A',
 'T',
 'T',
 'T',
 'T',
 'T',
 'A',
 'T',
 'A',
 'T',
 'T',
 'G',
 'A',
 'T',
 'C',
 'C',
 'T',
 'G',
 'A',
 'A',
 'A',
 'T',
 'A',
 'C',
 'T',
 'G',
 'G',
 'C',
 'A',
 'A',
 'G',
 'T',
 'C',
 'T',
 'G',
 'T',
 'A',
 'A',
 'G',
 'T',
 'T',
 'A',
 'T',
 'G',
 'G',
 'G',
 'A',
 'C',
 'T',
 'C',
 'A',
 'G',
 'C',
 'T',
 'G',
 'G',
 'G',
 'G',
 'C',
 'A',


## Model

In [11]:
model = CNN(
    number_of_classes=config["number_of_classes"],
    vocab_size=vocabulary.__len__(),
    embedding_dim=config["embedding_dim"],
    input_len=max_tok_len
).to(device)

## Training

In [12]:
model.train(train_loader, epochs=config["epochs"])

Epoch 0


  x = torch.tensor(pad(x), dtype=torch.long)


Train metrics: 
 Accuracy: 72.4%, Avg loss: 0.612228 

Epoch 1
Train metrics: 
 Accuracy: 76.1%, Avg loss: 0.601335 

Epoch 2
Train metrics: 
 Accuracy: 72.0%, Avg loss: 0.611224 

Epoch 3
Train metrics: 
 Accuracy: 78.1%, Avg loss: 0.592939 

Epoch 4
Train metrics: 
 Accuracy: 78.6%, Avg loss: 0.600212 



## Testing

In [13]:
test_dset = get_dataset(config["dataset"], 'test')
test_loader = DataLoader(test_dset, batch_size=config["batch_size"], shuffle=True, collate_fn=collate)

acc, f1 = model.test(test_loader)
acc, f1

p  121 ; tp  93.5970687866211 ; fp  26.871801018714905
recall  0.7735294941043066 ; precision  0.7769398761511029
num_batches 8
correct 188
size 242
Test metrics: 
 Accuracy: 0.776860, F1 score: 0.775231, Avg loss: 0.609989 



(0.7768595041322314, 0.7752309344229413)