In [1]:
import os
import json
import pandas as pd
import random
from tqdm import tqdm

In [44]:
def prep_ft_input(df,
                  output_folder,
                  prop={'train': 0.8, 'val': 0.1, 'test': 0.1}, 
                  sample_size=5000):
    
    assert sum(list(prop.values())) == 1, 'The proportions do not add up to 1.'
    
    df = df.sample(n=sample_size, random_state=7)
    
    n_train = int(prop['train'] * len(df))
    n_val = int(prop['val'] * len(df))
    n_test = len(df) - n_train - n_val

    train = df[:n_train]
    val = df[n_train:n_train+n_val]
    test = df[n_train+n_val:len(df)]
    
    print('Train dataset size:', len(train))
    print('Validation dataset size:', len(val))
    print('Test dataset size:', len(test))
    
    train['question'].to_csv(output_folder + 'train.source', header=False, index=False)
    val['question'].to_csv(output_folder + 'val.source', header=False, index=False)
    test['question'].to_csv(output_folder + 'test.source', header=False, index=False)
    
    train['answer'].to_csv(output_folder + 'train.target', header=False, index=False)
    val['answer'].to_csv(output_folder + 'val.target', header=False, index=False)
    test['answer'].to_csv(output_folder + 'test.target', header=False, index=False)
    
    return train, val, test

In [37]:
qa_corpora = pd.read_json('segmented_msmarco_v2/squad.biomedical.train.json', orient='split')
print(len(qa_corpora))
qa_corpora.head()

22134


Unnamed: 0,paragraphs
0,[{'context': 'Chronic Obstructive Pulmonary Di...
1,[{'context': 'Folic acid is the synthetic form...
2,"[{'context': 'So which minerals do you need, a..."
3,[{'context': 'A triglyceride is an ester deriv...
4,[{'context': 'What causes tooth root resorptio...


In [39]:
# Convert each string entry in qa_corpora to a dictionary for parsing
qa_data = []
for i, row in tqdm(qa_corpora.iterrows()):
    item = {}
    item['question'] = row['paragraphs'][0]['qas'][0]['question']
    item['answer'] = row['paragraphs'][0]['qas'][0]['answers'][0]['text']

    qa_data.append(item)

22134it [00:02, 10604.12it/s]


In [41]:
df = pd.DataFrame(qa_data)
df.replace({'"': '', '\n': ' ', '\t': ''}, regex=True, inplace=True)

print(len(df))
df.head()

22134


Unnamed: 0,question,answer
0,what disability is copd,Chronic Obstructive Pulmonary Disease and Soci...
1,what does folic acid do,crucial for proper brain function and plays an...
2,the importance of minerals in diet,Minerals are incredibly important for health a...
3,triglycerides what are they,A triglyceride is an ester derived from glycer...
4,what causes resorption of teeth,"trauma, periodontitis, orthodontic treatment, ..."


In [45]:
train, val, test = prep_ft_input(df, output_folder='input_rag_ft/', sample_size=3000)

Train dataset size: 2400
Validation dataset size: 300
Test dataset size: 300


In [47]:
# Deleting any double quotes in every file generated
inp = 'input_rag_ft/'
out = 'finetune_qa_pairs_processed/'

files = os.listdir(inp)

for f in files:
    with open(os.path.join(inp, f), 'r') as fi, open(os.path.join(out, f), 'w') as fo:
        for line in fi:
            fo.write(line.replace('"', ''))

## Generating samples for qualitative evals

In [54]:
sample_train = train[['question', 'answer']].head(20).to_dict('records')
sample_train

[{'question': 'what is ivig used to treat?',
  'answer': 'autoimmune diseases, immune deficiencies, and certain kinds of infections.'},
 {'question': 'process of recovery after a leg fracture',
  'answer': 'Give your fracture time to heal. Your body needs time to mend itself and create new bone; a fracture generally takes between 6 and 12 weeks to heal substantially. When healing, your bone will go through three stages: Inflammation: This process lasts for the first few days after a fracture.'},
 {'question': 'causes of left side chest pain',
  'answer': 'cardiovascular conditions like angina, a heart attack, pericarditis, pulmonary embolism or aortic dissection.'},
 {'question': 'what is fistula',
  'answer': 'an abnormal connection between two parts inside of the body.'},
 {'question': 'what is ptsd mean',
  'answer': "a mental health condition that's triggered by a terrifying event â€” either experiencing it or witnessing it."},
 {'question': 'causes of defatting skin',
  'answer': 

In [55]:
sample_test = test[['question', 'answer']].head(20).to_dict('records')
sample_test

[{'question': 'maximum dose of abilify for children',
  'answer': '2 mg once daily'},
 {'question': 'what is translocation mutation',
  'answer': 'Chromosomal translocation, that is a chromosomal segment is moved from one position to another, either within the same chromosome or to another chromosome.'},
 {'question': 'what are some bad things that can happen when having a genetic squencing done',
  'answer': "sequencing the genome doesn't immediately lay open the genetic secrets of an entire species. Even with a rough draft of the human genome sequence in hand, much work remains to be done."},
 {'question': 'idiopathic meaning g', 'answer': 'the cause is not known'},
 {'question': 'what is the pituitary gland',
  'answer': 'a pea-sized gland at the base of the brain, produces a number of hormones.'},
 {'question': 'how do lipids function',
  'answer': 'functions of lipids include storing energy, signaling, and acting as structural components of cell membranes.'},
 {'question': 'which 

In [63]:
sample_external = df[['question', 'answer']].head(20).to_dict('records')
sample_external

[{'question': 'what disability is copd',
  'answer': 'Chronic Obstructive Pulmonary Disease and Social Security Disability COPD, or chronic obstructive pulmonary disease is a series of lung diseases that damages your lungs, blocking airflow and affecting your ability to breathe.'},
 {'question': 'what does folic acid do',
  'answer': 'crucial for proper brain function and plays an important role in mental and emotional health.'},
 {'question': 'the importance of minerals in diet',
  'answer': 'Minerals are incredibly important for health and to prevent chronic disease.'},
 {'question': 'triglycerides what are they',
  'answer': 'A triglyceride is an ester derived from glycerol and three to four fatty acids.'},
 {'question': 'what causes resorption of teeth',
  'answer': 'trauma, periodontitis, orthodontic treatment, internal bleaching, cysts, tumors or stimuli from a necrotic dental pulp may cause tooth root resorption.'},
 {'question': 'is the mitochondria a organelle or a function',


## Checking for common questions between the examples above

In [61]:
set(train['question'].head(20)) & set(df['question'].head(20))

set()

In [65]:
set(test['question'].head(20)) & set(df['question'].head(20))

set()