This notebook illustrates how to use Masked Language Modeling for this competition.

Observation: most of the dataset names consist of only words with uppercased-first-letter and some stopwords like `on`, `in`, `and` (e.g. `Early Childhood Longitudinal Study`, `Trends in International Mathematics and Science Study`). 

Thus, one approach to find the datasets is: 
- Locate all the sequences of capitalized words (these sequences may contain some stopwords), 
- Replace each sequence with one of 2 special symbols (e.g. `$` and `#`), implying if that sequence represents a dataset name or not.
- Have the model learn the MLM task.
- Distil Bert -> BERT Model -> DistilRoberta -> ROBERTA

The code below shows how to train a model for that purpose with the help of the `huggingface`.

# Install packages

In [2]:
!pip install fsspec==0.9.0

Collecting fsspec==0.9.0
  Downloading fsspec-0.9.0-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 420 kB/s eta 0:00:01
Installing collected packages: fsspec
  Attempting uninstall: fsspec
    Found existing installation: fsspec 0.8.7
    Uninstalling fsspec-0.8.7:
      Successfully uninstalled fsspec-0.8.7
Successfully installed fsspec-0.9.0


In [3]:
import networkx as nx
!pip install datasets --no-index --find-links=../input/coleridge-packages/packages/datasets
!pip install ../input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
!pip install ../input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinux1_x86_64.whl
!pip install ../input/coleridge-packages/transformers-4.5.0.dev0-py3-none-any.whl

Looking in links: ../input/coleridge-packages/packages/datasets
Processing /kaggle/input/coleridge-packages/packages/datasets/datasets-1.5.0-py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/tqdm-4.49.0-py2.py3-none-any.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl
Processing /kaggle/input/coleridge-packages/packages/datasets/huggingface_hub-0.0.7-py3-none-any.whl
Installing collected packages: tqdm, xxhash, huggingface-hub, datasets
  Attempting uninstall: tqdm
    Found existing installation: tqdm 4.59.0
    Uninstalling tqdm-4.59.0:
      Successfully uninstalled tqdm-4.59.0
Successfully installed datasets-1.5.0 huggingface-hub-0.0.7 tqdm-4.49.0 xxhash-2.0.0
Processing /kaggle/input/coleridge-packages/seqeval-1.2.2-py3-none-any.whl
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2
Processing /kaggle/input/coleridge-packages/tokenizers-0.10.1-cp37-cp37m-manylinu

# Import

In [4]:
import os
import re
import json
import time
import datetime
import random
import glob
import importlib

import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import KFold

import tensorflow as tf
import tensorflow.keras as keras
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForLanguageModeling, \
TFAutoModel, AutoConfig

import transformers
import nltk
import pickle
nltk.download('wordnet')
nltk.download('averaged_perceptron_tagger')
  
sns.set()
random.seed(123)
np.random.seed(456)

