In [2]:
#NOTE: Create a .env file and put gemini keys in there as `GEMINI_API_KEY=...`
%load_ext dotenv
%dotenv

In [3]:
import os

parent_dir = os.path.dirname(os.getcwd())
data_dir = f'{parent_dir}/data'
results_dir = f'{parent_dir}/gemini_output/markdown_wiki'

## Download data from Hugging Face

In [None]:
from datasets import load_dataset

dataset = load_dataset("wikipedia", "20220301.en", split="train", streaming=True)

### Filter for medical domain

In [5]:
import re
from tqdm import tqdm

medical_keywords = [
    "medicine", "health", "disease", "treatment", "medical", "hospital", "doctor", "nurse",
    "pharmacy", "surgery", "clinical", "therapy", "diagnosis", "patient", "epidemic", "virus",
    "bacteria", "vaccine"
]

def is_medical_article(article):
    title = article["title"].lower()
    return any(re.search(rf"\b{keyword}\b", title) for keyword in medical_keywords)

def split_into_paragraphs(article):
    paragraphs = article["text"].split("\n\n")
    return paragraphs

In [8]:
articles = []
max_paragraphs = 50000
pbar = tqdm(total=max_paragraphs, desc="Processing Paragraphs")

for article in dataset:
    if is_medical_article(article):
        paragraphs = split_into_paragraphs(article)
        articles.extend(paragraphs)
        pbar.update(len(paragraphs))
        if len(articles) >= max_paragraphs:
            break


Processing Paragraphs:  12%|█▏        | 6153/50000 [01:23<21:45, 33.59it/s] '(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 8335e763-66be-40ec-95e4-1b622b019b90)')' thrown while requesting GET https://huggingface.co/datasets/wikipedia/resolve/main/data/20220301.en/train-00000-of-00041.parquet
Retrying in 1s [Retry 1/5].
Processing Paragraphs: 50010it [09:20, 68.05it/s]                            

In [15]:
articles[0]

'Motor neuron diseases or motor neurone diseases (MNDs) are a group of rare neurodegenerative disorders that selectively affect motor neurons, the cells which control voluntary muscles of the body. They include amyotrophic lateral sclerosis (ALS), progressive bulbar palsy (PBP), pseudobulbar palsy, progressive muscular atrophy (PMA), primary lateral sclerosis (PLS), spinal muscular atrophy (SMA) and monomelic amyotrophy (MMA), as well as some rarer variants resembling ALS.'

### Save the data

In [20]:
import pandas as pd

min_words = 10
max_words = 1024

df = pd.DataFrame(columns=['abstract'], data=articles)
df['#words'] = [len(a.split()) for a in articles]
df = df[(df['#words'] >= min_words) & (df['#words'] <= max_words)]
df.drop_duplicates(inplace=True)
df.describe()

Unnamed: 0,#words
count,38430.0
mean,76.719542
std,64.191047
min,10.0
25%,35.0
50%,61.0
75%,99.0
max,1004.0


In [24]:
df.drop(columns=['#words'])
df.to_csv(f'{data_dir}/Wikipedia_articles.tsv', sep='\t', index=False)

### Clean abstracts
* Some abstracts may contain HTML tags wheras others may contain URL links.
* We decided to retain the URL links but remove the HTML tags.
* Section headings such as `AIM`, `OBSERVATION`, `CONCLUSION`, etc present in the PubMed abstracts are removed as most model tend to summarise and thus section headings are not needed.

In [26]:
from bs4 import BeautifulSoup
from typing import List

def clean_html(strings: List[str]) -> List[str]:
    count_html = 0
    count_url = 0
    
    url_regex = r"(https?://[^\s]+)"
    # We remove special patterns which can be misidentified as HTML tags
    patterns_to_exclude = [r'<<.*?>>']

    for i in range(len(strings)):
        for pattern in patterns_to_exclude:
            strings[i] = re.sub(pattern, '', strings[i])
        soup = BeautifulSoup(strings[i], "html.parser")
        if soup.find():
            strings[i] = soup.get_text()
            count_html += 1
        if bool(re.search(url_regex, strings[i])):
            count_url += 1
    
    print(f"Number of abstracts with HTML tags: {count_html}")
    print(f"Number of abstracts with URLs: {count_url}")
    
    return strings

