Change background color for output as it wasn't distinguishable from the Markdown text.

In [2]:
%%html
<style>
.output_wrapper .output .output_area .output_subarea {
    background: #E0FFFF
}
</style>

# Datasets
---

In this notebook we'll build/implement the Dataset classes we need to work with all the dataset we have.
First we will introduce the datasets, then we will separate those based on the usage we are going to make of them, then we will use/build/implement our classes in order to manage those different datasets and tasks.

# 0.0 Utils
---

We will be using the 🤗*Datasets* library, the 🤗 *Tranformers* library, as we need a tokenizer and a vocab and we'll be using (for loggin) Weigths and Biases (`wandb`) so we are going to install it, independently from Hugging face, and use it within it.

Let's define all the `imports` and `hyperparameters` in one place.

In [3]:
# ----------------------------------- #
#           Imports
# ----------------------------------- #
import os # generic
import time # logging
from tqdm.auto import tqdm # custom progress bar
import enlighten # another custom progress bar
import io
import json # load/write data
import torch 
import numpy as np
import pandas as pd

# 🤗 Datasets
from datasets import (
    load_dataset, 
    DatasetDict, 
    Dataset as hfDataset
)

# 🤗 Tranformers
from transformers import (
    AutoTokenizer, 
    AutoModel, 
    PreTrainedTokenizer, 
    DataCollatorForLanguageModeling, 
    Trainer, 
    BertForMaskedLM,
    TrainingArguments
)

# Padding
from torch.nn.utils.rnn import pad_sequence

# data types
from torch.utils.data import (
    Dataset, 
    DataLoader
)
from typing import (
    Dict, List, Union
)

In [4]:
# ----------------------------------- #
#           Hyperparameters
# ----------------------------------- #

# --------- dataset         --------- #



# --------- logging         --------- #
# generic
verbose = True
debug = False
wandb_flag = False

# Logging on Weigths and Biases
if wandb_flag:
    import wandb
    # wandb
    wandb.login()
    wandb.init(project="datasets.explanation")
    # Optional: log both gradients and parameters
    %env WANDB_WATCH=all

# --------- preprocessing   --------- #
# in **partial_prepare_data**
remove_None_papers = True # if True, remove papers with None eather in abstract or title
remove_Unused_columns = True
# in **preprocess**
clean_None_data = False # if True, changes all the None (abstract of title) to ''
remove_None_data = False # if True (and clean_None_data set False), remove all the None abstract/title and the correspond title/abstract

# --------- paths           --------- #
# data folder path
data_base_dir = '/home/vivoli/Thesis/data'
s2orc_type = 'full'
N = None

# --------- model/tokenizer --------- #
# hugginface model/tokenizer name
MODEL_PATH = 'allenai/scibert_scivocab_uncased'
RUN_NAME   = 'scibert-s2orc'
RUN_NUMBER = 2
RUN_ITER   = 1

output_dir=f'./tmp_trainer/#{RUN_NUMBER}_{RUN_ITER}_{RUN_NAME}'
# seed for reproducibility of experiments
SEED = 1234

In [5]:
# ----------------------------------- #
#           Logging
# ----------------------------------- #
LOGS_PATH = 'logs'
import logging

# Create a custom logger
logger = logging.getLogger("datasets.explanation")

# Create handlers
c_handler = logging.StreamHandler()
f_handler = logging.FileHandler(f'{LOGS_PATH}/file.log')
d_handler = logging.FileHandler(f'{LOGS_PATH}/debug.log')

c_handler.setLevel(logging.DEBUG if verbose else logging.WARNING) # verbose is to log everything
f_handler.setLevel(logging.ERROR)
d_handler.setLevel(logging.DEBUG)

# Create formatters and add it to handlers
c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
d_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)
d_handler.setFormatter(d_format)

# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)
logger.addHandler(d_handler)

logger.warning('This is a warning')
logger.error('This is an error')
logger.info('This is an info')
logger.debug('This is a debug')


datasets.explanation - ERROR - This is an error


# 0.1 KeyPhrase Dataset
---

These are testing datasets

This keyphrase dataset could be useful for testing some model on keyphrase task or abstract-title summarization/generation/embedding.

For now, we can avoid implementing the Dataset's and DataLoader's classes for this objects.

Although, the dataset and Dataloader would be simple as follow:

The `data` object is composed by `500 tuples`, each one composed by 4 objects:
- `title_tensor_` is the title embedding (composed by integers values)
- `abstract_tensor_` is the abstract embedding (composed by integers values)
- `fulltext_tensor_` is the fulltext embedding (composed by integers values)
- `keywords_tensor_` is the keywords embedding (composed by integers values)

# 0.2 S2ORC Dataset
---

## 0.2.1 S2ORC ( jsonl | jsonl.gz ) Loader 
---
First of all we need to manage with the data, to unzip or already unzipped.

In [None]:
DATA_PATH = '/home/vivoli/Thesis/data' 
!ls $DATA_PATH

In [None]:
s2orc_path = f"{DATA_PATH}/s2orc-{s2orc_type}-20200705v1/{s2orc_type}"
!ls $s2orc_path

In [None]:
args:Dict = dict({
    'range': {
        'N': 3,
        'to': True, # N is int, this must be true
        'into': False, # N is tuple, or N is two element list but 'into' must be True
    },

    'only_extrated': False,
    'keep_extracted': False,
    
    'mag_field_of_study': ['Computer Science'] # options are empty List() or List(string)

})

In [None]:
if s2orc_type == 'sample':
    metadata_filenames = [ f"sample" ]
    pdf_parses_filenames = [ f"sample"]

    if N is not None:
        logging.warning(f"You set 'sample' but you also set `N` for full bucket range. \n The N selection will be discarded as only `sample` element will be used.")
        N = args['range']['N'] = 0
        list_range = [N]
    
elif s2orc_type == 'full':
    N = args['range']['N']

    if N is None:
        logging.warning(f"You set 'full' but no bucket index was specified. \n We'll use the index 0, so the first bucket will be used.")
        N = args['range']['N'] = 0
        list_range = [N]
        
    elif type(N) is list:
        if args['range']['into']:
            list_range = range(N[0], N[1]) 
            logging.warning(f"The range is intended as [{N[0]}, {N[1]}] (start {N[0]}, end {N[1]})")
        else:
            list_range = N
            logging.warning(f"The element list is intended as: {N}")
        
    elif type(N) is int:
        if args['range']['to']:
            list_range = range(0,N)
            logging.warning(f"The range is intended as [ 0, {N}] (start 0, end {N})")
        else:
            list_range = [N]
            logging.warning(f"The element list is intended as: [{N}]")
            
    metadata_filenames = [ f"metadata_{n}" for n in list_range ]
    pdf_parses_filenames = [ f"pdf_parses_{n}" for n in list_range ]

else:
    raise NameError(f"You must select an existed S2ORC dataset \n \
                You selected {s2orc_type}, but options are ['sample' or 'full']")

extention = 'jsonl' if args['only_extrated'] else 'jsonl.gz'

# unpossible option
if extention is None: 
    message = f"Extention cannot be None, options: \n - *jsonl* if only_extrated True \n - *jsonl.gz* if only_extrated False"
    logging.error(message)
    raise RuntimeError(message)

# we could also leave {s2orc_path} for later, if path's lenght is big
meta_s2orc_path = f'{s2orc_path}/metadata'
pdfs_s2orc_path = f'{s2orc_path}/pdf_parses'
meta_s2orc = [f'{metadata_filename}.{extention}' for metadata_filename in metadata_filenames]
pdfs_s2orc = [f'{pdf_parses_filename}.{extention}' for pdf_parses_filename in pdf_parses_filenames]

logging.info(f"meta_s2orc len: {len(meta_s2orc)}")
logging.info(f"pdfs_s2orc len: {len(pdfs_s2orc)}")

## 0.2.2 Creation (s2orc)
---

Now we have explored the `S2ORC` structure, we are ready to load the data (starting from the `sample` and following on the `full` folder). The first thing to do is create (as we did before) a method for read the json: `json_s2orc_read`.

In [None]:
print(f'{s2orc_type}_meta:\n {meta_s2orc} \n')
print(f'{s2orc_type}_pdfs:\n {pdfs_s2orc} \n')

In [None]:
print(f'meta_s2orc_path: {meta_s2orc_path}')
print(f'pdfs_s2orc_path: {pdfs_s2orc_path}')