[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


In [5]:
AUTO = tf.data.experimental.AUTOTUNE
def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
seed_it_all()

In [6]:
model_checkpoint = "roberta-large"
LOAD_FROM_PREV = False
PREPROCESSED_PATH = './'
SAVE_PATH = './'
MAX_LEN = 512 # pad values up to 512.
MAX_LBL = 60 # Max in lbl = 60# Longer than 300 is typically spanish or random text
OVERLAP = 20
train_corpus = None
val_corpus = None

DATASET_SYMBOL = '$' # this symbol represents a dataset name
NONDATA_SYMBOL = '#' # this symbol represents a non-dataset name

In [7]:
def freeze(layers):
  for layer in layers:
    for parameter in layer.parameters():
      parameter.requires_grad = False


# Load data

# Larger Sentences:
- No duplicates of Sentences, just all tokens at once.
- We have multi class labels, why not use them.

In [8]:
def merge_duplicates(csv):
    titles = set(csv.pub_title)
    df = {'pub_title': [], 'cleaned_label': []}
    for title in tqdm(titles):
        all_rows = csv[np.equal(csv.pub_title, title)]
        all_labels = []
        for row in all_rows.iterrows():
            row = row[1]
            labels = row.cleaned_label.split("|")
            all_labels += labels
            pub_title = row.pub_title
        all_labels = set(all_labels)
        df['pub_title'] += [pub_title]
        df['cleaned_label'] += ['|'.join(string for string in all_labels)]
    return pd.DataFrame(df)
        
            

In [9]:
def compute_all_labels(csv):
    labels = []
    for row in csv.iterrows():
        row = row[1]
        labels += row.cleaned_label.split("|")
    return labels
def get_labels(csv):
    labels = []
    for row in csv.iterrows():
        row = row[1]
        row_label = row.cleaned_label.split("|")
        labels += [row_label]
    return labels
def create_undirected_graph(csv):
    # Creates an Undirected Graph from the dataset
    labels = get_labels(csv)# (N, )
    unique_labels = compute_all_labels(csv)
    datasets_to_pub = {label: [] for label in unique_labels}
    
    pub_titles = csv.pub_title # (N, )
    # Create Mini Graph to create bigger graph
    for i in tqdm(range(len(pub_titles))):
        ex_label = labels[i]
        ex_title = pub_titles[i]
        for label in ex_label:
            datasets_to_pub[label] += [ex_title]

    graph = nx.DiGraph()
    graph.add_nodes_from(pub_titles)
    for label in tqdm(datasets_to_pub):
        # Double Loop
        titles = datasets_to_pub[label]
        for pub1_idx in range(len(titles)):
            for pub2_idx in range(pub1_idx, len(titles)):
                graph.add_edge(titles[pub1_idx], titles[pub2_idx])
    # Create Undirected Graph
    graph = graph.to_undirected()
    # compute Strongly Connected Components
    components = nx.strongly_connected_components(graph.to_directed())
    all_components = []
    for comp in components:
        all_components += [comp]
    return all_components
def grab_from_components(csv, components):
    # Restores a CSV from the Components
    return csv[csv.pub_title.isin(components)]
def create_train_test_splits(csv, components):
    # The Datasets are too intermingled to properly split, so we just chuck the two largest(4000 + 8000 = 12000) into train,
    # the 3rd largest(1000) into test, and split across the tiny ones(17 components, accumulated to 100 nodes)
    component_to_length = []
    for comp in components:
        component_to_length += [(comp, len(comp))]
    # Sort
    key_fn = lambda x: x[1]
    component_to_length = sorted(component_to_length, key = key_fn, reverse = True, )
    
    train_guarenteed = [t[0] for t in component_to_length[:3]]
    test_guarenteed = []
    to_split = [t[0] for t in component_to_length[3:]]
    splitter = KFold(shuffle = True, random_state = 42)
    
    FOLDS = []
    for train_idx, test_idx in splitter.split(to_split):
        train = train_guarenteed
        test = test_guarenteed
        for idx in train_idx:
            train += [to_split[idx]]
        for idx in test_idx:    
            test += [to_split[idx]]
        # Convert List of Sets to Single Set
        train_set = set()
        test_set = set()
        
        for comp in train:
            train_set.update(comp)
        for comp in test:
            test_set.update(comp)
        train_set = grab_from_components(csv, train_set)
        test_set = grab_from_components(csv, test_set)
        FOLDS += [(train_set, test_set)]
    return FOLDS
def get_splits(csv):
    # function that automates everything for you.
    print("Generating SCCs...................")
    all_components = create_undirected_graph(csv)
    print('generatings splits..........')
    return create_train_test_splits(csv, all_components)

In [10]:
if not LOAD_FROM_PREV:
  
  import copy
  # train
  train_path = '../input/coleridgeinitiative-show-us-the-data/train.csv'
  paper_train_folder = '../input/coleridgeinitiative-show-us-the-data/train/'

  train = pd.read_csv(train_path)
  # Group by publication, training labels should have the same form as expected output.
  train = train.groupby('Id').agg({
      'pub_title': 'first',
      'dataset_title': '|'.join,
      'dataset_label': '|'.join,
      'cleaned_label': '|'.join
  }).reset_index()    
  del train['dataset_label']
  del train['dataset_title']
  del train['Id']
  train = merge_duplicates(train)
  ORIG_FOLDS = get_splits(train)
  FOLDS = copy.deepcopy(ORIG_FOLDS)
  print('train size: ', len(train))

  all_datasets = pd.read_csv(train_path)
  # Group by publication, training labels should have the same form as expected output.
  all_datasets = all_datasets.groupby('Id').agg({
      'pub_title': 'first',
      'dataset_title': '|'.join,
      'dataset_label': '|'.join,
      'cleaned_label': '|'.join
  }).reset_index() 

  # Post Process Fold
  NEW_FOLDS = []
  for train, test in FOLDS:
      train = train.sort_values('pub_title').reset_index(drop = True)
      test = test.sort_values('pub_title').reset_index(drop = True)
      train_dataset = all_datasets[np.isin(all_datasets.pub_title, train.pub_title)].sort_values('pub_title').reset_index(drop = True)
      test_dataset = all_datasets[np.isin(all_datasets.pub_title, test.pub_title)].sort_values('pub_title').reset_index(drop = True)
      
      ## Add back in the ids
      train['Id'] = train_dataset.Id
      train['dataset_label'] = train_dataset.dataset_label
      train['dataset_title'] = train_dataset.dataset_title
      
      test['Id'] = test_dataset.Id
      test['dataset_label'] = test_dataset.dataset_label
      test['dataset_title'] = test_dataset.dataset_title
      
      NEW_FOLDS += [(train, test)]
      FOLDS = NEW_FOLDS

HBox(children=(FloatProgress(value=0.0, max=14271.0), HTML(value='')))


Generating SCCs...................


HBox(children=(FloatProgress(value=0.0, max=14271.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=130.0), HTML(value='')))


generatings splits..........
train size:  14271



# Prepare data for train MLM

### Auxiliary functions

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True, return_token_type_ids = True, return_attention_masks = True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=482.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…




In [15]:
def clean_paper_sentence(s):
    """
    This function is essentially clean_text without lowercasing.
    """
    s = re.sub('[^A-Za-z0-9]+', ' ', str(s).lower()).strip()
    s = re.sub(' +', ' ', s)
    count = 0
    start_span = 0
    spans_to_remove = []
    for idx, letter in enumerate(s):
        if letter == ' ':
            continue
        if letter.isdigit():
            if count == 0:
                start_span = idx
            count += 1
            
        else:
            if count > 10:
                # 10 Numbers in a row, likely tabular data, remove the entire string
                spans_to_remove += [(start_span, idx)]
                #print(s[start_span:idx])
            count = 0
    # Remove the spans, starting from the back 
    spans_to_remove.reverse()
    removed = False
    for idx1, idx2 in spans_to_remove:
        s = s[:idx1] + s[idx2:]
        removed = True
    
    return s

def shorten_sentences(sentences):
    """
    Sentences that have more than MAX_LENGTH words will be split
    into multiple sentences with overlappings.
    """
    sentences = sentences[1:-1]
    length = MAX_LEN - 2
    MAX_SENTS = 5000# Wayyyy more than the max anyways(200)
    short_sentences = np.zeros((MAX_SENTS, length), dtype = np.int32)
    cur_idx = 0
    MIN_WORDS = 25
    
    words = sentences
    if len(words) > length:
        for p in range(0, len(words), length - OVERLAP):
            new_words = words[p:p + length]
            padded_words = np.ones((length), dtype = np.int32) * tokenizer.pad_token_id
            if len(new_words) < MIN_WORDS:
                continue
            else:
                padded_words[:len(new_words)] = new_words
                short_sentences[cur_idx, :] = padded_words
                cur_idx += 1

    else:
        padded_words = np.ones((length), dtype = np.int32) * tokenizer.pad_token_id
        padded_words[:len(words)] = words
        short_sentences[cur_idx, :] = padded_words
        cur_idx += 1
    short_sentences = short_sentences[:cur_idx]
    
    return short_sentences

def find_sublist(big_list, small_list):
    """
    find all positions of $small_list in $big_list.
    """
    all_positions = []
    for i in range(len(big_list) - len(small_list) + 1):
        if small_list == big_list[i:i+len(small_list)]:
            all_positions.append(i)
    
    return all_positions

def jaccard_similarity_list(l1, l2):
    """
    Return the Jaccard Similarity score of 2 lists.
    """
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

connection_tokens = {'s', 'of', 'and', 'in', 'on', 'for', 'data', 'dataset'}
def find_negative_candidates(sentence, labels):
    """
    Extract negative samples for Masked Dataset Modeling from a given $sentence.
    A negative candidate should be a continuous sequence of at least 2 words, 
    each of these words either has the first letter in uppercase or is one of
    the connection words ($connection_tokens). Furthermore, the connection 
    tokens are not allowed to appear at the beginning and the end of the
    sequence. Lastly, the sequence must be quite different to any of the 
    ground truth labels (measured by Jaccard similarity).
    """
    def candidate_qualified(words, labels):
        while len(words) and words[0].lower() in connection_tokens:
            words = words[1:]
        while len(words) and words[-1].lower() in connection_tokens:
            words = words[:-1]
        
        return len(words) >= 5 and \
               all(jaccard_similarity_list(words, label) < 0.75 for label in labels)
    
    candidates = []
    
    phrase_start, phrase_end = -1, -1
    for id in range(1, len(sentence)):
        word = sentence[id]
        if word[0].isupper() or word in connection_tokens:
            if phrase_start == -1:
                phrase_start = phrase_end = id
            else:
                phrase_end = id
        else:
            if phrase_start != -1:
                if candidate_qualified(sentence[phrase_start:phrase_end+1], labels):
                    candidates.append((phrase_start, phrase_end))
                phrase_start = phrase_end = -1
    
    if phrase_start != -1:
        if candidate_qualified(sentence[phrase_start:phrase_end+1], labels):
            candidates.append((phrase_start, phrase_end))
    
    return candidates

In [16]:
def load_corpus(train):    
    corpus = []
    count = 0
    for paper_id, dataset_labels in tqdm(train[['Id', 'dataset_label']].itertuples(index=False)):
        labels = [clean_paper_sentence(label) for label in dataset_labels.split('|')]
        # Create a Full Dataset 
        
        with open(f'{paper_train_folder}/{paper_id}.json', 'r') as f:
            paper = json.load(f)
        content = '. '.join(section['text'] for section in paper)
        sentences = list(set([clean_paper_sentence(sentence) for sentence in [content]]))
        sentences = tokenizer(sentences)['input_ids'][0]
        
        sentences = shorten_sentences(sentences) # make sentences short, Each of these sentences store roughly 512 tokens(Pad or truncate)
        sentences = tokenizer.batch_decode(sentences)
        sentences = [sentence for sentence in sentences if len(sentence) > 10] # only accept sentences with length > 10 chars
        # FIND LABELS Per sentence 
        LABEL = []
        SENTENCES = []
        
        ZERO_SENTENCES = []
        ZERO_LABEL = []
        for sentence in sentences:
          lbl_sentence = []
          for lbl in labels:
            if lbl in sentence:
              lbl_sentence += [lbl]
          if len(lbl_sentence) == 0:
            # check if negative sample
            candidates = find_negative_candidates(sentence, labels)
            ZERO_SENTENCES += [sentence]
            ZERO_LABEL += [lbl_sentence]
          else:
            SENTENCES += [sentence]
            LABEL += [lbl_sentence]
        # Random Select 1.2 * sentences - Slightly more neg than pos(like in real dataset)
        num_pos = int(len(SENTENCES) * 1.2)
        try:
            indices = np.random.randint(0, high = len(ZERO_SENTENCES) - 1, size = num_pos)
            indices = list(set(indices.tolist()))
            for IDX in indices:
                SENTENCES += [ZERO_SENTENCES[IDX]]
                LABEL += [ZERO_LABEL[IDX]]
        except:
            pass
        NEW_SENTENCES = []
        NEW_LABELS = []
        for i in range(len(SENTENCES)):
          lbl = LABEL[i]
          sent = SENTENCES[i].split()
          processed_lbl = []
          for j, l in enumerate(lbl):
            if j == len(lbl) - 1:
              processed_lbl += l.split()
            else:
              processed_lbl += l.split() + ['|']
          corpus += [(sent, processed_lbl)]
          
        count += 1
        if count % 100 == 0:
            print(count)
    return corpus

### Save data to a file

In [17]:
if not LOAD_FROM_PREV:
  VAL_LABELS =[]
  FOLD_IDX = 0
  FOLDS = NEW_FOLDS
  FOLDS = FOLDS[FOLD_IDX]


  train, val = FOLDS
  # load corpus
  if train_corpus is None:
    train_corpus = load_corpus(train)
  if val_corpus is None:
    val_corpus = load_corpus(val)

  with open(f'train_mlm.json', 'w') as f:
      for idx in range(len(train_corpus)):
          sentence, label = train_corpus[idx] 
          row_json = {'text':sentence, 'label': label}
          json.dump(row_json, f)
          f.write('\n')
  # Map Dataset to Selected Corpus
  val_labels = []
  with open(f'val_mlm.json', 'w') as f:
      for idx in range(len(val_corpus)):
          sentence, label = val_corpus[idx]
          row_json = {'text':sentence, 'label': label}

          json.dump(row_json, f)
          f.write('\n')

  VAL_LABELS = val_labels

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Token indices sequence length is longer than the specified maximum sequence length for this model (102479 > 512). Running this sequence through the model will result in indexing errors


100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900
9000
9100
9200
9300
9400
9500
9600
9700
9800
9900
10000
10100
10200
10300
10400
10500
10600
10700
10800
10900
11000
11100
11200
11300
11400
11500
11600
11700
11800
11900
12000
12100
12200
12300
12400
12500
12600
12700
12800
12900
13000
13100
13200
13300
13400
13500
13600
13700
13800
13900
14000
14100
14200



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




# Fine-tune the Transformer

In [26]:
if not LOAD_FROM_PREV:
  datasets = load_dataset('json',
              data_files={
                  'train' : f'train_mlm.json'
              },
              keep_in_memory = True
  )
  val_dataset = load_dataset('json',
    data_files = {
        'test': f'val_mlm.json' 
    },
    keep_in_memory = True
  )
  datasets["train"][:5]

# Augment the Dataset

In [27]:
def fix_ending_tokens(sentence, token_type_ids, att_mask):
    # Find End Token
    end_token = np.argmax(np.equal(sentence, tokenizer.eos_token_id).astype(np.int32))
    # Replace this token with padding
    sentence[end_token] = tokenizer.pad_token_id
    # Find the first pad token
    pad_token = np.argmax(np.equal(sentence, tokenizer.pad_token_id).astype(np.int32))
    # Replace the first token with End
    sentence[pad_token] = tokenizer.eos_token_id
    att_mask[pad_token] = 1
    # Rest is padding 
    sentence[pad_token + 1:] = tokenizer.pad_token_id
    att_mask[pad_token + 1:] = 0


    # CHECK IF LEN IS > MAXLEN(TRUNCATE)
    if len(sentence) > MAX_LEN:
        sentence = sentence[:MAX_LEN]
        att_mask = att_mask[:MAX_LEN]
        token_type_ids = token_type_ids[:MAX_LEN]
        # check if last token is padding
        if sentence[-1] == tokenizer.pad_token_id:
            # Skip
            # Nothing changes
            pass
        else:
            # Replace Last token with End
            sentence[-1] = tokenizer.eos_token_id
            # Fix Att_mask
            att_mask[-1] = 1

    # CHECK IF LEN is < MAXLEN(PAD)
    if len(sentence) < MAX_LEN:
        padded_sentence = np.ones(MAX_LEN, dtype = sentence.dtype) * tokenizer.pad_token_id
        padded_sentence[:len(sentence)] = sentence 
        sentence = padded_sentence
        
        padded_att_mask = np.zeros(MAX_LEN, dtype = att_mask.dtype) 
        padded_att_mask[:len(att_mask)] = att_mask
        att_mask = padded_att_mask
        
        padded_token_type_ids = np.zeros(MAX_LEN, dtype = token_type_ids.dtype)
        padded_token_type_ids[:len(token_type_ids)] = token_type_ids
        token_type_ids = padded_token_type_ids
    return sentence, token_type_ids, att_mask

### Tokenize and collate data

BATCHES DON't WORK RN.

In [28]:
def tokenize_function(examples):
    
    values = tokenizer([" ".join(ex) for ex in examples['text']], return_attention_mask=True, return_token_type_ids=True)
    input_ids = values['input_ids']
    token_type_ids = values['token_type_ids']
    attention_mask = values['attention_mask']
    
    new_input_ids = np.zeros((len(input_ids), MAX_LEN), dtype = np.int32)
    new_token_type_ids = np.zeros((len(input_ids), MAX_LEN), dtype = np.int32)
    new_attention_mask = np.zeros((len(input_ids), MAX_LEN), dtype = np.int32)
    
    for b in range(len(input_ids)):
        i_id = np.array(input_ids[b])
        tti = np.array(token_type_ids[b])
        att_msk = np.array(attention_mask[b])
        
        i_id, tti, att_msk = fix_ending_tokens(i_id, tti, att_msk)
        
        new_input_ids[b, :] = i_id
        new_token_type_ids[b, :] = tti
        new_attention_mask[b, :] = att_msk
        
        
    values['input_ids'] = new_input_ids.tolist()
    values['token_type_ids'] = new_token_type_ids.tolist()
    values['attention_mask'] = new_attention_mask.tolist()
    
    encoded = tokenizer([" ".join(ex) for ex in examples['label']],  return_token_type_ids= False, return_attention_mask=False)
    values['label'] = encoded['input_ids']
    return values
if not LOAD_FROM_PREV:
  tokenized_train_dataset = datasets.map(tokenize_function, batched = True, num_proc = 1, remove_columns= ['text'])
  tokenized_val_dataset = val_dataset.map(tokenize_function, batched = True, num_proc = 1, remove_columns = ['text'])

In [29]:
def _bytes_feature(value, is_list=False):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    
    if not is_list:
        value = [value]
    
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _float_feature(value, is_list=False):
    """Returns a float_list from a float / double."""
        
    if not is_list:
        value = [value]
        
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value, is_list=False):
    """Returns an int64_list from a bool / enum / int / uint."""
        
    if not is_list:
        value = [value]
        
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


