# PYTORCH CNN Classifier

In [1]:
DATASET = 'none'
VOCAB_SIZE = 0
TOKENIZER = 'none'
KMER = 0

In [2]:
# Parameters
DATASET = "human_enhancers_cohn"
TOKENIZER = "subword"
VOCAB_SIZE = 128


In [3]:
print(DATASET, VOCAB_SIZE, TOKENIZER, KMER)

human_enhancers_cohn 128 subword 0


## Config

In [4]:
import torch
from torch.utils.data import DataLoader

from genomic_benchmarks.dataset_getters.pytorch_datasets import get_dataset
from glp.models import CNN
from glp.tokenizers import get_tokenizer
from glp.tokenizers.utils import build_vocab, coll_factory, check_config, check_seq_lengths

In [5]:
config = {
    "dataset": DATASET,
    "tokenizer": TOKENIZER,
    "dataset_version": 0,
    "epochs": 5,
    "batch_size": 32,
    "use_padding": True,
    "force_download": False,
    "run_on_gpu": True,
    "number_of_classes": 2,
    "embedding_dim": 100,
}
check_config(config)

## Choose the dataset

In [6]:
train_dset = get_dataset(config["dataset"], 'train')

## Tokenizer and vocab

In [7]:
tokenizer = get_tokenizer(config['tokenizer'])
tokenizer.train(train_dset=train_dset, vocab_size=VOCAB_SIZE, kmer=KMER)
vocabulary = build_vocab(train_dset, tokenizer, use_padding=config["use_padding"])

print("vocab len:" ,vocabulary.__len__())
print(vocabulary.get_stoi())

sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=sample.csv --model_prefix=sample --vocab_size=128 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: sample.csv
  input_format: 
  model_prefix: sample
  model_type: UNIGRAM
  vocab_size: 128
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
}


trainer_interface.cc(456) LOG(INFO) all chars count=10442343
trainer_interface.cc(477) LOG(INFO) Alphabet size=5
trainer_interface.cc(478) LOG(INFO) Final character coverage=1


trainer_interface.cc(510) LOG(INFO) Done! preprocessed 20843 sentences.


unigram_model_trainer.cc(138) LOG(INFO) Making suffix array...


unigram_model_trainer.cc(142) LOG(INFO) Extracting frequent sub strings...


unigram_model_trainer.cc(193) LOG(INFO) Initialized 1000000 seed sentencepieces