df['abstract'] = clean_html(df['abstract'].tolist())

  soup = BeautifulSoup(strings[i], "html.parser")


Number of abstracts with HTML tags: 10
Number of abstracts with URLs: 44


In [27]:
# Save the downloaded and cleaned data
df.to_csv(f'{data_dir}/Wikipedia_articles_cleaned.tsv', sep='\t', index=False)

## Using Google Gemini API

See the getting started guide for more information:
https://ai.google.dev/gemini-api/docs/get-started/python

In [4]:
import os

import google.generativeai as genai
from google.api_core.retry import Retry

genai.configure(api_key=os.environ['GEMINI_API_KEY'])

# Create the model
# See https://ai.google.dev/api/python/google/generativeai/GenerativeModel
generation_config = {
  "temperature": 1,
  "top_p": 0.95,
  "top_k": 64,
  "max_output_tokens": 8192,
  "response_mime_type": "text/plain",
}
# Safety settings are disabled as input text describing
# patient's mental health experience can contain disturbing
# content which is blocked by Gemini API filters.
safety_settings = [
  {
    "category": "HARM_CATEGORY_HARASSMENT",
    "threshold": "BLOCK_NONE",
  },
  {
    "category": "HARM_CATEGORY_HATE_SPEECH",
    "threshold": "BLOCK_NONE",
  },
  {
    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    "threshold": "BLOCK_NONE",
  },
  {
    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    "threshold": "BLOCK_NONE",
  },
]

model = genai.GenerativeModel(
    model_name="gemini-1.5-flash",
    safety_settings=safety_settings,
    generation_config=generation_config,
)

def generate(description: str) -> str:
    '''
    Generate output using Gemini-Flash API.
    Response is in markdown format.
    '''
    intro = "The below text contains some biomedical literature which is difficult for a layperson to understand."
    # Below instructions are used by the model to convert the description into a structured format
    instructions = "For the above text, create a simplified English version of the text which can be understood by a native English layperson with no medical background. Put the section heading as English simplified. The output section should have 1 paragraphs corresponding to the input text.\nNext, create an even more simpler version of the text which can be understood by a native English school kid with no medical background. Put the section heading as English super simplified. The output section should have 1 paragraphs corresponding to the input text.\n\nNext, created translated version of the simplified text in the following languages: Mandarin, followed by Spanish, followed by Arabic, followed by Hindi, followed by Bengali, followed by Portuguese, followed by Russian, followed by Japanese, followed by Punjabi\nPut the section heading as Langauge name Simplified. The output section should have 1 paragraphs corresponding to the input text. If some English terms excluding acronyms and numbers can't be translated then transliterate them.\n\nLet the voice in simplified text be same as in the original text so that the person narrating appears consistent. If there any URL links present in the original text then retain them in the simplified text as well."

    response = model.generate_content(f"{instructions}\n\n{description}", request_options={'timeout': 150, 'retry': Retry()})
    return response.text

### Test sample

#### Wikipedia sample article

Motor neuron diseases or motor neurone diseases (MNDs) are a group of rare neurodegenerative disorders that selectively affect motor neurons, the cells which control voluntary muscles of the body. They include amyotrophic lateral sclerosis (ALS), progressive bulbar palsy (PBP), pseudobulbar palsy, progressive muscular atrophy (PMA), primary lateral sclerosis (PLS), spinal muscular atrophy (SMA) and monomelic amyotrophy (MMA), as well as some rarer variants resembling ALS.

