In [24]:
import pandas as pd
import json
import os
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

In [13]:
project_path = os.getcwd()
train_path = os.path.join(project_path, '..', '..', 'data', 'processed', 'train.csv')
test_path = os.path.join(project_path, '..', '..', 'data', 'processed', 'test.csv')
train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

In [19]:
# min, max, mean, std of the length of the discharge reports and summaries
lengths = train_df['discharge_report'].str.len()  # Calculate the length of each report

# Calculate min, max, mean, and std
min_length = lengths.min()
max_length = lengths.max()
mean_length = lengths.mean()
std_length = lengths.std()

# Print the results
print(f"Min length report: {min_length}")
print(f"Max length report: {max_length}")
print(f"Mean length report: {mean_length}")
print(f"Standard deviation report: {std_length}")

lengths = train_df['discharge_summary'].str.len()  # Calculate the length of each summary

# Calculate min, max, mean, and std
min_length = lengths.min()
max_length = lengths.max()
mean_length = lengths.mean()
std_length = lengths.std()

# Print the results
print(f"Min length summary: {min_length}")
print(f"Max length summary: {max_length}")
print(f"Mean length summary: {mean_length}")
print(f"Standard deviation summary: {std_length}")


Min length report: 3416
Max length report: 33209
Mean length report: 11007.517482517482
Standard deviation report: 4990.852398247897
Min length summary: 576
Max length summary: 2210
Mean length summary: 1145.2202797202797
Standard deviation summary: 322.84188255421503


In [23]:
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)
dataset = DatasetDict({'train': train_dataset, 'test': test_dataset})
print(dataset)


DatasetDict({
    train: Dataset({
        features: ['discharge_report', 'discharge_summary'],
        num_rows: 286
    })
    test: Dataset({
        features: ['discharge_report', 'discharge_summary'],
        num_rows: 96
    })
})


In [26]:
model_name = "GanjinZero/biobart-v2-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(tokenizer)

tokenizer_config.json:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/892k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

BartTokenizerFast(name_or_path='GanjinZero/biobart-v2-base', vocab_size=85401, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}


In [42]:
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["discharge_report"]]
    model_inputs = tokenizer(inputs, truncation=False)  # Don't truncate to see actual length

    labels = tokenizer(text_target=examples["discharge_summary"], truncation=False)  # Don't truncate to see actual length

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset

Map:   0%|          | 0/286 [00:00<?, ? examples/s]

Map:   0%|          | 0/96 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['discharge_report', 'discharge_summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 286
    })
    test: Dataset({
        features: ['discharge_report', 'discharge_summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 96
    })
})

In [39]:
# Summary statistics
max_length = max(token_lengths)
min_length = min(token_lengths)
avg_length = sum(token_lengths) / len(token_lengths)
exceed_count = sum(1 for length in token_lengths if length > model_max_length)

3863