In [81]:
from datasets import load_dataset
dataset = load_dataset("pubmed_qa", "pqa_labeled")

  0%|          | 0/1 [00:00<?, ?it/s]

In [82]:
dataset['train'][0]

{'pubid': 21645374,
 'question': 'Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?',
 'context': {'contexts': ['Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.',
   'The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), ce

In [83]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("blizrys/biobert-v1.1-finetuned-pubmedqa")


In [84]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-classification", model="blizrys/biobert-v1.1-finetuned-pubmedqa")

In [85]:
MAX_TOKENS = 512  # BERT's maximum token limit

def truncate_text_to_fit(text, max_tokens=MAX_TOKENS):
    token_count = len(tokenizer.tokenize(text))
    if token_count > max_tokens - 2:  
        truncated_text = " ".join(text.split()[:-1])
        while len(tokenizer.tokenize(truncated_text)) > max_tokens - 2:
            truncated_text = " ".join(truncated_text.split()[:-1])
    else:
        truncated_text = text
        
    return truncated_text


In [86]:
def create_input_text(sample):
    input_text = "question: "+sample['question'] + " context: " + "".join(sample['context']['contexts'])
    return {'input_text': input_text}

# Apply function to each sample in the dataset
modified_dataset = dataset['train'].map(create_input_text)

  0%|          | 0/1000 [00:00<?, ?ex/s]

In [87]:
results = []
pubids = []
for text in modified_dataset:
        pubids.append(text['pubid'])
        results.append(pipe(truncate_text_to_fit(text['input_text'])))

Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors


In [88]:
label_mapping = {
    'LABEL_2': 'yes',
    'LABEL_1': 'no',
    'LABEL_0': 'maybe'
}

converted_results = [label_mapping[item[0]['label']] for item in results]

In [89]:
actual={}

for text in dataset['train']:
    if(text['pubid'] in pubids):
        actual[text['pubid']]= text['final_decision']

In [90]:
predictions = dict(zip(pubids, converted_results))

In [91]:
from sklearn.metrics import accuracy_score, classification_report


if set(predictions.keys()) != set(actual.keys()):
    raise ValueError("Keys in predictions and ground truth do not match!")

In [92]:
y_pred = [predictions[key] for key in sorted(predictions.keys())]
y_true = [actual[key] for key in sorted(actual.keys())]

# Compute accuracy
acc = accuracy_score(y_true, y_pred)
print(f"Accuracy: {acc}")

# Compute detailed classification report
report = classification_report(y_true, y_pred)
print(report)

Accuracy: 0.732
              precision    recall  f1-score   support

       maybe       0.00      0.00      0.00       110
          no       0.64      0.75      0.69       338
         yes       0.79      0.86      0.83       552

    accuracy                           0.73      1000
   macro avg       0.48      0.54      0.51      1000
weighted avg       0.65      0.73      0.69      1000



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
