In [139]:
import torch
from transformers import BertModel, BertForMaskedLM, BertConfig
from torch.utils.data import TensorDataset, random_split
from ipywidgets import IntProgress
from transformers import BertTokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmup, AdamW

import pandas as pd
import numpy as np
import gensim
import random
import sys
import glob
import os
import datetime
from nltk import sent_tokenize
from nltk import word_tokenize
from scipy.spatial.distance import cosine
import warnings
from gensim.parsing.preprocessing import strip_non_alphanum, stem_text, preprocess_string
from itertools import product

In [4]:
# Input paths to text datasets.
plant_abstracts_corpus_path = "../data/corpus_related_files/untagged_text_corpora/phenotypes_all.txt"
plant_phenotype_descriptions_path = "../../plant-data/genes_texts_annots.csv"

In [5]:
# Preparing the dataset that combines the dataset of plant phenotype descriptions and scrapped abstracts.
corpus = open(plant_abstracts_corpus_path, 'r').read()
sentences_from_corpus = sent_tokenize(corpus)
phenotype_descriptions = " ".join(pd.read_csv(plant_phenotype_descriptions_path)["descriptions"].values)
times_to_duplicate_phenotype_dataset = 5
sentences_from_descriptions = sent_tokenize(phenotype_descriptions)
sentences_from_descriptions_duplicated = list(np.repeat(sentences_from_descriptions, times_to_duplicate_phenotype_dataset))
sentences_from_corpus_and_descriptions = sentences_from_corpus+sentences_from_descriptions_duplicated
random.shuffle(sentences_from_corpus_and_descriptions)
random.shuffle(sentences_from_corpus)
random.shuffle(sentences_from_descriptions)
sentences_from_corpus_and_descriptions = [preprocess_string(s) for s in sentences_from_corpus_and_descriptions]
sentences_from_corpus = [preprocess_string(s) for s in sentences_from_corpus]
sentences_from_descriptions = [preprocess_string(s) for s in sentences_from_descriptions]
assert len(sentences_from_corpus_and_descriptions) == len(sentences_from_corpus)+(times_to_duplicate_phenotype_dataset*len(sentences_from_descriptions))
print(len(sentences_from_corpus_and_descriptions))
print(len(sentences_from_corpus))
print(len(sentences_from_descriptions))

346894
172374
34904


In [146]:
sentences = [" ".join(s) for s in sentences_from_corpus_and_descriptions]
#sentences = [" ".join(s) for s in sentences_from_corpus_and_descriptions[:50]]
sentences[:5]

['genet molecular analys cbp overlap function development pathwai hyl',
 'addit meristem gene like sequenc us phylogenet analys bamboo speci',
 'advanc backcross progeni resist phenotyp tag marker us accumul blast resist upland rice',
 'ran gtpase activ protein rangap import ran signal involv nucleocytoplasm transport spindl organ postmitot nuclear assembl',
 'reduc size leaf stage usual small plant']

In [147]:
# Preparing a vocabulary file based on these sentences for BERT.
vocabulary_file_path = "../data/corpus_related_files/vocabulary/vocab.txt"
vocabulary = set()
for s in sentences_from_corpus_and_descriptions:
    vocabulary.update(s)
vocabulary.update(["[PAD]","[SEP]","[UNK]","[MASK]"])
print(len(vocabulary))
vocabulary_size = len(vocabulary)
print(list(vocabulary)[:10])
with open(vocabulary_file_path, "w") as f:
    for token in list(vocabulary):
        f.write(token+"\n")
print("done")

35788
['vitl', 'aodelta', 'doq', 'mon', 'supersensit', 'fruitless', 'riesl', 'panorama', 'phapb', 'thermocycl']
done


In [149]:
# Creating and parameratizing the small BERT architecture.
vocab_size = vocabulary_size
small_bert_configuration = BertConfig(
    vocab_size=vocab_size, 
    hidden_size=50, 
    num_hidden_layers=2, 
    num_attention_heads=2,
    intermediate_size=100,
    max_position_embeddings=200,
    return_dict=True,   
)
model = BertForMaskedLM(small_bert_configuration)

# An easier to read description of the model, from BERT fine-tuning with PyTorch.
params = list(model.named_parameters())
print('The BERT model has {:} different named parameters.\n'.format(len(params)))
print('==== Embedding Layer ====\n')
for p in params[0:5]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== First Transformer ====\n')
for p in params[5:21]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== Output Layer ====\n')
for p in params[-4:]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

The BERT model has 42 different named parameters.

==== Embedding Layer ====

bert.embeddings.word_embeddings.weight                   (35788, 50)
bert.embeddings.position_embeddings.weight                 (200, 50)
bert.embeddings.token_type_embeddings.weight                 (2, 50)
bert.embeddings.LayerNorm.weight                               (50,)
bert.embeddings.LayerNorm.bias                                 (50,)

==== First Transformer ====

