In [1]:
from torch import cuda
import torch.nn as nn
import transformers
from transformers import DistilBertTokenizer, DistilBertModel
import warnings
import torch
import pickle
from tqdm import tqdm
import collections
import spacy
import json
import glob
from spacy import displacy

In [2]:
# Load the BERT tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')  
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
device = 'cuda' if cuda.is_available() else 'cpu'
warnings.filterwarnings("ignore")

nlp = spacy.load('en_core_web_trf')


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
class BERT(nn.Module):
    def __init__(self, bert):

        super(BERT, self).__init__()

        # Distil Bert model
        self.bert = bert
        ## Additional layers
        # Dropout layer
        self.dropout = nn.Dropout(0.1)
        # Relu activation function
        self.relu =  nn.ReLU()
        # Dense layer 1
        self.fc1 = nn.Linear(768, 512)
        # Dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512, 2)
        # Softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, **kwargs):

        #pass the inputs to the model BERT  
        cls_hs = self.bert(**kwargs)
        hidden_state = cls_hs.last_hidden_state
        pooler = hidden_state[:, 0]

        # dense layer 1        
        x = self.fc1(pooler)
        # ReLU activation
        x = self.relu(x)
        # Drop out
        x = self.dropout(x)
        # dense layer 2
        x = self.fc2(x)
        # apply softmax activation
        x = self.softmax(x)

        return x

In [10]:
model = BERT(bert)
model = model.to(device)

modelname = "saved_weights_BERT_description_classifier.pt"
location = "../../../models/saved_weights/"

model_save_name = modelname
path = location + model_save_name
model.load_state_dict(torch.load(path, 
                                    map_location=torch.device('cpu')))

model.eval()