In [5]:
sample = "Motor neuron diseases or motor neurone diseases (MNDs) are a group of rare neurodegenerative disorders that selectively affect motor neurons, the cells which control voluntary muscles of the body. They include amyotrophic lateral sclerosis (ALS), progressive bulbar palsy (PBP), pseudobulbar palsy, progressive muscular atrophy (PMA), primary lateral sclerosis (PLS), spinal muscular atrophy (SMA) and monomelic amyotrophy (MMA), as well as some rarer variants resembling ALS."
output = generate(sample)

#### Simplified texts generated

In [6]:
from IPython.display import Markdown, display

display(Markdown(output.replace('**\n', '**<br>')))

## English Simplified

Motor neuron diseases (MNDs) are a group of uncommon diseases that damage the nerve cells that control our muscles. These diseases make it harder to move our bodies because they affect the nerves that tell our muscles what to do. Some examples of MNDs include amyotrophic lateral sclerosis (ALS), progressive bulbar palsy (PBP), pseudobulbar palsy, progressive muscular atrophy (PMA), primary lateral sclerosis (PLS), spinal muscular atrophy (SMA), and monomelic amyotrophy (MMA). There are also some rarer diseases that are similar to ALS. 

## English Super Simplified

Motor neuron diseases (MNDs) are like diseases that hurt the special cells in our bodies that help us move. These diseases make it hard to use our muscles because they damage the cells that tell our muscles what to do. There are lots of different kinds of MNDs, like ALS, PBP, PMA, PLS, SMA, and MMA. There are also some other MNDs that are rare but act like ALS. 

## Mandarin Simplified

运动神经元疾病 (MNDs) 是一组罕见的疾病，会损害控制我们肌肉的神经细胞。这些疾病使我们难以移动身体，因为它们会影响告诉我们肌肉该做什么的神经。一些 MND 的例子包括肌萎缩侧索硬化症 (ALS)、进行性延髓麻痹 (PBP)、假延髓麻痹、进行性肌萎缩 (PMA)、原发性侧索硬化症 (PLS)、脊髓性肌萎缩 (SMA) 和单肢肌萎缩 (MMA)。还有一些更罕见的疾病类似于 ALS。

## Spanish Simplified

Las enfermedades de la motoneurona (EMN) son un grupo de trastornos neurodegenerativos poco comunes que afectan selectivamente a las neuronas motoras, las células que controlan los músculos voluntarios del cuerpo. Incluyen la esclerosis lateral amiotrófica (ELA), la parálisis bulbar progresiva (PBP), la parálisis pseudobulbar, la atrofia muscular progresiva (PMA), la esclerosis lateral primaria (PLS), la atrofia muscular espinal (SMA) y la amiotrofia monomélica (MMA), así como algunas variantes más raras que se asemejan a la ELA.

## Arabic Simplified

أمراض الخلايا العصبية الحركية (MNDs) هي مجموعة من الاضطرابات العصبية التنكسية النادرة التي تؤثر بشكل انتقائي على الخلايا العصبية الحركية، وهي الخلايا التي تتحكم في العضلات الإرادية في الجسم. وتشمل هذه الأمراض التصلب الجانبي الضموري (ALS)، والشلل البُلبي التقدمي (PBP)، والشلل الزائف البُلبي، والضمور العضلي التقدمي (PMA)، والتصلب الجانبي الأولي (PLS)، والضمور العضلي الشوكي (SMA)، والضمور العضلي أحادي العظم (MMA)، بالإضافة إلى بعض الاختلافات النادرة التي تشبه ALS.

## Hindi Simplified

मोटर न्यूरॉन रोग (MNDs) दुर्लभ न्यूरोडिजेनेरेटिव विकारों का एक समूह है जो शरीर की स्वैच्छिक मांसपेशियों को नियंत्रित करने वाली कोशिकाओं, मोटर न्यूरॉन्स को चुनिंदा रूप से प्रभावित करते हैं। इसमें एमियोट्रोफिक लेटरल स्क्लेरोसिस (ALS), प्रोग्रेसिव बल्बर पैरालिसिस (PBP), स्यूडोबल्बर पैरालिसिस, प्रोग्रेसिव मस्कुलर एट्रोफी (PMA), प्राइमरी लेटरल स्क्लेरोसिस (PLS), स्पाइनल मस्कुलर एट्रोफी (SMA) और मोनोमेलिक एमियोट्रोफी (MMA) शामिल हैं, साथ ही ALS जैसी कुछ दुर्लभ रूप भी शामिल हैं।