Lets's see what's inside the folder (in this case `metadata` but should be the same for `pdf_parses`:

In [None]:
grep_extention = f".*\.{extention}$"

metadata_output = !ls $meta_s2orc_path | grep $grep_extention
print('metadata (len):', len(metadata_output))
print('metadata (first 10):', metadata_output[:10])

pdf_parses_output = !ls $pdfs_s2orc_path | grep $grep_extention
print('pdf_parses (len):', len(pdf_parses_output))
print('pdf_parses (first 10):', pdf_parses_output[:10])

#### a. Intersection
---
As we want to examinate all `meta_s2orc` and `pdfs_s2orc` files, we need to search the intersection between those files namea and `metadata_output` and `pdf_parses_output` lists.

In [None]:
toread_meta_s2orc = sorted(list(set(meta_s2orc) & set(metadata_output)))
toread_pdfs_s2orc = sorted(list(set(pdfs_s2orc) & set(pdf_parses_output)))

In [None]:
print(f'toread meta len: {len(toread_meta_s2orc)} \n {toread_meta_s2orc} \n')
print(f'toread pdfs len: {len(toread_pdfs_s2orc)} \n {toread_pdfs_s2orc} \n')

So we can describe the function in charge to load the `jsonl` files. The function must have in input the `generic_path` (f"{DATA_PATH}/{SAMPLE_FOLDER}") and then searching in `metadata` and `pdf_parses` for the files present in `file_names`.

(Unused)
```python
def read_json_list(jsonl_path):
    json_list_of_dict = []
    with open(jsonl_path, 'r') as input_json:
        for json_line in input_json:
            json_dict = json.loads(json_line)
            json_list_of_dict.append(json_dict)
    return json_list_of_dict
```

In [None]:
def read_meta_json_list_dict(file_path, args):

    if verbose: print(f'file_path: {file_path}')
    # list of dictionaries, one for each row in pdf_parses
    json_list_of_dict = []
    # create a list index
    json_list_of_dict_idx = 0
    # dictionary of indexes, to obtain the object from list, starting from the `paper_id`
    json_dict_of_index = {}

    delete_it = False
    
    if args['only_extrated']:
        # just open as usual
        input_json = open(file_path, 'r')
        
        if verbose: print('You choose to only use unzipped files')

    else:
        # import for unzip
        import gzip
        # open by firstly unzipping it
        gz = gzip.open(file_path, 'rb')
        input_json = io.BufferedReader(gz)
        
        if args['keep_extracted']:
            if verbose: print('You choose to unzip files and keep the unzipped file (can full the memory)')
        else:
            delete_it = True
            if verbose: print('You choose to unzip files and delete the unzipped file (memory-driven decision)')
    
    # check if ["mag_field_of_study"] is in args, and is valid
    mag_field_filter = False
    mag_field_all = False
    mag_field_dict = {}
    if type(args["mag_field_of_study"]) is list:
        mag_field_filter = True
        if not args["mag_field_of_study"]:
            print("List is empty")
            mag_field_all = True
        else:
            for field in args["mag_field_of_study"]:
                mag_field_dict[field] = True
    
    
    # tocks_1 = manager.counter(total=1e6, desc='Tocks_1', unit='tocks')
    with input_json:
        for index, json_line in tqdm(enumerate(input_json)):
            json_dict = json.loads(json_line)
        
            # print(index, json_dict["mag_field_of_study"])
            mag_field_filter_pass = sum([ True if json_field in mag_field_dict else False for json_field in json_dict["mag_field_of_study"]]) >= 1 if json_dict["mag_field_of_study"] is not None else False
            
            # dataset = dataset.filter((lambda x, ids: none_papers_indexes.get(ids, True)), with_indices=True)
            if not mag_field_filter or mag_field_all or mag_field_filter_pass:
                # append the dictionary to the dictionaries' list
                json_list_of_dict.append(json_dict)
                # insert (paper_id, index) pair as (key, value) to the dictionary
                json_dict_of_index[json_dict['paper_id']] = json_list_of_dict_idx
                # increment the list index
                json_list_of_dict_idx += 1
            
            # tocks_1.update()

    # tocks_1.fill()
    
    if delete_it:
        if verbose: print(f'[INFO-START] Delete file operation: deleting {file_path}')
        # os.remove(file_path)
        if verbose: print(f'[INFO] Delete file operation skipped for {file_path}')
        if verbose: print(f'[INFO-END  ] Delete file operation: {file_path} deleted correclty')
            
    return json_list_of_dict, json_dict_of_index

In [None]:
def read_pdfs_json_list_dict(file_path, json_dict_of_index_meta, args):

    if verbose: print(f'file_path: {file_path}')
    # list of dictionaries, one for each row in pdf_parses
    json_list_of_dict = []
    # create a list index
    json_list_of_dict_idx = 0
    # dictionary of indexes, to obtain the object from list, starting from the `paper_id`
    json_dict_of_index = {}

    delete_it = False
    
    if args['only_extrated']:
        # just open as usual
        input_json = open(file_path, 'r')
        
        if verbose: print('You choose to only use unzipped files')

    else:
        # import for unzip
        import gzip
        # open by firstly unzipping it
        gz = gzip.open(file_path, 'rb')
        input_json = io.BufferedReader(gz)
        
        if args['keep_extracted']:
            if verbose: print('You choose to unzip files and keep the unzipped file (can full the memory)')
        else:
            delete_it = True
            if verbose: print('You choose to unzip files and delete the unzipped file (memory-driven decision)')
    
    # tocks_2 = manager.counter(total=1.5e6, desc='Tocks_2', unit='tocks')
    with input_json:
        for index, json_line in tqdm(enumerate(input_json)):
            json_dict = json.loads(json_line)
                        
            # if the metadata has been selected
            if json_dict['paper_id'] in json_dict_of_index_meta:
                # append the dictionary to the dictionaries' list
                json_list_of_dict.append(json_dict)
                # insert (paper_id, index) pair as (key, value) to the dictionary
                json_dict_of_index[json_dict['paper_id']] = json_list_of_dict_idx
                # increment the list index
                json_list_of_dict_idx += 1
            
            # tocks_2.update()
    
    # tocks_2.fill()
    
    if delete_it:
        if verbose: print(f'[INFO-START] Delete file operation: deleting {file_path}')
        # os.remove(file_path)
        if verbose: print(f'[INFO] Delete file operation skipped for {file_path}')
        if verbose: print(f'[INFO-END  ] Delete file operation: {file_path} deleted correclty')
            
    return json_list_of_dict, json_dict_of_index

In [None]:
def s2orc_chunk_read( s2orc_path, meta_s2orc_single_file, pdfs_s2orc_single_file, extention, args ):
    """
    Args:
        - s2orc_path (string): 
            Path to the Dataset directory (es. '{data}/s2orc-{sample|full}-20200705v1/{sample|full}').
            
        - meta_s2orc_single_file (string): 
            Filenames with extentions (es. 'sample_0.jsonl') present in `{dataset_path}/metadata`.

        - pdfs_s2orc_single_file (string): 
            Filenames with extentions (es. 'sample_0.jsonl') present in `{dataset_path}/pdf_parses`.
            
        - extention (string | None):
            String element either `jsonl` (only decompressed files) or `jsonl.gz` (to decompress files).
            
        - args (dict):
            Dictionary containing some config params.
            
    Return:
        - json_dict (list of dict): 
            Dictionary such as: { 'metadata': [...], 'pdf_parses': [...] } with objects of type 
            metadata_CLASS and pdf_parses_CLASS respectively.
    """
    if verbose: print("[INFO-START] Metadata Chunk read  : ", meta_s2orc_single_file)
    if verbose: print("[INFO-START] Pdf parses Chunk read: ", pdfs_s2orc_single_file)

    json_dict_of_list = {'metadata': [], 'pdf_parses': [], 'meta_key_idx': {}, 'pdf_key_idx': {}}
    
    # tocks_0 = manager.counter(total=2, desc='Tocks_0', unit='tocks')    
    
    if verbose: print("[INFO] Metadata read: ", meta_s2orc_single_file)    
    path_metadata = os.path.join(s2orc_path, 'metadata', meta_s2orc_single_file)
    if verbose: print(f"{path_metadata}")
    
    json_list_metadata, json_dict_of_index_meta = read_meta_json_list_dict(path_metadata, args)
    
    json_dict_of_list['metadata'] = json_list_metadata
    json_dict_of_list['meta_key_idx'] = json_dict_of_index_meta
    
    # tocks_0.update()
    
    if verbose: print("[INFO] Pdf_Parses read: ", pdfs_s2orc_single_file)
    path_pdf_parses = os.path.join(s2orc_path, 'pdf_parses', pdfs_s2orc_single_file)
    if verbose: print(f"{path_pdf_parses}")
    
    json_list_pdf_parses, json_dict_of_index_pdf = read_pdfs_json_list_dict(path_pdf_parses, json_dict_of_index_meta, args)
    
    json_dict_of_list['pdf_parses'] = json_list_pdf_parses
    json_dict_of_list['pdf_key_idx'] = json_dict_of_index_pdf
    
    if verbose: print("[INFO-END  ] Chunk read: ", meta_s2orc_single_file, pdfs_s2orc_single_file)                 
    
    # tocks_0.update()
    
    return json_dict_of_list

In [None]:
def s2orc_multichunk_read(s2orc_path, toread_meta_s2orc, toread_pdfs_s2orc, extention, args):
    """
    Args:
        - s2orc_path (string): 
            Path to the Dataset directory (es. '{data}/s2orc-{sample|full}-20200705v1/{sample|full}').
            
        - toread_meta_s2orc (list of string): 
            List of filenames with extentions (es. ['sample_0.jsonl', 'sample_1.jsonl'])
            present in `{dataset_path}/metadata`.

        - toread_meta_s2orc (list of string): 
            List of filenames with extentions (es. ['sample_0.jsonl', 'sample_1.jsonl'])
            present in `{dataset_path}/pdf_parses`.
            
        - extention (string | None):
            String element either `jsonl` (only decompressed files) or `jsonl.gz` (to decompress files).
            
        - args (dict):
            Dictionary containing some config params.
            
    """
    if verbose: print("[INFO-START] Multichunk read")
    if verbose: print(f"[INFO] Metadata reading  : {toread_meta_s2orc}")
    if verbose: print(f"[INFO] Pdf Parses reading: {toread_pdfs_s2orc}")    
    
    assert len(toread_meta_s2orc) == len(toread_pdfs_s2orc), "Files list (metadata and pdfs) must be the same length!"
    assert extention is not None, "Extention must be set!"
    
    if verbose: print(f" \n \
[INFO] Data read selection : \n \
    [{'x' if args['only_extrated'] else ' '}] Extracted \n \
    [{'x' if args['keep_extracted'] else ' '}] To Extract and Keep \n \
    [{'x' if not args['keep_extracted'] else ' '}] To Extract and Delete \n \
Only files already extracted will be analyzed.")

    multichunks_lists = []

    # ticks = manager.counter(total=min(len(toread_meta_s2orc), len(toread_pdfs_s2orc)), desc='Ticks', unit='ticks')
    
    for meta_s2orc_single_file , pdfs_s2orc_single_file in tqdm(zip(toread_meta_s2orc, toread_pdfs_s2orc)):

        chunk_list = s2orc_chunk_read( s2orc_path, meta_s2orc_single_file, pdfs_s2orc_single_file, extention, args )
    
        multichunks_lists.append(chunk_list)
        
        # ticks.update()
    
    return multichunks_lists

Important objects are:
    
- `s2orc_path` ('/home/vivoli/Thesis/data/s2orc-full-20200705v1/full')
- `meta_s2orc_path` (f'{s2orc_path}/metadata')
- `pdfs_s2orc_path` (f'{s2orc_path}/pdf_parses')
- `toread_meta_s2orc` ( ['metadata_0.jsonl.gz', 'metadata_1.jsonl.gz'] )
- `toread_pdfs_s2orc` ( ['pdf_parses_0.jsonl.gz', 'pdf_parses_1.jsonl.gz'] )

In [None]:
# manager = enlighten.get_manager()
multichunks_lists = s2orc_multichunk_read(s2orc_path, toread_meta_s2orc, toread_pdfs_s2orc, extention, args)
# manager.stop()

We have used only the `sample.jsonl` or the pair (`metadata_0.jsonl`-`pdf_parses_0.jsonl`) so we just have one element in the `multichunks_lists`. 

We have parses all the `metadata` and `pdf_parses` elements, so we have now a dictionary that is composed by:
```python
json_dict_of_list = {
    'metadata': [], 
    'pdf_parses': {}, 
    'meta_key_idx': {}, 
    'pdf_key_idx': {}
}
```
In this dictionary we see:
* metadata - `List[dict]` of type `metadata`.
* pdf_parses - `List[dict]` of type `pdf_parses`.
* meta_key_idx - `dict` with keys: `paper_id` and values: `index` in the metadata list.
* pdf_key_idx - `dict` with keys: `paper_id` and values: `index` in the pdf_parses list.

In [None]:
index = multichunks_lists[0]['meta_key_idx']['18980380']
multichunks_lists[0]['metadata'][index]['paper_id']

In [None]:
index = multichunks_lists[0]['pdf_key_idx']['18980380']
multichunks_lists[0]['pdf_parses'][index]['paper_id']

In [None]:
multichunks_lists[0]['metadata'][0]

## Multichunks getDataset( (id, multichunk) | (single_chunk) )
---

In [None]:
def fuse_dictionaries(single_chunk: list, 
                      data_field: List[str] =  ["title", "abstract"]) -> Dataset:
    
    # definition of **single_chunk**
    # {'metadata': [], 'pdf_parses': [], 'meta_key_idx': {}, 'pdf_key_idx': {}}

    print(f"len meta single_chunk: {len(single_chunk['metadata'])}")
    print(f"len pdfs single_chunk: {len(single_chunk['pdf_parses'])}")
    
    verbose = False
    print_all_debug = False
    paper_list = []
    for key in single_chunk['meta_key_idx']:
        
        if verbose: print(f"[INFO] Analyse metadata dictionary for paper {key}")
        # get metadata dictionary for paper with paper_id: key
        meta_index = single_chunk['meta_key_idx'].get(key, None)
        if verbose: print(f"meta_index: {meta_index}")
        metadata = single_chunk['metadata'][meta_index] if meta_index is not None else dict()
        # print(metadata)
        
        if verbose: print(f"[INFO] Analyse pdf_parses dictionary for paper {key}")
        # get pdf_parses dictionary for paper with paper_id: key
        pdf_index = single_chunk['pdf_key_idx'].get(key, None)
        if verbose: print(f"pdf_index: {pdf_index}")
        pdf_parses = single_chunk['pdf_parses'][pdf_index] if pdf_index is not None else dict()
        # print(pdf_parses)

        def not_None(element):
            """
                Here we see if the element is None, '' or [] 
                considering it to be Falsy type, in python.
            """
            if element == None:
                return False
            elif type(element)==str and element is '':
                return False
            elif type(element)==list and element is []:
                return False
            return True
        
        def fuse_field(meta_field, pdf_field):
            """
                With inspiration from https://docs.python.org/3/library/stdtypes.html#truth-value-testing
                both '' and `None` seems to be Falsy type, in python.
            """
            
            class s2orcBaseElement():
                """
                    'section': str,
                    'text': str,
                    'cite_spans': list,
                    'ref_spans': list                
                """
                def __init__(self, dictionary):
                    self.section:str = dictionary['section']
                    self.text:str = dictionary['text']
                    self.cite_spans:list = dictionary['cite_spans']
                    self.ref_spans:list = dictionary['ref_spans']
                    
                def get_text(self):
                    return self.text
            
            if type(pdf_field)==list:
                pdf_field = ' '.join([s2orcBaseElement(elem).get_text() for elem in pdf_field])
                # print(pdf_field)
        
            return meta_field if not_None(meta_field) else pdf_field
        
        if verbose: print(f"[INFO] Start fusion for paper {key}")
        paper = dict()
        for field in data_field:
            if print_all_debug: print(f"[INFO] Fusing field {field} for meta ({metadata.get(field, None)}) and pdf_parses ({pdf_parses.get(field, None)})")
            paper[field] = fuse_field(metadata.get(field, None), pdf_parses.get(field, None))
        
        paper_list.append(paper)
        
        if print_all_debug: print(f"[INFO] Deleting meta and pdf for paper {key}")
        # if meta_index is not None: del single_chunk['metadata'][meta_index]
        # if pdf_index is not None: del single_chunk['pdf_parses'][pdf_index]
    
    # Dataset.from_pandas(my_dict) could be a good try if we only convert our paper_list to Pandas Dataframes
    paper_df = pd.DataFrame(paper_list)
    
    # print(paper_df)
    # print(paper_df['title'][121560], type(paper_df['title'][121560]))
    # print(paper_df['abstract'][121560], type(paper_df['abstract'][121560]))
    
    return hfDataset.from_pandas(paper_df)

In [None]:
DATA_FIELD =  ["title", "abstract"]
dataset_dict_test = fuse_dictionaries(multichunks_lists[0], data_field=DATA_FIELD)

In [None]:
def getDataset( single_chunk: list,
                tokenizer: PreTrainedTokenizer,
                max_seq_length: int = None,
                batch_size: int = 64,
                num_workers: int = 4,
                seed: int = SEED,
                data_field: List[str] =  ["title", "abstract"]) -> Dict[str, DataLoader]:
    """Given an input file, prepare the train, test, validation dataloaders.
    :param dataset_f: input file (format: .txt; line by line)
    :param tokenizer: pretrained tokenizer that will prepare the data, i.e. convert tokens into IDs
    :param max_seq_length: maximal sequence length. Longer sequences will be truncated
    :param batch_size: batch size for the dataloaders
    :param num_workers: number of CPU workers to use during dataloading. On Windows this must be zero
    :return: a dictionary containing train, test, validation dataloaders
    """
    print_all_debug = False
    time_debug = True
    print_some_debug = False

    ## ------------------ ##
    ## -- LOAD DATASET -- ##
    ## ------------------ ##
    if time_debug: start = time.time()
    if time_debug: start_load = time.time()
        
    ## execution
    max_seq_length = tokenizer.model_max_length if not max_seq_length else max_seq_length
    if print_some_debug: print(max_seq_length)
    dataset_dict = fuse_dictionaries(single_chunk, data_field)
    
    # print(dataset_dict)
    
    if print_some_debug: print(dataset_dict)

    if time_debug: end_load = time.time()
    if time_debug: print(f"[TIME] load_dataset: {end_load - start_load}")
    
    ## ------------------ ##
    ## ---- MANAGING ---- ##
    ## ------------------ ##
    if time_debug: start_selection = time.time()
    
    ## execution
    dataset = dataset_dict #['train']
    
    if time_debug: end_selection = time.time()
    if time_debug: print(f"[TIME] dataset_train selection: {end_selection - start_selection}")
    if print_all_debug: print(dataset)
   
    ## ------------------ ##
    ## --- REMOVE none -- ##
    ## ------------------ ##
    if time_debug: start_removing = time.time()
    # clean input removing papers with **None** as abstract/title
    if remove_None_papers:

        ## --------------------- ##
        ## --- REMOVE.indexes -- ##
        ## --------------------- ##
        if time_debug: start_removing_indexes = time.time()
        if print_all_debug: print(data_field)
        
        ## execution
        none_papers_indexes = {}
        for field in data_field:
            none_indexes = [ idx_s for idx_s, s in enumerate(dataset[f"{field}"]) if s is None]
            none_papers_indexes = {**none_papers_indexes, **dict.fromkeys(none_indexes , False)}

        if time_debug: end_removing_indexes = time.time()
        if time_debug: print(f"[TIME] remove.indexes: {end_removing_indexes - start_removing_indexes}")
        if print_all_debug: print(none_papers_indexes)
        
        ## --------------------- ##
        ## --- REMOVE.concat --- ##
        ## --------------------- ##
        if time_debug: start_removing_concat = time.time()
        
        ## execution
        to_remove_indexes = list(none_papers_indexes.keys())

        if time_debug: end_removing_concat = time.time()
        if time_debug: print(f"[TIME] remove.concat: {end_removing_concat - start_removing_concat}")
        if print_all_debug: print(to_remove_indexes)
        if print_all_debug: print([ dataset["abstract"][i] for i in to_remove_indexes])

        ## --------------------- ##
        ## --- REMOVE.filter --- ##
        ## --------------------- ##
        if time_debug: start_removing_filter = time.time()
        
        ## execution
        dataset = dataset.filter((lambda x, ids: none_papers_indexes.get(ids, True)), with_indices=True)
        
        if time_debug: end_removing_filter = time.time()
        if time_debug: print(f"[TIME] remove.filter: {end_removing_filter - start_removing_filter}")
        if print_all_debug: print(dataset)

        
    if time_debug: end_removing = time.time()
    if time_debug: print(f"[TIME] remove None fields: {end_removing - start_removing}")

    ## --------------------- ##
    ## --- REMOVE.column --- ##
    ## --------------------- ##
    if time_debug: start_remove_unused_columns = time.time()
    if remove_Unused_columns:
        
        for column in dataset.column_names:
            if column not in data_field:
                if debug: print(f"{column}")
                dataset.remove_columns_(column)

    if time_debug: end_remove_unused_columns = time.time()
    if time_debug: print(f"[TIME] remove.column: {end_remove_unused_columns - start_remove_unused_columns}")
        
    ## ------------------ ##
    ## --- SPLIT 1.    -- ##
    ## ------------------ ##
    if time_debug: start_first_split = time.time()
    
    # 80% (train), 20% (test + validation)
    ## execution
    train_testvalid = dataset.train_test_split(test_size=0.2, seed=SEED)
    
    if time_debug: end_first_split = time.time()
    if time_debug: print(f"[TIME] first [train-(test-val)] split: {end_first_split - start_first_split}")

    ## ------------------ ##
    ## --- SPLIT 2.    -- ##
    ## ------------------ ##
    if time_debug: start_second_split = time.time()
    
    # 10% of total (test), 10% of total (validation)
    ## execution
    test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=SEED)

    if time_debug: end_second_split = time.time()
    if time_debug: print(f"[TIME] second [test-val] split: {end_second_split - start_second_split}")

    ## execution
    dataset = DatasetDict({"train": train_testvalid["train"],
                          "test": test_valid["test"],
                          "valid": test_valid["train"]})
    if time_debug: end = time.time()
    if time_debug: print(f"[TIME] TOTAL: {end - start}") 
    return dataset

