In [14]:
import pandas as pd
import numpy as np
from datasets import Dataset, load_dataset, DatasetDict
from sentence_transformers import SentenceTransformer
import nltk
from nltk.tokenize import sent_tokenize
import networkx as nx
import torch
from torch.nn import CosineSimilarity
import math
from itertools import chain
import torch.nn.functional as FT
torch.cuda.empty_cache()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [20]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_NAME = 'sentence-transformers/bert-base-nli-mean-tokens'

In [6]:
dataset_name = "dmacres/mimiciii-hospitalcourse-v2"
train = load_dataset(dataset_name, split = 'train')
valid = load_dataset(dataset_name, split = 'validation')
test = load_dataset(dataset_name, split = 'test')

Downloading readme:   0%|          | 0.00/993 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/172M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/163M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/165M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/111M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/108M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/24993 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5356 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5356 [00:00<?, ? examples/s]

In [15]:
def generate_extractive_summary(sentences, embeddings = None, context_size = 980, top_n=None):

    if embeddings is None:
        model = SentenceTransformer(MODEL_NAME, device = DEVICE)
        embeddings = model.encode(sentences, device = DEVICE, convert_to_tensor=True)
    # else:
    #     embeddings = sentences

    n_sents = embeddings.shape[0]
    # Step1: generate similarity matrix
    # similarity_matrix = gen_cosine_sim_matrix(embeddings, n_sents)
    similarity_matrix_cuda = FT.cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
    similarity_matrix = similarity_matrix_cuda.cpu().numpy()
    del similarity_matrix_cuda
    torch.cuda.empty_cache()
    # Step3: Rank sentences in similarity matrix
    sentence_similarity_graph = nx.from_numpy_array(similarity_matrix)
    scores = nx.pagerank(sentence_similarity_graph)
    # Step4: sort the rank and place top sentences
    # print(scores)
    ranked_sentences = sorted(((scores[i],s) for i,s in enumerate(sentences)),reverse=True)
    # print(ranked_sentences)

    summarize_text = []

    if context_size is not None:
        i=0
        sum_len = 0
        while sum_len <= context_size:
            if (i+1)>=n_sents:
                break
            next_sent = ranked_sentences[i][1]
            summarize_text.append(next_sent)
            sum_len+=len(next_sent.split(' '))
            i+=1

        return summarize_text, None


    elif top_n is not None:
        print(embeddings)
        ranked_embeddings = sorted(((scores[i],s) for i,s in enumerate(embeddings)),reverse=True)
        # print(ranked_embeddings)
        summarize_embeddings = []
        # print(ranked_embeddings[0][1].unsqueeze(0))
        # print(top_n)
        for i in range(top_n):
            # print(i)
            if (i+1)>n_sents:
                break
            summarize_text.append(ranked_sentences[i][1])
            summarize_embeddings.append(ranked_embeddings[i][1].unsqueeze(0))

        # print(summarize_text)
        # print(summarize_embeddings)
        summarize_embeddings_cat = torch.cat(summarize_embeddings, dim=0)

        return summarize_text, summarize_embeddings_cat

    else:
        raise ValueError("Please pass either a context_size or top_n parameter")

In [18]:
def gen_summary_mapper(example, max_batch_size = 50, batch_notes = True):
    notes = example['notes']
    # # sorting the notes by date
    # notes = sorted(notes, key = lambda x: x['chartdate'])

    if batch_notes:
        n_notes = len(notes)

        n_batches = math.ceil(n_notes/max_batch_size)
        batch_size = math.ceil(n_notes/n_batches)
        context_size = math.ceil(980/n_batches)

        note_sents_meta = []
        # batches = []
        for i in range(0, n_notes, batch_size):
            batch = notes[i:i + batch_size]
            # batches.append(batch)
            split_texts = [sent_tokenize(note['text'].replace('_','').replace('\n', '')) for note in batch]
            sentences = list(chain.from_iterable(split_texts))
            sum_note_sents, _ = generate_extractive_summary(sentences, embeddings = None, context_size = context_size, top_n=None)
            note_sents_meta.append(sum_note_sents)

        summary_sentences = list(chain.from_iterable(note_sents_meta))

    else:
        sentences = list(chain.from_iterable(split_texts))
        summary_sentences, _ = generate_extractive_summary(sentences, embeddings = None, context_size = 980, top_n=None)


    summary = " ".join(summary_sentences)

    example['extractive_notes_summ'] = summary

    return example

## Test Mapping

In [21]:
a = train.select([3]).map(lambda x: gen_summary_mapper(x, max_batch_size = 50, batch_notes = True))

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Downloading (…)821d1/.gitattributes:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)8d01e821d1/README.md:   0%|          | 0.00/3.95k [00:00<?, ?B/s]

Downloading (…)d1/added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)01e821d1/config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)821d1/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/399 [00:00<?, ?B/s]

Downloading (…)8d01e821d1/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)1e821d1/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

In [23]:
a

Dataset({
    features: ['subject_id', 'hadm_id', 'notes', 'target_text', 'extractive_notes_summ'],
    num_rows: 1
})

## Map function to all splits

