In [1]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
import pandas as pd
import torch
from tqdm import tqdm  # For progress tracking
from transformers import pipeline, BartForConditionalGeneration, BartTokenizer

In [3]:
# Assume `df` is your DataFrame with 'input_text' and 'output_text' columns
#df = pd.read_csv('medquad.csv')  # Replace with your dataset

In [4]:
#print(f"Number of unique questions before aggregation: {df.shape[0]}")

Number of unique questions before aggregation: 16412


In [5]:
## Check for duplicate questions
#print(f"Number of duplicated questions: {df['question'].duplicated().sum()}")

Number of duplicated questions: 1428


In [6]:
# Define a function to aggregate answers
# def aggregate_answers(group):
#     # Concatenate unique answers with a separator to preserve structure
#     # Convert answers to strings before joining to handle potential numeric types
#     combined_answer = " ".join([str(answer) for answer in group['answer'].unique()])
#     return combined_answer

# # Apply aggregation to remove duplicate questions but keep comprehensive answers
# aggregated_df = df.groupby('question').apply(lambda group: pd.Series({
#     'question': group['question'].iloc[0],
#     'answer': aggregate_answers(group)
# }))


# # Check the results
# print(f"Number of unique questions after aggregation: {aggregated_df.shape[0]}")

# # Save the cleaned and aggregated dataset to a new CSV file
# aggregated_df.to_csv('medquad_cleaned.csv', index=False)


Number of unique questions after aggregation: 14984


In [3]:
# Load the model and tokenizer
model_name = "facebook/bart-large-cnn"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name).to("cuda")

# Create a summarization pipeline and specify device=0 for GPU
pipe = pipeline("summarization", model=model, tokenizer=tokenizer, device=0)

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

In [4]:
# Load your dataset (Replace this path with your dataset path)
df = pd.read_csv("/content/medquad_cleaned.csv")

In [5]:
# Check for any NaN or non-string entries in the 'answer' column
invalid_entries = df[df['answer'].isna() | (df['answer'].apply(lambda x: not isinstance(x, str)))]
print(f"Number of invalid entries: {len(invalid_entries)}")

# Display the invalid entries
if not invalid_entries.empty:
    display(invalid_entries)

# Remove rows with NaN values in the 'answer' column
df = df.dropna(subset=['answer'])

# Ensure all remaining entries in the 'answer' column are strings
df['answer'] = df['answer'].astype(str)


Number of invalid entries: 5


Unnamed: 0,question,answer
11195,What is (are) Emery-Dreifuss muscular dystroph...,
11196,What is (are) Emery-Dreifuss muscular dystroph...,
11262,What is (are) Familial HDL deficiency ?,
11474,What is (are) HELLP syndrome ?,
13143,What is (are) X-linked lymphoproliferative syn...,


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['answer'] = df['answer'].astype(str)


In [6]:
# Optional: Use half-precision (FP16) for faster inference
model.half()

BartForConditionalGeneration(
  (model): BartModel(
    (shared): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): BartScaledWordEmbedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
    

In [8]:
# Set batch size; adjust based on GPU memory
batch_size = 32

# Summarization function for batch processing
def summarize_batch(text_list, min_length=150, max_length=350):
    inputs = tokenizer(
        text_list, return_tensors="pt", padding="longest", truncation=True, max_length=1024
    ).to("cuda")

    summary_ids = model.generate(
        inputs=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=max_length,
        min_length=min_length,
        length_penalty=0.9,
        num_beams=6,
        no_repeat_ngram_size=3,
        early_stopping=True,
    )

    summaries = [tokenizer.decode(g, skip_special_tokens=True) for g in summary_ids]
    return summaries

# Process the dataset in batches
summaries = []
num_batches = len(df) // batch_size + (1 if len(df) % batch_size != 0 else 0)

for i in tqdm(range(num_batches), desc="Summarizing Batches"):
    batch_texts = df['answer'][i * batch_size:(i + 1) * batch_size].tolist()
    batch_summaries = summarize_batch(batch_texts)
    summaries.extend(batch_summaries)
    # Clear GPU cache
    torch.cuda.empty_cache()

# Add the summaries to the dataframe
df['summary'] = summaries

Summarizing Batches: 100%|██████████| 469/469 [2:35:50<00:00, 19.94s/it]


In [9]:
# Save the summarized dataset to a new CSV file
df.to_csv("medquad_summarized.csv", index=False)