# Pretrain data preparation

## Imports

In [None]:
import os
from urllib.request import urlretrieve
from ftplib import FTP
import gzip
import xml.etree.ElementTree as ET
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import csv
from datasets import load_dataset, Dataset, DatasetDict

In [None]:
# Settings
ftp_host = "ftp.ncbi.nlm.nih.gov"
ftp_dir = "/pubmed/baseline"
download_dir = "../data/pubmed_baseline"
max_files = 300  # ← Limit how many files to download

gz_dir = download_dir
min_words = 20  # Minimum number of words in abstract
max_workers = 8  # Adjust based on your CPU
output_csv = download_dir + "/pubmed_filtered.csv"
output_csv_small = download_dir + "/pubmed_filtered_small.csv"
train_csv = download_dir + "/pubmed_train.csv"
test_csv = download_dir + "/pubmed_test.csv"
val_csv = download_dir + "/pubmed_val.csv"
train_tokenized = download_dir + "/pubmed_train"
test_tokenized = download_dir + "/pubmed_test"
val_tokenized = download_dir + "/pubmed_val"

train_size = 1_000_000
test_size = 3000
val_size = 200

max_len = 300

### Download the pubmed data

In [None]:
os.makedirs(download_dir, exist_ok=True)

# Connect to FTP and list files
ftp = FTP(ftp_host)
ftp.login()
ftp.cwd(ftp_dir)
files = []
ftp.retrlines("NLST", files.append)
ftp.quit()

# Filter .gz files and limit number
gz_files = sorted([f for f in files if f.endswith(".gz")])[:max_files]
gz_files

In [None]:
# Download
base_url = f"https://{ftp_host}{ftp_dir}/"
for fname in tqdm(gz_files):
    # print(f"Downloading: {fname}")
    url = base_url + fname
    dest = os.path.join(download_dir, fname)
    urlretrieve(url, dest)

print(f"\n✅ Downloaded {len(gz_files)} files to `{download_dir}/`")


# Process files in parallel

In [None]:
min_words = 20
max_workers = 4

def parse_and_filter(file_path):
    rows = []
    try:
        with gzip.open(file_path, 'rb') as f:
            tree = ET.parse(f)
        root = tree.getroot()
        for article in root.findall(".//PubmedArticle"):
            pmid = article.findtext(".//PMID")
            lang = article.findtext(".//Language")
            title = article.findtext(".//ArticleTitle")
            abstract = article.findtext(".//Abstract/AbstractText")

            if not (pmid and title and abstract):
                continue
            if lang and lang.strip().lower() != "eng":
                continue
            if len(abstract.split()) < min_words:
                continue

            rows.append((pmid.strip(), title.strip(), abstract.strip()))
    except Exception as e:
        print(f"Error in {file_path}: {e}")
    return rows

def process_files_in_chunks(gz_dir, output_csv, workers=4):
    files = sorted([os.path.join(gz_dir, f) for f in os.listdir(gz_dir) if f.endswith(".gz")])

    # Initialize CSV with header
    with open(output_csv, "w", newline='', encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["pmid", "title", "abstract"])

    # Process files in parallel
    with ProcessPoolExecutor(max_workers=workers) as executor:
        futures = {executor.submit(parse_and_filter, f): f for f in files}
        for future in as_completed(futures):
            rows = future.result()
            if rows:
                with open(output_csv, "a", newline='', encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerows(rows)

# Run it
process_files_in_chunks(gz_dir, output_csv, workers=max_workers)
print(f"\n✅ Data written incrementally to: {output_csv}")


In [None]:
df = pd.read_csv(output_csv)

In [None]:
df.shape

In [None]:
df = df.sample(frac=1).reset_index(drop=True)

In [None]:
train_df = df.loc[:train_size -1, :]
test_df = df.loc[train_size:train_size + test_size - 1, :]
val_df = df.loc[train_size + test_size: train_size + test_size + val_size - 1, :]

In [None]:
train_df.shape, test_df.shape, val_df.shape

In [None]:
train_df.to_csv(train_csv, index=False)
test_df.to_csv(test_csv, index=False)
val_df.to_csv(val_csv, index=False)

In [None]:
train_df.head()

In [None]:
from transformers import AutoTokenizer
model_id = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [None]:
train_df = pd.read_csv(train_csv).iloc[:200_001, :]

In [None]:
def calculate_token_count(example):
    text = f"{example['title']}\n{example['abstract']}{tokenizer.eos_token}"
    return len(tokenizer(text, truncation=False)["input_ids"])

In [None]:
train_df["n_tokens"] = train_df.apply(calculate_token_count, axis=1)

In [None]:
train_df.head()

In [None]:
train_df["n_tokens"].apply(lambda x: min(300, x)).sum()

In [None]:
train_df.iloc[8, 2]

total training tokens count = 48_931_561