In [63]:
import os
from urllib.request import urlretrieve
from ftplib import FTP
import gzip
import xml.etree.ElementTree as ET
import pandas as pd
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import csv
from datasets import load_dataset, Dataset, DatasetDict

In [64]:
# Settings
ftp_host = "ftp.ncbi.nlm.nih.gov"
ftp_dir = "/pubmed/baseline"
download_dir = "../data/pubmed_baseline"
max_files = 300  # ← Limit how many files to download

gz_dir = download_dir
min_words = 20  # Minimum number of words in abstract
max_workers = 8  # Adjust based on your CPU
output_csv = download_dir + "/pubmed_filtered.csv"
output_csv_small = download_dir + "/pubmed_filtered_small.csv"
train_csv = download_dir + "/pubmed_train.csv"
test_csv = download_dir + "/pubmed_test.csv"
val_csv = download_dir + "/pubmed_val.csv"
train_tokenized = download_dir + "/pubmed_train"
test_tokenized = download_dir + "/pubmed_test"
val_tokenized = download_dir + "/pubmed_val"

train_size = 1_000_000
test_size = 3000
val_size = 200

max_len = 512

In [20]:


os.makedirs(download_dir, exist_ok=True)

# Connect to FTP and list files
ftp = FTP(ftp_host)
ftp.login()
ftp.cwd(ftp_dir)
files = []
ftp.retrlines("NLST", files.append)
ftp.quit()

# Filter .gz files and limit number
gz_files = sorted([f for f in files if f.endswith(".gz")])[:max_files]
gz_files

['pubmed25n0001.xml.gz',
 'pubmed25n0002.xml.gz',
 'pubmed25n0003.xml.gz',
 'pubmed25n0004.xml.gz',
 'pubmed25n0005.xml.gz',
 'pubmed25n0006.xml.gz',
 'pubmed25n0007.xml.gz',
 'pubmed25n0008.xml.gz',
 'pubmed25n0009.xml.gz',
 'pubmed25n0010.xml.gz',
 'pubmed25n0011.xml.gz',
 'pubmed25n0012.xml.gz',
 'pubmed25n0013.xml.gz',
 'pubmed25n0014.xml.gz',
 'pubmed25n0015.xml.gz',
 'pubmed25n0016.xml.gz',
 'pubmed25n0017.xml.gz',
 'pubmed25n0018.xml.gz',
 'pubmed25n0019.xml.gz',
 'pubmed25n0020.xml.gz',
 'pubmed25n0021.xml.gz',
 'pubmed25n0022.xml.gz',
 'pubmed25n0023.xml.gz',
 'pubmed25n0024.xml.gz',
 'pubmed25n0025.xml.gz',
 'pubmed25n0026.xml.gz',
 'pubmed25n0027.xml.gz',
 'pubmed25n0028.xml.gz',
 'pubmed25n0029.xml.gz',
 'pubmed25n0030.xml.gz',
 'pubmed25n0031.xml.gz',
 'pubmed25n0032.xml.gz',
 'pubmed25n0033.xml.gz',
 'pubmed25n0034.xml.gz',
 'pubmed25n0035.xml.gz',
 'pubmed25n0036.xml.gz',
 'pubmed25n0037.xml.gz',
 'pubmed25n0038.xml.gz',
 'pubmed25n0039.xml.gz',
 'pubmed25n0040.xml.gz',


In [22]:
# Download
base_url = f"https://{ftp_host}{ftp_dir}/"
for fname in tqdm(gz_files):
    # print(f"Downloading: {fname}")
    url = base_url + fname
    dest = os.path.join(download_dir, fname)
    urlretrieve(url, dest)

print(f"\n✅ Downloaded {len(gz_files)} files to `{download_dir}/`")


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:51<00:00,  5.79it/s]


✅ Downloaded 300 files to `../data/pubmed_baseline/`





In [26]:


min_words = 20
max_workers = 4

def parse_and_filter(file_path):
    rows = []
    try:
        with gzip.open(file_path, 'rb') as f:
            tree = ET.parse(f)
        root = tree.getroot()
        for article in root.findall(".//PubmedArticle"):
            pmid = article.findtext(".//PMID")
            lang = article.findtext(".//Language")
            title = article.findtext(".//ArticleTitle")
            abstract = article.findtext(".//Abstract/AbstractText")

            if not (pmid and title and abstract):
                continue
            if lang and lang.strip().lower() != "eng":
                continue
            if len(abstract.split()) < min_words:
                continue

            rows.append((pmid.strip(), title.strip(), abstract.strip()))
    except Exception as e:
        print(f"Error in {file_path}: {e}")
    return rows

def process_files_in_chunks(gz_dir, output_csv, workers=4):
    files = sorted([os.path.join(gz_dir, f) for f in os.listdir(gz_dir) if f.endswith(".gz")])

    # Initialize CSV with header
    with open(output_csv, "w", newline='', encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["pmid", "title", "abstract"])

    # Process files in parallel
    with ProcessPoolExecutor(max_workers=workers) as executor:
        futures = {executor.submit(parse_and_filter, f): f for f in files}
        for future in as_completed(futures):
            rows = future.result()
            if rows:
                with open(output_csv, "a", newline='', encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerows(rows)

# Run it
process_files_in_chunks(gz_dir, output_csv, workers=max_workers)
print(f"\n✅ Data written incrementally to: {output_csv}")



✅ Data written incrementally to: ../data/pubmed_baseline/pubmed_filtered.csv


In [65]:
df = pd.read_csv(output_csv)

In [66]:
df = df.sample(frac=1).reset_index(drop=True)

In [67]:
train_df = df.loc[:train_size -1, :]
test_df = df.loc[train_size:train_size + test_size - 1, :]
val_df = df.loc[train_size + test_size: train_size + test_size + val_size - 1, :]

In [58]:
train_df.shape, test_df.shape, val_df.shape

((1000000, 3), (3000, 3), (200, 3))

In [68]:
train_df.to_csv(train_csv, index=False)
test_df.to_csv(test_csv, index=False)
val_df.to_csv(val_csv, index=False)

In [50]:
train_df.head()

Unnamed: 0,pmid,title,abstract
0,2712356,Membrane relationships in murine Meissner corp...,Mechanoreceptive sensory corpuscles (murine Me...
1,8979397,Two-dimensional protein patterns of Arabidopsi...,In order to detect gene products involved in A...
2,1462207,Pathoanatomy of lumbar disc herniation as demo...,Computed tomography/discography was performed ...
3,1807731,An innovative method of teaching Advanced Card...,A demonstration and discussion of the effectiv...
4,3207648,The distribution of CA 125 in the reproductive...,Investigation of serum and tissue homogenates ...


In [38]:
def tokenize_dataset(tokenizer, data_df):
    dataset = Dataset.from_pandas(data_df)
    def tokenize(example):
        text = f"<s>{example['title']}\n{example['abstract']}</s>"
        return tokenizer(text, truncation=True, padding="max_length", max_length=max_len)
    dataset = dataset.map(tokenize, batched=True)
    return dataset