In [238]:
import os
import sys
import configparser
from dataclasses import dataclass

import datasets
import torch

from torch import nn
from nltk import sent_tokenize
from datasets import load_from_disk
from tokenizers import Tokenizer
from transformers import BertModel, AutoTokenizer, AutoModel, BertTokenizerFast, BertPreTrainedModel
from torch.utils.data import DataLoader

# Loading

In [270]:
file_name = "../tokenizers/tokenizer_2_1000.json"

base_tokenizer = Tokenizer.from_file(file_name)
tokenizer = BertTokenizerFast(tokenizer_object=base_tokenizer)

additional_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [6]:
config = configparser.ConfigParser()
config.read('../config.ini')

root_dir = r"C:\Users\onurg\.cache\huggingface\datasets"
data_dir = "updated_wiki40b"

path = os.path.join(root_dir, data_dir, "long_small_dataset")

dataset = load_from_disk(path)
dataset

# Prototyping

In [12]:
small_dataset = dataset.select(range(100))

In [281]:
@dataclass
class Args:
    def __init__(self, max_sentence_len, max_document_len, per_device_train_batch_size, pretrained_model_path, upper_hidden_dimension, upper_nhead, upper_num_layers):
      self.max_sentence_len = max_sentence_len
      self.max_document_len = max_document_len
      self.per_device_train_batch_size = per_device_train_batch_size
      self.pretrained_model_path = pretrained_model_path
      self.upper_hidden_dimension = upper_hidden_dimension
      self. upper_nhead = upper_nhead
      self.upper_num_layers = upper_num_layers

args = Args(128,64, 8, "bert-base-uncased", 768, 8, 2)


In [141]:
def tokenize(example, tokenizer, args):

    def tokenize_helper(article, tokenizer, args):
        # TODO: comment
        sentences = [tokenizer.encode(sentence, add_special_tokens=False) for sentence in sent_tokenize(article)] 
        sentences = [sentence[:args.max_sentence_len - 2] for sentence in sentences]
        sentences = [[tokenizer.convert_tokens_to_ids("[CLS]")] + sentence + [tokenizer.convert_tokens_to_ids("[SEP]")] for sentence in sentences]

        sentence_lengths = [len(sentence) for sentence in sentences]
        # TODO: check for attention_mask ID
        mask = [[1]*sen_len for sen_len in sentence_lengths]

        return sentences, mask
 
    # TODO: make "end" dynamic
    for i in range(1, 3):

        example[f"article_{i}"], example[f"mask_{i}"] = tokenize_helper(example[f"article_{i}"], tokenizer, args)
        
    return example

In [254]:
new_dataset = small_dataset.map(tokenize, fn_kwargs={"tokenizer": additional_tokenizer, "args": args})

100%|██████████| 100/100 [00:02<00:00, 38.60ex/s]


## Data Colator

In [255]:
@dataclass
class CustomDataCollator:
    """ A data collator which can be used for dynamic padding, when each instance of a batch is a 
    list of lists. Each sentence is a list and each document (instance of a batch) contains multiple 
    sentences.
    """
    tokenizer: None
    max_sentence_len: int = 128
    max_document_len: int = 32
    return_tensors: str = "pt"
    

    def __call__(self, features: list) -> dict:
        batch = {}

        # TODO: make article number dynamic
        for article_number in range(1, 3):
            batch_sentences = list()
            batch_masks = list()
            
            sen_len_article = [len(sentence) for instance in features for sentence in instance[f"article_{article_number}"]]
            sen_len_mask = [len(sentence) for instance in features for sentence in instance[f"mask_{article_number}"]]
            
            assert sen_len_article == sen_len_mask, (
                f"There is a mismatch for article_{article_number} and mask_{article_number}."
                )
            
            sen_len = min(self.max_sentence_len, max(sen_len_article))
            
            doc_len_article = [len(instance[f"mask_{article_number}"]) for instance in features]
            doc_len = min(self.max_document_len, max(doc_len_article))
            for feature in features:
                sentences, masks = self.pad_sentence(sen_len, feature, article_number)
                self.pad_document(sentences, masks, doc_len)

                batch_sentences.append(sentences)
                batch_masks.append(masks)
            
            # TODO: decide on dtype for tensor, torch.int/torch.long?
            batch[f"article_{article_number}"] = torch.tensor(batch_sentences, dtype=torch.int64)
            batch[f"mask_{article_number}"] = torch.tensor(batch_masks, dtype=torch.int64)  
        return batch

    def pad_sentence(self, sen_len: int, feature: dict, i: int) -> tuple():
        sentences = [sentence + [self.tokenizer.convert_tokens_to_ids("[PAD]")] * (sen_len - len(sentence))  for sentence in feature[f"article_{i}"]]
        # TODO: check for attention_mask ID
        masks = [sentence + [0] * (sen_len - len(sentence))  for sentence in feature[f"mask_{i}"]]
        return sentences, masks

    def pad_document(self, sentences: list, masks: list, doc_len: int):
        mask_padding_array = [0 for i0 in range(len(masks[0]))]
        sentence_padding_array = [self.tokenizer.convert_tokens_to_ids("[PAD]") for i0 in range(len(sentences[0]))]

        if len(sentences) < doc_len:
            sentences += [sentence_padding_array for difference in range(doc_len - len(sentences))]
            masks += [mask_padding_array for difference in range(doc_len - len(masks))]
        elif len(sentences) > doc_len:
            sentences[:] = sentences[: doc_len]
            masks[:] = masks[: doc_len]


In [256]:
data_collator = CustomDataCollator(tokenizer=additional_tokenizer, 
                                   max_sentence_len=args.max_sentence_len, 
                                   max_document_len=args.max_document_len)

train_dataloader = DataLoader(
    new_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
)

In [284]:
class LowerEncoder(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.init_weights()
        # TODO: post_init or init_weights?
        # self.post_init()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        model_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # TODO: comment
        output = model_output['last_hidden_state'][:,0,:]
        return output


class HiearchicalModel(nn.Module):
    def __init__(self, args, **kwargs):
        super().__init__()
        # TODO: from pretrained or config
        self.lower_model = LowerEncoder.from_pretrained(args.pretrained_model_path)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=args.upper_hidden_dimension,
                                                        nhead=args.upper_nhead,
                                                        batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer,
                                                         num_layers=args.upper_num_layers)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        input_ids = input_ids.permute(1, 0, 2)  # (sentences, batch_size, words)
        attention_mask = attention_mask.permute(1, 0, 2) 
        lower_encoded = []
        
        for i_i, a_m in zip(input_ids, attention_mask):
            lower_encoded.append(self.lower_model(i_i, a_m))
            
        # TODO: add document level [CLS]
        
        lower_output = torch.stack(lower_encoded)  # (sentences, batch_size, hidden_size)
        lower_output = lower_output.permute(1, 0, 2)  # (batch_size, sentences, hidden_size)
        upper_output = self.transformer_encoder(lower_output) # (batch_size, sentences, hidden_size)
        upper_output = upper_output[:,0,:] # (batch_size, hidden_size)

        return upper_output


upper_model = HiearchicalModel(args)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing LowerEncoder: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing LowerEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LowerEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [285]:
for batch in train_dataloader:
    
     output = upper_model(input_ids= batch["article_1"],
                        attention_mask=batch["mask_1"])

torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([8, 64, 116])
torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([8, 43, 128])
torch.Size([8, 46, 128])
torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([8, 64, 128])
torch.Size([4, 64, 128])