## Bengali Simplified

মোটর নিউরন রোগ (MNDs) হলো বিরল নিউরোডেজেনারেটিভ ব্যাধিগুলির একটি গ্রুপ যা শরীরের স্বেচ্ছাসেবক পেশী নিয়ন্ত্রণকারী কোষগুলি, মোটর নিউরনগুলিকে নির্বাচনীভাবে প্রভাবিত করে। এর মধ্যে রয়েছে এমিয়োট্রফিক ল্যাটারাল স্ক্লেরোসিস (ALS), প্রোগ্রেসিভ বালবার প্যারালাইসিস (PBP), সিউডোবালবার প্যারালাইসিস, প্রোগ্রেসিভ মাসকুলার অ্যাট্রফি (PMA), প্রাইমারি ল্যাটারাল স্ক্লেরোসিস (PLS), স্পাইনাল মাসকুলার অ্যাট্রফি (SMA) এবং মোনোমেলাইক এমিয়োট্রফি (MMA), পাশাপাশি ALS এর মতো কিছু বিরল রূপও রয়েছে।

## Portuguese Simplified

Doenças do neurônio motor (DMN) são um grupo de doenças neurodegenerativas raras que afetam seletivamente os neurônios motores, as células que controlam os músculos voluntários do corpo. Elas incluem esclerose lateral amiotrófica (ELA), paralisia bulbar progressiva (PBP), paralisia pseudobulbar, atrofia muscular progressiva (PMA), esclerose lateral primária (PLS), atrofia muscular espinhal (SMA) e amiotrofia monomélica (MMA), bem como algumas variantes mais raras que se assemelham à ELA.

## Russian Simplified

Болезни двигательных нейронов (БДН) - это группа редких нейродегенеративных заболеваний, которые избирательно поражают двигательные нейроны, клетки, контролирующие произвольные мышцы тела. Они включают в себя боковой амиотрофический склероз (БАС), прогрессирующий бульбарный паралич (ПБП), псевдобульбарный паралич, прогрессирующую мышечную атрофию (ПМА), первичный латеральный склероз (ПЛС), спинальную мышечную атрофию (СМА) и мономелическую амиотрофию (ММА), а также некоторые более редкие варианты, напоминающие БАС.

## Japanese Simplified

運動ニューロン疾患（MND）は、身体の随意筋を制御する細胞である運動ニューロンを選択的に侵す、まれな神経変性疾患のグループです。これには、筋萎縮性側索硬化症（ALS）、進行性球麻痺（PBP）、偽球麻痺、進行性筋萎縮症（PMA）、原発性側索硬化症（PLS）、脊髄性筋萎縮症（SMA）、および単肢性筋萎縮症（MMA）などがあり、ALSに似たさらにまれな変異もあります。

## Punjabi Simplified