BERT(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_features=768, 

In [4]:
def classify_text(span, model, truncation=True):

    """
    Uses a trained bert classifier to see if a span
    belongs to a species description or otherwise.
    """
        
    with torch.no_grad():
        # Tokenize input
        inputs = tokenizer(span, return_tensors="pt", truncation=truncation)
        # Predict class
        outputs = model(**inputs)
        # Get prediction values
        exps = torch.exp(outputs)
        # Get class
        span_class = exps.argmax(1).item()

        return span_class

In [6]:
folder_text = "../../../data/OpenAI/TextSnippetsCleaned/"

# caribbean_text_dict = pickle.load(open(F"{folder_text}paragraphs_caribbean_cleaned.pkl", 'rb'))
# palms_text_dict = pickle.load(open(F"{folder_text}paragraphs_palms_cleaned.pkl", 'rb'))
plantnet_text_dict = pickle.load(open(F"{folder_text}paragraphs_plantnet_cleaned.pkl", 'rb'))

In [8]:
def paragraph_to_descriptions(paragraph_dict):
    """Converts a dictionary of paragraphs to descriptions for each species.

    Args:
        paragraph_dict (dict): A dictionary where keys are species and values
            are lists of paragraphs.

    Returns:
        Tuple[Dict[str, List[str]], Dict[str, List[str]]]: A tuple of two 
        dictionaries. The first dictionary contains the descriptions for each 
        species where each description is a concatenation of several sentences.
        The second dictionary contains the sentences for each species that 
        passed a classification check.

    """
    # Initialize dictionaries
    description_paragraph_dict = collections.defaultdict(list)
    description_sentence_dict = collections.defaultdict(list)

    # Process each species
    for (species, paragraphs) in tqdm(paragraph_dict.items(), desc="Species", leave=True, position=0):

        # Process each paragraph
        for paragraph in tqdm(paragraphs, desc="Paragraph", leave=False, position=0):

            # Ignore very long paragraphs
            if len(paragraph) > 800000:
                continue
            
            # Parse paragraph with spaCy
            doc = nlp(paragraph)

            # Store sentences that pass classification check
            new_paragraph = []
            for sent in doc.sents:
                if classify_text(sent.text, model=model):
                    description_sentence_dict[species].append(sent.text)
                    new_paragraph.append(sent.text)

            # Store paragraph if it has valid sentences
            if new_paragraph:
                description_paragraph_dict[species].append(' '.join(new_paragraph))

    return description_paragraph_dict, description_sentence_dict


def species_paragraphs_to_json(text_dict):

    """Converts species descriptions from dictionary format to JSON format, and saves
    the resulting JSON files to specified folders.

    Args:
        text_dict (dict): A dictionary of species descriptions.

    Returns:
        None
    """

    folder_paragraphs = "../../../data/OpenAI/DescriptionSnippets/Paragraphs/"
    folder_sentences = "../../../data/OpenAI/DescriptionSnippets/Sentences/"

    for idx, (species, paragraphs) in enumerate(text_dict.items()): 

        # Folder Naming
        species_name = species.replace(' ', '_')
        file_name_sent = F"plantnet_{species_name}_descriptions_sentences"
        file_name_para = F"plantnet_{species_name}_descriptions_paragraphs"

        # Files Already done
        jsons_done = glob.glob("../../../data/OpenAI/DescriptionSnippets/Paragraphs/*")
        final_name = F"{folder_paragraphs}{file_name_para}.json"
        final_name = F"{folder_paragraphs}{file_name_para}.json"
        if final_name in jsons_done:
            continue

        # Init dict
        description_paragraph_dict = collections.defaultdict(list)
        description_sentence_dict = collections.defaultdict(list)

        for paragraph in (pbar := tqdm(paragraphs, leave=False, position=0)):
            pbar.set_description(f"{idx} {species}")

            # Not able to tokenize
            if len(paragraph) > 50000: # Was 800.000
                continue

            doc = nlp(paragraph)
            new_paragraph = []

            for sent in doc.sents:

                if classify_text(sent.text, model=model):

                    description_sentence_dict[species].append(sent.text)
                    new_paragraph.append(sent.text)

            if new_paragraph:
                description_paragraph_dict[species].append(' '.join(new_paragraph))

        with open(F"{folder_sentences}{file_name_sent}.json", 'w') as fp:
            json.dump(description_sentence_dict, fp)

        with open(F"{folder_paragraphs}{file_name_para}.json", 'w') as fp:
            json.dump(description_paragraph_dict, fp)  

In [8]:
# caribbean_text_dict = {k: caribbean_text_dict[k][0:20] for k in list(caribbean_text_dict)[:4]}

In [11]:
species_paragraphs_to_json(plantnet_text_dict)

47 Blighia sapida:  74%|███████▍  | 358/481 [01:06<00:18,  6.68it/s]                       Token indices sequence length is longer than the specified maximum sequence length for this model (558 > 512). Running this sequence through the model will result in indexing errors
                                                                                              

In [10]:
# caribbean_description_paragraph_dict, caribbean_description_sentence_dict = paragraph_to_descriptions(caribbean_text_dict)
# palms_description_paragraph_dict, palms_description_sentence_dict = paragraph_to_descriptions(palms_text_dict)
# plantnet_description_paragraph_dict, plantnet_description_sentence_dict = paragraph_to_descriptions(plantnet_text_dict)

In [11]:
# folder_text = "../../../data/OpenAI/DescriptionSnippets/"

# with open(F"{folder_text}descriptions_paragraphs_caribbean.pkl", 'wb') as f:
#     pickle.dump(caribbean_description_paragraph_dict, f)
# with open(F"{folder_text}descriptions_sentences_caribbean.pkl", 'wb') as f:
#     pickle.dump(caribbean_description_sentence_dict, f)

# with open(F"{folder_text}descriptions_paragraphs_palms.pkl", 'wb') as f:
#     pickle.dump(palms_description_paragraph_dict, f)
# with open(F"{folder_text}descriptions_sentences_palms.pkl", 'wb') as f:
#     pickle.dump(palms_description_sentence_dict, f)

# with open(F"{folder_text}descriptions_paragraphs_plantnet.pkl", 'wb') as f:
#     pickle.dump(plantnet_description_paragraph_dict, f)
# with open(F"{folder_text}descriptions_sentences_plantnet.pkl", 'wb') as f:
#     pickle.dump(plantnet_description_sentence_dict, f)