trainer_interface.cc(516) LOG(INFO) Tokenizing input sentences with whitespace: 20843
trainer_interface.cc(526) LOG(INFO) Done! 20843
unigram_model_trainer.cc(488) LOG(INFO) Using 20843 sentences for EM training


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=451960 obj=656.396 num_tokens=1062834 num_tokens/piece=2.35161


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=328755 obj=662.576 num_tokens=1151170 num_tokens/piece=3.5016


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=214788 obj=659.085 num_tokens=1189558 num_tokens/piece=5.53829


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=195654 obj=651.89 num_tokens=1212596 num_tokens/piece=6.19766


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=144886 obj=653.472 num_tokens=1240509 num_tokens/piece=8.56197


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=141860 obj=649.031 num_tokens=1247362 num_tokens/piece=8.79291


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=106361 obj=652.395 num_tokens=1276529 num_tokens/piece=12.0019


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=106221 obj=648.169 num_tokens=1278410 num_tokens/piece=12.0354


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=79663 obj=653.869 num_tokens=1316403 num_tokens/piece=16.5246


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=79658 obj=648.559 num_tokens=1316903 num_tokens/piece=16.532


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=59743 obj=655.848 num_tokens=1359053 num_tokens/piece=22.7483


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=59742 obj=650.015 num_tokens=1359467 num_tokens/piece=22.7556


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=44806 obj=657.667 num_tokens=1400371 num_tokens/piece=31.2541


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=44806 obj=652.073 num_tokens=1400851 num_tokens/piece=31.2648


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=33604 obj=659.381 num_tokens=1441209 num_tokens/piece=42.888


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=33603 obj=654.176 num_tokens=1441766 num_tokens/piece=42.9059


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=25202 obj=661.335 num_tokens=1482972 num_tokens/piece=58.8434


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=25202 obj=656.203 num_tokens=1483424 num_tokens/piece=58.8614


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=18901 obj=663.245 num_tokens=1526444 num_tokens/piece=80.76


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=18901 obj=658.003 num_tokens=1527069 num_tokens/piece=80.793


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=14175 obj=665.311 num_tokens=1573592 num_tokens/piece=111.012


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=14175 obj=659.734 num_tokens=1574133 num_tokens/piece=111.05


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=10631 obj=667.497 num_tokens=1623310 num_tokens/piece=152.696


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=10631 obj=661.529 num_tokens=1623986 num_tokens/piece=152.759


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=7973 obj=669.647 num_tokens=1677995 num_tokens/piece=210.46


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=7973 obj=663.05 num_tokens=1678566 num_tokens/piece=210.531


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=5979 obj=671.319 num_tokens=1734023 num_tokens/piece=290.019


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=5979 obj=664.29 num_tokens=1734616 num_tokens/piece=290.118


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=4484 obj=672.797 num_tokens=1794114 num_tokens/piece=400.115


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=4484 obj=665.366 num_tokens=1794882 num_tokens/piece=400.286


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=3363 obj=674.688 num_tokens=1858954 num_tokens/piece=552.767


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=3363 obj=666.505 num_tokens=1859784 num_tokens/piece=553.013


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=2522 obj=676.667 num_tokens=1931142 num_tokens/piece=765.718


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=2522 obj=667.926 num_tokens=1931879 num_tokens/piece=766.011


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1891 obj=679.206 num_tokens=2007806 num_tokens/piece=1061.77


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1891 obj=669.539 num_tokens=2008799 num_tokens/piece=1062.29


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1418 obj=680.388 num_tokens=2088648 num_tokens/piece=1472.95


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1418 obj=670.701 num_tokens=2089565 num_tokens/piece=1473.6


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=1063 obj=684.877 num_tokens=2178168 num_tokens/piece=2049.08


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=1063 obj=671.698 num_tokens=2179396 num_tokens/piece=2050.23


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=797 obj=685.058 num_tokens=2275812 num_tokens/piece=2855.47


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=797 obj=672.737 num_tokens=2277467 num_tokens/piece=2857.55


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=597 obj=687.794 num_tokens=2381618 num_tokens/piece=3989.31


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=597 obj=673.614 num_tokens=2382788 num_tokens/piece=3991.27


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=447 obj=693.326 num_tokens=2498111 num_tokens/piece=5588.62


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=447 obj=674.31 num_tokens=2498693 num_tokens/piece=5589.92


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=335 obj=690.79 num_tokens=2613085 num_tokens/piece=7800.25


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=335 obj=674.858 num_tokens=2614424 num_tokens/piece=7804.25


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=251 obj=694.936 num_tokens=2752494 num_tokens/piece=10966.1


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=251 obj=676.08 num_tokens=2754941 num_tokens/piece=10975.9


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=188 obj=696.701 num_tokens=2907905 num_tokens/piece=15467.6


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=188 obj=677.004 num_tokens=2910538 num_tokens/piece=15481.6


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=141 obj=699.839 num_tokens=3075471 num_tokens/piece=21811.9


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=141 obj=677.844 num_tokens=3078159 num_tokens/piece=21830.9


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=0 size=140 obj=677.835 num_tokens=3088370 num_tokens/piece=22059.8


unigram_model_trainer.cc(504) LOG(INFO) EM sub_iter=1 size=140 obj=677.091 num_tokens=3096875 num_tokens/piece=22120.5
trainer_interface.cc(604) LOG(INFO) Saving model: sample.model
trainer_interface.cc(615) LOG(INFO) Saving vocabs: sample.vocab