ਮੋਟਰ ਨਿਊਰੋਨ ਬਿਮਾਰੀਆਂ (MNDs) ਦੁਰਲੱਭ ਨਿਊਰੋਡੀਜਨਰੇਟਿਵ ਵਿਕਾਰਾਂ ਦਾ ਇੱਕ ਸਮੂਹ ਹਨ ਜੋ ਸਰੀਰ ਦੀਆਂ ਸਵੈਇੱਛਤ ਮਾਸਪੇਸ਼ੀਆਂ ਨੂੰ ਕੰਟਰੋਲ ਕਰਨ ਵਾਲੀਆਂ ਸੈੱਲਾਂ, ਮੋਟਰ ਨਿਊਰੋਨਾਂ ਨੂੰ ਚੁਣੌਤੀ ਦੇਣ ਵਾਲੇ ਪ੍ਰਭਾਵਿਤ ਕਰਦੇ ਹਨ। ਇਨ੍ਹਾਂ ਵਿੱਚ ਐਮੀਓਟ੍ਰੋਫਿਕ ਲੈਟਰਲ ਸਕਲੇਰੋਸਿਸ (ALS), ਪ੍ਰੋਗਰੈਸਿਵ ਬੁਲਬਰ ਪੈਰਾਲਿਸਿਸ (PBP), ਸਿਊਡੋਬੁਲਬਰ ਪੈਰਾਲਿਸਿਸ, ਪ੍ਰੋਗਰੈਸਿਵ ਮਸਕੂਲਰ ਐਟ੍ਰੋਫੀ (PMA), ਪ੍ਰਾਈਮਰੀ ਲੈਟਰਲ ਸਕਲੇਰੋਸਿਸ (PLS), ਸਪਾਈਨਲ ਮਸਕੂਲਰ ਐਟ੍ਰੋਫੀ (SMA) ਅਤੇ ਮੋਨੋਮੈਲਿਕ ਐਮੀਓਟ੍ਰੋਫੀ (MMA) ਸ਼ਾਮਲ ਹਨ, ਨਾਲ ਹੀ ALS ਵਰਗੀਆਂ ਕੁਝ ਹੋਰ ਦੁਰਲੱਭ ਕਿਸਮਾਂ ਵੀ ਸ਼ਾਮਲ ਹਨ। 


### Loading Wikipedia dataset

* Load previously downloaded data as a pandas dataframe

In [5]:
import pandas as pd

df = pd.read_csv(f'{data_dir}/Wikipedia_articles_cleaned.tsv', sep='\t')

### Generate in batches & parallel

In [13]:
import importlib
import os
import re
from time import sleep
from typing import List, Tuple

import numpy as np
from google.api_core.exceptions import ResourceExhausted
from joblib import Parallel, delayed
from requests.exceptions import RequestException
from tqdm import tqdm

progress_bar = None
data_dir = 'data'
results_dir = 'results/markdown_wiki'
RETRIES = 200 # Occasionally, the Gemini API can have glitches
failures = list() # All failed inputs get stored here
SAVE = True
DEBUG = True # Displays errors

REQUEST_TIMEOUT = 180 # If it takes longer than 3 minutes then timeout
GEMINI_API_LIMIT = 1500
GEMINI_MAX_OUT = 8196 # Maximum number of tokens that can be returned
TOKEN_RETURN_RATIO = 16 # Approximate output tokens returned for input text

if not os.path.exists(results_dir):
    os.makedirs(results_dir)


def get_model(api_key: str) -> genai.GenerativeModel:
    '''Returns a model configured with the API key to be used for parallel requests.'''
    module_name = 'google.generativeai'
    _genai = importlib.import_module(module_name)
    _genai.configure(api_key=api_key)
    return _genai.GenerativeModel(
        model_name="gemini-1.5-flash",
        safety_settings=safety_settings,
        generation_config=generation_config,
    )


def save_batch(text: str) -> None:
    matches = list(re.finditer(r"#*\s*Text ID (\d+)", text))
    # Split the text based on Text ids
    for i in range(len(matches)):
        start = matches[i].start()
        if i < len(matches) - 1:
            end = matches[i + 1].start()
        else:
            end = len(text)
        record_number = int(matches[i].group(1))

        with open(f'{results_dir}/{record_number}.md', 'w') as fp:
            fp.write(text[start:end].strip())