## Multichunks getDatasets
---

In [19]:
# tokenizer from 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = BertForMaskedLM.from_pretrained(MODEL_PATH)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
max_seq_length = model.config.max_position_embeddings
print(max_seq_length)

512


In [None]:
%time

dictionary_input = { "data": ["abstract"], "target": ["title"], "classes": ["mag_field_of_study"]}
dictionary_columns = sum(dictionary_input.values(), [])

# here we use meta_s2orc for speed, 
dataset = getDataset(multichunks_lists[0], tokenizer, data_field=dictionary_columns, max_seq_length=max_seq_length)

In [None]:
dataset

## S2ORC Preparation
---

To build a generic loading function we take inspiration from [here](https://discuss.huggingface.co/t/pipeline-with-custom-dataset-tokenizer-when-to-save-load-manually/1084/11).

(unused)
```python
class S2orcDataField(Enum):
    TITLE: List[str] = ["title"]
    ABSTRACT: List[str] = ["abstract"]
    PAPER_ID: List[str] = ["paper_id"]
    YEAH: List[str] = ["year"]
    MAG_FIELD_OF_STUDY: List[str] = ["mag_field_of_study"]
    S2_URL: List[str] = ["s2_url"]
    TITLE_ABSTRACT: List[str] = ["title", "abstract"]
```

(unused original function)
```python
def prepare_data(dataset_f: str,
                tokenizer: PreTrainedTokenizer,
                max_seq_length: int = None,
                batch_size: int = 64,
                num_workers: int = 0,
                seed: int = SEED,
                data_field: List[str] =  ["title", "abstract"]) -> Dict[str, DataLoader]:
    """Given an input file, prepare the train, test, validation dataloaders.
    :param dataset_f: input file (format: .txt; line by line)
    :param tokenizer: pretrained tokenizer that will prepare the data, i.e. convert tokens into IDs
    :param max_seq_length: maximal sequence length. Longer sequences will be truncated
    :param batch_size: batch size for the dataloaders
    :param num_workers: number of CPU workers to use during dataloading. On Windows this must be zero
    :return: a dictionary containing train, test, validation dataloaders
    """
    max_seq_length = tokenizer.model_max_length if not max_seq_length else max_seq_length

    def preprocess(sentences: List[str]): #-> Dict[str, Union[list, Tensor]]:
        """Preprocess the raw input sentences from the text file.
        :param sentences: a list of sentences (strings)
        :return: a dictionary of "input_ids"
        """
        tokens = [s.strip().split() for s in sentences]
        tokens = [t[:max_seq_length - 1] + [tokenizer.eos_token] for t in tokens]

        # The sequences are not padded here. we leave that to the dataloader in a collate_fn
        # ----------------------------------------------- #
        # -------- TODO include the `collate_fn` -------- #
        # ----------------------------------------------- #
        # That means: a bit slower processing, but a smaller saved dataset size
        encoded_d = tokenizer(tokens,
                             add_special_tokens=False,
                             is_pretokenized=True,
                             return_token_type_ids=False,
                             return_attention_mask=False)

        return {"input_ids": encoded_d["input_ids"]}

    dataset_dict = load_dataset("json", data_files=dataset_f)
    # dataset = Dataset.from_dict({"text": Path(dataset_f).read_text(encoding="utf-8").splitlines()})
    dataset = dataset_dict['train']
    # 90% (train), 20% (test + validation)
    train_testvalid = dataset.train_test_split(test_size=0.2, seed=SEED)
    # 10% of total (test), 10% of total (validation)
    test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=SEED)

    dataset = DatasetDict({"train": train_testvalid["train"],
                          "test": test_valid["test"],
                          "valid": test_valid["train"]})
    print(dataset)
    """
    choose one of the dataset columns:
    - IMPORTANT fields: 
        'title', 'authors', 'abstract', 
    - LESS important fields: 
        'paper_id', 'year', 'arxiv_id', 'acl_id', 'pmc_id', 'pubmed_id', 'doi', 
        'venue', 'journal', 'mag_id', 'mag_field_of_study', 
        'outbound_citations', 'inbound_citations', 'has_outbound_citations', 'has_inbound_citations', 
        'has_pdf_body_text', 'has_pdf_parse', 'has_pdf_parsed_abstract', 'has_pdf_parsed_body_text', 'has_pdf_parsed_bib_entries', 'has_pdf_parsed_ref_entries', 
        's2_url'
    """
    dataset = dataset.map(preprocess, input_columns=data_field, batched=True)
    dataset.set_format("torch", columns=["input_ids"])

    return {partition: DataLoader(ds,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=num_workers,
                                 pin_memory=True) for partition, ds in dataset.items()}
```

(unused function)
```python
# tokenizer from 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

# tokenizer.add_special_tokens({"eos_token": "[EOS]"})
DATA_FIELD =  ["abstract"]

prepare_data(meta_s2orc, tokenizer, DATA_FIELD)
```

---
---
---
## ❌ PARTIAL PREPARE
---

Max sequence length from tokenizer, model and input might be differents:

---
---
---

In [None]:
print(dataset)

In [None]:
def preprocess(*sentences_by_column, data, target, classes): #-> Dict[str, Union[list, Tensor]]:
    """Preprocess the raw input sentences from the text file.
    :param sentences: a list of sentences (strings)
    :return: a dictionary of "input_ids"
    """
    print_all_debug = False
    time_debug = False
    print_some_debug = False

    if debug: print(f"[INFO-START] Preprocess on data: {data}, target: {target}") 
    
    assert data == ['abstract'], "data should be ['abstract']"
    if debug: print(data)
    assert target == ['title'], "target should be ['title']"
    if debug: print(target)
        
    data_columns_len = len(data)
    target_columns_len = len(target)
    columns_len = data_columns_len + target_columns_len
    
    assert data_columns_len == 1, "data length should be 1"
    if debug: print(data_columns_len)
    assert target_columns_len == 1, "target length should be 1"
    if debug: print(target_columns_len)
        
    sentences_by_column = np.asarray(sentences_by_column)
    input_columns_len = len(sentences_by_column)
    
    if debug: print(f'all sentences (len {input_columns_len}): {sentences_by_column}')
    
    if target_columns_len == 0:
        raise NameError("No target variable selected, \
                    are you sure you don't want any target?")
        
    data_sentences = sentences_by_column[0]
    target_sentences = sentences_by_column[1] # if columns_len == input_columns_len else sentences_by_column[data_columns_len:-1]
    
    if debug: print(data_sentences)
    if debug: print(target_sentences)

    """
    # clean input removing **None**, converting them to **''**
    if clean_None_data:
        data_sentences = np.asarray([ s if s is not None else '' for s in data_sentences])
        target_sentences = np.asarray([ s if s is not None else '' for s in target_sentences])

    # clean input removing papers with **None** as abstract/title
    elif remove_None_data:
        none_data_indexes = np.asarray([ idx_s for idx_s, s in enumerate(data_sentences) if s is None])
        none_target_indexes = np.asarray([ idx_s for idx_s, s in enumerate(target_sentences) if s is None])

        if debug: print(none_data_indexes)
        if debug: print(none_target_indexes)

        to_removed_indexes = np.unique(none_data_indexes, none_target_indexes)

        if debug: print(to_removed_indexes)

        data_sentences = np.delete(data_sentences, to_removed_indexes)
        target_sentences = np.delete(target_sentences, to_removed_indexes)
    
    if debug: print(data_sentences)
    if debug: print(target_sentences)
    """
    
    # sentences = [s for s in sentences if s is not None]
    # tokens = [s.strip().split() for s in sentences]
    # tokens = [t[:max_seq_length - 1] + [tokenizer.eos_token] for t in tokens]

    # The sequences are not padded here. we leave that to the dataloader in a collate_fn
    # ----------------------------------------------- #
    # -------- TODO include the `collate_fn` -------- #
    # ----------------------------------------------- #
    # That means: a bit slower processing, but a smaller saved dataset size
    if print_some_debug: print(max_seq_length)
        
    data_encoded_d = tokenizer(
                        text=data_sentences.tolist(),
                        # add_special_tokens=False,
                        # is_pretokenized=True,
                        padding=True, truncation=True, max_length=max_seq_length,
                        return_token_type_ids=False,
                        return_attention_mask=False,
                        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
                        # receives the `special_tokens_mask`.
                        return_special_tokens_mask=True,
                        return_tensors='np'
    )
    
    target_encoded_d = tokenizer(
                        text=target_sentences.tolist(),
                        # add_special_tokens=False,
                        # is_pretokenized=True,
                        padding=True, truncation=True, max_length=max_seq_length,
                        return_token_type_ids=False,
                        return_attention_mask=False,
                        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
                        # receives the `special_tokens_mask`.
                        return_special_tokens_mask=True,
                        return_tensors='np'
    )

                            

    if debug: print(data_encoded_d["input_ids"].shape)
    if debug: print(target_encoded_d["input_ids"].shape)
    # return encoded_d
    
    return {"data_input_ids": data_encoded_d["input_ids"], "target_input_ids": target_encoded_d["input_ids"]}
    # return {"input_ids": sum(encoded_d['input_ids'], [])} 

print an example
```python 
print(dataset['train'][:10]['title'], dataset['train'][:10]['abstract'])
```

In [None]:
vocab = tokenizer.get_vocab()
print(f"[PAD]: {vocab['[PAD]']}")
print(f"[UNK]: {vocab['[UNK]']}")
print(f"[SEP]: {vocab['[SEP]']}")
print(f"[CLS]: {vocab['[CLS]']}")
print(f"0: {tokenizer.convert_ids_to_tokens(0)}")
print(f"1: {tokenizer.convert_ids_to_tokens(1)}")
print(f"2: {tokenizer.convert_ids_to_tokens(2)}")
print(f"99: {tokenizer.convert_ids_to_tokens(99)}")
print(f"100: {tokenizer.convert_ids_to_tokens(100)}")
print(f"101: {tokenizer.convert_ids_to_tokens(101)}")

In [None]:
tokenizer

Finally, I found [this](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=datasetdict#datasets.DatasetDict.map) documentation for the function `DatasetDict.map` from the `dataset` library.

In [None]:
debug = False

dataset_map = dataset.map(preprocess, input_columns= dictionary_columns, fn_kwargs= dictionary_input, batched=True)

In [None]:
dataset_map

In [None]:
mag_field_dict: Dict = {
    "Medicine":    0,
    "Biology":     1,
    "Chemistry":   2,
    "Engineering": 4,
    "Computer Science":    5,
    "Physics":     6,
    "Materials Science":     7,
    "Mathematics":        8,
    "Psychology":  9,
    "Economics":   10,
    "Political Science":    11,
    "Business":    12,
    "Geology":     13,
    "Sociology":   14,
    "Geography":   15,
    "Environmental Science":     16,
    "Art":         17,
    "History":     18,
    "Philosophy":  19
    # "null":         3, 
}

# The key null is actually null, not "null":str
#
#      real_mag_field_value = paper_metadata['mag_field_of_study']
#
# so we could return the id 3 if it not contained as key of dictionary
#
#      mag_field_dict.get(real_mag_field_value, 3)
#

In [None]:
def mag_preprocessing(*mags):
    """Preprocess the raw input sentences from the text file.
    :param sentences: a list of sentences (strings)
    :return: a dictionary of "input_ids"
    """
    debug = False
    
    if debug: print(f"[INFO-START] Mag Preprocess") 
        
    mag_field = np.array(mags)
    input_columns_len = mag_field.shape
    if debug: print(f'pre flatten (len {input_columns_len}): {mag_field}')
    if debug: print(f'pre types: {[type(ele) for ele in mag_field]}')
    if debug: print(f'pre types: {type(mag_field)}')
    
    mag_field = mag_field.flatten()
    input_columns_len = mag_field.shape
    if debug: print(f'after flatten (len {input_columns_len}): {mag_field}')
    if debug: print(f'after types: {[type(ele) for ele in mag_field]}')
    if debug: print(f'after types: {type(mag_field)}')
        
    mag_field = np.array([ele if type(ele) == str else list(ele)[0] for ele in mag_field])
        
    if input_columns_len == 0:
        raise NameError("No mag variable selected, \
                    are you sure you don't want any target?")
    
    if debug: print(mag_field)
    if debug: print(mag_field_dict)
    if debug: print([mag_field_dict.get(real_mag_field_value, 3) for real_mag_field_value in mag_field])
        
    mag_index = np.asarray([mag_field_dict.get(real_mag_field_value, 3) for real_mag_field_value in mag_field])
    
    if debug: print(mag_index)
    
    return {"mag_index": mag_index}
    # return

In [None]:
dataset_mag_map = dataset_map.map(mag_preprocessing, input_columns= dictionary_input['classes'], batched=True)

In [None]:
dataset_mag_map

# Rename it as you want
---

- `dataset_map.rename_column` ,method for renaming
- `dataset_map.set_format`, method for define what columns need to be returned

In [None]:
dataset_mag_map = dataset_mag_map.rename_column("data_input_ids", "input_ids")

In [None]:
dataset_mag_map.set_format("torch", columns=["input_ids"])

In [None]:
print(dataset_mag_map['train'][1]['input_ids'].size())

Then, if you want to store it, it will be stored in the conda environment you are

In [None]:
%store dataset_mag_map

---
---
## ❌ FAKE PIPELINE for train BERT-based NETS
---
---

In [6]:
%store -r dataset_mag_map

In [7]:
dataset_mag_map

DatasetDict({
    train: Dataset({
        features: ['abstract', 'input_ids', 'mag_field_of_study', 'mag_index', 'target_input_ids', 'title'],
        num_rows: 71952
    })
    test: Dataset({
        features: ['abstract', 'input_ids', 'mag_field_of_study', 'mag_index', 'target_input_ids', 'title'],
        num_rows: 8994
    })
    valid: Dataset({
        features: ['abstract', 'input_ids', 'mag_field_of_study', 'mag_index', 'target_input_ids', 'title'],
        num_rows: 8994
    })
})

In [8]:
# tokenizer: we already have it
# model: we already have it

# If you print some element from `dataset_map['train'][element_index]['input_ids']` you'll see that lots of element
vect = [ele[ele.nonzero()].size(0) for ele in dataset_mag_map['train'][:]['input_ids']]

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370156314/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  """


In [9]:
max_vect = max(vect)
min_vect = min(vect)
sum_vect = sum(vect)
len_vect = len(vect)

print(f" max: {max_vect} \n min: {min_vect} \n avg: {sum_vect/len_vect}")

 max: 512 
 min: 2 
 avg: 162.88874527462752


In [10]:
print(dataset_mag_map['train'][:10])

{'input_ids': tensor([[  102, 11261,   669,  ...,     0,     0,     0],
        [  102,  6773,   165,  ...,     0,     0,     0],
        [  102,   833,   111,  ...,     0,     0,     0],
        ...,
        [  102,  5005,  4994,  ...,     0,     0,     0],
        [  102,   121,   238,  ...,     0,     0,     0],
        [  102, 15794,   190,  ...,     0,     0,     0]])}


In [11]:
dataset_mag_map.set_format("torch", columns=["input_ids", "target_input_ids", "mag_index"])

In [13]:
print(dataset_mag_map['train'][:2])

{'input_ids': tensor([[  102, 11261,   669,  ...,     0,     0,     0],
        [  102,  6773,   165,  ...,     0,     0,     0]]), 'mag_index': tensor([8, 5]), 'target_input_ids': tensor([[  102,   130,  4119, 19638,   579,   791,  4604,   727,   467,  3427,
          2713,   131, 16982,   103,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  102,  2602,  2516,   190,  2652,   645, 28701,  2554,   137,   633,
          1836,   147,  6773,   103,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0]])}


---
---
## ❌ FAKE PIPELINE for train BERT-based NETS
---
---

From [here](https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py) you can get an idea from were the code has been borrowed.

In [21]:
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

train_dataset = dataset_mag_map['train']
eval_dataset = dataset_mag_map['valid']

# Inizialize TrainerArguments
training_args = TrainingArguments(
    output_dir=output_dir,           # [def.`tmp_trainer`] output directory
    num_train_epochs=3,              # [def.   3 ] total # of training epochs
    per_device_train_batch_size=8,  # [def.   8 ] batch size per device during training
    per_device_eval_batch_size=8,   # [def.   8 ] batch size for evaluation
    evaluation_strategy="no",     # [def. 'no'] evaluation is done (and logged) every eval_steps
    warmup_steps=0,                # [def.   0 ] number of warmup steps for learning rate scheduler
    weight_decay=0,               # [def.   0 ] strength of weight decay 
    learning_rate=5e-5,              # [def. 5e-5] 
    logging_dir='./logs',            # [def. runs/__id__] directory for storing logs. TensorBoard log directory.
)

# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)



In [22]:
# Training

checkpoint = None
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model()  # Saves the tokenizer too for easy upload
metrics = train_result.metrics

max_train_samples = len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()



[34m[1mwandb[0m: Currently logged in as: [33memanuelevivoli[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Step,Training Loss
500,1.74
1000,1.7683
1500,1.7458
2000,1.772
2500,1.7579
3000,1.7181
3500,1.7299
4000,1.7038
4500,1.6933
5000,1.6808


AttributeError: 'Trainer' object has no attribute 'log_metrics'

In [None]:
# Evaluation

logger.info("*** Evaluate ***")

metrics = trainer.evaluate()

max_val_samples = dataset_args.max_val_samples if dataset_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
perplexity = math.exp(metrics["eval_loss"])
metrics["perplexity"] = perplexity

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

---
# 1. Introduction
---

The following datasets were downloaded from the internet (we try to provide links to those we have the right to do so). We divide the dataset based on the task they are mostly used for.

## 1.1 Keyphrase task
---

SOTA: [keyphrase generation](https://arxiv.org/pdf/1704.06879.pdf).

The Keyphrase datasets (***duc***, ***Inspect***, ***Krapivin***, ***NUS***, ***SemEval-2010***, ***KP20k dataset***, ***MagKP-CS***) are structured as follow:

- title
- abstract
- fulltext
- keywords

The only dataset that variates is ***STACKEX*** that instead of having *abstract* and *keywords* has:

- question (abstract)
- tags (keywords)

Here there is a list of the datasets previously cited, with some information:

- **duc**, we haven't had much information on this dataset untill now.

- **Inspec** [(Hulth, 2003)](https://www.aclweb.org/anthology/W03-1028.pdf), This dataset provides *2,000 paper abstracts*. We adopt the *500 testing* papers and their corresponding uncontrolled keyphrases for evaluation, and the remaining *1,500 papers* are used for *training* the supervised baseline models.

- **Krapivin** [(Krapivin et al., 2008)](http://eprints.biblio.unitn.it/1671/1/disi09055-krapivin-autayeu-marchese.pdf): This dataset provides *2,304 papers with full-text* and *author-assigned keyphrases*. However, the author did not mention how to split testing data, so we selected the first *400 papers in alphabetical order as the testing data*, and the *remaining* papers are used to *train* the supervised baselines.

- **NUS** [(Nguyen and Kan, 2007)](https://www.comp.nus.edu.sg/~kanmy/papers/icadl2007.pdf): We use both author-assigned and reader-assigned keyphrases and treat *all 211 papers as the testing data*. Since the NUS dataset did not specifically mention the ways of splitting training and testing data, the results of the supervised baseline models are obtained through a *five-fold cross-validation*.

- **SemEval-2010** [(Kim et al., 2010)](https://www.aclweb.org/anthology/S10-1004.pdf): 288 articles were collected from the ACM Digital Library. 100 articles were used for testing and the rest were used for training supervised baselines.

- **KP20k dataset** [(Meng et al., 2018)](https://arxiv.org/abs/1704.06879): They built a new testing dataset that contains the *titles, abstracts, and keyphrases* of *20,000 scientific articles* in computer science. They were *randomly selected from their obtained 567,830 articles*. Thus they took the 20,000 articles in the validation set to train the supervised baselines.

- **MagKP-CS** (from OpenNMT-py and [OpenNMT-kpg-release](https://github.com/memray/OpenNMT-kpg-release)) that is available for download. 

- **STACKEX** (from [StackExchange](https://archive.org/details/stackexchange)) has been constructed from the computer science forums (CS/AI) at StackExchange using “title” + “body” as source text and “tags” as the target keyphrases. After removing questions without valid tags, they collected 330,965 questions. They randomly selected *16,000 for validation*, and another *16,000 as test set*. Note some questions in StackExchange forums contain large blocks of code, resulting in long texts (sometimes more than 10,000 tokens after tokenization), this is difficult for most neural models to handle. Consequently, the texts have been truncated to 300 tokens and 1,000 tokens for training and evaluation splits respectively.

###### ⚠️ATTENTION
> As we aren't going to use the Keyphrase dataset for now, we don't need any custom classes for managing this dataset. We will implement this functions and classes as we go, if there will be the needs.

## 1.2 Sentence embedding task
---

SOTA: [sBERT](https://arxiv.org/abs/1908.10084)

- **SNLI** [(Bowman et al., 2015)](https://arxiv.org/abs/1508.05326) is a collection of *570,000 sentence pairs* annotated with the *labels contradiction, eintailment, and neutral*.

- **MultiNLI** [(Williams et al., 2018)](https://arxiv.org/abs/1704.05426) contains *430,000 sentence pairs* and covers a *range of genres of spoken and written text*.

- **SciTail** [(allenai)](http://ai2-website.s3.amazonaws.com/publications/scitail-aaai-2018_cameraready.pdf), the entailment dataset consists of 27k. In contrast to the SNLI and MultiNLI, it was not crowd-sourced but created from sentences that already exist “in the wild”. *Hypotheses* were created from *science questions* and the corresponding *answer candidates*, while relevant web sentences from a large corpus were used as premises. Models are evaluated based on accuracy.

###### ❌ATTENTION
> As we aren't going to use the NLI tasks dataset (for now), we don't need any custom classes for managing this dataset. We will implement this functions and classes as we go, if there will be the needs.

## 1.3 Generic NLP tasks
---

- **S2ORC** [(Lo et al., 2020)](https://github.com/allenai/s2orc) is a large corpus of *81.1M English-language academic papers* spanning many academic disciplines. The corpus consists of *rich metadata, paper abstracts, resolved bibliographic references*, as well as *structured full text for 8.1M open access papers*. Full text is annotated with automatically-detected inline mentions of citations, figures, and tables, each linked to their corresponding paper objects. In S2ORC, they aggregate papers from hundreds of academic publishers and digital archives into a unified source, and create the largest publicly-available collection of machine-readable academic text to date. Built for text mining over academic text.

- **OAG** [(Tang et al., 2008)](http://keg.cs.tsinghua.edu.cn/jietang/publications/KDD08-Tang-et-al-ArnetMiner.pdf)  is a large knowledge graph unifying *two billion-scale academic graphs*: Microsoft Academic Graph (**MAG**) and **AMiner**. In mid 2017, they published OAG v1, which contains *166,192,182 papers from MAG and 154,771,162 papers from AMiner* and generated *64,639,608 linking (matching) relations between the two graphs*. This time, in OAG v2, author, venue and newer publication data and the corresponding matchings are available.



###### ✅ATTENTION
> We are going to use the S2ORC dataset as it contains full_text data as well as citation/reference informations. It contains also authorship - title - tables data that we will describe below.

---
# 2. S2ORC
---

## 2.1 Description (s2orc)
---
The `S2ORC` dataset is in the `data` path under the folder `s2orc-full-20200705v1` (where `s2orc` is the name of the dataset, `full` is the type, as there is also a sample fingerprint; and `20200705v1` is the version). 
We can reach the data folder exiting by the project and entering in the data folder:

As you can see (going into `s2orc-full-20200705v1/full/`) there are the `metadata` folder and the `pdf_parses` folder. The main difference (as we can already get it from the names) is that in the `metadata` you only have some information about the dataset (retrieved from the published metadata), while in the `pdf_parses` you get all the extensive data conteined in the paper (if the paper was present, was correctly parsed and no restriction in the paper data were applied due to limited licence permition). For some reason, the `title` of the paper is contained only in the `metadata` file, but it can get from the `paper_id` field of the paper itself.

More information about the `S2ORC` dataset can be read in the [README.md](https://github.com/allenai/s2orc/blob/master/README.md) of the project and in the [project repository](https://github.com/allenai/s2orc/)

### mag field
- MAG fields of study:

| Field of study | All papers | Full text |
|----------------|------------|-----------|
| Medicine       | 12.8M      | 1.8M      |
| Biology        | 9.6M       | 1.6M      |
| Chemistry      | 8.7M       | 484k      |
| n/a            | 7.7M       | 583k      |
| Engineering    | 6.3M       | 228k      |
| Comp Sci       | 6.0M       | 580k      |
| Physics        | 4.9M       | 838k      |
| Mat Sci        | 4.6M       | 213k      |
| Math           | 3.9M       | 669k      |
| Psychology     | 3.4M       | 316k      |
| Economics      | 2.3M       | 198k      |
| Poli Sci       | 1.8M       | 69k       |
| Business       | 1.8M       | 94k       |
| Geology        | 1.8M       | 115k      |
| Sociology      | 1.6M       | 93k       |
| Geography      | 1.4M       | 58k       |
| Env Sci        | 766k       | 52k       |
| Art            | 700k       | 16k       |
| History        | 690k       | 22k       |
| Philosophy     | 384k       | 15k       |

We need now a function that reads all the lines of the `jsonl` files inside both `metadata` and `pdf_parses` folders. Then we'll 

## `metadata` schema

We recommend everyone work with `metadata/` as the starting point.  This is a JSONlines file (one line per paper) with the following keys:

#### Identifier fields

* `paper_id`: a `str`-valued field that is a unique identifier for each S2ORC paper.

* `arxiv_id`: a `str`-valued field for papers on [arXiv.org](https://arxiv.org).

* `acl_id`: a `str`-valued field for papers on [the ACL Anthology](https://www.aclweb.org/anthology/).

* `pmc_id`: a `str`-valued field for papers on [PubMed Central](https://www.ncbi.nlm.nih.gov/pmc/articles).

* `pubmed_id`: a `str`-valued field for papers on [PubMed](https://pubmed.ncbi.nlm.nih.gov/), which includes MEDLINE.  Also known as `pmid` on PubMed.

* `mag_id`: a `str`-valued field for papers on [Microsoft Academic](https://academic.microsoft.com).

* `doi`: a `str`-valued field for the [DOI](http://doi.org/).  

Notably:

* Resolved citation links are represented by the cited paper's `paper_id`.

* The `paper_id` resolves to a Semantic Scholar paper page, which can be verified using the `s2_url` field.

* We don't always have a value for every identifier field.  When missing, they take `null` value.


#### Metadata fields

* `title`: a `str`-valued field for the paper title.  Every S2ORC paper *must* have one, though the source can be from publishers or parsed from PDFs.  We prioritize publisher-provided values over parsed values.

* `authors`: a `List[Dict]`-valued field for the paper authors.  Authors are listed in order.  Each dictionary has the keys `first`, `middle`, `last`, and `suffix` for the author name, which are all `str`-valued with exception of `middle`, which is a `List[str]`-valued field.  Every S2ORC paper *must* have at least one author.

* `venue` and `journal`: `str`-valued fields for the published venue/journal.  *Please note that there is not often agreement as to what constitutes a "venue" versus a "journal". Consolidating these fields is being considered for future releases.*   

* `year`: an `int`-valued field for the published year.  If a paper is preprinted in 2019 but published in 2020, we try to ensure the `venue/journal` and `year` fields agree & prefer non-preprint published info. *We know this decision prohibits certain types of analysis like comparing preprint & published versions of a paper.  We're looking into it for future releases.*  

* `abstract`: a `str`-valued field for the abstract.  These are provided directly from gold sources (not parsed from PDFs).  We preserve newline breaks in structured abstracts, which are common in medical papers, by denoting breaks with `':::'`.     

* `inbound_citations`: a `List[str]`-valued field containing `paper_id` of other S2ORC papers that cite the current paper.  *Currently derived from PDF-parsed bibliographies, but may have gold sources in the future.*

* `outbound_citations`: a `List[str]`-valued field containing `paper_id` of other S2ORC papers that the current paper cites.  Same note as above.   

* `has_inbound_citations`: a `bool`-valued field that is `true` if `inbound_citations` has at least one entry, and `false` otherwise.

* `has_outbound_citations` a `bool`-valued field that is `true` if `outbound_citations` has at least one entry, and `false` otherwise.

We don't always have a value for every metadata field.  When missing, `str` fields take `null` value, while `List` fields are empty lists.

#### PDF parse-related metadata fields

* `has_pdf_parse`:  a `bool`-valued field that is `true` if this paper has a corresponding entry in `pdf_parses/`, which means we had processed that paper's PDF(s) at some point.  The field is `false` otherwise.

* `has_pdf_parsed_abstract`: a `bool`-valued field that is `true` if the paper's PDF parse contains a parsed abstract, and `false` otherwise.   

* `has_pdf_parsed_body_text`: a `bool`-valued field that is `true` if the paper's PDF parse contains parsed body text, and `false` otherwise.

* `has_pdf_parsed_bib_entries`: a `bool`-valued field that is `true` if the paper's PDF parse contains parsed bibliography entries, and `false` otherwise.

* `has_pdf_parsed_ref_entries`: a `bool`-valued field that is `true` if the paper's PDF parse contains parsed reference entries (e.g. tables, figures), and `false` otherwise.

Please note:

* If `has_pdf_parse = false`, the other four fields will not be present in the JSON (trivially `false`).

* If `has_pdf_parse = true` but `has_pdf_parsed_abstract`, `has_pdf_parsed_body_text`, or `has_pdf_parsed_ref_entries` are `false`, this can be because:

    * Our PDF parser failed to extract that element
    * Our PDF parser succeeded but that paper simply did not have that element (e.g. papers without abstracts)
    * Our PDF parser succeeded but that element was removed because the paper is not identified as open-access.  


##### metadata_CLASS
```python
{
 "paper_id": (string), 
 "title": (string), 
 "authors": [
     {
         "first": (string), 
         "middle": [], 
         "last": (string), 
         "suffix": (string)
     },
     ...
   ]: **Author_Class**, 
 "abstract": (string), 
 "year": (int), 
 "arxiv_id": null, 
 "acl_id": null, 
 "pmc_id": null, 
 "pubmed_id": null, 
 "doi": null, 
 "venue": null, 
 "journal": (string), 
 "mag_id": (string-number), 
 "mag_field_of_study": [
     "Medicine",
     "Computer Science"
   ]: **FieldOfStudy_Enum**, 
 "outbound_citations": [], 
 "inbound_citations": [], 
 "has_outbound_citations": false, 
 "has_inbound_citations": false, 
 "has_pdf_parse": false, 
 "s2_url": (string)
}
```

Here I represent Author_Class as an object of 
```python
{
    "first": (string), 
    "middle": [], 
    "last": (string), 
    "suffix": (string)
}
```
and `FieldOfStudy_Enum` as an Enum of string such as `[ "Medicine", "Computer Science", "Physics", "Mathematics", ... ]`


## `pdf_parses` schema

We view `pdf_parses/` as supplementary to the `metadata/` entries.  PDF parses are also represented as JSONlines file (one line per paper) with the following keys:

* `paper_id`: a `str`-valued field which is the same S2ORC paper ID in `metadata/`

* `_pdf_hash`: a `str`-valued field.  Internal usage only.  We use this for debugging.

* `abstract` and `body_text` are `List[Dict]`-valued fields representing parsed text from the PDF.  Each `Dict` corresponds to a paragraph.  `List` preserves their original ordering.

* `bib_entries` and `ref_entries` are `Dict`-valued fields representing extracted entities that can be referenced (inline) within the text.

#### example 1

One example paragraph in `abstract` or `body_text` might look like:

```python
{
    "section": "Introduction",
    "text": "Dogs are happier cats [13, 15]. See Figure 3 for a diagram.",
    "cite_spans": [
        {"start": 22, "end": 25, "text": "[13", "ref_id": "BIBREF11"},
        {"start": 27, "end": 30, "text": "15]", "ref_id": "BIBREF30"},
        ...
    ],
    "ref_spans": [
        {"start": 36, "end": 44, "text": "Figure 3", "ref_id": "FIGREF2"},
    ]
}
```

and example `bib_entries` and `ref_entries` might look like:

```python
{
    ...,
    "BIBREF11": {
        "title": "Do dogs dream of electric humans?",
        "authors": [
            {"first": "Lucy", "middle": ["Lu"], "last": "Wang", "suffix": ""}, 
            {"first": "Mark", "middle": [], "last": "Neumann", "suffix": "V"}
        ],
        "year": "", 
        "venue": "barXiv",
        "link": null
    },
    ...
}
```

```python
{
    "TABREF4": {
        "text": "Table 5. Clearly, we achieve SOTA here or something.",
        "type": "table"
    }
    ...,
    "FIGREF2": {
        "text": "Figure 3. This is the caption of a pretty figure.",
        "type": "figure"
    },
    ...
}
```

Notice: 

* Inline `spans` are represented by character start and end indices into the paragraph `text`
* `spans` resolve to `BIBREF`, `TABREF` or `FIGREF` entries.
* `BIBREF` are IDs of bibliographic elements of `bib_entries`.  Bib entries may be missing fields (e.g. `year`).  They can be linked to S2ORC papers, as specified by `link`, but we also preserve any unlinked entries by setting `link` to `null`.
* `FIGREF` and `TABREF` are IDs of figure and table elements of `ref_entries`.  Ref entries contain the caption text of the corresponding object, and also indicate the type of object.


#### example 2

You may see empty `pdf_parses/` JSONs that look like: 

```python
{
    "paper_id": "...", 
    "_pdf_hash": "...", 
    "abstract": [], 
    "body_text": [], 
    "bib_entries": {}, 
    "ref_entries": {}
}
```

We keep these around for our internal usage, but the way to interpret these is that there is no usable PDF parse here, despite the corresponding `metadata/` entry still displaying `has_pdf_parse = true`.

These exist when (i) `bib_entries` does not successfully parse *and* (ii) the paper is not open-access, so we had to remove `abstract`, `body_text`, and `ref_entries`.   



##### pdf_parses_CLASS
```python
{
 "paper_id": (string), 
 "_pdf_hash": (string-number), 
 "abstract": [
     {
         "section": (string) "Abstract", 
         "text": (string), 
         "cite_spans": [
             {
                 "start": (int), 
                 "end": (int), 
                 "text": (string-number) "[4, 
                 "ref_id": (string)
             }
           ]: **CiteSpan_Class**, 
         "ref_spans": []
     },
     ...
 ]: **TextSection_Class**, 
 "body_text": [], 
 "bib_entries": 
     {
         "BIBREF0": 
             {
              "title": (string), 
              "authors": [
                  {
                      "first": (string), 
                      "middle": [], 
                      "last": (string), 
                      "suffix": (string)
                   }
                 ], 
               "year": (int), 
               "venue": (string), 
               "link": (string-number)
              }, 
          "BIBREF1": 
              {
                  ...
              }
       }: **BIBREF_Class**, 
 "ref_entries": {}
}
```

Here I represent `TextSection_Class` as an object of 
```python
{
 "section": (string), 
 "text": (string), 
 "cite_spans": [
     {
         "start": (int), 
         "end": (int), 
         "text": (string-number) "[4, 
         "ref_id": (string)
     }
   ], 
 "ref_spans": []
}
```
where `CiteSpan_Class` itself is another structured object:
```python
{
 "start": (int), 
 "end": (int), 
 "text": (string-number), 
 "ref_id": (string)
}
```
and `BIBREF_Class` as dictionary field with `BIBREF_#` as key and related to it an object as follow:
```python
"BIBREF_#": 
 {
  "title": (string), 
  "authors": [
      {
          "first": (string), 
          "middle": [] (list of string),
          "last": (string), 
          "suffix": (string)
       }
     ], 
   "year": (int), 
   "venue": (string), 
   "link": null
  }
```

## 2.3 Title Abstract - Full text  (s2orc)
---
We have loaded the `S2ORC` dataset, created our (one chunk) dataset parses and we want now starting creating our dataset objects (Classes and Loaders).

Let's start with the datasets.

### Dataset creation
We want to create the datasets for papers' title-abstract and fulltext-(title-abstract) generation. 
> we'd like also to create a KeyPhrase dataset, we are actualling waiting for the response from the `S2ORC` authors to understand where can we possibly obtain the keyphrases/keywords.

In order to do this, we want to create the two datasets (saving them as `jsonl` files).
We can organize the data folder as :
```bash
- data/
    # keyphrase dataset 
    - keyphrase/
        # (title - abstract - fulltext - keyphrase)
        - s2orc/
            - README.md
            - chuncks_dataset_idx.json
            - train/
                - train_0.jsonl
                - train_1.jsonl
                - ...
            - test/
                - test_0.jsonl
                - test_1.jsonl
                - ...
            - val/
                - val_0.jsonl
                - val_1.jsonl
                - ...
    
    # sts datasets
    - sts/ 
        # (title - abstract - cosine_similarity)
        - s2orc_partial/
            - README.md
            - chuncks_dataset_idx.json
            - train/
                - train_0.jsonl
                - train_1.jsonl
                - ...
            - test/
                - test_0.jsonl
                - test_1.jsonl
                - ...
            - val/
                - val_0.jsonl
                - val_1.jsonl
                - ...
                
        # (title - abstract - fulltext - cosine_similarity)
        - s2orc_full/
            - README.md
            - chuncks_dataset_idx.json
            - train/
                - train_0.jsonl
                - train_1.jsonl
                - ...
            - test/
                - test_0.jsonl
                - test_1.jsonl
                - ...
            - val/
                - val_0.jsonl
                - val_1.jsonl
                - ...
```
and in the `chuncks_dataset_idx.json` there is the dictionary that maps the `chuncks` (`metadata_{id}.jsonl, pdf_parses_{id}.jsonl for id in range(99)`) into the {train|test|validation}_{id}.

A first step to not-using chuncks (neither metadata nor fulltext) anymore is to summarize the data we want into a new python structure (dict) as follow, and save them 

```python
{
    "paper_id": (string-int), 
    "title":  (string),
    "abstract": (string), 
    "fulltext": (string), 
    "keywords": List[string],
}
```

1. get the training/validation dataset by extracting Title-Abstract from the `S2ORC` dataset, and getting the testing data from the `KeyPhrase` (*'inspec', 'krapivin', 'nus', 'semeval', 'kp20k', 'duc', 'stackexchange'*) datasets. We should have a pair of sentence (indicativelly a *title* and an *abstract*), possibly a *fulltext* and a *keywords* fields those can be

    - completelly related (abstract and its corresponding title)
    - someway related (abstract and a field-keyphrase related title {cs+(deep learning; metric learning; nlp; sts;)}
    - unrelated but not far away (abstract and a field-**not**keyphrase related title {cs+(nlp; transformer;)-vs-(cv; attention)}
    - completelly unrelated (abstract and title are field-keyphrase unrelated {cs+a -vs- phy+z})



2. **🤗transformers**, we can see [here](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) the dataset loader (from `jsonl` files) can be used to load train/validation datasets. As we have alrerady load the dataset as dictionary (it is called `multichunks_lists` now, depending on how many chuncks we need to load in one shot) we could also be using the example [here](https://huggingface.co/docs/datasets/loading_datasets.html#from-a-python-dictionary) in order to load the dataset from an existing dictionary. 


1. **sentence-transformer**, [sBERT example for train](https://www.sbert.net/docs/training/overview.html#loss-functions) 

2. **🤗transformers**, we can see [here](https://huggingface.co/docs/datasets/loading_datasets.html#json-files) the dataset loader (from `jsonl` files) can be used to load train/validation datasets. As we have alrerady load the dataset as dictionary (it is called `multichunks_lists` now, depending on how many chuncks we need to load in one shot) we could also be using the example [here](https://huggingface.co/docs/datasets/loading_datasets.html#from-a-python-dictionary) in order to load the dataset from an existing dictionary. 

In [None]:
import torch

# TADataset states for TitleAbstractDataset
class TADataset(torch.utils.data.Dataset):
    
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
paper_id = metadata_dict['paper_id']
print(f"Currently viewing S2ORC paper: {paper_id}")

# suppose we only care about ACL anthology papers
if not metadata_dict['acl_id']:
    continue

# and we want only papers with resolved outbound citations
if not metadata_dict['has_outbound_citations']:
    continue

# get citation context (paragraphs)!
if paper_id in paper_id_to_pdf_parse:
    # (1) get the full pdf parse from the previously computed lookup dict
    pdf_parse = paper_id_to_pdf_parse[paper_id]

    # (2) pull out fields we need from the pdf parse, including bibliography & text
    bib_entries = pdf_parse['bib_entries']
    paragraphs = pdf_parse['abstract'] + pdf_parse['body_text']

    # (3) loop over paragraphs, grabbing citation contexts
    for paragraph in paragraphs:

        # (4) loop over each inline citation in this paragraph
        for cite_span in paragraph['cite_spans']:

            # (5) each inline citation can be resolved to a bib entry
            cited_bib_entry = bib_entries[cite_span['ref_id']]

            # (6) that bib entry *may* be linked to a S2ORC paper.  if so, grab paragraph
            linked_paper_id = cited_bib_entry['link']
            if linked_paper_id:
                citation_contexts.append({
                    'citing_paper_id': paper_id,
                    'cited_paper_id': linked_paper_id,
                    'context': paragraph['text'],
                    'citation_mention_start': cite_span['start'],
                    'citation_mention_end': cite_span['end'],
                })

# 3. Computing Word Embeddings: `Continuous Bag-of-Words`

The Continuous Bag-of-Words model (CBOW) is frequently used in NLP deep learning. It is a model that tries to predict words given the context of a few words before and a few words after the target word. This is distinct from language modeling, since CBOW is not sequential and does not have to be probabilistic. Typcially, CBOW is used to quickly train word embeddings, and these embeddings are used to initialize the embeddings of some more complicated model. Usually, this is referred to as pretraining embeddings. It almost always helps performance a couple of percent.

The CBOW model is as follows. Given a target word $w_i$ and an $N$ context window on each side, $w_{i−1}, … , w_{i−N}$ and $w_{i+1},…,w_{i+N}$, referring to all context words collectively as $C$, CBOW tries to minimize:


$$ −log p(w_i|C) = − log Softmax( A( \sum_{w∈C}{}{q_w})+b) $$

where $q_w$ is the embedding for word $w$.

In [None]:
CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()

# By deriving a set from `raw_text`, we deduplicate the array
vocab = set(raw_text)
vocab_size = len(vocab)

word_to_ix = {word: i for i, word in enumerate(vocab)}
data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    data.append((context, target))
print(data[:5])


class CBOW(nn.Module):

    def __init__(self):
        pass

    def forward(self, inputs):
        pass

# create your model and train.  here are some functions to help you make
# the data ready for use by your module


def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    return torch.tensor(idxs, dtype=torch.long)


make_context_vector(data[0][0], word_to_ix)  # example