# Install/import libraries

In [1]:
#!pip install transformers[sentencepiece] datasets sacrebleu rouge_score py7zr -q
#!pip install transformers
#!pip install bert_score
#!pip install textstat

In [17]:
from transformers import pipeline, set_seed
import matplotlib.pyplot as plt
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BartTokenizerFast, BartForConditionalGeneration
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import torch
from datasets import load_dataset, load_metric
import textstat
from transformers import Seq2SeqTrainingArguments, Trainer, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
import os

# Load data

In [3]:
train_key='train.json'
test_key='test.json'
val_key='val.json'

In [4]:
data_location = 'data/'

In [5]:
# Train
train_data=pd.read_json(f'{data_location}{train_key}')
print("Train Shape: ", train_data.shape)
train_data.head(1)

Train Shape:  (4346, 8)


Unnamed: 0,id,year,title,sections,headings,abstract,summary,keywords
0,elife-35500-v1,2018,National and regional seasonal dynamics of all...,[[It is well-established that death rates vary...,"[Introduction, Results, Discussion, Materials ...","[In temperate climates , winter deaths exceed ...","[In the USA , more deaths happen in the winter...",[epidemiology and global health]


In [6]:
# Val
val_data=pd.read_json(f'{data_location}{val_key}')
print("Val Shape: ", val_data.shape)
val_data.head(1)

Val Shape:  (241, 8)


Unnamed: 0,id,year,title,sections,headings,abstract,summary,keywords
0,elife-15477-v3,2016,Increasing Notch signaling antagonizes PRC2-me...,"[[Cell-fate decisions are controlled , on the ...","[Introduction, Results, Discussion, Materials ...",[Cell-fate reprograming is at the heart of dev...,[The DNA in genes encodes the basic informatio...,[developmental biology]


In [7]:
# Test
test_data=pd.read_json(f'{data_location}{test_key}')
print("Test Shape: ", test_data.shape)
test_data.head(1)

Test Shape:  (241, 8)


Unnamed: 0,id,year,title,sections,headings,abstract,summary,keywords
0,elife-37443-v3,2018,Cerebellar implementation of movement sequence...,"[[Most movements are comprised of sequences .,...","[Introduction, Results, Discussion, Materials ...","[Most movements are not unitary , but are comp...",[Imagine a gymnastics competition in which par...,[neuroscience]


# Preprocessing

In [8]:
def reshape_dataframe(df, columns_to_keep):
    """
    Reshapes a dataframe based on its 'sections' and 'headings' columns. Each unique heading
    becomes a column in the reshaped dataframe, where the entries are the corresponding sections.

    Parameters:
    - df (pd.DataFrame): The input dataframe, which  have columns named 'id', 'sections',
                         and 'headings' etc. The 'sections' column should contain lists of strings,
                         while the 'headings' column should contain lists of headings corresponding
                         to the sections.

    - columns_to_keep (list of str): List of columns from the original dataframe that should
                                     be retained in the final reshaped dataframe.

    Returns:
    - pd.DataFrame: A reshaped dataframe where each unique heading from the 'headings' column
                    is now its own column. Original columns specified in 'columns_to_keep' are
                    also retained.
    """

    section_dicts = []
    for index, row in df.iterrows():
        temp_dict = {'id': row['id']}
        row_headings = row['headings']

        row_sections = row['sections']

        for heading, section in zip(row_headings, row_sections):
            temp_dict[heading] = section

        section_dicts.append(temp_dict)

    section_df = pd.DataFrame(section_dicts)
    result_df = pd.merge(df, section_df, on='id', how='outer')
    result_df.rename(columns={'summary': 'lay summary'}, inplace=True)
    result_df = result_df[columns_to_keep]
    result_df = result_df.dropna(subset=['Introduction', 'abstract'])
    result_df = result_df.reset_index(drop=True)
    return result_df

In [9]:
# The columns you want to keep BASED ON IITR:
cols_to_keep = ['Introduction', 'abstract', 'lay summary']
IITR_train_df = reshape_dataframe(train_data, cols_to_keep)
IITR_val_df = reshape_dataframe(val_data, cols_to_keep)
IITR_test_df = reshape_dataframe(test_data, cols_to_keep)

## Combine Introduction & Abstract based on fragment

In [10]:
def combine_texts(row, k=0.6):
    intro_part = row['Introduction'][:int(len(row['Introduction']) * k)]
    return intro_part + row['abstract']

In [11]:
IITR_train_df['intro_abstract_combined'] = IITR_train_df.apply(lambda row: combine_texts(row, k=0.6), axis=1)
IITR_test_df['intro_abstract_combined'] = IITR_test_df.apply(lambda row: combine_texts(row, k=0.6), axis=1)
IITR_val_df['intro_abstract_combined'] = IITR_val_df.apply(lambda row: combine_texts(row, k=0.6), axis=1)

In [12]:
IITR_train_df.drop(['Introduction', 'abstract'],inplace=True, axis=1)
IITR_test_df.drop(['Introduction', 'abstract'],inplace=True, axis=1)
IITR_val_df.drop(['Introduction', 'abstract'],inplace=True, axis=1)
IITR_train_df

Unnamed: 0,lay summary,intro_abstract_combined
0,"[In the USA , more deaths happen in the winter...",[It is well-established that death rates vary ...
1,[Most people have likely experienced the disco...,[Dysregulated complement activation is increas...
2,[The immune system protects an individual from...,"[HOIL-1 ( encoded by the RBCK1 gene ) , HOIP (..."
3,[The brain adapts to control our behavior in d...,[Flexible control of cognitive processes is fu...
4,[Cells use motor proteins that to move organel...,[Myosin 5a moves in a hand-over-hand fashion w...
...,...,...
4311,[To defend itself against bacteria and viruses...,[Antibodies are immunogenic proteins expressed...
4312,[DNA is tightly packaged in a material called ...,[The eukaryotic genome is packaged into chroma...
4313,[Associative learning is a simple learning abi...,[The temporal and spatial heterogeneity of any...
4314,"[In 1848 , a railroad worker named Phineas Gag...",[Correlates of decision variables are routinel...


## Convert list instances to String

In [13]:
import ast
from datasets import Dataset

def str_list_to_str(s):
    # If it's a list, then join its items into a string
    if isinstance(s, list):
        return ' '.join(s)
    # If it's already a string, return as it is
    return s



IITR_train_df = IITR_train_df.applymap(str_list_to_str)
IITR_test_df = IITR_test_df.applymap(str_list_to_str)
IITR_val_df = IITR_val_df.applymap(str_list_to_str)

In [14]:
train_dataset = Dataset.from_pandas(IITR_train_df)
val_dataset = Dataset.from_pandas(IITR_val_df)
test_dataset = Dataset.from_pandas(IITR_test_df)

# BART

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [18]:
model_ckpt = "facebook/bart-large-cnn"
model_bart = BartForConditionalGeneration.from_pretrained(model_ckpt).to(device)
tokenizer = BartTokenizerFast.from_pretrained(model_ckpt)

In [19]:
max_input_length = 1024
max_target_length = 512

# Tokenize data

In [20]:
def tokenize_function(examples):
    inputs = tokenizer(examples['intro_abstract_combined'], max_length=max_input_length, padding="max_length", truncation=True)
    targets = tokenizer(examples['lay summary'], max_length=max_target_length, padding="max_length", truncation=True)

    return {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "labels": targets.input_ids
    }

train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

# Fine-tuning

In [21]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    num_train_epochs=5,  
    per_device_train_batch_size=8,
    evaluation_strategy="epoch",
    save_strategy='epoch',
    logging_dir="./logs",
    report_to="none",
    fp16=True,
    gradient_accumulation_steps=4
)

In [24]:
trainer = Seq2SeqTrainer(
    model=model_bart,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


Using cuda_amp half precision backend


In [25]:
os.environ['WANDB_DISABLED'] = 'true'
trainer.train()

The following columns in the training set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: intro_abstract_combined, lay summary. If intro_abstract_combined, lay summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 4316
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 675


Epoch,Training Loss,Validation Loss
1,No log,1.934365
2,No log,1.890086
3,No log,1.878118
4,1.893000,1.869586
5,1.893000,1.875278


The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: intro_abstract_combined, lay summary. If intro_abstract_combined, lay summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 239
  Batch size = 8
Saving model checkpoint to ./results/checkpoint-135
Configuration saved in ./results/checkpoint-135/config.json
Model weights saved in ./results/checkpoint-135/pytorch_model.bin
The following columns in the evaluation set don't have a corresponding argument in `BartForConditionalGeneration.forward` and have been ignored: intro_abstract_combined, lay summary. If intro_abstract_combined, lay summary are not expected by `BartForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 239
  Batch size = 8
Saving model checkpoint to ./results/checkpo

TrainOutput(global_step=675, training_loss=1.8103749367042825, metrics={'train_runtime': 1761.8702, 'train_samples_per_second': 12.248, 'train_steps_per_second': 0.383, 'total_flos': 4.676611731357696e+16, 'train_loss': 1.8103749367042825, 'epoch': 5.0})

In [26]:
#trainer.save_model('my_saved_model')
#model_bart.save_pretrained("saved_the_model_bart")
#tokenizer.save_pretrained("saved_the_tokenizer.save_pretrained")

Saving model checkpoint to my_saved_model
Configuration saved in my_saved_model/config.json
Model weights saved in my_saved_model/pytorch_model.bin


# Inference