def generate_and_save(batch: List[Tuple[int, str]]) -> None:
    '''
    Generates and stores simplified text for the
    given batch using the Google Gemini Flash API.
    Response can be in markdown format or sometimes as plain text.
    '''
    # Below instructions are used by the model to convert the description into a structured format
    intro = f"{len(batch)} biomedical literature texts are provided below which are difficult for a layperson to understand."
    # Below instructions are used by the model to convert the description into a structured format
    instructions = "For each of the above texts, create a simplified English version of the text which can be understood by a native English layperson with no medical background. Put the section heading as English simplified. The output section should have 1 paragraphs corresponding to the input text.\nNext, create an even more simpler version of the text which can be understood by a native English school kid with no medical background. Put the section heading as English super simplified. The output section should have 1 paragraphs corresponding to the input text.\n\nNext, created translated version of the simplified text in the following languages: Mandarin, followed by Spanish, followed by Arabic, followed by Hindi, followed by Bengali, followed by Portuguese, followed by Russian, followed by Japanese, followed by Punjabi\nPut the section heading as Langauge name Simplified. The output section should have 1 paragraphs corresponding to the input text. If some English terms excluding acronyms and numbers can't be translated then transliterate them. Put the heading for each text as ## Text ID X, where X is the id of the text.\n\nLet the voice in simplified text be same as in the original text so that the person narrating appears consistent. If there any URL links present in the original text then retain them in the simplified text as well."
    batch_description = '\n\n'.join([f'Text ID {i}: {desc}'.replace('\n', ' ') for i, desc in batch])
    
    try:
        response = model.generate_content(f"{intro}\n\n{batch_description}\n\n{instructions}", request_options={'timeout': REQUEST_TIMEOUT})

        if SAVE:
            save_batch(response.text)

    except (RequestException, ValueError):
        # For very long output the request can timeout
        # For output containing unsafe text, ValueError is raised
        if DEBUG:
            print(f'Skipped the following indices for producing unsafe outputs:', [i for i, desc in batch])

    except Exception as e:
        global RETRIES
        if RETRIES <= 0:
            print(f"Error for batch: {e}")
            failures.append(batch_description)
        else:
            RETRIES -= 1
            if DEBUG:
                print('Retries left:', RETRIES, f'| {type(e).__name__}')
            sleep(10+RETRIES%10)
            return generate_and_save(batch)

    progress_bar.update(1)


def batch_generate(descriptions: List[str], start_at: int = 0, n_jobs: int = 1) -> None:
    '''
    Generates and stores simplified medical text in batches and in parallel.
    '''
    tasks = list()

    # Gemini has a max output limit of 8196, based on which we dynamically select the size of every batch.
    i = start_at
    while i < len(descriptions) and len(tasks)<GEMINI_API_LIMIT:
        batch = []
        num_words = 0
        for j in range(i, len(descriptions)):
            num_words += len(descriptions[j].split())
            if j > i and num_words * TOKEN_RETURN_RATIO >= GEMINI_MAX_OUT:
                break
            batch.append((j, descriptions[j]))
        tasks.append((batch,))
        i += len(batch)

    global progress_bar
    progress_bar = tqdm(total=len(tasks))

    Parallel(n_jobs=n_jobs, prefer='threads')(delayed(generate_and_save)(*task) for task in tasks)

In [None]:
for i in range(3, 6):
    model = get_model(os.environ[f'GEMINI_API_KEY_{i}'])
    batch_generate(
        descriptions = df['abstract'].tolist(), # Get all records
        start_at = max([int(n.split('.')[0])+1 for n in os.listdir(results_dir) if '.md' in n], default=0), # Skip if previously mined
        n_jobs = 8 # Adjust based on hardware and Gemini API per minute token rate limit
    )
    RETRIES = 200 # Reset the retries
    sleep(1800)



100%|██████████| 519/519 [6:11:28<00:00, 42.94s/it]


### Reproduce missing records
Some records got missed out due to glitches

In [6]:
import os
import markdown
import pandas as pd
from bs4 import BeautifulSoup
from tqdm import tqdm


def markdown_to_text(markdown_string):
    html = markdown.markdown(markdown_string)
    text = ''.join(BeautifulSoup(html, "html.parser").findAll(string=True))
    return text