bert.encoder.layer.0.attention.self.query.weight            (50, 50)
bert.encoder.layer.0.attention.self.query.bias                 (50,)
bert.encoder.layer.0.attention.self.key.weight              (50, 50)
bert.encoder.layer.0.attention.self.key.bias                   (50,)
bert.encoder.layer.0.attention.self.value.weight            (50, 50)
bert.encoder.layer.0.attention.self.value.bias                 (50,)
bert.encoder.layer.0.attention.output.dense.weight          (50, 50)
bert.encoder.layer.0.attention.output.dense.bias               (

In [150]:
# Creating the tokenizer using the provided vocabulary.
tokenizer = BertTokenizer(vocab_file=vocabulary_file_path)
print(tokenizer)

PreTrainedTokenizer(name_or_path='', vocab_size=35788, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [152]:
# Testing out the tokenizer with the first few sentences from the dataset.
encoding = tokenizer(sentences[0:3], return_tensors='pt', padding=True, truncation=True)
print(encoding.input_ids)
print(encoding.token_type_ids)
print(encoding.attention_mask)
print(encoding.input_ids.shape)
print(encoding.token_type_ids.shape)
print(encoding.attention_mask.shape)

tensor([[19100, 26733,  6707, 35559,   370, 16016, 30503, 29005, 12615, 18478,
         25404,  7320,  7320,  7320,  7320],
        [19100,  7391, 30381,  6617, 32827, 25279, 14731, 17372, 35559,  1675,
         10033, 25404,  7320,  7320,  7320],
        [19100, 16392,  6437,  3006, 24441, 31378, 27568, 25983, 14731, 17132,
         21417, 24441, 29076, 28620, 25404]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
torch.Size([3, 15])
torch.Size([3, 15])
torch.Size([3, 15])


In [153]:
# Producing the corresponding dataset of sentences with masked tokens for training.
masked = [" ".join([np.random.choice(['[MASK]',token],p=[0.15,0.85]) for token in s.split()]) for s in sentences]
masked[:5]

['genet molecular analys [MASK] overlap function development pathwai hyl',
 'addit [MASK] gene like sequenc us phylogenet analys bamboo speci',
 'advanc backcross progeni resist [MASK] tag marker us accumul blast resist upland rice',
 'ran gtpase activ [MASK] rangap import ran [MASK] involv nucleocytoplasm transport spindl organ postmitot nuclear assembl',
 'reduc size leaf stage usual small plant']

In [154]:
# Preparing the dataset object that can be read in as batches during the training loop.
inputs_dict = tokenizer(masked, return_tensors='pt', padding=True, truncation=True)
labels_dict = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True)
print(inputs_dict["input_ids"].shape)
print(labels_dict["input_ids"].shape)
dataset = TensorDataset(inputs_dict["input_ids"], inputs_dict["attention_mask"], labels_dict["input_ids"])


# Pick the batch size here.
batch_size = 32
train_dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)

KeyboardInterrupt: 

In [143]:
# Creating and parameterizing the necessary objects for the optimizer and learning rate scheduler.
optimizer = AdamW(model.parameters(), lr = 2e-5, eps = 1e-8)
epochs = 5
total_steps = len(train_dataloader)*epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps=total_steps)
print(total_steps)

In [145]:
# The training loop that uses batches from that data loader.
epochs = 20
for epoch_i in range(0, epochs):
    model.train()
    total_train_loss = 0 
    for step,batch in enumerate(train_dataloader):    
        model.zero_grad()
        outputs = model(input_ids=batch[0], attention_mask=batch[1], labels=batch[2])
        loss = outputs.loss
        logits = outputs.logits
        
        print(step, loss)
        
        total_train_loss += loss.item()

        # Perform a backward pass to calculate the gradients.
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        scheduler.step()
        optimizer.step()
        
    avg_train_loss = total_train_loss / len(train_dataloader) 
    print(avg_train_loss)


tensor(10.4577, grad_fn=<NllLossBackward>)
tensor(10.4520, grad_fn=<NllLossBackward>)
10.454821109771729
tensor(10.4580, grad_fn=<NllLossBackward>)
tensor(10.4543, grad_fn=<NllLossBackward>)
10.456151962280273
tensor(10.4523, grad_fn=<NllLossBackward>)
tensor(10.4578, grad_fn=<NllLossBackward>)
10.455024719238281
tensor(10.4567, grad_fn=<NllLossBackward>)
tensor(10.4540, grad_fn=<NllLossBackward>)
10.455362319946289
tensor(10.4596, grad_fn=<NllLossBackward>)
tensor(10.4545, grad_fn=<NllLossBackward>)
10.457032203674316
tensor(10.4544, grad_fn=<NllLossBackward>)
tensor(10.4598, grad_fn=<NllLossBackward>)
10.45707893371582
tensor(10.4534, grad_fn=<NllLossBackward>)
tensor(10.4582, grad_fn=<NllLossBackward>)
10.455801010131836
tensor(10.4576, grad_fn=<NllLossBackward>)
tensor(10.4599, grad_fn=<NllLossBackward>)
10.458718299865723
tensor(10.4571, grad_fn=<NllLossBackward>)
tensor(10.4578, grad_fn=<NllLossBackward>)
10.457447052001953
tensor(10.4591, grad_fn=<NllLossBackward>)
tensor(10.455