In [1]:
import re
import os
import pickle
import glob
import time
import json
import collections

import numpy as np
import pandas as pd
from tqdm import tqdm
from json.decoder import JSONDecodeError

# set a fixed seed
import random
random.seed(333)

# the dotenv contains all the credentials for queries and prompts
from dotenv import load_dotenv
load_dotenv()

base_folder = 'Data'
dataset_name = 'example'

## Set up the species, traits and trait values of interest
---

In [None]:
# A short example of how to input the species and trait/trait-value vocabulary
species_to_query = ['Ceroxylon peruvianum', 'Calamus australis']

traits_dict =  {
    'Fruit Colour': ['black', 'blue', 'brown', 'green', 'grey', 'ivory', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'],
    'Flower Colour': ['black', 'blue', 'brown', 'green', 'grey', 'ivory', 'orange', 'pink', 'purple', 'red', 'white', 'yellow'],
    'Crown layer': ['both', 'canopy', 'understorey'],
    'Fruit Size': ['large', 'small'],
    'Fruit Shape': ['ellipsoid', 'elongate', 'fusiform', 'globose', 'ovoid', 'pyramidal', 'rounded'],
    'Conspicuousness': ['conspicuous', 'cryptic']
}


In [2]:
# An alternative example, using a .xlsx file containing the species and traits of interest
# Load the .xlsx file
file_path = 'WFO_Solanum_description.xlsx'  # Replace with the path to your .xlsx file

# Read all sheets from the file into a dictionary of dataframes
# Each sheet's name or index will be the key in the dictionary
sheets_dict = pd.read_excel(file_path, sheet_name=None, index_col=None, header=None)  # None loads all sheets

# Access individual sheets by name or by index
sheet1 = sheets_dict[list(sheets_dict.keys())[0]]  # First sheet by index
sheet2 = sheets_dict[list(sheets_dict.keys())[1]]  # Second sheet by index
sheet3 = sheets_dict[list(sheets_dict.keys())[2]]  # Third sheet by index

def keep_first_two_words(text):
    words = text.split()  # Split the string into words
    first_two = words[:2]  # Get the first two words
    return ' '.join(first_two)  # Join them back into a string

species_to_query = list(sheet1.iloc[0:10,0])
for i in range(len(species_to_query)):
    species_to_query[i] = keep_first_two_words(species_to_query[i])

species_to_query = [
    'Solanum lycopersicum',
'Solanum dulcamara',
'Solanum nigrum',
'Solanum tuberosum',
'Solanum laxum',
'Solanum americanum',
'Solanum melongena',
'Solanum pseudocapsicum',
'Solanum laciniatum',
'Solanum villosum' 
]

def clean_string(s):
    if isinstance(s, str):
        # Replace any special characters (like \n, \r, etc.) using regex
        return re.sub(r'[\n\r]', '', s).strip()  # Add more characters in the regex as needed
    return s

# Initialize the dictionary to hold the names and associated surnames
traits_dict = {}
current_name = None

# Iterate over each row
for index, row in sheet3.iterrows():
    name = clean_string(row[0])  # First column (name or index)
    surname = row[1]  # Second column (surname or NaN)
    
    if pd.isna(surname) and isinstance(name, str) and name != ' ':  # Name row
        current_name = name
        traits_dict[current_name] = []
    elif not pd.isna(surname) and current_name is not None:  # Surname row
        traits_dict[current_name].append(surname)

del traits_dict['Cultivated and non-native distribution']
del traits_dict['Major clade (group)']
del traits_dict['Geographic distribution']
del traits_dict['Minor clade (group)']
del traits_dict['Anther length']
del traits_dict['Filament length']
del traits_dict['']

## Custom Search Engine Setup
---

To retrieve the corresponding URLs for the species that we are goind to check, we are using the Google Custom Search API. To this end, we must first assign the appropriate credentials and then construct a function to define a custom service engine. See more about the Custom Search API and how to set it up here: [Custom Search API](https://developers.google.com/custom-search/v1/overview)

In [6]:
# Set up the google API credentials 
from googleapiclient.discovery import build

google_api_key = os.environ['GOOGLE_API']
cse_id = os.environ['GOOGLE_CSE']

# create a search query
def google_search(exact_term, other_search_term, api_key, cse_id, **kwargs):
    
    """
    Creates a search link for the custom Google search.
    """
    service = build("customsearch", "v1", developerKey=api_key)
    res = service.cse().list(exactTerms=exact_term, q=other_search_term, cx=cse_id, **kwargs).execute()
    return res['items']

## Query the Custom Search Engine
---

In the following, we define the search terms that we will search along with the names of the species defined before. Then, we utilize the **google_search** function defined in the previous blocks to retrieve the results from the custom search engine.

Throughout the notebook, we store all the intermediate data that emerge using the *pickle* python package. In the following, we are storing the richer retrieved urls that include the site title and other information (*species_urls_full*) and just the urls (*species_urls*).

In [7]:
# the search terms to use along with the selected species
search_terms = ["description","characteristics"]

# save the retrieved urls and links
species_urls = collections.defaultdict(list)
species_urls_full = collections.defaultdict(list)

# parse all the species and search terms
for species in tqdm(species_to_query):
    for search_term in search_terms:
        # constructs the query, e.g., Archontophoenix maxima description.
        #query = F'{species} {search_term}'
        exact_term = F'{species}'
        other_search_term = F'{search_term}'

        # Search results 
        search_results = google_search(exact_term, other_search_term, api_key=google_api_key, cse_id=cse_id)
        
        # Record the google Urls
        species_urls_full[species] = search_results
        
        # Recoerd just the links
        for result in search_results:
            species_urls[species].append(result['link'])


#save the results for all the species to files
path_to_save = f'{base_folder}/Search_Query_Results/{dataset_name}_dataset/'

#create the folder if it does not exist
os.makedirs(path_to_save, exist_ok = True)

# specific paths for urls and links
url_path = f'{path_to_save}urls.pkl'
full_url_path = f'{path_to_save}full_urls.pkl'

# now save the data to respective files 
with open(url_path, 'wb') as f:
    pickle.dump(species_urls, f)

with open(full_url_path, 'wb') as f:
    pickle.dump(species_urls_full, f)

100%|██████████| 10/10 [00:08<00:00,  1.16it/s]


### URL Processing: Converting results to plain text
---

We need to extract the text from the links in order to be able to process them and use them for prompting. We use the *requests* and *bs4* packages to this end in the following snippet to define a function to extract paragraphs from urls. 

In [8]:
# for parsing links 
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from bs4 import BeautifulSoup

def extract_paragraphs_from_urls(url_dict):

    # Initialise the dictionary that will contain the plain text for each species
    text_dict = collections.defaultdict(list)

    # Loop over species and URLS to extract the text
    for (species, urls) in tqdm(url_dict.items(), desc="Species", leave=True, position=0):
        for url in tqdm(urls, desc=f"URLs Species: {species}", leave=False, position=0):

            try:
                session = requests.Session()
                retry = Retry(total=1,
                              connect=1, 
                              backoff_factor=0.5)
                adapter = HTTPAdapter(max_retries=retry)
                session.mount('http://', adapter)
                session.mount('https://', adapter)
                response = session.get(url, timeout=3)
                
                
                soup = BeautifulSoup(response.text, 'html.parser')
                paragraphs = soup.find_all('p')
                for paragraph in paragraphs:
                    text_dict[species].append(paragraph.text)

            except:
                text_dict[species].append("Invalid URL")

    return text_dict

# call the extract paragraph function and save the results 
species_text = extract_paragraphs_from_urls(species_urls)

with open(F"{path_to_save}paragraphs.pkl", 'wb') as f:
    pickle.dump(species_text, f)

Species: 100%|██████████| 10/10 [02:40<00:00, 16.03s/it]                             


In [None]:
##### print some of the paragraphs to see if everything is ok. 
species_text[species_to_query[-1]][:5]

### Text Processing: Filter and Clean
---

The extracted paragraphs are highly "noisy": they contain special and irrelevant characters, while some others cannot be encoded correctly or they are repeated, e.g., whitespace characters. The following snippet used the *re* package along with some regular expressions to clean the text. After cleaning is performed, we remove any potential duplicate entries (by converting the list of paragraphs to a python set) and the return the valid pieces of text.

In [10]:
def regex_cleaner(string):
    # Define a list of regular expression patterns and their replacements
    cleaners = [
        # Replace multiple consecutive whitespace characters (spaces, tabs, newlines) with a single space character
        ("\s+", " "),
        # Replace multiple consecutive newline characters with a single newline character
        ("\n+", "\n"),
        # Replace multiple consecutive tab characters with a single tab character
        ("\t+", "\t"),
        # remove non-alphanumeric characters
        (r'[^\w\s/]', ''),
        # replace malformed characters,
        ('Â',''),
        ('â€“', '-'),
        ('·','.'),
        ('Ã','x'),
        (u'\xa0', u' '),
        ('â€‰', ''),
        ('â€', '-'),
        ('x©', 'e'),
        
    ]
    
    # Apply each regular expression pattern and its replacement to the input string
    for (cleaner, replacement) in cleaners:
        string = re.sub(cleaner, replacement, string)
    
    # Return the cleaned string
    return string.strip()

def filter_species_dict(text_dict):
    """
    Filters the descriptions in a dictionary of species, removing invalid text and duplicates.

    Args:
        text_dict (dict): A dictionary where the keys are the species names and the values are lists of text descriptions.

    Returns:
        dict: A filtered dictionary where the keys are the species names and the values are lists of valid and unique text descriptions.
    """

    valid_species_dict = {}

    # Loop through each species and its descriptions in the dictionary
    for idx, (species, descriptions) in enumerate(tqdm(text_dict.items(), leave=False, position=0)):
        # Create a progress bar for the species
        species_description = f"{idx} {species}"
        species_pbar = tqdm(descriptions, leave=False, position=1, desc=species_description)

        valid_descriptions = []
        # Loop through each description for the species
        for description in species_pbar:
            # Clean the description using regex_cleaner
            cleaned_description = regex_cleaner(description)
           
            if len(cleaned_description) < 10000 and len(cleaned_description) > 1:
                valid_descriptions.append(cleaned_description)

        # Remove any duplicate descriptions in the list
        valid_descriptions = list(set(valid_descriptions))
        # Add the valid descriptions for the species to the valid_species_dict
        valid_species_dict[species] = valid_descriptions

    return valid_species_dict

# now clean the text extracted from the urls and save the results
species_text_cleaned = filter_species_dict(species_text)

with open(f"{path_to_save}paragraphs_cleaned.pkl", 'wb') as f:
    pickle.dump(species_text_cleaned, f)

                                      

### Text Processing: Cleaned Text to Descriptions Classification
---

We now have a useful list of texts for each species from the extracted information from the web. However, it is highly likely that most of the extracted and cleaned parts do not contain any useful information for the respective species. To this end, we trained a Description Classifier; this model takes as input a piece of text and classifies the text as being a description or not. We retain the ones that are classified as descriptions while dicarding the rest. 

The model is based on the BERT model. We first import all necessary libraries and define the essential components; then we load the trained checkpoint and perform classification.

Requires: python -m spacy download en_core_web_trf

In [None]:
import torch
from torch import cuda
import torch.nn as nn
import transformers
from transformers import DistilBertTokenizer, DistilBertModel
import warnings
import spacy
from spacy import displacy
warnings.filterwarnings("ignore")


# Load some utilities
device = 'cuda' if cuda.is_available() else 'cpu'
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
nlp = spacy.load('en_core_web_trf')

# define the base bert model that we finetune for description classification
class BERT(nn.Module):
    def __init__(self, bert):

        super(BERT, self).__init__()

        # Distil Bert model
        self.bert = bert
        ## Additional layers
        # Dropout layer
        self.dropout = nn.Dropout(0.1)
        # Relu activation function
        self.relu =  nn.ReLU()
        # Dense layer 1
        self.fc1 = nn.Linear(768, 512)
        # Dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512, 2)
        # Softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, **kwargs):

        #pass the inputs to the model BERT  
        cls_hs = self.bert(**kwargs)
        hidden_state = cls_hs.last_hidden_state
        pooler = hidden_state[:, 0]

        # dense layer 1        
        x = self.fc1(pooler)
        # ReLU activation
        x = self.relu(x)
        # Drop out
        x = self.dropout(x)
        # dense layer 2
        x = self.fc2(x)
        # apply softmax activation
        x = self.softmax(x)

        return x
    
# Define an instance of the BERT model and load the weights
model = BERT(bert).to(device)

modelname = "saved_weights_BERT_description_classifier.pt"
location = "models/"

model_save_name = modelname
path = location + model_save_name
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
model.eval()


### Classification Function and Processing Functions
---

Having defined the description classifier, we now construct the appropriate functions for classifying the cleaned texts to descriptions. Instead of directly classifying the whole paragraph, it is first split into sentences using the *SpaCy* package. We store both individual sentences that were classified as descriptions, as well as new paragraphs containing only the descriptive sentences.

In [12]:
def classify_text(span, model, truncation=True, device = 'cpu'):

    """
    Uses a trained bert classifier to see if a span
    belongs to a species description or not.
    """
        
    with torch.no_grad():
        # Tokenize input
        inputs = tokenizer(span, return_tensors="pt", truncation=truncation).to(device)
        # Predict class
        outputs = model(**inputs)
        # Get prediction values
        exps = outputs.softmax(1)
        # Get class
        span_class = (exps[:,1]>0.6).item()

        return span_class


def paragraph_to_descriptions(paragraph_dict, device = 'cpu'):
    """Converts a dictionary of paragraphs to descriptions for each species.

    Args:
        paragraph_dict (dict): A dictionary where keys are species and values
            are lists of paragraphs.

    Returns:
        Tuple[Dict[str, List[str]], Dict[str, List[str]]]: A tuple of two 
        dictionaries. The first dictionary contains the descriptions for each 
        species where each description is a concatenation of several sentences.
        The second dictionary contains the sentences for each species that 
        passed a classification check.

    """
    # Initialize dictionaries
    description_paragraph_dict = collections.defaultdict(list)
    description_sentence_dict = collections.defaultdict(list)

    # Process each species
    for (species, paragraphs) in tqdm(paragraph_dict.items(), desc="Species", leave=True, position=0):

        # Process each paragraph
        for paragraph in tqdm(paragraphs, desc=f"Paragraphs Species: {species}", leave=False, position=0):

            # Ignore very long paragraphs
            if len(paragraph) > 80000:
                continue
            
            # Parse paragraph with spaCy
            doc = nlp(paragraph)

            # Store sentences that pass classification check
            new_paragraph = []
            for sent in doc.sents:
                if classify_text(sent.text, model=model, device = device):
                    description_sentence_dict[species].append(sent.text)
                    new_paragraph.append(sent.text)

            # Store paragraph if it has valid sentences
            if new_paragraph:
                description_paragraph_dict[species].append(' '.join(new_paragraph))

    return description_paragraph_dict, description_sentence_dict

# classify sentences and save
description_sentence_dict, description_paragraph = paragraph_to_descriptions(species_text_cleaned, device = device)

description_sentence_path = f'{path_to_save}description_sentences.pkl'
description_paragraphs_path = f'{path_to_save}description_paragraphs.pkl'

with open(description_sentence_path, 'wb') as f:
    pickle.dump(description_sentence_dict, f) 

with open(description_paragraphs_path, 'wb') as f:
    pickle.dump(description_paragraph, f)

Paragraphs Species: Solanum lycopersicum:   0%|          | 0/439 [00:00<?, ?it/s]

Species: 100%|██████████| 10/10 [02:41<00:00, 16.15s/it]                                     


In [None]:
import pprint
pp = pprint.PrettyPrinter(depth=4, width=500)
pp.pprint(description_sentence_dict)

### LLM Prompting
---

We now have the processed and classified text for each species. With this data at hand, we can prompt an LLM to explore if we can find information about species traits. First, we need to define the LLM client, e.g. MistralClient, load/set the traits that we want to epxlore and then use the sentences from the previous snippet to construct a prompt that we feed to the LLM. 

In [14]:
# Set Up the mistral/LLM API credentials
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

mistral_api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-medium-latest"

client = MistralClient(api_key=mistral_api_key)

In [15]:
data_folder = 'Solanum10/'

# define the name of the dataset 
results_folder = f'{data_folder}Prompt_Results/Mistral/{dataset_name}_dataset'
os.makedirs(results_folder, exist_ok = True)

# also save the traits dictionary in a text file
with open(f'{results_folder}/traits.txt', 'w') as f:
    for tname in traits_dict:
        f.write('{}: {}\n'.format(tname, traits_dict[tname]))

### LLM Prompting: Main Loop
---

Having loaded the traits, we proceed with the core loop that is responsible for the prompting of the LLM. For each species and its sentences, we construct an appropriate prompt (see the *text* variable below) that is sent to the LLM via the client chat function. We record the response of the LLM, making sure that it is in an appropriate form (JSON) before proceeding; otherwise we re=prompt the LLM. We save the results in human-readable form, i.e., plain text format. Then, we can proceed with the aggregation of the information for all the species. 

In [None]:
# traits are in traits
traits_names = list(traits_dict.keys())
traits_names_cap = [ tr.capitalize() for tr in traits_dict.keys() ]
    
# parse all the species
for idx, species in enumerate(description_sentence_dict):
    
    
    print('Cur Species Num: {}/{} Name: {}'.format(idx+1, len(description_sentence_dict), species))

    
    # create the folder for the species. replace blanks with underscores to avoid potential issues
    species_folder = results_folder + '/{}'.format(species.replace(' ', '_'))
    os.makedirs(species_folder, exist_ok = True)
    
    
    # this is the list with all the sentences, we are gonna iterate and combine.
    sentences_ = description_sentence_dict[species]
    
    #print(sentences_)
    # save cleaned sentences in a txt format
    with open(f'{species_folder}/sentences_cleaned.txt', 'w') as f:
        for sent in sentences_:
            f.write(sent+'\n')
        
    # all the reponses and contents only for the species 
    responses_full = []
    contents = []
    llm_dict_traits = {}
    
    cur_paragraph = '\n'.join(sentences_)
    
    pos_traits = '{'
    for j in range(0, len(traits_names)):
        pos_traits += '\"{}\": {}, '.format(traits_names[j].capitalize(), traits_dict[traits_names[j]]                                    )
        # until the third to last element to remove comma and space...
    pos_traits = pos_traits[:-2] + '}'
     
    
    text = 'We are interested in obtaining botanical trait information about the species {}.\n\n'.format(species)
    text += 'We will provide an input text with botanical descriptions,'\
            'followed by a dictionary where each key \'name\' represents a trait name, '\
            'referring to specific organ or other element of the plant, and is associated to a list '\
            'with all possible trait values for that trait, [\'value_1\', \'value_2\', ..., \'value_n\'].\n\n'
    
    text += 'Input text:\n'
    text += cur_paragraph +'\n\n'
    
    text += 'Initial dictionary of traits with all possible values:\n'
    text += pos_traits +'\n\n'
    
    text += 'Turn each string s in the list of values in the dictionary into a sublist (s,b), where b is a binary number,'\
             'either 0 or 1, indicating whether there is strong evidence for value s in the input text. '
    text+= 'Double check that \'value_i\' is reported referring to trait \'name\' in the text, '\
            'and not to a different trait. Set \'b\' to \'0\' if you are not sure about '\
            'the association. Do not modify the initial trait names and trait values, and do not add other traits than those provided in the dictionary. '\
            'Return the dictionary of traits and sublists of (value, evidence) containing all possible names and (value, evidence) tuples.\n\n'
    text += 'Output first a dictionary in JSON format, followed by a very short textual explanation for each positive response.\n\n'
    

    cur_path = '{}/results/'.format(species_folder)

    os.makedirs(cur_path, exist_ok = True)
  
    messages = [ChatMessage(role="user", content = text)]
    
    retries = 3
    while retries>0:
        try:

            chat_response = client.chat(
                model=model,
                #response_format={"type": "json_object"},
                messages=messages,
            )
            chat_response = chat_response.choices[0].message.content
            content = re.search(r'{.*}', chat_response, re.DOTALL).group()
            content_as_json = json.loads(content)
            
            retries = 0.
            break
        except (Exception, JSONDecodeError) as e:
            if e:
                print('Some Kind of Error, {}'.format(e))
                retries -= 1
                time.sleep(5)

    
    with open('{}/mistral_response_full.txt'.format(cur_path), 'w') as f:
        f.write(str(chat_response))
    with open('{}/mistral_response_content_only.txt'.format(cur_path), 'w') as f:
        f.write(content)

    with open('{}/mistral_sent_info_and_content.txt'.format(cur_path), 'w') as f:
        f.write('{}\n\n{}'.format(text, content))

    responses_full.append(str(chat_response))
    contents.append(content)


    with open('{}/responses_all.txt'.format(species_folder), 'w') as f:
        for resp in responses_full:
            f.write(resp + '\n\n')
    with open('{}/contents_all.txt'.format(species_folder), 'w') as f:
        for cont in contents:
            f.write(cont + '\n\n')

print(content)
    

### Post Processing 
---

All the responses of the LLM for each species are saved in their corresponding folders. All that is left is to aggregate the information from the individual species and output the results. We use the functions defined in the *aggregate_traits.py* file, found in the notebook's folder, and specifically the *post_processing* function.

In [None]:
from aggregate_traits import post_processing
post_processing(traits_dict, species_to_query,results_folder)