In [30]:
def serialize_tokenized(attention_mask, input_ids, label, token_type_ids):
    """
    Creates a tf.Example message ready to be written to a file from 4 features.

    Args:
        image (TBD): TBD
        other: Either the image_id or the target inchi
    
    Returns:
        A tf.Example Message ready to be written to file
    """
    # Create a dictionary mapping the feature name to the 
    # tf.Example-compatible data type.
    
    # PAD THE VALUE
    NUM_PAD = MAX_LEN - len(input_ids)
    PAD_ID = tokenizer.pad_token_id
    input_ids += [PAD_ID] * NUM_PAD
    token_type_ids += [0] * NUM_PAD
    attention_mask += [0] * NUM_PAD
    NUM_LBL_PAD = MAX_LBL - len(label)
    label += [PAD_ID] * NUM_LBL_PAD
 
    feature = {
        'attention_mask': _int64_feature(attention_mask, is_list = True),
        'input_ids': _int64_feature(input_ids, is_list = True),
        'label': _int64_feature(label, is_list = True),
        'token_type_ids': _int64_feature(token_type_ids, is_list = True) 
    }
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [31]:
def write_tfrecords(tokenized, save_path, shard_size = 100000):
  try:
    os.mkdir(f"{PREPROCESSED_PATH}{save_path}")
  except:
    pass
  dataset_length = len(tokenized)
  counts = []
  for i in range((dataset_length // shard_size ) + 1):
    start_idx = i * shard_size
    end_idx = (i + 1) * shard_size
    
    file_name = f"{PREPROCESSED_PATH}{save_path}{i}.tfrec"
    all_values = tokenized[start_idx: end_idx]
    
    attention_mask = all_values['attention_mask']
    input_ids = all_values['input_ids']
    label = all_values['label']
    token_type_ids = all_values['token_type_ids']
    with open(file_name, 'w') as file:
      pass
    with tf.io.TFRecordWriter(file_name) as writer:
        count = 0
        for idx in tqdm(range(len(input_ids))):
          att_mask = attention_mask[idx]
          i_id = input_ids[idx]
          lbl = label[idx]
          tti = token_type_ids[idx]
            
          ex = serialize_tokenized(att_mask, i_id, lbl, tti)
          writer.write(ex)
        print(count)
  return counts

In [32]:
try:
  os.mkdir(PREPROCESSED_PATH)
except:
  pass

In [33]:
TRAIN_COUNTS = write_tfrecords(tokenized_train_dataset['train'], 'train_tfrecords/')
VAL_COUNTS = write_tfrecords(tokenized_val_dataset['test'], 'test_tfrecords/')

HBox(children=(FloatProgress(value=0.0, max=56701.0), HTML(value='')))


0


HBox(children=(FloatProgress(value=0.0, max=167.0), HTML(value='')))


0