def get_generated_record(index: int) -> str:
    file_path = f'{results_dir}/{index}.md'
    with open(file_path, 'r') as fp:
        return fp.read()

results_dir = 'results/markdown_wiki'
data_dir = 'data'
df = pd.read_csv(f'{data_dir}/Wikipedia_articles_cleaned.tsv', sep='\t')
abstracts = df['abstract'].tolist()
gen_record_indices = [int(n.split('.')[0]) for n in os.listdir(results_dir) if '.md' in n]
languages = [
    'english simplified',
    'english super simplified',
    'mandarin simplified',
    'spanish simplified',
    'arabic simplified',
    'hindi simplified',
    'bengali simplified',
    'portuguese simplified',
    'russian simplified',
    'japanese simplified',
    'punjabi simplified'
]

In [7]:
df['#words'] = [len(a.strip().split(' ')) for a in df['abstract'].tolist()]
df.describe()

Unnamed: 0,#words
count,38430.0
mean,76.291543
std,63.973502
min,1.0
25%,34.0
50%,61.0
75%,99.0
max,1007.0


In [8]:
processed_records = list()
failures = list() # Unsafe outputs

for index in tqdm(gen_record_indices):
    md = get_generated_record(index)
    sections = markdown_to_text(md).split('\n')
    json = {'original': abstracts[index]}
    i = 1
    while i < len(sections)-1:
        lang = sections[i].lower().strip()
        if not lang in languages:
            break
        else:
            # Handle both single para and multi-para
            next_i = i+2
            while next_i < len(sections):
                section = sections[next_i].lower().strip()
                if section in languages:
                    break
                else:
                    next_i += 1
            # Sometimes the generated text contains duplicate sections
            if lang not in json:
                section_text = ' '.join(section.strip() for section in sections[i+1:next_i])
                json[lang] = section_text
            i = next_i

    json['languages present'] = len(json) - 2
    if json['languages present'] > 0:
        processed_records.append(json)
    else:
        failures.append(index)

100%|██████████| 32161/32161 [00:01<00:00, 22579.60it/s]


In [9]:
missing_records = [index for index in range(len(abstracts)) if index not in gen_record_indices or index in failures]
print(f'Number of records missing: {len(missing_records)}')

Number of records missing: 6343


In [20]:
import importlib
import os
import re
from time import sleep
from typing import List, Tuple

import numpy as np
from google.api_core.exceptions import ResourceExhausted
from joblib import Parallel, delayed
from requests.exceptions import RequestException
from tqdm import tqdm

progress_bar = None
data_dir = 'data'
results_dir = 'results/markdown_wiki'
RETRIES = 200 # Occasionally, the Gemini API can have glitches
failures = list() # All failed inputs get stored here
SAVE = True
DEBUG = True # Displays errors

REQUEST_TIMEOUT = 180 # If it takes longer than 3 minutes then timeout
GEMINI_API_LIMIT = 1500
GEMINI_MAX_OUT = 8196 # Maximum number of tokens that can be returned
TOKEN_RETURN_RATIO = 16 # Approximate output tokens returned for input text

if not os.path.exists(results_dir):
    os.makedirs(results_dir)


def get_model(api_key: str) -> genai.GenerativeModel:
    '''Returns a model configured with the API key to be used for parallel requests.'''
    module_name = 'google.generativeai'
    _genai = importlib.import_module(module_name)
    _genai.configure(api_key=api_key)
    return _genai.GenerativeModel(
        model_name="gemini-1.5-flash",
        safety_settings=safety_settings,
        generation_config=generation_config,
    )


def save_batch(text: str) -> None:
    matches = list(re.finditer(r"#*\s*Text ID (\d+)", text))
    # Split the text based on Text ids
    for i in range(len(matches)):
        start = matches[i].start()
        if i < len(matches) - 1:
            end = matches[i + 1].start()
        else:
            end = len(text)
        record_number = int(matches[i].group(1))

        with open(f'{results_dir}/{record_number}.md', 'w') as fp:
            fp.write(text[start:end].strip())