vocab len: 127
{'CTTG': 125, 'TCCA': 124, 'GCTG': 123, 'T': 122, 'CATG': 121, 'ATTA': 120, 'AAAAA': 119, 'A': 116, 'AAC': 115, 'G': 114, 'C': 112, 'CAGC': 111, 'GGAG': 110, 'AGTG': 109, 'AC': 107, 'AATT': 106, 'ATC': 104, 'GTA': 102, 'TTA': 98, 'CATT': 100, 'ATT': 97, 'ATTT': 92, 'TACA': 91, 'GCT': 90, 'CAAA': 105, 'TTCC': 67, 'CCAC': 89, 'GTGG': 23, 'TGTT': 87, 'TGGA': 10, 'TCCT': 85, 'ATGG': 82, 'AAGA': 58, 'TTTG': 80, 'CTT': 34, 'TAT': 96, 'CTGT': 113, 'CACA': 75, 'GTG': 74, 'TAA': 49, 'ACC': 72, 'CTCA': 70, 'GAG': 68, 'CCAT': 83, 'AGG': 101, 'AGAA': 61, 'CAGG': 60, 'TGGT': 59, 'AAAT': 56, 'CAGT': 55, 'TGGG': 54, 'AGCA': 52, 'AAAA': 51, 'CAGA': 31, 'AAT': 50, 'TATT': 118, 'TCT': 73, 'TTTTT': 84, 'TGA': 32, 'TAG': 46, 'TAAT': 94, 'TTTT': 81, 'CAG': 12, 'CCC': 5, 'TCA': 24, 'TGCT': 117, 'AGGA': 43, 'ATGT': 27, 'ACAG': 47, 'TGAG': 42, 'AGC': 108, 'TTCA': 41, 'AAAG': 44, 'TTC': 39, 'CTGG': 25, 'GGT': 9, 'GCC': 33, 'AAGG': 19, 'AAG': 76, 'CTTT': 36, 'CAAG': 63, 'AGAG': 69, 'CG': 65, 'TTC

## Dataloader and batch preparation

In [8]:
# Run on GPU or CPU
device = 'cuda' if config["run_on_gpu"] and torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

max_tok_len = check_seq_lengths(dataset=train_dset, tokenizer=tokenizer)

# Data Loader
collate = coll_factory(vocabulary, tokenizer, device, pad_to_length = max_tok_len)

train_loader = DataLoader(train_dset, batch_size=config["batch_size"], shuffle=True, collate_fn=collate)

Using cpu device


max_tok_len  188


not all sequences are of the same length


In [9]:
tokenizer(train_dset[1][0])

['▁',
 'CCAT',
 'CTA',
 'TTTTT',
 'GAA',
 'TCCT',
 'TTT',
 'AACA',
 'CTTT',
 'ATG',
 'AGAG',
 'AGA',
 'AAG',
 'AACA',
 'GAA',
 'CAGG',
 'CAGG',
 'CAGG',
 'CTG',
 'TGTT',
 'TAA',
 'CCC',
 'CAC',
 'AACA',
 'CCAC',
 'CAT',
 'AGAG',
 'CCT',
 'AAAA',
 'GCT',
 'TGGG',
 'AGT',
 'AAA',
 'TACA',
 'ATTT',
 'CTTT',
 'GGC',
 'ACC',
 'TTG',
 'TAAT',
 'AAAA',
 'CAT',
 'GAA',
 'CAA',
 'TGAG',
 'TTT',
 'CTCA',
 'CCA',
 'TGT',
 'TAT',
 'TAG',
 'TCT',
 'CTGG',
 'ATT',
 'CCA',
 'CG',
 'CTCC',
 'GGC',
 'ATA',
 'CTT',
 'ACA',
 'CAT',
 'TTA',
 'CAG',
 'AAA',
 'AGAG',
 'AGA',
 'CAGA',
 'CAA',
 'AGAG',
 'AGT',
 'ACT',
 'GTG',
 'CTG',
 'AAAG',
 'GTT',
 'CAGA',
 'CAGA',
 'AGAA',
 'TTC',
 'GGG',
 'ACT',
 'CAAG',
 'CCA',
 'CAA',
 'CAA',
 'CATT',
 'AAT',
 'GAGG',
 'AAA',
 'GGG',
 'AGG',
 'GTA',
 'GGG',
 'CAG',
 'CCC',
 'AGT',
 'GAA',
 'CAA',
 'TGAG',
 'CTG',
 'CTC',
 'GC',
 'ATC',
 'AGT',
 'TTTG',
 'AAA',
 'GAGA',
 'CAGT',
 'AGAG',
 'AAGA',
 'CACA',
 'CAAA',
 'CAT',
 'TCCT',
 'AATT',
 'TCTT',
 'GC',
 'CTT',
 'GGC'

## Model

In [10]:
model = CNN(
    number_of_classes=config["number_of_classes"],
    vocab_size=vocabulary.__len__(),
    embedding_dim=config["embedding_dim"],
    input_len=max_tok_len
).to(device)

## Training

In [11]:
model.train(train_loader, epochs=config["epochs"])

Epoch 0


  x = torch.tensor(pad(x), dtype=torch.long)


Train metrics: 
 Accuracy: 67.8%, Avg loss: 0.639029 

Epoch 1


Train metrics: 
 Accuracy: 68.0%, Avg loss: 0.634062 

Epoch 2


Train metrics: 
 Accuracy: 68.9%, Avg loss: 0.630115 

Epoch 3


Train metrics: 
 Accuracy: 73.2%, Avg loss: 0.618496 

Epoch 4


Train metrics: 
 Accuracy: 75.4%, Avg loss: 0.616339 



## Testing

In [12]:
test_dset = get_dataset(config["dataset"], 'test')
test_loader = DataLoader(test_dset, batch_size=config["batch_size"], shuffle=True, collate_fn=collate)

acc, f1 = model.test(test_loader)
acc, f1

p  3474 ; tp  2104.2328075170517 ; fp  902.8944049580023
recall  0.6057089255950062 ; precision  0.6997485170523053
num_batches 218
correct 4676
size 6948
Test metrics: 
 Accuracy: 0.672999, F1 score: 0.649342, Avg loss: 0.655398 



(0.6729994242947611, 0.6493416155963629)