In [None]:
torch.cuda.empty_cache()
train_cos = train.map(gen_summary_mapper, fn_kwargs = {'max_batch_size': 20, 'batch_notes': True})
train_cos=train_cos.remove_columns("notes")

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24993/24993 [12:03:07<00:00,  1.74s/ examples]


In [None]:
torch.cuda.empty_cache()
valid_cos = valid.map(gen_summary_mapper, fn_kwargs = {'max_batch_size': 20, 'batch_notes': True})
valid_cos=valid_cos.remove_columns("notes")

Map:   0%|          | 0/5356 [00:00<?, ? examples/s]

In [None]:
torch.cuda.empty_cache()
test_cos = test.map(gen_summary_mapper, fn_kwargs = {'max_batch_size': 20, 'batch_notes': True})
test_cos=test_cos.remove_columns("notes")

## Push Dataset to Hub

In [None]:
meta_dataset = DatasetDict({
    "train": train_cos,   # split1_ds is an instance of `datasets.Dataset`
    "validation": valid_cos,
    "test": test_cos,
})
meta_dataset.push_to_hub("mimiciii-hospitalcourse-cossim-pagerank-batched-extractive-summ-v2", private = True)

Pushing dataset shards to the dataset hub: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 24244.53it/s]
Pushing dataset shards to the dataset hub:   0%|                                                                                                                      | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format:   0%|                                                                                                                             | 0/6 [00:00<?, ?ba/s][A
Creating parquet from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 30.07ba/s][A
Pushing dataset shards to the dataset hub: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:19<00:00, 19.12s/it]
Pushing dataset shards to the dataset hub:   0%| 

## Add back n_notes to metadata

In [None]:
dataset_name = "dmacres/mimiciii-hospitalcourse-bert-base-uncased-tokenized"
# dataset_name = "dmacres/mimiciii-hospitalcourse"
train1 = load_dataset(dataset_name, split = 'train')
valid1 = load_dataset(dataset_name, split = 'validation')
test1 = load_dataset(dataset_name, split = 'test')

In [None]:
train1

Dataset({
    features: ['subject_id', 'hadm_id', 'notes', 'target_text', 'n_notes', 'total_token_length', 'target_input_ids', 'target_length'],
    num_rows: 24993
})

In [None]:
VALID_NNOTES_DICT = {i['hadm_id']:i['n_notes'] for i in valid1}

In [None]:
TEST_NNOTES_DICT = {i['hadm_id']:i['n_notes'] for i in test1}

In [None]:
TRAIN_NNOTES_DICT = {i['hadm_id']:i['n_notes'] for i in train1}

In [None]:
def map_nnotes(example, nnotes_dict):
    example['n_notes'] = nnotes_dict[example['hadm_id']]
    return example


In [None]:
train_cos1 = train_cos.map(lambda x: map_nnotes(x, TRAIN_NNOTES_DICT))
valid_cos1 = valid_cos.map(lambda x: map_nnotes(x, VALID_NNOTES_DICT))
test_cos1 = test_cos.map(lambda x: map_nnotes(x, TEST_NNOTES_DICT))

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24993/24993 [00:01<00:00, 20006.68 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5356/5356 [00:00<00:00, 16758.12 examples/s]
Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5356/5356 [00:00<00:00, 11215.23 examples/s]


In [None]:
meta_dataset = DatasetDict({
    "train": train_cos1,   # split1_ds is an instance of `datasets.Dataset`
    "validation": valid_cos1,
    "test": test_cos1,
})
meta_dataset.push_to_hub("mimiciii-hospitalcourse-cossim-pagerank-batched-extractive-summ-v2", private = True)

Pushing dataset shards to the dataset hub:   0%|                                                                                                                      | 0/1 [00:00<?, ?it/s]
Creating parquet from Arrow format:   0%|                                                                                                                            | 0/25 [00:00<?, ?ba/s][A
Creating parquet from Arrow format:  16%|██████████████████▌                                                                                                 | 4/25 [00:00<00:00, 24.67ba/s][A
Creating parquet from Arrow format:  28%|████████████████████████████████▍                                                                                   | 7/25 [00:00<00:00, 24.63ba/s][A
Creating parquet from Arrow format:  40%|██████████████████████████████████████████████                                                                     | 10/25 [00:00<00:00, 24.46ba/s][A
Creating parquet from Arrow format:  60%|██

In [None]:
test_load_dataset = load_dataset("dmacres/mimiciii-hospitalcourse-cossim-pagerank-batched-extractive-summ-v2", split = 'test')
test_load_dataset

Downloading readme: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 883/883 [00:00<00:00, 12.3MB/s]
Downloading data files:   0%|                                                                                                                                         | 0/3 [00:00<?, ?it/s]
Downloading data:   0%|                                                                                                                                          | 0.00/110M [00:00<?, ?B/s][A
Downloading data:   4%|████▉                                                                                                                            | 4.19M/110M [00:00<00:18, 5.75MB/s][A
Downloading data:  11%|██████████████▋                                                                                                                  | 12.6M/110M [00:01<00:09, 10.0MB/s][A
Downloading data:  19%|███████████████████████

Dataset({
    features: ['subject_id', 'hadm_id', 'target_text', 'extractive_notes_summ', 'n_notes'],
    num_rows: 5356
})