def generate_and_save(batch: List[Tuple[int, str]]) -> None:
    '''
    Generates and stores simplified text for the
    given batch using the Google Gemini Flash API.
    Response can be in markdown format or sometimes as plain text.
    '''
    # Below instructions are used by the model to convert the description into a structured format
    intro = f"{len(batch)} biomedical literature texts are provided below which are difficult for a layperson to understand."
    # Below instructions are used by the model to convert the description into a structured format
    instructions = "For each of the above texts, create a simplified English version of the text which can be understood by a native English layperson with no medical background. Put the section heading as English simplified. The output section should have 1 paragraphs corresponding to the input text.\nNext, create an even more simpler version of the text which can be understood by a native English school kid with no medical background. Put the section heading as English super simplified. The output section should have 1 paragraphs corresponding to the input text.\n\nNext, created translated version of the simplified text in the following languages: Mandarin, followed by Spanish, followed by Arabic, followed by Hindi, followed by Bengali, followed by Portuguese, followed by Russian, followed by Japanese, followed by Punjabi\nPut the section heading as Langauge name Simplified. The output section should have 1 paragraphs corresponding to the input text. If some English terms excluding acronyms and numbers can't be translated then transliterate them. Put the heading for each text as ## Text ID X, where X is the id of the text.\n\nLet the voice in simplified text be same as in the original text so that the person narrating appears consistent. If there any URL links present in the original text then retain them in the simplified text as well."
    batch_description = '\n\n'.join([f'Text ID {i}: {desc}'.replace('\n', ' ') for i, desc in batch])
    
    try:
        response = model.generate_content(f"{intro}\n\n{batch_description}\n\n{instructions}", request_options={'timeout': REQUEST_TIMEOUT})

        if SAVE:
            save_batch(response.text)

    except (RequestException, ValueError):
        # For very long output the request can timeout
        # For output containing unsafe text, ValueError is raised
        if DEBUG:
            print(f'Skipped the following indices for producing unsafe outputs:', [i for i, desc in batch])

    except Exception as e:
        global RETRIES
        if RETRIES <= 0:
            print(f"Error for batch: {e}")
            failures.append(batch_description)
        else:
            RETRIES -= 1
            if DEBUG:
                print('Retries left:', RETRIES, f'| {type(e).__name__}')
            sleep(10+RETRIES%10)
            return generate_and_save(batch)

    progress_bar.update(1)


def batch_generate(descriptions: List[str], start_at: int = 0, n_jobs: int = 1) -> None:
    '''
    Generates and stores simplified medical text in batches and in parallel.
    '''
    tasks = list()

    # Gemini has a max output limit of 8196, based on which we dynamically select the size of every batch.
    i = start_at
    while i < len(descriptions) and len(tasks)<GEMINI_API_LIMIT:
        batch = []
        num_words = 0
        for j in range(i, len(descriptions)):
            if j in missing_records:
                num_words += len(descriptions[j].split())
                if len(batch) > 1 and num_words * TOKEN_RETURN_RATIO >= GEMINI_MAX_OUT:
                    break
                batch.append((j, descriptions[j]))
        if not batch: break
        tasks.append((batch,))
        i = j
    global progress_bar
    progress_bar = tqdm(total=len(tasks))

    Parallel(n_jobs=n_jobs, prefer='threads')(delayed(generate_and_save)(*task) for task in tasks)

  0%|          | 0/3 [00:39<?, ?it/s]


In [21]:
model = get_model(os.environ[f'GEMINI_API_KEY'])
batch_generate(
    descriptions = df['abstract'].tolist(), # Get all records
    start_at = 0, # Skip if previously mined
    n_jobs = 8 # Adjust based on hardware and Gemini API per minute token rate limit
)

100%|██████████| 3/3 [01:01<00:00, 18.58s/it]