<a href="https://colab.research.google.com/github/faraway1nspace/AnathemTransformer/blob/main/dev/notebooks/dev_anathem_transformer_base_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Development Notebook: build and test base layers for Anathem Transformer (aka Silo'd Transformer)

### Notes
- the google-minature models have the same vocab size and heads as bert-large-ucased
- the minature-google papers discusses the classification and distallation tasks & corpus's including:
    - *NLI* (Natural language inference involves classifying pairs of sentences (a premise and a hypothesis) as entailment, contradiction, or neutral. This task is representative of the scenario in which proxy data is non-trivial to gather (Gururangan et al., 2018). We chose MNLI (Williams et al., 2018) as our target dataset. Since strictly in-domain data is difficult to obtain, we supplement DT with two other sentence-pair datasets: SNLI (Bowman et al., 2015) and QQP (Chen et al., 2018).
    - *sentiment analysis* -
- the MTEB leader best model is e5-large (24 layers) which uses the CLS token. It is also "instruction fine-tuned", requiring query and passage prefixes.
- distillation example: https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker/blob/master/knowledge-distillation.ipynb
    - they set temperature to 2: which results in a flatter probability distribution. I could make this dynamic -> start 0.5 progress to 1
    - they set alpha to 0.5, which balances label-loss vs distil-loss

#### Loss MLM - hf example:
- https://github.com/huggingface/transformers/blob/601ac5b1dc1438f00d09696588f2deb0f045ae3b/src/transformers/modeling_bert.py#L1001-L1004
    - notice that when initializing CrossEntropyLoss, the ignore index is -100, so, when I make the masked-token objective, I can compute the loss by masking out all -100?


#### DataCollator for Masked MLM - hf example
- https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/data/data_collator.py#L607


# Dataset specifics

### From the Google mini-architectures:
- with labels: Williams 2018 (NLI-task): citation: https://aclanthology.org/N18-1101/; available at https://huggingface.co/datasets/multi_nli  
    - how should I process these? [sep] or sentence pairs? or both?
    - I could do sentence-pairs for teaching & labels, I guess (why not)
    - I could also include concatenated text, stricly with labels (what would be the point of this though? Better sub-sectioning the input data, not so much a sentence-vector thing
- with no-labels, used for teaching: Since strictly in-domain data is difficult to obtain, we supplement DT with two other sentence-pair datasets: SNLI (Bowman et al., 2015) and QQP (Chen et al., 2018).

### 1) MLM Tasks
- Pile (multi-domain, books, wiki, law, and more) - curate and remove twitter  
    - see urls at: https://github.com/EleutherAI/the-pile/blob/master/the_pile/datasets.py
    - https://the-eye.eu/public/AI/pile_preliminary_components/
- Supplements to pile:  
    - https://huggingface.co/datasets/him1411/EDGAR10-Q - numeric filings
    - eloukas/edgar-corpus - annual reports (but it is in weird sections)
    - LEDGAR .jsonl https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A - this can be streamed too
    - Pile of Law - https://huggingface.co/datasets/pile-of-law/pile-of-law - but cannot be streamed
- JanosAudran/financial-reports-sec - SEC financial reports in small sentences
- RefinedWeb - a competitor to Pile, curated common-crawl - https://arxiv.org/abs/2306.01116
- CNN_dailymail? ag_news?

### A) Retrieval Tasks
In general, what loss would I use for the QA & retrieval tasks? Distillation is obvious, but what about
- SQUAD - has QA pairs - squad_v2
    - good for distillation
- ORCA - has GPT-like prompting QA pairs: https://huggingface.co/datasets/Open-Orca/OpenOrca/viewer/Open-Orca--OpenOrca/train?row=29
- Simple-Wiki https://huggingface.co/datasets/embedding-data/simple-wiki - has paraphrases
- embedding-data/coco_captions_quintets - multiple captions as paraphrases
- embedding-data/simple-wiki - pairs of paraphrases from wikipedia
- embedding-data/SPECTER - triplets of {anchor, pos, neg}, small headline-like snippets in technical /statistical /science fields
- https://huggingface.co/embedding-data - has a lot of retrieval tasks
- LLukas22/scidocs - titles and abstracts
- LEDGAR - can possible do triplets on same label
- Rahmaa/ElsevieR_ClEaN - possible relation between title and abstract
- embedding-data/WikiAnswers - 25 question paraphrases (maybe no answers)

### B) QA Tasks
- squad_2
- WikiHow - used by S-BERT (questions and articles) - needs to be manually downloaded - https://github.com/mahnazkoupaee/WikiHow-Dataset/
- trivia_qa - 680 question, ans, evidence triplets. But, the context strings are very long (like wikipedia) and the questions are almost pop culture
- LLukas22/fiqa - financial QA, like conversations
- embedding-data/WikiAnswers - question-duplicates as paraphrases
- embedding-data/QQP_triplets - question-duplicates plus negatives (Quora)
- LLukas22/lfqa_preprocessed - question and answers 226k
- gbharti/finance-alpaca (like FIQA - finance Q&A)
- embedding-data/PAQ_pairs - wikipedia question & answers
- the_pile_stack_exchange - single texts, but can be split into question, answer
- cais/mmlu - multiple choice, but some of the answers are longers (need to filter)
- sciq - science questions - see question and support
- wiki_qa - wikipedia QA
- qasc - high-school questions - can combine the "facts" into a support
- pubmed_qa - science QA with answers
- EnglishDictionary - auto convert "What is the definition of X'?

## C) NER tasks
- tner/ontonotes5 - has > 12 entities and 59.9k
- tner/multinerd - 23 entiteis and 157k test set - see also tner/wikineural which has a 98.8k training set?
-


# Teacher Models

## Embeddings
Mteb leaderboard

- instructor-xl / large - this does best, but it prepends instructions that are domain specific (like science this, or wikipedia that.... it could be possible to do that with the Pile dataset, possible) https://huggingface.co/hkunlp/instructor-xl
- https://huggingface.co/intfloat/e5-large-v2 - winner otherwise






#### Playing Around with novel architectures

In [None]:
%pip install torch transformers datasets zstandard rank_bm25 langdetect
#%pip install langdetect
from langdetect import detect



Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.0-py3-none-any.whl (492 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.2/492.2 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting zstandard
  Downloading zstandard-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!

In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch.utils.data import DataLoader, DataSet
from typing import List, Optional
from torch import nn
import torch.nn.functional as F
from torch.cuda import is_available
if is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
import copy

model_string = 'google/bert_uncased_L-12_H-512_A-8' # 'distilroberta-base
tokenizer = AutoTokenizer.from_pretrained(model_string)
basemod = AutoModel.from_pretrained(model_string)
basemod.to(device)

ImportError: ignored

In [None]:
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

In [None]:
from transformers import BertTokenizer


class CustomTokenizer:
    def __init__(self, model_string='google/bert_uncased_L-12_H-512_A-8', n_cls_prepend = 4, n_pad_to_multiple_of=4):
        self.base_tokenizer = AutoTokenizer.from_pretrained(model_string)
        self.n_cls_prepend = n_cls_prepend
        self.n_pad_to_multiple_of = n_pad_to_multiple_of
        for k in dir(self.base_tokenizer):
            if not (k[0]=='_' or k=='tokenize' or k=='encode' or k=='build_inputs_with_special_tokens' or k == 'batch_encode_plus'):
                setattr(self,k,getattr(self.base_tokenizer, k))

    def __call__(self, text, pad_to_multiple_of=None, add_special_tokens = True, return_tensors=None, *args, **kwargs):
        if pad_to_multiple_of is None:
            pad_to_multiple_of = self.n_pad_to_multiple_of

        # run through base tokenizer
        tokens = self.base_tokenizer(
            text,
            pad_to_multiple_of=(pad_to_multiple_of if not add_special_tokens else False),
            add_special_tokens=add_special_tokens,
            return_tensors=return_tensors if (not add_special_tokens) else None,
            *args,
            **kwargs
        )
        if add_special_tokens:
            tokens = self._prepend_extra_cls_tokens_because_of_maxpooling(tokens, return_tensors)

        return tokens

    def _num_pad_tokens(self, token_list):
        """Calculates how many PAD tokens to append to sequence to make a multiple of X"""
        return (self.n_pad_to_multiple_of - ((len(token_list)+(self.n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of

    def _prepend_extra_cls_tokens_because_of_maxpooling(self, tokens, return_tensors=None):
        n_cls_prepend = self.n_cls_prepend
        # prepend (n-1) CLS tokens to the front of the token_ids (because of maxpooling)
        # also pad so that the total length is a multiple of n_cls_prepend
        #num_pad_tokens = (self.n_pad_to_multiple_of - ((len_tokens+(n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of
        tokens['input_ids'] = [
            [self.cls_token_id]*(n_cls_prepend-1)+input_id + [self.pad_token_id]*self._num_pad_tokens(input_id)
            for input_id
            in tokens['input_ids']
        ]
        tokens['attention_mask'] = [
            [1]*(n_cls_prepend-1)+attnmask +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        if 'token_type_ids' in tokens.keys():
            tokens['token_type_ids'] = [
                [toktypeid[0]]*(n_cls_prepend-1)+toktypeid +[toktypeid[-1]]*self._num_pad_tokens(toktypeid)
                for toktypeid
                in tokens['token_type_ids']
            ]
        if return_tensors == 'pt':
            for k,v in tokens.items():
                tokens[k] = torch.LongTensor(v)
        return tokens

    def encode(self, text, pad_to_multiple_of=4, add_special_tokens = True, *args, **kwargs):
        encoded = self.base_tokenizer.encode(text, pad_to_multiple_of=False, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            encoded = [self.cls_token_id]*(pad_to_multiple_of-1) + encoded
        if bool(pad_to_multiple_of):
            num_pad_tokens = (pad_to_multiple_of - (len(encoded) % pad_to_multiple_of)) % pad_to_multiple_of
            encoded += [self.pad_token_id] * num_pad_tokens
        return encoded

    def tokenize(self, text, add_special_tokens=True, *args, **kwargs):
        toks = self.base_tokenizer.tokenize(text, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            toks = [self.cls_token] * (self.n_cls_prepend-1) + toks
        return toks

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        out = self.base_tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
        return [self.cls_token_id]*3 + out

    def batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
        batched_encoded = self.base_tokenizer.batch_encode_plus( batch_text_or_text_pairs, *args, **kwargs)
        #batched_encoded.update({'foo':'bar'})
        return batched_encoded



# Note, if I use the vanilla LineByLineTextDataset, it just calls tokenizer.__call__ turns on the `use_special_tokens`, and it pads to a multiple of optional
# .. so somehow I need to ensure that, whatever base function it calls as part of the tokenizer pipeline, it will continue using MY new function
# the tokenizer.__call__ DOES NOT use `encode` nor `tokenize` otherwise my modifications would manifest
# looks like `prepare_for_model` (and maybe `batch_prepare_for_model`) is what adds special tokens?
# looks like `prepare_for_model` just calls `build_inputs_with_special_tokens`, so maybe intervene there?
#         if add_special_tokens:
#            sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
#            token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
# editing `build_inputs_with_special_tokens` didn't work either

# FOOFU:
# see how .pad works: https://github.com/huggingface/transformers/blob/c5454eba9eac00a3e7d0a46a3d25aacd43187f1e/src/transformers/tokenization_utils_base.py#L2887
# notice the `self.model_input_names[0]` list for a tokenizer -> I should update this for my unique inputs
# ... and there is also a ._pad function

ModuleNotFoundError: ignored

In [None]:
tokenizer2 = CustomTokenizer()
tokenizer2.pad_token_id

In [None]:
#toks = tokenizer2.encode(text[0], add_special_tokens=True)
#print(len(toks)) # works
#print(toks[:10])

tokens = tokenizer2(text, padding='longest', return_tensors=None) # doesn't work, obviously
#print(tokens)
print(len(tokens['input_ids'][0]))
print(len(tokens['attention_mask'][0]))

print(len(tokens['input_ids'][1]))
print(len(tokens['attention_mask'][1]))

tokens

#tokenizer2.batch_encode_plus(text, add_special_tokens=True) # doesn't work


In [None]:
dir(basemod)
# base embedding layers
layer_emb = copy.deepcopy(basemod._modules['embeddings'])


In [None]:
# base trasnformers (full)
layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

In [None]:
# text
text = [
    "A standard indemnity clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with legal issues1.",
    "It usually consists of two elements: a trigger event or circumstance and a payment obligation2. The trigger event or circumstance is the breach of the agreement, willful misconduct, or negligence of the indemnifying party or its affiliates"
]

import math

#padding_length = int(math.ceil(max_length / 4)) * 4
tokens = tokenizer(text,padding=True, return_tensors='pt', pad_to_multiple_of=4)
input_shape = tokens['input_ids'].size()

# change token padding to be multiple of 4
#ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
#if input_shape[-1]!=ideal_length:
#  tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
#  input_shape = tokens['input_ids'].size()

token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
tokens['token_type_ids'] = token_type_ids
past_key_values_length =0

# need to extend attention mask
extended_attention_mask = basemod.get_extended_attention_mask(tokens['attention_mask'], input_shape)
tokens['extended_attention_mask'] = extended_attention_mask
print(tokens.keys())
print(tokens['input_ids'].shape)


In [None]:
silo_dimensions = {0:basemod.config.hidden_size,
                  1:basemod.config.hidden_size//2,
                  2:basemod.config.hidden_size//4,
                  }
reintegration_dim = silo_dimensions[1] + silo_dimensions[2]


NameError: ignored

In [None]:
embedding_output = layer_emb(
            input_ids=tokens['input_ids'],
            position_ids=tokens.get('position_ids',None),
            token_type_ids=tokens['token_type_ids'],
            inputs_embeds=None,
            past_key_values_length=past_key_values_length
)
print(embedding_output.shape)

NameError: ignored

In [None]:
# basemodel transformer outputs: *full bert model
out_l1 = layer_basetransformer(
    hidden_states = embedding_output,
    attention_mask = tokens['extended_attention_mask'],#tokens['attention_mask'],
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    #past_key_values=0,
    #use_cache=None,
    output_attentions=True,
    #output_hidden_states=True,
    #return_dict=True
)

hidden_states_l1 = out_l1[0]
self_attention_l1 = out_l1[1]

NameError: ignored

In [None]:
# Next Layer:
# Query -> max pool and reduce  hidden dimension // 2
# Key -> reduce hidden_dim // 2
# value -> reduce hidden_dim //2
#maxpool_l2 = nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

maxpool_l2 = nn.Sequential(
    nn.Dropout(0.05),
    nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True),
)

maxpool_l2_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

In [None]:
# reduce dimension of hidden states
hiddens_states_l1_reduced = maxpool_l2(hidden_states_l1)
print(hidden_states_l1.shape)
print(hiddens_states_l1_reduced.shape)

# reduce dimension of attention mask
attention_mask_l1_reduced = maxpool_l2_attn(tokens['attention_mask'].float())
print(attention_mask_l1_reduced.shape)

# extend the dimension of the reduced attention_mask
print(input_shape)
extended_attention_mask_l1_reduced = basemod.get_extended_attention_mask(attention_mask_l1_reduced, attention_mask_l1_reduced.shape)
print(tokens['extended_attention_mask'].shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 48, 768])
torch.Size([2, 24, 768])
torch.Size([2, 24])
torch.Size([2, 48])
torch.Size([2, 1, 1, 48])
torch.Size([2, 1, 1, 24])


In [None]:
# Try to do Multi Headed attenion with differently sized query and value

In [None]:
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
import copy

class BertSelfAttnDimensionReduction(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self,
        config,
        hidden_size_input=768,
        hidden_size_query = None,
        position_embedding_type=None,
        dim_reduction = 2
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()
        if (config.hidden_size // dim_reduction) % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.dim_reduction = dim_reduction
        self.hidden_size_input = hidden_size_input
        self.hidden_size_reduced = hidden_size_input // dim_reduction
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size_reduced / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_input, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_input, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if encoder_attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            #print(attention_scores.shape)
            #print(attention_scores.shape)
            attention_scores = attention_scores + encoder_attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

bertlayer_l2_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size,
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

bertlayer_l3_reduction = BertSelfAttnDimensionReduction(
    config=basemod.config,
    hidden_size_input=basemod.config.hidden_size // 2,
    position_embedding_type=basemod.config.position_embedding_type,
    dim_reduction = 2
)

In [None]:
out_l2 = bertlayer_l2_reduction(
        hidden_states = hiddens_states_l1_reduced,
        attention_mask = extended_attention_mask_l1_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l1,
        encoder_attention_mask= tokens['extended_attention_mask'],
        past_key_value=None,
        output_attentions=False
    )
hidden_states_l2 = out_l2[0]
print(hidden_states_l2.shape)

torch.Size([2, 24, 384])


In [None]:
# Next dimension reduction:
hiddens_states_l2_reduced = maxpool_l2(hidden_states_l2)
print(hidden_states_l2.shape)
print(hiddens_states_l2_reduced.shape)

# reduce dimension of attention mask
attention_mask_l2_reduced = maxpool_l2_attn(attention_mask_l1_reduced.float())
print(attention_mask_l2_reduced.shape)

# extend the dimension of the reduced attention_mask
extended_attention_mask_l2_reduced = basemod.get_extended_attention_mask(attention_mask_l2_reduced, attention_mask_l2_reduced.shape)
print(extended_attention_mask_l2_reduced.shape)

if True:
  out_l3 = bertlayer_l3_reduction(
        hidden_states = hiddens_states_l2_reduced, # input has been maxpooled
        attention_mask = extended_attention_mask_l2_reduced,
        head_mask=None,
        encoder_hidden_states = hidden_states_l2,
        encoder_attention_mask= extended_attention_mask_l1_reduced,
        past_key_value=None,
        output_attentions=False
    )
  hidden_states_l3 = out_l3[0]
  print(hidden_states_l3.shape)


# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

torch.Size([2, 24, 384])
torch.Size([2, 12, 384])
torch.Size([2, 12])
torch.Size([2, 1, 1, 12])
torch.Size([2, 12, 192])


In [None]:
# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times

config_lowres_encoder = copy.deepcopy(basemod.config)
config_lowres_encoder.hidden_size = config_lowres_encoder.hidden_size//4
config_lowres_encoder.num_hidden_layers = 3
print(config_lowres_encoder)

# The outputs of the bertlayer_l3_reduction can now run through a usual BertLayer for 3 times
encoder_lowres = BertEncoder(config_lowres_encoder)

RobertaConfig {
  "_name_or_path": "distilroberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 192,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.29.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}



In [None]:
out_encoder_lowres = encoder_lowres(
    hidden_states=hidden_states_l3,
    attention_mask=extended_attention_mask_l2_reduced,
    head_mask = None,
    return_dict=True,
)
hidden_states_lowres = out_encoder_lowres[0]
print(hidden_states_lowres.shape)

torch.Size([2, 12, 192])


In [None]:
## Upresolution Layer: up-resolution from dim-3 to dim-2 is as follows:
# hs_l3 -> upsampled sequence-length as hs-l2
# -> could have another attention-based mechanism that expands dimension of hs-l2

class InterpolateCombo(nn.Module):
    """there could also be an attentive way to do this"""
    def __init__(self, scale_factor=2, dropout=0.05, alpha=0.667):
        """Arguments:
        :param scaler_factor: float, multiple of up-scaling
        :param dropout: float, dropout proportion
        :param alpha: float, mixture weight between nearest-neighbor vs linear-interpolation
        """
        super(InterpolateCombo, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.a = alpha

    def forward(self, x):
        x_trans = x.transpose(-2,-1)
        z = self.a*self.interp(x_trans, mode='nearest',scale_factor=self.scale_factor) + (1-self.a)*self.interp(x_trans, mode='linear',scale_factor=self.scale_factor)
        z = self.dropout(z)
        return z.transpose(-2,-1)

#hidden_states_upscaled_3to2_nearest = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='nearest').transpose(-2,-1)
#hidden_states_upscaled_3to2_linear = nn.functional.interpolate(hidden_states_rowres.transpose(-2,-1), scale_factor=2, mode='linear').transpose(-2,-1)

upscaler_x2 = InterpolateCombo(scale_factor=2)

In [None]:
hidden_states_upscaled3to2 = upscaler_x2(hidden_states_lowres)


In [None]:
## BertAttentiveIntegrator

class BertCrossAttention(nn.Module):
    def __init__(
        self,
        config,
        hidden_size,
        hidden_size_query,
        hidden_size_keyvalue=None,
        position_embedding_type=None
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        if hidden_size_keyvalue is None:
            hidden_size_keyvalue = hidden_size
        self.hidden_size_keyvalue = hidden_size_keyvalue
        if self.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

In [None]:
bertlayer_l3_to_l2_crossattn = BertCrossAttention(
        config=basemod.config,
        hidden_size=silo_dimensions[1],
        hidden_size_query=silo_dimensions[2],
        position_embedding_type=None
    )

In [None]:
print(hidden_states_upscaled3to2.shape)
print(hidden_states_l2.shape)
print(attention_mask_l1_reduced.shape)
print(extended_attention_mask_l1_reduced.shape)

torch.Size([2, 24, 192])
torch.Size([2, 24, 384])
torch.Size([2, 24])
torch.Size([2, 1, 1, 24])


In [None]:
out_l2_postencode = bertlayer_l3_to_l2_crossattn(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
hidden_states_l2_postencode = out_l2_postencode[0]
print(hidden_states_l2_postencode.shape)
assert hidden_states_l2_postencode.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [None]:
print(basemod.config.hidden_size)
print(basemod.config.intermediate_size)
print(basemod.config.intermediate_size/basemod.config.hidden_size)

768
3072
4.0


In [None]:
# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertIntegrativeLayer(nn.Module):
    """Vanilla Bert Layer, but integrates other hiddens states from a parallel transformers stack typically low-re"""
    def __init__(
            self,
            config,
            hidden_size,
            hidden_size_query,
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        self.hidden_size_concat = int(hidden_size + hidden_size_query)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertCrossAttention(
            config,
            hidden_size,
            hidden_size_query,
            position_embedding_type="absolute"
        )
        self.is_decoder = config.is_decoder
        #self.intermediate = BertIntermediate(config)
        #self.output = BertOutput(config)
        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # cross attn between hiddens states and (low-res) query vector
        cross_attn_outputs = self.attention(
            hidden_states = hidden_states,
            attention_mask = attention_mask,
            head_mask = head_mask,
            query_hidden_states = query_hidden_states,
            query_attention_mask = query_attention_mask
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [None]:

# from low-res to mid-res
bert_integrative_layer_midres = BertIntegrativeLayer(
    basemod.config,
    hidden_size=silo_dimensions[1],
    hidden_size_query=silo_dimensions[2],
    intermediate_size=silo_dimensions[1]*4,
)

# from mid-res to high-res
bert_integrative_layer_hires = BertIntegrativeLayer(
    basemod.config,
    hidden_size=silo_dimensions[0],
    hidden_size_query=reintegration_dim,
    intermediate_size=silo_dimensions[0]*4,
)

In [None]:
hidden_states_midres = bert_integrative_layer_midres(
    hidden_states = hidden_states_l2,
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled3to2,
    query_attention_mask = attention_mask_l1_reduced
)
print(hidden_states_midres.shape)
assert hidden_states_midres.shape == hidden_states_l2.shape

torch.Size([2, 24, 384])


In [None]:
# upscale the l2 and l3 to the full dimension
upscaler_x4 = InterpolateCombo(scale_factor=4)
hidden_states_upscaled3to1 = upscaler_x4(hidden_states_lowres)
hidden_states_upscaled2to1 = upscaler_x2(hidden_states_midres)

hidden_states_upscaled = torch.cat(
    (hidden_states_upscaled2to1, hidden_states_upscaled3to1),
    axis=2)

print(hidden_states_upscaled.shape)

torch.Size([2, 48, 576])


In [None]:
# final layer to bring it up to full dimension
hidden_states_hires = bert_integrative_layer_hires(
    hidden_states = hidden_states_l1,
    attention_mask = extended_attention_mask,
    head_mask = None,
    query_hidden_states = hidden_states_upscaled,
    query_attention_mask = extended_attention_mask
)
print(hidden_states_hires.shape)
assert hidden_states_hires.shape == hidden_states_l1.shape

torch.Size([2, 48, 768])


In [None]:
hidden_states_hires.shape

torch.Size([2, 48, 768])

In [None]:
attention_mask_l1_reduced.shape

torch.Size([2, 24])

### The Reduce and Integrate layer:
- this is like a Transformer block, but:
- does dimension reduction along sequence and embedding-dim
- includes a skip connection from previous hidden-states of the same dimension

In [None]:



# this is the layer that just does cross-attention between a seq-reduced query and full-size value and key


"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)


BertReduceAddIntegrativeLayer
inputs = x_l1, x_l1_reduced, x_l2_prev
- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
- x3 = lnorm(drop(f(x2)) + x_l2_prev)
- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertReduceAddIntegrativeLayer(nn.Module):
    """Bert Layer that does dimenion reduction along embedding-dimenion and integrations a skip connection"""
    def __init__(
            self,
            config,
            hidden_size,
            hidden_size_input=None,
            hidden_size_query=None,
            intermediate_size=None,
            dim_reduction=2,
            do_concat_hidden_and_query = True
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.do_concat_hidden_and_query = do_concat_hidden_and_query
        assert bool(do_concat_hidden_and_query), 'not implemented: concatenation of query and hidden-states must happen'
        self.hidden_size = hidden_size
        if dim_reduction is None:
            dim_reduction = 2
        self.dim_reduction = dim_reduction
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        if hidden_size_input is None:
            hidden_size_input = hidden_size
        self.hidden_size_input = hidden_size_input
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query + do_concat_hidden_and_query*hidden_size
        self.hidden_size_concat = int(hidden_size + hidden_size_input)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertSelfAttnDimensionReduction(
            config,
            hidden_size_input=self.hidden_size_input,
            hidden_size_query = self.hidden_size_query,
            position_embedding_type="absolute",
            dim_reduction = self.dim_reduction
        )
        self.is_decoder = config.is_decoder
        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        inputs: torch.Tensor, # higher-resolution inputs for key and values (long sequence dimension)
        hidden_states: torch.Tensor, # previous hidden-states for skip connection (short squence-dim, low-res)
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: torch.FloatTensor = None, # hidden-states for query (short squence-dim, low-res)
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        if self.do_concat_hidden_and_query:
            query_hidden_states_plus = torch.cat((query_hidden_states, hidden_states),axis=2)
        # cross attn between (low-res) query vector and (high-res) key-values
        cross_attn_outputs = self.attention(
            query_hidden_states_plus, # query (short seq-dim, high-res)
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states = inputs, # for key/value (longer sequence dimension, high-res)
            past_key_value=past_key_value,
            output_attentions=output_attentions,
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states


In [None]:
# initialize the mid-resolution BertReduceAndIntegrate layer
bert_reduce_add_integrate_midres = BertReduceAddIntegrativeLayer(
    config,
    hidden_size = silo_dimensions[1], # size of mid-res
    hidden_size_input=silo_dimensions[0],
    hidden_size_query=silo_dimensions[0],
    intermediate_size=silo_dimensions[1]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

bert_reduce_add_integrate_lowres = BertReduceAddIntegrativeLayer(
    config,
    hidden_size = silo_dimensions[2], # size of mid-res
    hidden_size_input=silo_dimensions[1],
    hidden_size_query=silo_dimensions[1],
    intermediate_size=silo_dimensions[2]*3,
    dim_reduction=2,
    do_concat_hidden_and_query = True
)

In [None]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_hires_reduced = maxpool_l2(hidden_states_hires)
assert hidden_states_hires_reduced.shape[1] == hidden_states_midres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres.shape)
hidden_states_midres = bert_reduce_add_integrate_midres(
    inputs = hidden_states_hires, # from highres outputs previous layer (key, values)
    hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
    attention_mask = extended_attention_mask_l1_reduced,
    head_mask=None,
    query_hidden_states = hidden_states_hires_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
)
print(hidden_states_midres.shape)

torch.Size([2, 24, 384])
torch.Size([2, 24, 384])


In [None]:
# Reduce sequence-dim from l1->l2, and from high-res->mid-res
hidden_states_midres_reduced = maxpool_l2(hidden_states_midres)
assert hidden_states_midres_reduced.shape[1] == hidden_states_lowres.shape[1] # reduced-seq-dim should be same as mid-res hidden-states
print(hidden_states_midres_reduced.shape)

if True:
  print(hidden_states_lowres.shape)
  hidden_states_lowres = bert_reduce_add_integrate_lowres(
      inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
      hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
      attention_mask = extended_attention_mask_l2_reduced,
      head_mask=None,
      query_hidden_states = hidden_states_midres_reduced # reduced version of high-res inputs (reduced along sequence dimenion)
  )
  print(hidden_states_lowres.shape)

torch.Size([2, 12, 384])
torch.Size([2, 12, 192])
torch.Size([2, 12, 192])


In [None]:
try:
    from transformers.modeling_utiles import get_extended_attention_mask
except:
    def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: device) -> torch.Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

### Base-Layer nn.Module

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
from torch import nn
from torch import Tensor

from transformers.models.bert.modeling_bert import BertEncoder
from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union

def make_config(
    modelstring = "distilroberta-base",
    num_transformer_stacks = 2, # number of transformer stacks
    scale_ratio2 = 0.5, # reduce sequence-length by X, from high-res to mid-res
    scale_ratio3 = 0.25, # reduce sequence-length by Y, from high-res to low-res
    multipler_intermediate2 = 4.0, # intermeidate size is a multiple of hidden size
    multipler_intermediate3 = 4.0, # intermeidate size is a multiple of hidden size
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05, # dropout when performing downscaling from one-sequence length to next
    use_cheap_integrator_for_stacks = [],
    do_mlm=False,# whether to output MLM token predictions
    do_cls=False,# whether to output a pooled sentence-vector for sequence classification
):
    #if True:
    #modelstring = "distilroberta-base"
    #scale_ratio2 = 0.5
    #scale_ratio3 = 0.25
    #scale_intermediate2 = 4
    #scale_intermediate3 = 4
    base_config = AutoConfig.from_pretrained(modelstring)
    config_l2 = copy.deepcopy(base_config)
    config_l3 = copy.deepcopy(base_config)
    setattr(base_config,'model_string', modelstring)
    setattr(base_config,'num_transformer_stacks',num_transformer_stacks)
    setattr(base_config,'num_layers_l2', num_layers_l2)
    setattr(base_config,'num_layers_l3', num_layers_l3)
    setattr(base_config,'scale_ratio2', scale_ratio2)
    setattr(base_config,'scale_ratio3', scale_ratio3)
    setattr(base_config,'scale_factor2', int(1/base_config.scale_ratio2))
    setattr(base_config,'scale_factor3', int(1/base_config.scale_ratio3*base_config.scale_ratio2))
    setattr(base_config,"hidden_size_l2", int(base_config.hidden_size * scale_ratio2))
    setattr(base_config,"hidden_size_l3", int(base_config.hidden_size * scale_ratio3))
    setattr(base_config,"intermediate_size_l1", int(base_config.hidden_size_l2*multipler_intermediate2))
    setattr(base_config,"intermediate_size_l2", int(base_config.hidden_size_l3*multipler_intermediate3))
    setattr(base_config,"query_size1", base_config.hidden_size_l2 + base_config.hidden_size_l3)
    setattr(base_config,"query_size2", base_config.hidden_size_l3)
    setattr(base_config,"dropout_scaling", dropout_scaling)
    setattr(base_config,"use_cheap_integrator_for_stacks", use_cheap_integrator_for_stacks)
    setattr(base_config, "do_mlm", do_mlm)
    setattr(base_config, "do_cls", do_cls)

    # make the configuration for the l2 mid-res encoder
    config_l2.hidden_size = base_config.hidden_size_l2
    config_l2.num_hidden_layers = num_layers_l2
    setattr(base_config, 'config_l2', config_l2)

    # make the configuration for the l3 encoder
    config_l3.hidden_size = base_config.hidden_size_l3
    config_l3.num_hidden_layers = num_layers_l3
    setattr(base_config, 'config_l3', config_l3)
    return base_config


def initialize_baselayers(config, basemod = None, tokenizer=None, stack_id=0):
    """Initializes the embeddings and first stack of layers for the Anathem transformers"""
    # initialize the basemodel
    if basemod is None:
        basemod = AutoModel.from_pretrained(config.model_string)
    if tokenizer is None:
        # download pretrained tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config.model_string)

    device = basemod.device
    setattr(config, 'device', device)

    # get basemodel's embeddings
    layer_embedding = copy.deepcopy(basemod._modules['embeddings'])

    # get basemodel's first transformer block
    layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize downsampling attention layers
    bert_reducer_l2 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size,
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor2
    )
    # 1/4 hidden size
    bert_reducer_l3 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size_l2,
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor3
    )

    # initialize the mid-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    # initialize the low-resolution BertEncoder
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrative_layer_2 = BertIntegrativeLayer(
        config,
        hidden_size=config.hidden_size_l2,
        hidden_size_query=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    # from mid-res to high-res
    if not do_cheap_integrator:
        # cheap (non-transformer) method to integrate high- and mid-res hidden states
        bert_integrative_layer_1 = CheapMLPIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_query=config.query_size1,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        # full Transformer layer as mid-to-highres upscaling
        BertIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_query=config.query_size1,
            intermediate_size=config.intermediate_size_l1//2
        )

    return (
        tokenizer,
        basemod,
        layer_embedding,
        layer_basetransformer,
        maxpool,
        maxpool_attn,
        bert_reducer_l2,
        bert_reducer_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrative_layer_2,
        bert_integrative_layer_1
    )

def initialize_midlayers(config, basemod=None, tokenizer=None):
    """Initializes all the intermediate layers for the Anathem transformers"""
    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize bert attentive downsampling and skipconnection (1/2 embedding dim)
    bert_reduceintegrator_l2 = BertReduceAddIntegrativeLayer(
        config,
        config.hidden_size_l2, # size of mid-res
        hidden_size_input=config.hidden_size, # size full-resolution
        hidden_size_query=config.hidden_size, # size full-resolution
        intermediate_size=config.intermediate_size_l1, # BertIntermediate dimension (expansion *4 the hiddensize)
        dim_reduction=config.scale_factor2, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # 1/4 the size
    bert_reduceintegrator_l3 = BertReduceAddIntegrativeLayer(
        config,
        config.hidden_size_l3, # size of mid-res
        hidden_size_input=config.hidden_size_l2, # size full-resolution
        hidden_size_query=config.hidden_size_l2, # size full-resolution
        intermediate_size=config.intermediate_size_l2, # BertIntermediate dimension
        dim_reduction=config.scale_factor3, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # initialize the low-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrative_layer_2 = BertIntegrativeLayer(
        config,
        hidden_size=config.hidden_size_l2,
        hidden_size_query=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    # from mid-res to high-res
    bert_integrative_layer_1 = BertIntegrativeLayer(
        config,
        hidden_size=config.hidden_size,
        hidden_size_query=config.query_size1,
        intermediate_size=config.intermediate_size_l1
    )

    return (
        maxpool,
        maxpool_attn,
        bert_reduceintegrator_l2,
        bert_reduceintegrator_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrative_layer_2,
        bert_integrative_layer_1
    )


class AnathemBaseModule(nn.Module):
    """First Sstack of layers with embeddings, that go full circle form high-res to low-res back to high res"""
    def __init__(
            self,
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device = None
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            tokenizer, basemod,
            layer_embedding,
            layer_basetransformer,
            maxpool,
            maxpool_attn,
            bert_reducer_l2,
            bert_reducer_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrative_layer_2,
            bert_integrative_layer_1
        ) = initialize_baselayers(config, basemod, tokenizer)

        self.get_extended_attention_mask = basemod.get_extended_attention_mask
        self.embedding = layer_embedding
        self.layer_basetransformer = layer_basetransformer
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducer_l2 = bert_reducer_l2
        self.bert_reducer_l3 = bert_reducer_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrative_layer_2 = bert_integrative_layer_2
        self.bert_integrative_layer_1 = bert_integrative_layer_1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = input_ids
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        extended_attention_mask_l1 = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        # downsample the attention mask to l2 dimension
        attention_mask_l2 = self.maxpool_attn(attention_mask.float())
        extended_attention_mask_l2 = self.get_extended_attention_mask(attention_mask_l2,attention_mask_l2.shape, self.device)
        # downsample the attention mask to l3 dimension
        attention_mask_l3 = self.maxpool_attn(attention_mask_l2.float())
        extended_attention_mask_l3 = self.get_extended_attention_mask(attention_mask_l3,attention_mask_l3.shape, self.device)

        # embed
        embedding_output = self.embedding(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            #input_embeds=None,
            past_key_values_length = past_key_values_length
        )

        # first transformer block (vanilla transformer)
        out_l1 = self.layer_basetransformer(
            hidden_states = embedding_output,
            attention_mask = extended_attention_mask_l1,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        hidden_states_l1 = out_l1[0]

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_l1)

        # reduce dimenion on sequence 2
        out_l2 = self.bert_reducer_l2(
            hidden_states = hiddens_states_l1_reduced,
            attention_mask = extended_attention_mask_l2,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l1,
            encoder_attention_mask= extended_attention_mask_l1,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l2 = out_l2[0]

        # Vanilla transformers block at mid-resolution (1/2 seq-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_l2,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]

        # reduce sequence length (1/4 seq-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        out_l3 = self.bert_reducer_l3(
            hidden_states = hiddens_states_l2_reduced,
            attention_mask = extended_attention_mask_l3,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l2,
            encoder_attention_mask= extended_attention_mask_l2,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l3 = out_l3[0]

        #print(hidden_states_l3.shape)
        #print(extended_attention_mask_l3.shape)
        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_l3,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l3 = out_encoder[0]

        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_l3)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_l2 = self.bert_integrative_layer_2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_l2,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled3to2,
            query_attention_mask = attention_mask_l2
        )

        # upscaling: l3/l2 to l1 sequence length
        hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_l3)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_l2)
        hidden_states_upscaled = torch.cat((
            hidden_states_upscaled2to1, hidden_states_upscaled3to1
        ),axis=2)

        # integrate low-resolution information back to original dimension
        hidden_states_l1 = self.bert_integrative_layer_1(
            hidden_states = hidden_states_l1,
            attention_mask = extended_attention_mask_l1,
            head_mask = head_mask,
            query_hidden_states = hidden_states_upscaled,
            query_attention_mask = extended_attention_mask_l1
        )
        if not return_dict:
            return (
                (hidden_states_l1, hidden_states_l2, hidden_states_l3),
                (extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3)
            )
        return {
            "hidden_states": (hidden_states_l1, hidden_states_l2, hidden_states_l3),
            "attention":(extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3)
        }


class AnathemMidModule(nn.Module):
    """Stack of layers that go full circle form high-res to low-res back to high res"""
    def __init__(
            self,
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            maxpool,
            maxpool_attn,
            bert_reducerintegrator_l2,
            bert_reducerintegrator_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrative_layer_2,
            bert_integrative_layer_1
        ) = initialize_midlayers(config, basemod, tokenizer)

        self.get_extended_attention_mask = get_extended_attention_mask
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducerintegrator_l2 = bert_reducerintegrator_l2
        self.bert_reducerintegrator_l3 = bert_reducerintegrator_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrative_layer_2 = bert_integrative_layer_2
        self.bert_integrative_layer_1 = bert_integrative_layer_1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device

    def forward(
        self,
        hidden_states_highres: torch.Tensor,
        hidden_states_midres: torch.Tensor,
        hidden_states_lowres: torch.Tensor,
        attention_mask: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_highres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_midres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_lowres: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = hidden_states_highres.shape[:2]
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        if extended_attention_mask_highres is None:
            extended_attention_mask_highres = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        if extended_attention_mask_midres is None:
            attention_mask_midres = self.maxpool_attn(attention_mask.float())
            extended_attention_mask_midres = self.get_extended_attention_mask(attention_mask_midres,attention_mask_midres.shape, self.device)
        if extended_attention_mask_lowres is None:
           attention_mask_lowres = self.maxpool_attn(attention_mask_midres.float())
           extended_attention_mask_lowres = self.get_extended_attention_mask(attention_mask_lowres,attention_mask_lowres.shape, self.device)

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_highres)

        # reduce dimenion on sequence 2
        hidden_states_l2 = self.bert_reducerintegrator_l2(
            inputs = hidden_states_highres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_midres,
            head_mask=None,
            query_hidden_states = hiddens_states_l1_reduced
        )

        # Vanilla transformers at mid-resolution (1/2 sequence-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_midres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]

        # reduce sequence length (to 1/4 sequence-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        hidden_states_l3 = self.bert_reducerintegrator_l3(
            inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_lowres,
            head_mask=None,
            query_hidden_states = hiddens_states_l2_reduced
        )

        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_lowres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_lowres = out_encoder[0]

        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_lowres)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_midres = self.bert_integrative_layer_2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_midres,
            head_mask = None,
            query_hidden_states = hidden_states_upscaled3to2        )

        # upscaling: l3/l2 to l1 sequence length
        hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_lowres)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_midres)
        hidden_states_upscaled = torch.cat((
            hidden_states_upscaled2to1, hidden_states_upscaled3to1
        ),axis=2)

        # integrate low-resolution information back to original dimension
        hidden_states_highres = self.bert_integrative_layer_1(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_mask_highres,
            head_mask = None,
            query_hidden_states = hidden_states_upscaled,
            query_attention_mask = extended_attention_mask_highres
        )
        if not return_dict:
            return (
                (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
                (extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
            )
        return {
            "hidden_states": (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
            "attention":(extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
        }

class BertClassificationHead(nn.Module):
    def __init__(self, config, n_classes = 1, activation = 'sigmoid', device=None):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x
        if device is not None:
            self.to(device)

    def forward(self, hidden_states, attention_mask) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        output_vectors=[]
        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)


def tokenize_anathem(text, device=device):
    #padding_length = int(math.ceil(max_length / 4)) *
    tokens = tokenizer(text,padding=True, return_tensors='pt', pad_to_multiple_of=4)
    input_shape = tokens['input_ids'].size()

    # change token padding to be multiple of 4
    #ideal_length = int(math.ceil(input_shape[-1] / 4)) * 4 # should be a multiple of 4
    #if input_shape[-1]!=ideal_length:
    #  tokens = tokenizer(text,padding='max_length', max_length = ideal_length, return_tensors='pt')
    #  input_shape = tokens['input_ids'].size()

    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    tokens['token_type_ids'] = token_type_ids
    for k,v in tokens.items():
        tokens[k] = v.to(device)

    return tokens

In [None]:
#config = make_config('distilroberta-base')
#config = make_config('t5-small') # can't use t5 because it uses relative
config = make_config('google/bert_uncased_L-12_H-512_A-8') #

if False:
  (tokenizer,basemod,layer_embedding,layer_basetransformer,maxpool,maxpool_attn,bert_reducer_l2,
   bert_reducer_l3,bert_encoder_lowres,upscaler_x2,upscaler_x4,bert_integrative_layer_2,bert_integrative_layer_1) = initialize(config)

# make the basemod and tokenizer
basemod = AutoModel.from_pretrained(config.model_string)
basemod.to(device)
tokenizer = AutoTokenizer.from_pretrained(config.model_string)



Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# the Anathem encoder includes the embeddings and first transformer block
anathem_encoder1 = AnathemBaseModule(config, basemod, tokenizer)
anathem_encoder2 = AnathemMidModule(config, basemod)

In [None]:
cls_head = BertClassificationHead(config, n_classes = 3, activation = 'none',device=device)


In [None]:
text = [
    "* Welcome home to this gorgeously upgraded, beautifully maintained, three-bedroom home with double attached garage. Drive up to this quiet cul-de-sac and let the experience begin. On the main floor, you’ll notice the abundance of natural light. There is a separate office with view over the front of the property. The layout was customized, with a great open living space. The kitchen is a chef’s dream, with a breakfast bar, granite countertops, stainless steel appliance package, a pantry, and a view out to the sunny west facing yard.",
    "There’s room for formal dining and the family room has a gas fireplace to relax by on the cooler nights. Out back, there’s a stunner of a deck, perfect for BBQ season! Upstairs, you’ll find a massive bonus room with tons of windows. There are two, secondary bedrooms and the master suite is amazing",
]

In [None]:
tokens = tokenize_anathem(text,device)

In [None]:
#stack 1
out1 = anathem_encoder1(
      input_ids = tokens['input_ids'],
      attention_mask = tokens['attention_mask'],
      token_type_ids = tokens['token_type_ids']
)
(hidden_states, extended_attention_masks) = out1



In [None]:
# stack2
out2 = anathem_encoder2(
      hidden_states_highres = hidden_states[0],
      hidden_states_midres = hidden_states[1],
      hidden_states_lowres = hidden_states[2],
      extended_attention_mask_highres = extended_attention_masks[0],
      extended_attention_mask_midres = extended_attention_masks[1],
      extended_attention_mask_lowres = extended_attention_masks[2]
)
(hidden_states, extended_attention_masks) = out2

cls_head(hidden_states[0], tokens['attention_mask'])



tensor([[-0.8376, -0.3891, -0.6668],
        [-0.8747, -0.3621, -0.7735]], device='cuda:0',
       grad_fn=<AddmmBackward0>)

In [None]:
out1[0][0].shape

torch.Size([2, 48, 768])

In [None]:
####

In [None]:
## Next steps, do something simple like sentiment analysis

In [None]:
from datasets import list_datasets, load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from sklearn.metrics import precision_recall_fscore_support
from scipy.special import softmax
#datasets_list = list_datasets()
#[k for k in datasets_list if 'phrasebank' in k]


In [None]:
#[k for k in datasets_list if 'phrasebank' in k]

dataset = load_dataset('financial_phrasebank', 'sentences_75agree')

# split
idx_train, idx_val = train_test_split(np.arange(len(dataset['train']['sentence'])), test_size=0.1)
dataset_train = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]}  for idx in idx_train]
dataset_val = [{'text':dataset['train']['sentence'][idx], 'label':dataset['train']['label'][idx]} for idx in idx_val]



  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
print(len(dataset_train)); print(len(dataset_val))

3107
346


In [None]:
class MyDataset(Dataset):
    """torch dataset."""

    def __init__(self, dataset):
        self.data = dataset
        self.n = len(self.data)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        unit = self.data[idx]
        return unit

In [None]:
ds_train = MyDataset(dataset_train)
ds_val = MyDataset(dataset_val)

In [None]:
batch_size_train = 12
batch_size_val = 36
lr = 0.00005
eval_iter = 20
n_epochs = 1

In [None]:
dl_train = DataLoader(ds_train, batch_size=batch_size_train, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=batch_size_val, shuffle=False)

In [None]:
optimizer = AdamW(list(anathem_encoder1.parameters()) + list(anathem_encoder2.parameters()) + list(cls_head.parameters()), lr=lr)

In [None]:

optimizer.zero_grad()
anathem_encoder1.train()
anathem_encoder2.train()
cls_head.train()
for epoch in range(n_epochs):

  for iteration, batch in enumerate(tqdm(dl_train, disable=True)):

      # tokenize the batch
      tokens = tokenize_anathem(batch['text'],device)
      target = batch['label'].to(device)

      optimizer.zero_grad()

      out1 = anathem_encoder1(
        input_ids = tokens['input_ids'],
        attention_mask = tokens['attention_mask'],
        token_type_ids = tokens['token_type_ids']
      )
      (hidden_states, extended_attention_masks) = out1

      features,_ = anathem_encoder2(
          hidden_states_highres = hidden_states[0],
          hidden_states_midres = hidden_states[1],
          hidden_states_lowres = hidden_states[2],
          extended_attention_mask_highres = extended_attention_masks[0],
          extended_attention_mask_midres = extended_attention_masks[1],
          extended_attention_mask_lowres = extended_attention_masks[2]
      )

      # prediction
      preds = cls_head(features[0], tokens['attention_mask'])

      # loss
      loss = nn.functional.cross_entropy(preds, target)
      loss.backward()
      optimizer.step()

      # do evaluation
      if ((iteration+1) % eval_iter)==0:
          anathem_encoder1.eval()
          anathem_encoder2.eval()
          cls_head.eval()
          # tokenize the eval
          eval_logits = []
          eval_targets = []
          for i, batch_eval in enumerate(tqdm(dl_val, disable=True)):
              with torch.no_grad():
                  # tokenize the batch
                  tokens_eval = tokenize_anathem(batch_eval['text'], device)
                  labels_eval = batch_eval['label'].to(device)
                  out_eval1 = anathem_encoder1(
                      input_ids = tokens_eval['input_ids'],
                      attention_mask = tokens_eval['attention_mask'],
                      token_type_ids = tokens_eval['token_type_ids']
                  )
                  (hidden_states, extended_attention_masks) = out_eval1
                  features,_ = anathem_encoder2(
                      hidden_states_highres = hidden_states[0],
                      hidden_states_midres = hidden_states[1],
                      hidden_states_lowres = hidden_states[2],
                      extended_attention_mask_highres = extended_attention_masks[0],
                      extended_attention_mask_midres = extended_attention_masks[1],
                      extended_attention_mask_lowres = extended_attention_masks[2]
                  )
                  # prediction
                  batch_logits = cls_head(features[0], tokens_eval['attention_mask'])
                  eval_logits+=batch_logits.detach().tolist()
                  eval_targets+=labels_eval.detach().tolist()

          eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)
          print('E:%d; i:%d: f1:%0.3f (%0.3f); prec:%0.3f (%0.3f); rec:%0.3f (%0.3f)' % (epoch, iteration, eval_f1.mean(), eval_f1.min(), eval_prec.mean(), eval_prec.min(), eval_recall.mean(), eval_recall.min()))
          cls_head.train()
          anathem_encoder1.train()
          anathem_encoder2.train()






E:0; i:19: f1:0.402 (0.000); prec:0.352 (0.000); rec:0.469 (0.000)




E:0; i:39: f1:0.326 (0.000); prec:0.400 (0.000); rec:0.372 (0.000)




E:0; i:59: f1:0.459 (0.158); prec:0.531 (0.405); rec:0.485 (0.095)




E:0; i:79: f1:0.506 (0.305); prec:0.583 (0.450); rec:0.494 (0.231)




E:0; i:99: f1:0.499 (0.190); prec:0.555 (0.383); rec:0.551 (0.116)




E:0; i:119: f1:0.552 (0.280); prec:0.663 (0.568); rec:0.534 (0.179)




E:0; i:139: f1:0.661 (0.469); prec:0.708 (0.600); rec:0.636 (0.385)




KeyboardInterrupt: ignored

In [None]:
target

tensor([1, 0, 2, 1, 1, 0, 1, 1, 1, 1, 1, 2])

## Test performance speed

In [None]:
# how many parameters in the model in total
from math import prod
nparam = 0
for encoder in [anathem_encoder1, anathem_encoder2]:
    for na,l in encoder.named_parameters():
        nparam+=prod(l.data.shape)
print('Number of parameters for anathem: %d' % nparam)
# 33676544

Number of parameters for anathem: 33283328


In [None]:
# compare this to distilbert
#other_mod = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
other_mod = AutoModel.from_pretrained('google/bert_uncased_L-12_H-512_A-8')

Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
nparam = 0
for na,l in other_mod.named_parameters():
    nparam+=prod(l.data.shape)

print('Number of parameters for other-mod: %d' % nparam)

# number of parameters for anathem-trans: 33676544 (google/bert_uncased_L-12_H-512_A-8)
# number of parametres for anathem-trans: 78973824 (includng 2 more mid-res encoders)
# number of parameters for anathem-trans: 73062528 (with a 768 dimension)
# Number of parameters for distilroberta: 82118400 (with a 768 dimension)
# Number of parameters  all-MiniLM-L6-v2: 22713216
# Number of parameters google/bert_uncased_L-12_H-512_A-8: 53982720 (512 dim, 12L)


Number of parameters for other-mod: 53982720


## Test Performance Speed at inference (CPU)
- distilroberta-base: 10 batches: 23.517s , CPU
- oogle/bert_uncased_L-12_H-512_A-8: 10 batches: 12.44s, CPU
- anathem (distilroberta-768): 10 batches, 23.23s,
- anathem ((google/bert_uncased_L-12_H-512_A-8)): 10 batches, ~7.5s, CPU

## Test Performance Speed at inference (GPU)
- anathem ((google/bert_uncased_L-12_H-512_A-8)): 30 batches, 0.79s, GPU
- google/bert_uncased_L-12_H-512_A-8: 30 batches: 0.8 GPU


In [None]:
import time

In [None]:
time1 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time2 = time.time()
        print(time2-time1)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])
        (hidden_states, extended_attention_masks) = anathem_encoder1(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )
        features,_ = anathem_encoder2(
            hidden_states_highres = hidden_states[0],
            hidden_states_midres = hidden_states[1],
            hidden_states_lowres = hidden_states[2],
            extended_attention_mask_highres = extended_attention_masks[0],
            extended_attention_mask_midres = extended_attention_masks[1],
            extended_attention_mask_lowres = extended_attention_masks[2]
        )

0.8027215003967285


In [None]:
time3 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time4 = time.time()
        print(time4-time3)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])
        out = basemod(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )

0.7066085338592529


In [None]:
eval

array([0.        , 0.86464646, 0.52173913])

In [None]:
eval_prec,eval_recall,eval_f1,eval_support = precision_recall_fscore_support(eval_targets, np.array(eval_logits).argmax(axis=1),zero_division=0)

## Variant: Possibly Faster Integrative Layer

The above version uses a BertIntegrativeLayer that uses the high-res hidden-states as the key/values, and the upscaled-low res as the query

This variant flips it: the high-res is the query (thereby upscaling via attention) and the low-res are the value and keys

#### Varient #2 has slightly fewer parameters: 33283328 vs 336

In [None]:
%pip install torch transformers datasets zstandard rank_bm25 langdetect


Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.2-py3-none-any.whl (518 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting zstandard
  Downloading zstandard-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting huggin

In [None]:
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForMaskedLM
from torch.utils.data import DataLoader, Dataset
import torch
from typing import List, Optional, Tuple, Union
from torch import nn
import torch.nn.functional as F
from torch.cuda import is_available
if is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

from transformers.models.bert.modeling_bert import BertEncoder
from transformers.tokenization_utils_base import BatchEncoding
from transformers.activations import ACT2FN
import copy
import math
from langdetect import detect

from transformers import BertTokenizer

from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from transformers.utils import PaddingStrategy

EncodedInput = List[int]

In [None]:
class CustomTokenizer:
    def __init__(
        self,
        model_string='google/bert_uncased_L-12_H-512_A-8',
        n_cls_prepend = 4,
        n_pad_to_multiple_of=4,
        downscale_multiple=2
    ):
        # initialize the tokenizer from the base model
        self.base_tokenizer = AutoTokenizer.from_pretrained(model_string)
        # how many cls tokens to prepend to the fullsize data
        self.n_cls_prepend = n_cls_prepend
        self.n_pad_to_multiple_of = n_pad_to_multiple_of
        for k in dir(self.base_tokenizer):
            if not ((k[0]=='_') or (k in ['tokenize','encode','build_inputs_with_special_tokens','batch_encode_plus','encode_plus','pad'])):
                setattr(self,k,getattr(self.base_tokenizer, k))
        self.downscale_multiple = downscale_multiple
        # downscale attention
        self.maxpool_attn = nn.MaxPool1d(
            (self.downscale_multiple), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True
        )

        # ensure excess_token_ids are included for .pad operations
        if 'excess_cls_ids' not in self.base_tokenizer.model_input_names:
            self.base_tokenizer.model_input_names += ['excess_cls_ids']

    def __call__(self, text, pad_to_multiple_of=None, add_special_tokens = True, return_tensors=None, *args, **kwargs):
        if pad_to_multiple_of is None:
            pad_to_multiple_of = self.n_pad_to_multiple_of
        tokens = self.base_tokenizer(
            text,
            pad_to_multiple_of=(pad_to_multiple_of if not add_special_tokens else False),
            add_special_tokens=add_special_tokens,
            return_tensors=return_tensors if (not add_special_tokens) else None,
            *args,
            **kwargs
        )
        if add_special_tokens:
            tokens = self._batch_prepend_extra_cls_tokens_because_of_maxpooling(tokens, return_tensors)

        # downscale the attention, add to tokens
        tokens = self.downscale_attention(
            tokens, downscale_multiple=[self.downscale_multiple, self.downscale_multiple],name='attention_mask'
        )
        # dowscale the excess_cls_tokens, add to tokens
        tokens = self.downscale_attention(
            tokens, downscale_multiple=[self.downscale_multiple, self.downscale_multiple],name='excess_cls_ids'
        )
        return tokens

    def __len__(self):
        return len(self.base_tokenizer)

    def _num_pad_tokens(self, token_list):
        """Calculates how many PAD tokens to append to sequence to make a multiple of X"""
        return (self.n_pad_to_multiple_of - ((len(token_list)+(self.n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of

    def _prepend_extra_cls_tokens_because_of_maxpooling(self, tokens,return_tensors=None):
        n_cls_prepend = self.n_cls_prepend
        # prepend (n-1) CLS tokens to the front of the token_ids (because of maxpooling)
        # also pad so that the total length is a multiple of n_cls_prepend
        #num_pad_tokens = (self.n_pad_to_multiple_of - ((len_tokens+(n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of
        tokens['input_ids'] = [self.cls_token_id]*(n_cls_prepend-1)+tokens['input_ids'] + [self.pad_token_id]*self._num_pad_tokens(tokens['input_ids'])
        tokens['excess_cls_ids'] = [0]*(n_cls_prepend)+tokens['attention_mask'][1:] +[0]*self._num_pad_tokens(tokens['attention_mask'])
        tokens['attention_mask'] = [1]*(n_cls_prepend-1)+tokens['attention_mask'] +[0]*self._num_pad_tokens(tokens['attention_mask'])
        if 'token_type_ids' in tokens.keys():
            tokens['token_type_ids'] = [
                tokens['token_type_ids'][0]
            ]*(n_cls_prepend-1) + tokens['token_type_ids'] + [tokens['token_type_ids'][-1]]*self._num_pad_tokens(tokens['token_type_ids'])
        if return_tensors == 'pt':
            for k,v in tokens.items():
                tokens[k] = torch.LongTensor(v)
        return tokens

    def _batch_prepend_extra_cls_tokens_because_of_maxpooling(self, tokens,return_tensors=None):
        n_cls_prepend = self.n_cls_prepend
        # prepend (n-1) CLS tokens to the front of the token_ids (because of maxpooling)
        # also pad so that the total length is a multiple of n_cls_prepend
        #num_pad_tokens = (self.n_pad_to_multiple_of - ((len_tokens+(n_cls_prepend-1)) % self.n_pad_to_multiple_of)) % self.n_pad_to_multiple_of
        tokens['input_ids'] = [
            [self.cls_token_id]*(n_cls_prepend-1)+input_id + [self.pad_token_id]*self._num_pad_tokens(input_id)
            for input_id
            in tokens['input_ids']
        ]
        tokens['excess_cls_ids'] = [
            [0]*(n_cls_prepend)+attnmask[1:] +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        tokens['attention_mask'] = [
            [1]*(n_cls_prepend-1)+attnmask +[0]*self._num_pad_tokens(attnmask)
            for attnmask
            in tokens['attention_mask']
        ]
        if 'token_type_ids' in tokens.keys():
            tokens['token_type_ids'] = [
                # we use the token_type_ids
                [toktypeid[0]]*(n_cls_prepend-1)+toktypeid +[toktypeid[-1]]*self._num_pad_tokens(toktypeid)
                for toktypeid
                in tokens['token_type_ids']
            ]
        if return_tensors == 'pt':
            for k,v in tokens.items():
                tokens[k] = torch.LongTensor(v)
        return tokens

    def encode(self, text, pad_to_multiple_of=4, add_special_tokens = True, *args, **kwargs):
        encoded = self.base_tokenizer.encode(text, pad_to_multiple_of=False, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            encoded = [self.cls_token_id]*(pad_to_multiple_of-1) + encoded
        if bool(pad_to_multiple_of):
            num_pad_tokens = (pad_to_multiple_of - (len(encoded) % pad_to_multiple_of)) % pad_to_multiple_of
            encoded += [self.pad_token_id] * num_pad_tokens
        return encoded

    def encode_plus(self, text, add_special_tokens=True, return_tensors=None, *args, **kwargs):
        tokens = self.base_tokenizer.encode_plus(text, add_special_tokens=add_special_tokens, return_tensors=return_tensors, *args, **kwargs)
        if add_special_tokens:
            tokens = self._prepend_extra_cls_tokens_because_of_maxpooling(tokens, return_tensors)
        return tokens

    def tokenize(self, text, add_special_tokens=True, *args, **kwargs):
        toks = self.base_tokenizer.tokenize(text, add_special_tokens=add_special_tokens, *args, **kwargs)
        if add_special_tokens:
            toks = [self.cls_token] * (self.n_cls_prepend-1) + toks
        return toks

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ):
        out = self.base_tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
        return [self.cls_token_id]*3 + out

    def batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
        batched_encoded = self.base_tokenizer.batch_encode_plus( batch_text_or_text_pairs, *args, **kwargs)
        batched_encoded.update({'foo':'bar'})
        return batched_encoded

    def downscale_attention(self, tokens, downscale_multiple=None, name = 'attention_mask'):
        """
        Reduces the sequence-dimenion by self.downscale_multiple using nn.maxpool
        Adds the downscale attention to the tokens dictionary
        """
        if downscale_multiple is None:
            downscale_multiple = [self.downscale_multiple, self.downscale_multiple]

        # fullsize attention
        attn = tokens[name]
        if not isinstance(attn, torch.Tensor):
            attn = torch.Tensor(attn)

        for i, mult in enumerate(downscale_multiple):
            name_of_downsized_attn = '%s_l%d' % (name, i+2)
            with torch.no_grad():
                attn = self.maxpool_attn(attn.float())
            tokens[name_of_downsized_attn] = attn
        return tokens

    def pad(
        self,
        encoded_inputs,
        pad_to_multiple_of=4,
        return_tensors=None,
        padding: Union[bool, str, PaddingStrategy] = True,
        max_length: Optional[int] = None,
        *args,
        **kwargs
    ):
        """Pad a list of tokenized-inputs to the same batch-length, with special processing of Anathem-specific inputs"""

        # which are conventional inputs and which are anathem specific
        conventional_input_nm = [k for k in encoded_inputs[0].keys() if k in ['input_ids', 'token_type_ids','attention_mask']]
        unconventional_input_nm = [k for k in encoded_inputs[0].keys() if k not in conventional_input_nm]

        # pad the vanilla inputs
        conventional_encoded_inputs = self.base_tokenizer.pad([
                {k:v for k,v in encoded_input.items() if k in conventional_input_nm}
                for encoded_input in encoded_inputs
            ], pad_to_multiple_of=pad_to_multiple_of, return_tensors=return_tensors, padding=padding, max_length=max_length, *args, **kwargs
        )

        # deal with the remaining inputs
        padding_strategy, _, max_length, _ = self.base_tokenizer._get_padding_truncation_strategies(
            padding=padding, max_length=max_length, verbose=False
        )

        #required_input = encoded_inputs[][self.model_input_names[0]]
        # this is stupid, I need to pad each input in batch individually
        special_anathem_inputs = [
                {k:v for k,v in encoded_input.items() if k in unconventional_input_nm}
                for encoded_input in encoded_inputs
        ]
        special_anathem_encoded_inputs = self.pad_special_anathem_inputs(
            special_anathem_inputs=special_anathem_inputs,
            encoded_inputs=conventional_encoded_inputs,
            max_length=max_length,
            padding_strategy=padding_strategy,#: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors=return_tensors
        )
        # let's see if I can just insert into the conventional_encode_inputs
        conventional_encoded_inputs.update(special_anathem_encoded_inputs) # apparently I can just append..

        # downscale the attention and add to inputs
        conventional_encoded_inputs = self.downscale_attention(
            conventional_encoded_inputs,
            downscale_multiple=[self.downscale_multiple, self.downscale_multiple],
            name='attention_mask'
        )
        # dowscale the excess_cls_tokens, add to tokens
        conventional_encoded_inputs = self.downscale_attention(
            conventional_encoded_inputs,
            downscale_multiple=[self.downscale_multiple, self.downscale_multiple],
            name='excess_cls_ids'
        )
        return conventional_encoded_inputs

    def pad_special_anathem_inputs(
        self,
        special_anathem_inputs,
        encoded_inputs,
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None,
        return_tensors=None,
    ):
        required_input = encoded_inputs[self.model_input_names[0]]
        batch_size,max_length = required_input.shape
        #print(batch_size,max_length)
        assert batch_size == len(special_anathem_inputs)
        assert isinstance(special_anathem_inputs, list)
        padding_strategy = PaddingStrategy.MAX_LENGTH
        special_anathem_batch_outputs = {}
        for i in range(batch_size):
            inputs = special_anathem_inputs[i] #{k: v[i] for k, v in special_anathem_inputs.items()}
            assert isinstance(inputs, dict)
            outputs = self._pad_special_anathem_input(
                inputs,
                max_length=max_length,
                padding_strategy=padding_strategy,
                pad_to_multiple_of=pad_to_multiple_of
            )
            for key, value in outputs.items():
                if key not in special_anathem_batch_outputs:
                    special_anathem_batch_outputs[key] = []
                special_anathem_batch_outputs[key].append(value)

        return BatchEncoding(special_anathem_batch_outputs, tensor_type=return_tensors) # returning because of failure

    def _pad_special_anathem_input(
        self,
        special_anathem_input,
        max_length: Optional[int] = None,
        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
        pad_to_multiple_of: Optional[int] = None
    ) -> dict:
        """
        Pad encoded Anathem-specific inputs (on left/right and up to predefined length or max length in the batch)
        """
        assert isinstance(special_anathem_input, dict)
        len_required_input = len(special_anathem_input[list(special_anathem_input.keys())[0]])
        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len_required_input != max_length

        # Initialize attention mask if not present
        if needs_to_be_padded:
            special_anathem_outputs = dict.fromkeys(special_anathem_input.keys())
            difference = max_length - len_required_input
            if self.padding_side == "right":
                for k in special_anathem_input.keys():
                    special_anathem_outputs[k] = special_anathem_input[k] + [0] * difference
            elif self.padding_side == "left":
                for k in special_anathem_input.keys():
                    special_anathem_outputs[k] = [0] * difference + special_anathem_input[k]
            else:
                raise ValueError("Invalid padding strategy:" + str(self.padding_side))

            return special_anathem_outputs
        return special_anathem_input

In [None]:
tokenizer = CustomTokenizer(
        model_string='google/bert_uncased_L-12_H-512_A-8',
        n_cls_prepend = 4,
        n_pad_to_multiple_of=4,
        downscale_multiple=2
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/384 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


In [None]:
tokenizer.base_tokenizer.model_input_names

['input_ids', 'token_type_ids', 'attention_mask', 'excess_cls_ids']

In [None]:
text = [
    "A standard [MASK] clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a [MASK] obligation. The trigger event or circumstance is the [MASK] of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

tokens = tokenizer(text, return_tensors='pt', padding=True)

In [None]:
# FOOFU
# in the vanilla DataCollatorForLanguageModelling, if the data is pretokenized (unpadded)
#    then collator will simply "pad", the input_ids and the attention_mask (but not the generated excess_cls_ids, nor the attention_mask_l2 or l3)
#    ... but, I created these _l2,_l3 assuming that everything was already padded properly
# so, adding excess_token_ids to _model_names_inputs (or whatev, doesn't automatically cause the behaviour I wanted)
# the error is because the _pad specifically only handles special_token_ids and token_type_ids in a very specific way
#... there is no generic list_of_names to enforce padding of generic inputs.

# options:
# --- make an updated "pad" function for the tokenizer, that will likewise apply padding
tokens = [tokenizer.encode_plus(txt, add_special_tokens=True) for txt in text]

for tok in tokens:
    for k,v in tok.items():
        print(k,len(v))
        print(k,v)
print('---')

pad_out = tokenizer.pad(tokens, pad_to_multiple_of=4, return_tensors='pt')
print('CONVENTIONAL')
print(pad_out)

#for k,v in tokenizer.base_tokenizer.pad(tokens, pad_to_multiple_of=4, return_tensors='pt').items():
print('SPECIAL')
print(pad_out)
for k,v in pad_out.items():
    print(k, len(v))
    for j in v:
        print(len(j))


# still need to do: reduce attention_mask
# return as tensor
# merge and make a BatchEncoding

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids 40
input_ids [101, 101, 101, 101, 1037, 3115, 103, 11075, 2003, 1037, 23701, 6299, 11075, 2008, 2163, 2008, 2028, 2283, 2180, 1005, 1056, 2907, 1996, 2060, 20090, 2005, 12394, 1010, 6409, 1010, 2030, 5366, 3378, 2007, 3314, 1012, 102, 0, 0, 0]
token_type_ids 40
token_type_ids [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
attention_mask 40
attention_mask [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
excess_cls_ids 40
excess_cls_ids [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]
input_ids 48
input_ids [101, 101, 101, 101, 2009, 2788, 3774, 1997, 2048, 3787, 1024, 1037, 9495, 2724, 2030, 25652, 1998, 1037, 103, 14987, 1012, 1996, 9495, 2724, 2030, 25652, 2003, 1996, 103, 1997, 1996, 3820, 1010, 23337, 1010, 2030, 27988, 1997, 1996, 27427, 6633, 3490, 14116, 2

In [None]:
type(pad_out)

transformers.tokenization_utils_base.BatchEncoding

In [None]:
class BertSelfAttnDimensionReduction(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self,
        config,
        hidden_size_input=768,
        hidden_size_query = None,
        position_embedding_type=None,
        dim_reduction = 2
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()
        if (config.hidden_size // dim_reduction) % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )
        self.dim_reduction = dim_reduction
        self.hidden_size_input = hidden_size_input
        self.hidden_size_reduced = hidden_size_input // dim_reduction
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size_reduced / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_input, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_input, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if encoder_attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            #print(attention_scores.shape)
            #print(attention_scores.shape)
            attention_scores = attention_scores + encoder_attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs


class InterpolateCombo(nn.Module):
    """there could also be an attentive way to do this"""
    def __init__(self, scale_factor=2, dropout=0.05, alpha=0.667):
        """Arguments:
        :param scaler_factor: float, multiple of up-scaling
        :param dropout: float, dropout proportion
        :param alpha: float, mixture weight between nearest-neighbor vs linear-interpolation
        """
        super(InterpolateCombo, self).__init__()
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.a = alpha

    def forward(self, x):
        x_trans = x.transpose(-2,-1)
        z = self.a*self.interp(x_trans, mode='nearest',scale_factor=self.scale_factor) + (1-self.a)*self.interp(x_trans, mode='linear',scale_factor=self.scale_factor)
        z = self.dropout(z)
        return z.transpose(-2,-1)


class BertCrossAttention(nn.Module):
    def __init__(
        self,
        config,
        hidden_size,
        hidden_size_query,
        hidden_size_keyvalue=None,
        position_embedding_type=None
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.hidden_size_query = hidden_size_query
        if hidden_size_keyvalue is None:
            hidden_size_keyvalue = hidden_size
        self.hidden_size_keyvalue = hidden_size_keyvalue
        if self.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(self.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(self.hidden_size_query, self.all_head_size)
        self.key = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)
        self.value = nn.Linear(self.hidden_size_keyvalue, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: Optional[torch.FloatTensor] = None,
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(query_hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)

        use_cache = past_key_value is not None
        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            if use_cache:
                position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                    -1, 1
                )
            else:
                position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs


class BertReduceAddIntegrativeLayer(nn.Module):
    """Bert Layer that does dimenion reduction along embedding-dimenion and integrations a skip connection"""
    def __init__(
            self,
            config,
            hidden_size,
            hidden_size_input=None,
            hidden_size_query=None,
            intermediate_size=None,
            dim_reduction=2,
            do_concat_hidden_and_query = True
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.do_concat_hidden_and_query = do_concat_hidden_and_query
        assert bool(do_concat_hidden_and_query), 'not implemented: concatenation of query and hidden-states must happen'
        self.hidden_size = hidden_size
        if dim_reduction is None:
            dim_reduction = 2
        self.dim_reduction = dim_reduction
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size
        if hidden_size_input is None:
            hidden_size_input = hidden_size
        self.hidden_size_input = hidden_size_input
        if hidden_size_query is None:
            hidden_size_query = hidden_size_input
        self.hidden_size_query = hidden_size_query + do_concat_hidden_and_query*hidden_size
        self.hidden_size_concat = int(hidden_size + hidden_size_input)

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertSelfAttnDimensionReduction(
            config,
            hidden_size_input=self.hidden_size_input,
            hidden_size_query = self.hidden_size_query,
            position_embedding_type="absolute",
            dim_reduction = self.dim_reduction
        )
        self.is_decoder = config.is_decoder
        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        inputs: torch.Tensor, # higher-resolution inputs for key and values (long sequence dimension)
        hidden_states: torch.Tensor, # previous hidden-states for skip connection (short squence-dim, low-res)
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        query_hidden_states: torch.FloatTensor = None, # hidden-states for query (short squence-dim, low-res)
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        if self.do_concat_hidden_and_query:
            query_hidden_states_plus = torch.cat((query_hidden_states, hidden_states),axis=2)
        # cross attn between (low-res) query vector and (high-res) key-values
        cross_attn_outputs = self.attention(
            query_hidden_states_plus, # query (short seq-dim, high-res)
            attention_mask=attention_mask,
            head_mask=head_mask,
            encoder_hidden_states = inputs, # for key/value (longer sequence dimension, high-res)
            past_key_value=past_key_value,
            output_attentions=output_attentions,
        )
        cross_hidden_states = cross_attn_outputs[0]

        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.dropout_attn(self.output_attn(cross_hidden_states))
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.intermediate_act_fn(self.intermediate(
            self.cat((hidden_states, query_hidden_states),axis=2)
        ))
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        intermediate_states = self.dropout_intm(self.output_intm(intermediate_states))
        out_states = self.lnorm_intm(intermediate_states + hidden_states)

        #inputs = x_l1, x_l1_reduced, x_l2_prev
        #- x2 = BertCrossAttention(k,v=x_l1, q= cat(x_l1_reduced, x_l2_prev) ) -notice three inputs
        #- x3 = lnorm(drop(f(x2)) + x_l2_prev)
        #- x4_ex = activation( f(cat(x3, x_l1_reduced))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states

try:
    from transformers.modeling_utils import get_extended_attention_mask
except:
    def get_extended_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: device) -> torch.Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            if self.config.is_decoder:
                batch_size, seq_length = input_shape
                seq_ids = torch.arange(seq_length, device=device)
                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
                # in case past_key_values are used we need to add a prefix ones mask to the causal mask
                # causal and attention masks must have same type with pytorch version < 1.3
                causal_mask = causal_mask.to(attention_mask.dtype)

                if causal_mask.shape[1] < attention_mask.shape[1]:
                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
                    causal_mask = torch.cat(
                        [
                            torch.ones(
                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
                            ),
                            causal_mask,
                        ],
                        axis=-1,
                    )

                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
            else:
                extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
                    input_shape, attention_mask.shape
                )
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask



In [None]:


# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class BertIntegrativeLayer(nn.Module):
    """Vanilla Bert Layer, but integrates other hiddens states from a parallel transformers stack typically low-re"""
    def __init__(
            self,
            config,
            hidden_size, # dimensions of the (high-res) hiddens states; same dimension as output
            hidden_size_keyvalues, # dimensions of (low-res) states used as key/values; 1/2 sequence-length and dim
            hidden_size_query_to_concat=None, # dimensions of (low-res) to concat to hidden_states; 1/2 sequence-length and dim
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.hidden_size = hidden_size
        self.hidden_size_keyvalues = hidden_size_keyvalues
        if hidden_size_query_to_concat is None:
            hidden_size_query_to_concat = hidden_size_keyvalues
        self.hidden_size_query_to_concat = hidden_size_query_to_concat
        self.hidden_size_query = int(hidden_size + hidden_size_query_to_concat)
        self.hidden_size_concat = int(hidden_size + hidden_size_query_to_concat)
        if intermediate_size is None:
            intermediate_size = int(4*hidden_size)
        self.intermediate_size = intermediate_size

        # cross attention between (low-res) query and hidden layers below
        self.attention = BertCrossAttention(
            config,
            hidden_size= self.hidden_size, # high dim output
            hidden_size_query = self.hidden_size_query, # high dim query
            hidden_size_keyvalue = self.hidden_size_keyvalues, # low-dim keyvalues
            position_embedding_type="absolute"
        )
        self.is_decoder = config.is_decoder
        #self.intermediate = BertIntermediate(config)
        #self.output = BertOutput(config)
        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)

        # corresponds to BertAttention SelfOutput
        self.output_attn = nn.Linear(self.hidden_size, self.hidden_size)
        self.lnorm_attn = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_attn = nn.Dropout(config.hidden_dropout_prob)

        # corresponds to BertIntermediate
        self.intermediate = nn.Linear(self.hidden_size_concat, self.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

        # corresponds to BertOutput
        self.output_intm = nn.Linear(self.intermediate_size, self.hidden_size)
        self.lnorm_intm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout_intm = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor, # high-res hidden states (same dimensions as output), used as query
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        keyvalue_hidden_states: torch.Tensor=None, # low-res hidden-states (1/2 seq-dim) used for key-value pairs
        query_to_concat_hidden_states: torch.Tensor=None, # to concatenate to query
        query_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None

        # cross attn between hiddens states and (low-res) query vector
        cross_attn_outputs = self.attention(
            hidden_states = keyvalue_hidden_states,
            attention_mask = attention_mask,
            head_mask = head_mask,
            query_hidden_states = torch.cat((hidden_states, query_to_concat_hidden_states),axis=2),
            query_attention_mask = query_attention_mask
        )
        cross_hidden_states = cross_attn_outputs[0]
        assert cross_hidden_states.shape[1]==hidden_states.shape[1], f"{cross_hidden_states.shape[1]},{cross_hidden_states.shape[2]} vs {hidden_states.shape[1]},{hidden_states[2]}"
        assert cross_hidden_states.shape[2]==hidden_states.shape[2]


        # first Add+Norm skip connection (BertSelfOutput)
        cross_hidden_states = self.output_attn(cross_hidden_states)
        cross_hidden_states = self.dropout_attn(cross_hidden_states)
        hidden_states = self.lnorm_attn(cross_hidden_states + hidden_states)

        # intermediate expension
        intermediate_states = self.cat((hidden_states, query_to_concat_hidden_states),axis=2)
        intermediate_states = self.intermediate(intermediate_states)
        intermediate_states = self.intermediate_act_fn(intermediate_states)
        assert intermediate_states.shape[0]==hidden_states.shape[0]
        assert intermediate_states.shape[1]==hidden_states.shape[1]

        # BertOutput
        out_states = self.output_intm(intermediate_states)
        out_states = self.dropout_intm(out_states)
        out_states = self.lnorm_intm(out_states + hidden_states)

        #- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
        #- x3 = lnorm(drop(f(x2)) + x_l2)
        #- x4_ex = activation( f(cat(x3, x_l3_up))  )
        #- x5 = lnorm(drop(f(x4_ex)) + x3)
        return out_states



In [None]:


# how does bert actually work?
"""
input = x

BertLayer:
- BertAttention
--- x2 = BertSelfAttention(x)
--- x3 = BertSelfOutput(x2,x) -> lnorm(drop(f(x2)) + x)
- BertIntermediate (expension:  4*hidden_size)
--- x4_ex = activation(f(x3)) # expansion (4*)
- BertOutput
--- x5 = lnorm(drop(f(x4_ex)) + x3 )


inputs = x_l2, x_l3_up

BertIntegrativeLayer:
- x2 = BertCrossAttention(k,v=x_l2, q=x_l3_up)
- x3 = lnorm(drop(f(x2)) + x_l2)
- x4_ex = activation( f(cat(x3, x_l3_up))  )
- x5 = lnorm(drop(f(x4_ex)) + x3)
"""


class CheapMLPIntegrativeLayer(nn.Module):
    """Cheap (non-transformer) Integrator layer that merges a (low-res) layers with higher-res"""
    def __init__(
            self,
            config,
            hidden_size, # dimensions of the (high-res) hiddens states; same dimension as output
            hidden_size_keyvalues=None, # dimensions of (low-res) states used as key/values; 1/2 sequence-length and dim
            hidden_size_query_to_concat=None, # dimensions of (low-res) to concat to hidden_states; 1/2 sequence-length and dim
            intermediate_size=None
        ):
        super().__init__()
        #self.chunk_size_feed_forward = config.chunk_size_feed_forward
        #self.seq_len_dim = 1
        self.cat = torch.cat
        self.hidden_size = hidden_size
        if hidden_size_keyvalues is None:
            hidden_size_keyvalues = hidden_size
        self.hidden_size_keyvalues = hidden_size_keyvalues
        if hidden_size_query_to_concat is None:
            hidden_size_query_to_concat = hidden_size_keyvalues
        self.hidden_size_query_to_concat = hidden_size_query_to_concat
        self.hidden_size_query = int(hidden_size + hidden_size_query_to_concat)
        if intermediate_size is None:
            intermediate_size = int(2*hidden_size)
        self.intermediate_size = intermediate_size

        # expand hidden-size to a multiple
        self.dense_expander = nn.Linear(
            self.hidden_size_query,
            self.intermediate_size
        ) # deflate back to same size as hidden-state
        self.dense_deflator = nn.Linear(
            self.intermediate_size,
            self.hidden_size
        )

        # intermediate activation function
        self.intermediate_act_fn = nn.RReLU(0.0625, 0.125)

        # corresponds to BertOutput
        self.lnorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self,
        hidden_states: torch.Tensor, # high-res hidden states (same dimensions as output), used as query
        attention_mask = None, # ignored
        head_mask = None, # ignored
        keyvalue_hidden_states =None, # ignored
        query_to_concat_hidden_states: torch.Tensor=None, # to concatenate to hidden_states
        query_attention_mask = None, # ignored
        past_key_value = None, # ignored
        output_attentions = False, # ignored
    ) -> torch.Tensor:

        # concat (lowres) to hidden-states
        inputs = self.cat((hidden_states, query_to_concat_hidden_states),axis=2)
        # expand x2 dimension
        intermediate_states = self.dense_expander(inputs)
        # activation (leaky relue)
        intermediate_states = self.intermediate_act_fn(intermediate_states)
        # like BertOutput
        out_states = self.dense_deflator(intermediate_states)
        # dropout
        out_states = self.dropout(out_states)
        # combine with hidden-state inputs
        out_states = self.lnorm(out_states + hidden_states)

        return out_states



In [None]:

def make_config(
    modelstring = "distilroberta-base",
    num_transformer_stacks = 3,
    scale_ratio2 = 0.5,
    scale_ratio3 = 0.25,
    multiplier_intermediate2 = 4.0,
    multiplier_intermediate3 = 4.0,
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05,
    do_cheap_integrator = [1],
    sequence_classification_intermediate_dim = None, # default is the same as the basemodel hidden-dim
    sequence_classification_out_dim = None, # default is x2 same as the basemodel hidden-dim
    do_mlm =False,
    do_cls = False
):
    #if True:
    #modelstring = "distilroberta-base"
    #scale_ratio2 = 0.5
    #scale_ratio3 = 0.25
    #scale_intermediate2 = 4
    #scale_intermediate3 = 4
    base_config = AutoConfig.from_pretrained(modelstring)
    config_l2 = copy.deepcopy(base_config)
    config_l3 = copy.deepcopy(base_config)
    setattr(base_config, 'model_string', modelstring)
    setattr(base_config,'num_transformer_stacks', num_transformer_stacks)
    setattr(base_config,'num_layers_l2', num_layers_l2)
    setattr(base_config,'num_layers_l3', num_layers_l3)
    setattr(base_config,'scale_ratio2', scale_ratio2)
    setattr(base_config,'scale_ratio3', scale_ratio3)
    setattr(base_config,'scale_factor2', int(1/base_config.scale_ratio2))
    setattr(base_config,'scale_factor3', int(1/base_config.scale_ratio3*base_config.scale_ratio2))
    setattr(base_config,"hidden_size_l2", int(base_config.hidden_size * scale_ratio2))
    setattr(base_config,"hidden_size_l3", int(base_config.hidden_size * scale_ratio3))
    setattr(base_config,"intermediate_size_l1", int(base_config.hidden_size_l2*multiplier_intermediate2))
    setattr(base_config,"intermediate_size_l2", int(base_config.hidden_size_l3*multiplier_intermediate3))
    setattr(base_config,"query_size1", base_config.hidden_size_l2 + base_config.hidden_size_l3)
    setattr(base_config,"query_size2", base_config.hidden_size_l3)
    setattr(base_config,"dropout_scaling", dropout_scaling)
    setattr(base_config,"use_cheap_integrator_for_stacks", do_cheap_integrator)
    setattr(base_config, "do_mlm", do_mlm)
    setattr(base_config, "do_cls", do_cls)

    # hidden dimension
    setattr(
        base_config,
        "sequence_classification_intermediate_dim",
        sequence_classification_intermediate_dim  if sequence_classification_intermediate_dim is not None else [
            int(base_config.hidden_size*s)
            for s in [1, scale_ratio2, scale_ratio3]
        ]
    )
    # final dimension outputed for sequence classification
    setattr(
        base_config,
        "sequence_classification_out_dim",
        sequence_classification_out_dim  if sequence_classification_out_dim is not None else base_config.hidden_size*2
    )


    # make the configuration for the l2 mid-res encoder
    config_l2.hidden_size = base_config.hidden_size_l2
    config_l2.num_hidden_layers = num_layers_l2
    setattr(base_config, 'config_l2', config_l2)

    # make the configuration for the l3 encoder
    config_l3.hidden_size = base_config.hidden_size_l3
    config_l3.num_hidden_layers = num_layers_l3
    setattr(base_config, 'config_l3', config_l3)
    return base_config

def initialize_baselayers(config, basemod = None, tokenizer=None, stack_id=0):
    """Initializes the embeddings and first stack of layers for the Anathem transformers"""
    # initialize the basemodel
    if basemod is None:
        basemod = AutoModel.from_pretrained(config.model_string)
    if tokenizer is None:
        # download pretrained tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config.model_string)

    device = basemod.device
    setattr(config, 'device', device)

    # get basemodel's embeddings
    layer_embedding = copy.deepcopy(basemod._modules['embeddings'])

    # get basemodel's first transformer block
    layer_basetransformer = copy.deepcopy(basemod._modules['encoder']._modules['layer']._modules['0'])

    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize downsampling attention layers
    bert_reducer_l2 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size,
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor2
    )
    # 1/4 hidden size
    bert_reducer_l3 = BertSelfAttnDimensionReduction(
        config=config,
        hidden_size_input=config.hidden_size_l2,
        position_embedding_type=config.position_embedding_type,
        dim_reduction = config.scale_factor3
    )

    # initialize the mid-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    # initialize the low-resolution BertEncoder
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: low res to mid res
    bert_integrater_l2 = BertIntegrativeLayer(
        config,
        hidden_size=config.hidden_size_l2,
        hidden_size_keyvalues = config.hidden_size_l3,
        hidden_size_query_to_concat=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    # from mid-res to high-res
    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    # from mid-res to high-res
    if not do_cheap_integrator:
        bert_integrater_l1 = BertIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_keyvalues = config.hidden_size_l2,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        bert_integrater_l1 = CheapMLPIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.hidden_size*2
        )

    return (
        tokenizer,
        basemod,
        layer_embedding,
        layer_basetransformer,
        maxpool,
        maxpool_attn,
        bert_reducer_l2,
        bert_reducer_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrater_l2,
        bert_integrater_l1
    )

def initialize_midlayers(config, basemod=None, tokenizer=None, stack_id=1):
    """Initializes all the intermediate layers for the Anathem transformers"""
    # initialize the maxpooling downsamplers
    maxpool = nn.Sequential(
        nn.Dropout(config.dropout_scaling),
        nn.MaxPool2d((2,1), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)
    )
    # pooling the attention has no dropout
    maxpool_attn = nn.MaxPool1d((2), stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=True)

    # initialize bert attentive downsampling and skipconnection (1/2 embedding dim)
    bert_reduceintegrator_l2 = BertReduceAddIntegrativeLayer(
        config,
        config.hidden_size_l2, # size of mid-res
        hidden_size_input=config.hidden_size, # size full-resolution
        hidden_size_query=config.hidden_size, # size full-resolution
        intermediate_size=config.intermediate_size_l1, # BertIntermediate dimension (expansion *4 the hiddensize)
        dim_reduction=config.scale_factor2, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # 1/4 the size
    bert_reduceintegrator_l3 = BertReduceAddIntegrativeLayer(
        config,
        config.hidden_size_l3, # size of mid-res
        hidden_size_input=config.hidden_size_l2, # size full-resolution
        hidden_size_query=config.hidden_size_l2, # size full-resolution
        intermediate_size=config.intermediate_size_l2, # BertIntermediate dimension
        dim_reduction=config.scale_factor3, # reduce embedding dimension by factor of 2
        do_concat_hidden_and_query = True
    )

    # initialize the low-resolution BertEncoder
    bert_encoder_midres = BertEncoder(config.config_l2)
    bert_encoder_lowres = BertEncoder(config.config_l3)

    # initailize the upscalers
    upscaler_x2 = InterpolateCombo(scale_factor=config.scale_factor3, dropout=config.dropout_scaling)
    upscaler_x4 = InterpolateCombo(scale_factor=int(1/config.scale_ratio3), dropout=config.dropout_scaling)

    # initialize the BertIntegrative Layers: from low-res to mide-res
    bert_integrater_l2 = BertIntegrativeLayer(
        config,
        hidden_size=config.hidden_size_l2,
        hidden_size_keyvalues = config.hidden_size_l3,
        hidden_size_query_to_concat=config.hidden_size_l3,
        intermediate_size=config.intermediate_size_l2
    )

    do_cheap_integrator = (stack_id in config.use_cheap_integrator_for_stacks)
    if not do_cheap_integrator:
        # from mid-res to high-res
        bert_integrater_l1 = BertIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_keyvalues = config.hidden_size_l2,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.intermediate_size_l1
        )
    else:
        bert_integrater_l1 = CheapMLPIntegrativeLayer(
            config,
            hidden_size=config.hidden_size,
            hidden_size_query_to_concat=config.hidden_size_l2,
            intermediate_size=config.hidden_size*2
        )

    return (
        maxpool,
        maxpool_attn,
        bert_reduceintegrator_l2,
        bert_reduceintegrator_l3,
        bert_encoder_midres,
        bert_encoder_lowres,
        upscaler_x2,
        upscaler_x4,
        bert_integrater_l2,
        bert_integrater_l1
    )


def initialize_finaltransformerlayers(config, basemod=None, tokenizer=None, names_encoder_module = 'encoder', stack_id=3):
    """Initializes the final BertLayer before output, but copying the final BertLayer from `Basemod`"""
    # initialize the maxpooling downsamplers
    assert basemod is not None, "`initialize_finaltransformerlayers` requires the basemod to instantiate the final transformer block"

    # get the Encoder stacks
    assert names_encoder_module in basemod._modules.keys(), 'expected %s in basemod._modules' % names_encoder_module
    basemod_encoder_stack = get_to_bertlayer(basemod, target_layer_name = names_encoder_module)

    # get the name of the final transformer block (-1) in encoder
    names_of_final_transformer_block = list(basemod_encoder_stack._modules['layer']._modules.keys())[-1]

    # get the final transformer block (NN weights pretrained)
    bert_finaltransformer_block = basemod_encoder_stack._modules['layer']._modules[
        names_of_final_transformer_block
    ]

    return copy.deepcopy(bert_finaltransformer_block)

def get_to_bertlayer(basemod, target_layer_name = 'encoder', model_string = None):
    """Clumsily locates a particular layer within a pretrained bert model"""
    if  target_layer_name in basemod._modules.keys():
        return basemod._modules[target_layer_name]
    elif target_layer_name in basemod._modules['bert']._modules.keys():
        return basemod._modules['bert']

In [None]:

class AnathemBaseModule(nn.Module):
    """First Sstack of layers with embeddings, that go full circle form high-res to low-res back to high res"""
    def __init__(
            self,
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device = None,
            stack_id=0
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            tokenizer, basemod,
            layer_embedding,
            layer_basetransformer,
            maxpool,
            maxpool_attn,
            bert_reducer_l2,
            bert_reducer_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrater_l2,
            bert_integrater_l1
        ) = initialize_baselayers(config, basemod, tokenizer, stack_id=0)

        self.get_extended_attention_mask = basemod.get_extended_attention_mask
        self.embedding = layer_embedding
        self.layer_basetransformer = layer_basetransformer
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducer_l2 = bert_reducer_l2
        self.bert_reducer_l3 = bert_reducer_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrater_l2 = bert_integrater_l2
        self.bert_integrater_l1 = bert_integrater_l1
        self.stack_id = 0
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_l2: Optional[torch.Tensor] = None,
        attention_mask_l3: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = input_ids
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        extended_attention_mask_l1 = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        # downsample the attention mask to l2 dimension
        if attention_mask_l2 is None:
            attention_mask_l2 = self.maxpool_attn(attention_mask.float())
        extended_attention_mask_l2 = self.get_extended_attention_mask(attention_mask_l2,attention_mask_l2.shape, self.device)
        # downsample the attention mask to l3 dimension
        if attention_mask_l2 is None:
            attention_mask_l3 = self.maxpool_attn(attention_mask_l2.float())
        extended_attention_mask_l3 = self.get_extended_attention_mask(attention_mask_l3,attention_mask_l3.shape, self.device)

        # embed
        embedding_output = self.embedding(
            input_ids = input_ids,
            position_ids = position_ids,
            token_type_ids = token_type_ids,
            #input_embeds=None,
            past_key_values_length = past_key_values_length
        )

        # first transformer block (vanilla transformer)
        out_l1 = self.layer_basetransformer(
            hidden_states = embedding_output,
            attention_mask = extended_attention_mask_l1,
            head_mask=head_mask,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        hidden_states_l1 = out_l1[0]

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_l1)

        # reduce dimenion on sequence 2
        out_l2 = self.bert_reducer_l2(
            hidden_states = hiddens_states_l1_reduced,
            attention_mask = extended_attention_mask_l2,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l1,
            encoder_attention_mask= extended_attention_mask_l1,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l2 = out_l2[0]

        # Vanilla transformers block at mid-resolution (1/2 seq-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_l2,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]

        # reduce sequence length (1/4 seq-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        out_l3 = self.bert_reducer_l3(
            hidden_states = hiddens_states_l2_reduced,
            attention_mask = extended_attention_mask_l3,
            head_mask=head_mask,
            encoder_hidden_states = hidden_states_l2,
            encoder_attention_mask= extended_attention_mask_l2,
            past_key_value=past_key_values,
            output_attentions=output_attentions,
        )
        hidden_states_l3 = out_l3[0]

        #print(hidden_states_l3.shape)
        #print(extended_attention_mask_l3.shape)
        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_l3,
            head_mask = head_mask,
            return_dict=return_dict
        )
        hidden_states_l3 = out_encoder[0]

        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_l3)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_l2 = self.bert_integrater_l2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_l3,
            head_mask = head_mask,
            keyvalue_hidden_states = hidden_states_l3,
            query_to_concat_hidden_states = hidden_states_upscaled3to2,
            query_attention_mask = attention_mask_l2
        )

        # upscaling: l3/l2 to l1 sequence length
        #hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_l3)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_l2)
        #hidden_states_upscaled = torch.cat((
        #    hidden_states_upscaled2to1, hidden_states_upscaled3to1
        #),axis=2)

        # integrate low-resolution information back to original dimension
        hidden_states_l1 = self.bert_integrater_l1(
            hidden_states = hidden_states_l1,
            attention_mask = extended_attention_mask_l2,
            head_mask = head_mask,
            keyvalue_hidden_states = hidden_states_l2,
            query_to_concat_hidden_states = hidden_states_upscaled2to1,
            query_attention_mask = extended_attention_mask_l2
        )
        if not return_dict:
            return (
                (hidden_states_l1, hidden_states_l2, hidden_states_l3),
                (extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3),
                (attention_mask, attention_mask_l2, attention_mask_l3)
            )
        return {
            "hidden_states": (hidden_states_l1, hidden_states_l2, hidden_states_l3),
            "extended_attention_masks":(extended_attention_mask_l1, extended_attention_mask_l2, extended_attention_mask_l3),
            "attention_masks":(attention_mask, attention_mask_l2, attention_mask_l3)
        }


class AnathemMidModule(nn.Module):
    """Stack of layers that go full circle form high-res to low-res back to high res"""
    def __init__(
            self,
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
            stack_id = 1
        ):
        super().__init__()
        self.config = config

        # initalize the layers
        (
            maxpool,
            maxpool_attn,
            bert_reducerintegrator_l2,
            bert_reducerintegrator_l3,
            bert_encoder_midres,
            bert_encoder_lowres,
            upscaler_x2,
            upscaler_x4,
            bert_integrater_l2,
            bert_integrater_l1
        ) = initialize_midlayers(config, basemod, tokenizer, stack_id)

        self.get_extended_attention_mask = get_extended_attention_mask
        self.maxpool = maxpool
        self.maxpool_attn = maxpool_attn
        self.bert_reducerintegrator_l2 = bert_reducerintegrator_l2
        self.bert_reducerintegrator_l3 = bert_reducerintegrator_l3
        self.bert_encoder_midres = bert_encoder_midres
        self.bert_encoder_lowres = bert_encoder_lowres
        self.upscaler_x2 = upscaler_x2
        self.upscaler_x4 = upscaler_x4
        self.bert_integrater_l2 = bert_integrater_l2
        self.bert_integrater_l1 = bert_integrater_l1
        if device is None:
            self.to(basemod.device)
            #print(self.device)
            self.device = basemod.device
        else:
            self.to(device)
            self.device = device

    def forward(
        self,
        hidden_states_highres: torch.Tensor,
        hidden_states_midres: torch.Tensor,
        hidden_states_lowres: torch.Tensor,
        attention_mask: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_highres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_midres: Optional[List[torch.FloatTensor]] = None,
        extended_attention_mask_lowres: Optional[List[torch.FloatTensor]] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):
        input_shape = hidden_states_highres.shape[:2]
        past_key_values_length =0 if past_key_values is None else len(past_key_values)

        # extend attention mask
        if extended_attention_mask_highres is None:
            extended_attention_mask_highres = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
        if extended_attention_mask_midres is None:
            attention_mask_midres = self.maxpool_attn(attention_mask.float())
            extended_attention_mask_midres = self.get_extended_attention_mask(attention_mask_midres,attention_mask_midres.shape, self.device)
        if extended_attention_mask_lowres is None:
           attention_mask_lowres = self.maxpool_attn(attention_mask_midres.float())
           extended_attention_mask_lowres = self.get_extended_attention_mask(attention_mask_lowres,attention_mask_lowres.shape, self.device)

        # downsample to sequence 1 to length sequence 2
        hiddens_states_l1_reduced = self.maxpool(hidden_states_highres)

        # reduce dimenion on sequence 2
        hidden_states_l2 = self.bert_reducerintegrator_l2(
            inputs = hidden_states_highres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_midres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_midres,
            head_mask=None,
            query_hidden_states = hiddens_states_l1_reduced
        )

        # Vanilla transformers at mid-resolution (1/2 sequence-length)
        out_encoder = self.bert_encoder_midres(
            hidden_states=hidden_states_l2,
            attention_mask=extended_attention_mask_midres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_l2 = out_encoder[0]

        # reduce sequence length (to 1/4 sequence-length)
        hiddens_states_l2_reduced = self.maxpool(hidden_states_l2)

        # reduce dimenion on sequence 2
        hidden_states_l3 = self.bert_reducerintegrator_l3(
            inputs = hidden_states_midres, # from highres outputs previous layer (key, values)
            hidden_states = hidden_states_lowres, # previous hidden-states for skip connection (short squence-dim, low-res)
            attention_mask = extended_attention_mask_lowres,
            head_mask=None,
            query_hidden_states = hiddens_states_l2_reduced
        )

        # BertEncoder at low-res
        out_encoder = self.bert_encoder_lowres(
            hidden_states=hidden_states_l3,
            attention_mask=extended_attention_mask_lowres,
            head_mask = None,
            return_dict=return_dict
        )
        hidden_states_lowres = out_encoder[0]

        # upscaling: l3 to l2
        hidden_states_upscaled3to2 = self.upscaler_x2(hidden_states_lowres)

        # integrate sequence-2 and upscaled sequence-3
        hidden_states_midres = self.bert_integrater_l2(
            hidden_states = hidden_states_l2,
            attention_mask = extended_attention_mask_lowres,
            head_mask = None,
            keyvalue_hidden_states = hidden_states_lowres,
            query_to_concat_hidden_states = hidden_states_upscaled3to2
        )
        #hidden_states_midres = self.bert_integrative_layer_2(
        #    hidden_states = hidden_states_l2,
        #    attention_mask = extended_attention_mask_midres,
        #    head_mask = None,
        #    query_hidden_states = hidden_states_upscaled3to2)

        # upscaling: l3/l2 to l1 sequence length
        #hidden_states_upscaled3to1 = self.upscaler_x4(hidden_states_lowres)
        hidden_states_upscaled2to1 = self.upscaler_x2(hidden_states_midres)
        #hidden_states_upscaled = torch.cat((hidden_states_upscaled2to1, hidden_states_upscaled3to1),axis=2)

        # integrate low-resolution information back to original dimension
        hidden_states_highres = self.bert_integrater_l1(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_mask_midres,
            head_mask = None,
            keyvalue_hidden_states = hidden_states_midres,
            query_to_concat_hidden_states = hidden_states_upscaled2to1
        )

        if not return_dict:
            return (
                (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
                (extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
            )
        return {
            "hidden_states": (hidden_states_highres, hidden_states_midres, hidden_states_lowres),
            "attention":(extended_attention_mask_highres, extended_attention_mask_midres, extended_attention_mask_lowres)
        }


class AnathemEncoder(nn.Module):
    """Anathem cores stacks of layers, from embeddings to final transformer block"""
    def __init__(
            self,
            config,
            basemod=None,
            tokenizer=None,
            past_key_values_length = None,
            device=None,
        ):
        super().__init__()
        self.config = config
        self.device = device

        # initialize embedings and first stack
        self.anathem_base_stack = AnathemBaseModule(
            config,
            basemod,
            tokenizer,
            past_key_values_length,
            device,
        )

        # initialize all subsequence stacks
        self.anathem_mid_stack = nn.ModuleList([
            AnathemMidModule(
                config,
                basemod,
                tokenizer,
                past_key_values_length,
                device,
                stack_id = i
            ) for i in range(1, self.config.num_transformer_stacks)
        ])

        # initialize the final transformer modules
        self.final_transformer_block = initialize_finaltransformerlayers(
            config,
            basemod,
            tokenizer,
            stack_id=self.config.num_transformer_stacks+1
        )

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_l2: Optional[torch.Tensor] = None,
        attention_mask_l3: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = False
    ):

        # embed and run through first stack of transformers
        hidden_states, extended_attention_masks, attention_masks = self.anathem_base_stack(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attention_mask_l2=attention_mask_l2,
            attention_mask_l3=attention_mask_l3,
            token_type_ids=token_type_ids, #: Optional[torch.Tensor] = None,
            position_ids=position_ids,#: Optional[torch.Tensor] = None,
            head_mask=head_mask,#: Optional[torch.Tensor] = None,
            inputs_embeds=None,#: Optional[torch.Tensor] = None,
            encoder_hidden_states=None,#: Optional[torch.Tensor] = None,
            encoder_attention_mask=None,#: Optional[torch.Tensor] = None,
            past_key_values=past_key_values,#: Optional[List[torch.FloatTensor]] = None,
            use_cache=use_cache,#: Optional[bool] = None,
            output_attentions=output_attentions,#: Optional[bool] = None,
            output_hidden_states=output_hidden_states,#: Optional[bool] = None,
            return_dict=return_dict
        )

        # middle stack of transformers
        for i, anathem_stack in enumerate(self.anathem_mid_stack):

            # run through each stack (1-2)
            hidden_states, extended_attention_masks = anathem_stack(
                hidden_states_highres = hidden_states[0],
                hidden_states_midres = hidden_states[1],
                hidden_states_lowres = hidden_states[2],
                extended_attention_mask_highres = extended_attention_masks[0],
                extended_attention_mask_midres = extended_attention_masks[1],
                extended_attention_mask_lowres = extended_attention_masks[2]
            )

        # hidden states (high,med,low resolution)
        hidden_states_highres, hidden_states_midres, hidden_states_lowres = hidden_states

        # run through final transformer block (pretrained)
        out_final = self.final_transformer_block(
            hidden_states = hidden_states_highres,
            attention_mask = extended_attention_masks[0],
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            output_attentions=output_attentions
        )
        #print(type(out_final))
        #print(len(out_final))
        hidden_states_highres = out_final[0]
        if not output_attentions:
            return (hidden_states_highres, hidden_states_midres, hidden_states_lowres), attention_masks

        attention_final = out_final[1]
        return (hidden_states_highres, hidden_states_midres, hidden_states_lowres), attention_masks, attention_final


class BertGenericClassificationHead(nn.Module):
    """Instantiates a basic classification head that takes the CLS token and mean of the final layer for classification"""
    def __init__(self, config, n_classes = 1, activation = 'sigmoid', device=None):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = torch.sigmoid
        elif activation == 'none':
            self.activation = lambda x: x
        if device is not None:
            self.to(device)

    def forward(self, hidden_states, attention_mask) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        output_vectors=[]
        first_token_tensor = hidden_states[:, 0]
        output_vectors.append(first_token_tensor)
        # mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        output_vectors.append(sum_embeddings / sum_mask)
        # concatenate
        pooled_output = torch.concat(output_vectors, axis=1)
        #print(pooled_output.shape)
        logits = self.dense(pooled_output)
        return self.activation(logits)


class AnathemMultiSiloPooler(nn.Module):
    """
    Pools the token-embeddings along the sequence dimenions for a final sentence-vector.
    The pooling occuras across all three 'silos'
    The pooling consists of the CLS token as well as mean pooling, concatenated token
    Use the pooling outputs prior to any sequenceClassification
    """
    def __init__(
        self,
        config,
        dim_out = None,
        mean_activation = nn.Tanhshrink,
        out_activation = None,
        dims_in = None,
        p_dropout=None,
        device=None
    ):
        super().__init__()

        # dimensions of the hiddens states being processed as inputs
        if dims_in is None:
            try:
                dims_in = config.sequence_classification_intermediate_dim
            except:
                dims_in = [dim_out, dim_out//2, dim_out//4]
        self.dims_in = dims_in
        self.dim_in = sum(dims_in)
        self.hidden_size = config.hidden_size
        if dim_out is None:
            try:
                dim_out = config.sequence_classification_out_dim
            except:
                dim_out = config.hidden_size*2
        self.dim_out = dim_out
        self.mean_activation = mean_activation

        #self.dense = nn.Linear(config.hidden_size*2, n_classes)
        if out_activation == 'none' or out_activation is None:
            self.activation = lambda x: x
        elif out_activation == 'tanh':
            self.activation = nn.Tanh()
        elif out_activation == 'relu':
            self.activation = nn.ReLU()
        elif out_activation == 'sigmoid':
            self.activation = torch.sigmoid

        if device is not None:
            self.to(device)

        # linear layer operating on the concatenated CLS tokens from all silos
        self.cls_pooler = nn.Sequential(
            nn.Dropout(p_dropout),
            nn.Linear(self.dim_in, int(self.hidden_size)),
        )

        # pre-mean-pooling (one for each silo)
        #self.pre_poolers = [nn.Sequential(
        #    nn.Dropout(p_dropout),
        #    nn.Linear(dim,dim)
        #    ) for dim in self.dims_in
        # ]
        self.pre_poolers = nn.Sequential(
            nn.Dropout(p_dropout),
            self.mean_activation
        )

        # sequential layer to concatenate the mean tokens from multiple tokens
        self.mean_pooler = nn.Linear(self.dim_in, self.hidden_size)

    def forward(self, hidden_states, attention_masks, excess_cls_ids=None) -> torch.Tensor:
        """Combines CLS token and mean-pooling for the sentence-vectorization"""

        # CLS/first-tokens from all silos, all concatenated together
        first_token_tensors = self._get_cls_tokens_all_silos(hidden_states)

        # mean pooling
        mean_pooled_tensors = self._mean_pool_all_silos(hidden_states, attention_masks, excess_cls_ids)

        # concatenate CLS and mean
        pooled_output = torch.concat((first_token_tensors, mean_pooled_tensors), axis=1)

        return self.activation(pooled_output)

    def _get_cls_token(self, hidden_state):
        """Grabs the CLS token from a hidden-states"""
        return hidden_state[:, 0]

    def _get_cls_tokens_all_silos(self, hidden_states):
        """Grabs the CLS token from all hidden_states"""
        first_tokens = [
            self._get_cls_token(hidden_state) for hidden_state in hidden_states
        ]
        # concat all first tokens
        all_first_tokens_cat = torch.cat(first_tokens,axis=1)
        # run the concatenated first-tokens through Dense
        all_first_tokens_out = self.cls_pooler(all_first_tokens_cat)
        return all_first_tokens_out

    def _mean_pool(self, hidden_state, attention_mask=None, excess_cls_id=None):
        """Pool along a sequence dimension (for just one silo)"""
        if excess_cls_id is None:
            excess_cls_id = attention_mask
        input_mask_expanded = excess_cls_id.unsqueeze(-1).expand(hidden_state.size()).float()
        sum_embeddings = torch.sum(hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        return sum_embeddings / sum_mask

    def _mean_pool_all_silos(self, hidden_states, attention_masks=None, excess_cls_ids=None):
        """Pool along a sequence dimension (for all silos)"""
        if excess_cls_ids is None:
            excess_cls_ids = attention_masks

        # pre-pool: dense-layer before pooling
        hidden_states = [
            self.pre_poolers(hidden_state) for hidden_state in hidden_states
        ]

        # mean pool each silo
        mean_pooled_states = [
            self._mean_pool(
                hidden_state=hidden_state, excess_cls_id=excess_cls_id
            ) for hidden_state, excess_cls_id
            in zip(hidden_states, excess_cls_ids)
        ]

        # concat all mean-pooled states
        all_mean_pooled_states = torch.cat(mean_pooled_states,axis=1)
        # run the concatenated meanpooled states through Dense
        all_mean_pooled_states = self.mean_pooler(all_mean_pooled_states)
        return all_mean_pooled_states


In [None]:
class AnathemTransformer(nn.Module):
    def __init__(
        self,
        config=None,
        device=None,
        do_mlm = None,
        do_cls = None
    ):
        super().__init__()

        # default config
        if config is None:
            config = make_config()
        self.config = config
        self.do_mlm = config.do_mlm if do_mlm is None else do_mlm
        self.do_cls = config.do_cls if do_cls is None else do_cls

        # device
        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        self.device= device

        # get the basemodel (and its masked LM head
        self.model_string = self.config.model_string
        basemodelLM_pretrained = AutoModelForMaskedLM.from_pretrained(self.model_string)

        # get the Pretrained BertEncoder
        basemod_pretrained = get_to_bertlayer(
            basemodelLM_pretrained,
            target_layer_name = 'encoder'
        )

        # make the tokenizer (based on pretrained)
        self.tokenizer = CustomTokenizer(
            model_string=self.config.model_string,
            n_cls_prepend = int(1/config.scale_ratio3),
            n_pad_to_multiple_of= int(1/config.scale_ratio3)
        )

        # make the Embedding and first layers (pretrained)
        self.encoder = AnathemEncoder(
            self.config,
            basemod=basemod_pretrained,
            tokenizer=self.tokenizer ,
            past_key_values_length = None,
            device=self.device,
        )

        # get the Pretrained maskedLM head
        if self.do_mlm:
            # perform maskedLM
            self.mlm = get_to_bertlayer(
                basemodelLM_pretrained,
                target_layer_name = 'cls'
            )
        else:
            self.mlm = lambda x : x

        # make the sequence-classification head
        if self.do_cls:
            self.pooler = AnathemMultiSiloPooler(
                config=self.config,
                mean_activation = nn.Tanhshrink(),
                dims_in = self.config.sequence_classification_intermediate_dim,
                p_dropout=self.config.hidden_dropout_prob,
                device=self.device
            )

    def _get_name(self):
        return 'ANATHEM_MODEL_FOR_MLM'

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_l2: Optional[torch.Tensor] = None,
        attention_mask_l3: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        excess_cls_ids: Optional[torch.Tensor] = None,
        excess_cls_ids_l2: Optional[torch.Tensor] = None,
        excess_cls_ids_l3: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = False
    ):

        # run through base-layer (embeddings, transformer-block, 1 anathem stack)
        outputs_encoder = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attention_mask_l2=attention_mask_l2, # optional downsized attention mask for sequence-dim 1/2
            attention_mask_l3=attention_mask_l3, # optional downsized attention mask for sequence-dim 1/4
            token_type_ids=token_type_ids, #: Optional[torch.Tensor] = None,
            position_ids=position_ids,#: Optional[torch.Tensor] = None,
            head_mask=head_mask,#: Optional[torch.Tensor] = None,
            inputs_embeds=None,#: Optional[torch.Tensor] = None,
            encoder_hidden_states=None,#: Optional[torch.Tensor] = None,
            encoder_attention_mask=None,#: Optional[torch.Tensor] = None,
            past_key_values=past_key_values,#: Optional[List[torch.FloatTensor]] = None,
            use_cache=use_cache,#: Optional[bool] = None,
            output_attentions=output_attentions,#: Optional[bool] = None,
            output_hidden_states=output_hidden_states,#: Optional[bool] = None,
            return_dict=False
        )
        if output_attentions:
            hidden_states, extended_attention_masks, attention = outputs_encoder
        else:
            hidden_states, extended_attention_masks = outputs_encoder
            attention = None

        out_mlm = {'logits':None}
        out_pooled_vector = None
        hidden_states_highres, hidden_states_midres, hiddenstates_lowres = hidden_states

        # MLM outputs
        if self.do_mlm:
            out_mlm = self.mlm(hidden_states_highres)

        # sequence pooling (for classification)
        if self.do_cls:
            out_pooled_vector = self.pooler(
                hidden_states=hidden_states,
                attention_masks=(attention_mask, attention_mask_l2, attention_mask_l3),
                excess_cls_ids=(excess_cls_ids, excess_cls_ids_l2, excess_cls_ids_l3)
            )
        #
        if return_dict:
            return {
                'hidden_states':(hidden_states_highres, hidden_states_midres, hiddenstates_lowres),
                'pooled':out_pooled_vector,
                'logits':out_mlm['logits'],
                'attention':attention,
                'extended_attention_masks':extended_attention_masks
            }
        return hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks


In [None]:
modelstring_teacher_mlm = 'bert-base-uncased'
model_string = "google/bert_uncased_L-4_H-512_A-8"

config = make_config(
    modelstring = model_string,
    num_transformer_stacks = 3,
    scale_ratio2 = 0.5,
    scale_ratio3 = 0.25,
    multiplier_intermediate2 = 4.0,
    multiplier_intermediate3 = 4.0,
    num_layers_l2 = 1, # mid-res encoder
    num_layers_l3 = 3, # low-res encoder
    dropout_scaling = 0.05,
    do_cheap_integrator = [1],
    do_mlm=True,
    do_cls=True
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

In [None]:

anamod = AnathemTransformer(
        config=config,
        device=None,
        do_mlm = True,
        do_cls = True
    )

teacher_mlm = AutoModelForMaskedLM.from_pretrained(modelstring_teacher_mlm)


from torch import Tensor
class TeacherEmbedder:

    def __init__(self, pretrained_name = 'intfloat/e5-large-v2'):
        self.pretrained_name = pretrained_name
        self.teacher_tokenizer = AutoTokenizer.from_pretrained(pretrained_name)
        self.teacher_embedder = AutoModel.from_pretrained(pretrained_name)

    @staticmethod
    def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def forward(self, input_text, prepend = 'passage: '):
        input_text = [prepend + s for s in input_text]
        with torch.no_grad():
            batch_dict = self.teacher_tokenizer(input_text, max_length=512, padding=True, truncation=True, return_tensors='pt')
            outputs = self.teacher_embedder(**batch_dict)
            embeddings = self.average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        return embeddings

    def __call__(self, input_text, prepend = 'passage: '):
        return self.forward(input_text)


teacher_emb = TeacherEmbedder()

Downloading pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

In [None]:

print(anamod.mlm) # MLM head

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=512, out_features=30522, bias=True)
  )
)


In [None]:
text = [
    "A standard [MASK] clause is a waiver clause that states that one party won't hold the other liable for damages, losses, or costs associated with issues.",
    "It usually consists of two elements: a trigger event or circumstance and a [MASK] obligation. The trigger event or circumstance is the [MASK] of the agreement, misconduct, or negligence of the indemnifying party or its affiliates"
]

inputs = anamod.tokenizer(text, add_special_tokens=True, return_tensors='pt', padding='longest')

print(inputs.keys())
inputs

outputs = anamod.forward(
    input_ids = inputs['input_ids'],
    attention_mask = inputs['attention_mask'],
    attention_mask_l2 = inputs['attention_mask_l2'],
    attention_mask_l3 = inputs['attention_mask_l3'],
    excess_cls_ids = inputs['excess_cls_ids'],
    excess_cls_ids_l2 = inputs['excess_cls_ids_l2'],
    excess_cls_ids_l3 = inputs['excess_cls_ids_l3']
)
# hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks

outputs_teacher_mlm = teacher_mlm(input_ids = inputs['input_ids'], attention_mask=inputs['attention_mask'])


print(outputs[0][0].shape) # full hidden state sequence
print(outputs[0][1].shape) # mid hidden state sequence
print(outputs[0][2].shape) # small hidden state sequence
print(outputs[1].shape) # sentencevector
print(outputs[2].shape) # mlm outputs

#
print(outputs_teacher_mlm['logits'].shape) # Teacher shape mlm

predicted_token_ids1 = outputs_teacher_mlm[0][0].argmax(dim=-1)
predicted_token_ids2 = outputs[2][0].argmax(dim=-1)

print('Bert Base')
print(anamod.tokenizer.convert_ids_to_tokens(outputs_teacher_mlm[0][0].argmax(dim=-1)))
print('Anamod')
print(anamod.tokenizer.convert_ids_to_tokens(outputs[2][0].argmax(dim=-1)))


print('Bert Base')
print(anamod.tokenizer.convert_ids_to_tokens(outputs_teacher_mlm[0][1].argmax(dim=-1)))
print('Anamod')
print(anamod.tokenizer.convert_ids_to_tokens(outputs[2][1].argmax(dim=-1)))

# try to embed text with the teacher_emb
text2 = input_texts = [
    'query: how much protein should a female eat',
    'query: summit define',
    "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "passage: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."
]
sentence_embeddings = teacher_emb(text2)
print(sentence_embeddings.shape)

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'excess_cls_ids', 'attention_mask_l2', 'attention_mask_l3', 'excess_cls_ids_l2', 'excess_cls_ids_l3'])




torch.Size([2, 48, 512])
torch.Size([2, 24, 256])
torch.Size([2, 12, 128])
torch.Size([2, 1024])
torch.Size([2, 48, 30522])
torch.Size([2, 48, 30522])
Bert Base
['.', '.', '.', '.', 'a', 'standard', 'liability', 'clause', 'is', 'a', 'wai', '##ver', 'clause', 'that', 'states', 'that', 'one', 'party', 'won', "'", 't', 'hold', 'the', 'other', 'liable', 'for', 'damages', ',', 'losses', ',', 'or', 'costs', 'associated', 'with', 'issues', '.', 's', '.', '.', 'it', '.', 'the', 'it', 'it', 'it', 'parties', 'one', 'party']
Anamod
['-', 'the', '-', '-', 'a', '-', '-', '-', '.', 'a', '-', '-', '.', '.', 'is', '.', 'the', '.', '-', "'", 's', '.', 'the', 'other', ',', 'for', 'me', ',', 'my', ',', 'or', 'the', '-', 'with', 'the', '.', 'the', 'he', 'he', 'he', '-', '-', ',', ',', ',', 'the', '-', ',']
Bert Base
['.', '.', '.', '.', 'it', 'usually', 'consists', 'of', 'two', 'elements', ':', 'a', 'trigger', 'event', 'or', 'circumstance', 'and', 'a', 'trigger', 'obligation', '.', 'the', 'trigger', 'even

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading and preparing dataset glue/mrpc to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

{'sentence1': Value(dtype='string', id=None), 'sentence2': Value(dtype='string', id=None), 'label': ClassLabel(names=['not_equivalent', 'equivalent'], id=None), 'idx': Value(dtype='int32', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'excess_cls_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
{'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'excess_cls_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}


In [None]:

## Test a batched inference routine: including loss calculations
## steps:
## 1) tokenize inputs internal to a torch dataset (encode_plus?)
## 2) loop through dataloader, with a MLM collator also set?
## 3) do inference using teacher
## 5) do inference using anathem
## 6) loss
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from torch.optim import AdamW

# load dummy dataset
dataset_glue = load_dataset('glue', 'mrpc', split='test') # small set

# tokens = [tokenizer.encode_plus(txt, add_special_tokens=True) for txt in text]
# tokenize
dataset_glue = dataset_glue.map(lambda e: tokenizer.encode_plus(e['sentence1'], add_special_tokens=True))
print(dataset_glue.features)
dataset_glue = dataset_glue.remove_columns(column_names = ['sentence1','sentence2','idx','label'])
print(dataset_glue.features)
_ = """
{'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None),
 'label': ClassLabel(names=['not_equivalent', 'equivalent'], id=None),
 'idx': Value(dtype='int32', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'excess_cls_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
 """

# MLM collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

# MLM distillation loss function (kl-divergence between teacher and student outputs)
loss_fn_mlm_distil = nn.KLDivLoss(reduction="batchmean")
loss_fn_mlm_labels = nn.CrossEntropyLoss(ignore_index=-100) # non-masked tokens have -100
weights_mlm_distil = 0.5
weights_mlm_labels = (1-weights_mlm_distil)

# dataloader with MLM collator
dl_mlm = DataLoader(dataset_glue, collate_fn=data_collator, batch_size=4)

# optimizer
optimizer = AdamW(anamod.parameters(), lr = 0.00001)
# (model.parameters(), lr=learning_rate)

# MLM objective
teacher_mlm.eval()
distillation_temperature = 1.0

for step_i, batch in enumerate(dl_mlm):

    # do inference using anathem model
    # hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks
    outputs = anamod.forward(
        input_ids = batch['input_ids'],
        attention_mask = batch['attention_mask'],
        attention_mask_l2 = batch['attention_mask_l2'],
        attention_mask_l3 = batch['attention_mask_l3'],
        excess_cls_ids = batch['excess_cls_ids'],
        excess_cls_ids_l2 = batch['excess_cls_ids_l2'],
        excess_cls_ids_l3 = batch ['excess_cls_ids_l3']
    )

    # hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks
    with torch.no_grad():
        outputs_teacher_mlm = teacher_mlm(
            input_ids = batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

    # FOOFU
    assert outputs[2].size() == outputs_teacher_mlm.logits.size()
    # Soften probabilities and compute distillation loss
    loss_mlm_distil = loss_fn_mlm_distil(
            F.log_softmax(outputs[2] / distillation_temperature, dim=-1),
            F.softmax(outputs_teacher_mlm.logits / distillation_temperature, dim=-1)
        ) * (distillation_temperature ** 2) * weights_mlm_distil
    # label loss
    loss_mlm_labels = loss_fn_mlm_labels(
        outputs[2].view(-1, anamod.config.vocab_size),
        batch['labels'].view(-1)
    ) * weights_mlm_labels
    # Return weighted student loss
    #loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
    #return (loss, outputs_student) if return_outputs else loss
    optimizer.zero_grad()
    # Backward pass: compute gradient of the loss with respect to model
    (loss_mlm_distil+loss_mlm_labels).backward()
    #
    optimizer.step()

    if ((step_i+1) % 20) ==0:
        raise NotImplementedError('hit %d' % step_i)



{'sentence1': Value(dtype='string', id=None), 'sentence2': Value(dtype='string', id=None), 'label': ClassLabel(names=['not_equivalent', 'equivalent'], id=None), 'idx': Value(dtype='int32', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'excess_cls_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}
{'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'excess_cls_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}


NotImplementedError: ignored

## MultiTask Training: adapted from s-bert

In [None]:
### Normal label-based losses (MLI
# -- https://huggingface.co/datasets/multi_nli
dataset_nli3 = load_dataset('multi_nli', split='train') # 383k examples

# I think I should keep the text untokenize for the multi-task, maybe use the default collator from sbert
dataset_nli3 = dataset_nli3.remove_columns(
    column_names = ['promptID', 'pairID', 'premise_binary_parse', 'premise_parse','hypothesis_binary_parse', 'hypothesis_parse', 'genre']
)

dl_mli3 = DataLoader(dataset_nli3, batch_size=4, shuffle=True)


# make a classification head

Downloading builder script:   0%|          | 0.00/5.14k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.67k [00:00<?, ?B/s]

Downloading and preparing dataset multi_nli/default to /root/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39...


Downloading data:   0%|          | 0.00/227M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Dataset multi_nli downloaded and prepared to /root/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39. Subsequent calls will reuse this data.


In [None]:

class ClassifierMNLI3(nn.Module):
    """Bert Attention Layer that uses a dimension-reduced version of the query, so to reduce the dimension of the outputs"""
    def __init__(
        self,
        hidden_size = 512,
        do_subtract = True,
        dropout = 0.1,
        n_labels = 3
    ):
        """Special type of Bert Self attention that reduces the dimension of the inputs by half"""
        super().__init__()

        self.hidden_size = hidden_size
        self.do_subtract = do_subtract
        self.dropout_p = dropout
        self.n_labels = n_labels
        self.size_of_concatenated_inputs = self.hidden_size*2*2 + self.do_subtract*self.hidden_size*2

        # final output
        self.layer = nn.Sequential(
            nn.Dropout(self.dropout_p),
            nn.Linear(self.size_of_concatenated_inputs, self.n_labels)
        )
    def forward(self, input1, input2):
        features_concat = torch.concat((
            input1,
            input2,
            torch.sub(input1,input2)
        ),axis=1)
        return self.layer(features_concat)


# Make classifier for MNLI labelled data
classifier_mnli3 = ClassifierMNLI3(
    hidden_size = anamod.config.hidden_size,
    n_labels=3
)
classifier_mnli3.train()
anamod.train()
optimizer = torch.optim.AdamW(
    list(anamod.encoder.parameters()) +  list(anamod.pooler.parameters()) + list(classifier_mnli3.parameters()),
    lr=0.0001
)

# make loss function (3 labels)
loss_fn_nmli3 = nn.CrossEntropyLoss()
weights_mnli_distil = 0.5
weights_mnli_labels = (1-weights_mnli_distil)

loss_fn_mnli3_distil = nn.MSELoss()


In [None]:
for i, batch_mnli in enumerate(dl_mli3):
    optimizer.zero_grad()
    # get tokens
    tokens_mnli_1 = anamod.tokenizer(batch_mnli['premise'],pad_to_multiple_of=4, add_special_tokens = True, return_tensors='pt', padding='longest')
    tokens_mnli_2 = anamod.tokenizer(batch_mnli['hypothesis'],pad_to_multiple_of=4, add_special_tokens = True, return_tensors='pt', padding='longest')

    # student embeddings
    out_student_mnli1 = anamod.forward(
            input_ids = tokens_mnli_1['input_ids'],
            attention_mask = tokens_mnli_1['attention_mask'],
            attention_mask_l2 = tokens_mnli_1['attention_mask_l2'],
            attention_mask_l3 = tokens_mnli_1['attention_mask_l3'],
            excess_cls_ids = tokens_mnli_1['excess_cls_ids'],
            excess_cls_ids_l2 = tokens_mnli_1['excess_cls_ids_l2'],
            excess_cls_ids_l3 = tokens_mnli_1 ['excess_cls_ids_l3']
    )
    out_student_mnli2 = anamod.forward(
            input_ids = tokens_mnli_2['input_ids'],
            attention_mask = tokens_mnli_2['attention_mask'],
            attention_mask_l2 = tokens_mnli_2['attention_mask_l2'],
            attention_mask_l3 = tokens_mnli_2['attention_mask_l3'],
            excess_cls_ids = tokens_mnli_2['excess_cls_ids'],
            excess_cls_ids_l2 = tokens_mnli_2['excess_cls_ids_l2'],
            excess_cls_ids_l3 = tokens_mnli_2 ['excess_cls_ids_l3']
    )

    # raw sentence-vectors from student
    feature_student_mnli1, feature_student_mnli2 = out_student_mnli1[1], out_student_mnli2[1]
    # mnli predictions n labels
    pred_mnli3 = classifier_mnli3(feature_student_mnli1, feature_student_mnli2)
    # mnli binary loss
    loss_cls_nmli3 = loss_fn_nmli3(pred_mnli3, batch_mnli['label']) * weights_nmli_labels
    #loss_cls_nmli3.backward()

    # NEXT do distillation loss with teacher
    feature_teacher_nmli1 = teacher_emb(input_text=batch_mnli['premise'], prepend = 'passage: ')
    feature_teacher_nmli2 = teacher_emb(input_text=batch_mnli['hypothesis'], prepend = 'passage: ')
    # MNLI distillation loss
    loss_mnli_distil = (
        loss_fn_mnli3_distil(feature_student_mnli1, feature_teacher_nmli1) + loss_fn_mnli3_distil(feature_student_mnli2, feature_teacher_nmli2)
    )*weights_nmli_distil
    # backprop
    (loss_mnli_distil + loss_cls_nmli3).backward()

    # update weights
    optimizer.step()

    if (i+1)%3 ==0:
        print(loss_cls_nmli3.detach().item())





0.6361832022666931
0.5656223297119141
0.3880550265312195


KeyboardInterrupt: ignored

In [None]:
# Combine the teacher training with classification
optimizer = AdamW(list(anamod.parameters()) + list(classifier_mnli3.parameters()), lr = 0.00001)
# (model.parameters(), lr=learning_rate)

# MLM objective
teacher_mlm.eval()
distillation_temperature = 1.0
for i,(batch_mlm, batch_mnli) in enumerate(zip(dl_mlm, dl_mli3)):
    optimizer.zero_grad()
    # do inference using anathem model
    # hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks
    outputs = anamod.forward(
        input_ids = batch['input_ids'],
        attention_mask = batch['attention_mask'],
        attention_mask_l2 = batch['attention_mask_l2'],
        attention_mask_l3 = batch['attention_mask_l3'],
        excess_cls_ids = batch['excess_cls_ids'],
        excess_cls_ids_l2 = batch['excess_cls_ids_l2'],
        excess_cls_ids_l3 = batch ['excess_cls_ids_l3']
    )

    # hidden_states, out_pooled_vector, out_mlm, attention, extended_attention_masks
    with torch.no_grad():

        # mlm teacher outputs
        outputs_teacher_mlm = teacher_mlm(
            input_ids = batch['input_ids'],
            attention_mask=batch['attention_mask']
        )
        # to do this, I'd need to have the original text, and NOT pre-tokenized text
        #teacher_emb(input_text=batch['premise'], prepend = 'passage: ')

    # FOOFU
    assert outputs[2].size() == outputs_teacher_mlm.logits.size()
    # Soften probabilities and compute distillation loss
    #loss_function = nn.KLDivLoss(reduction="batchmean")
    loss_mlm_distil = loss_fn_mlm_distil(
            F.log_softmax(outputs[2] / distillation_temperature, dim=-1),
            F.softmax(outputs_teacher_mlm.logits / distillation_temperature, dim=-1)
        ) * (distillation_temperature ** 2) * weights_mlm_distil
    #loss_mlm_distil.backward()
    loss_mlm_labels = loss_fn_mlm_labels(
        outputs[2].view(-1, anamod.config.vocab_size),
        batch['labels'].view(-1)
    ) * weights_mlm_labels

    # loss on paragraph embedding

    # BACKPROP MLM label loss and distilloss
    (loss_mlm_distil+loss_mlm_labels).backward()
    # Return weighted student loss
    #loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
    #return (loss, outputs_student) if return_outputs else loss

    # NLI task: get tokens
    tokens_mnli_1 = anamod.tokenizer(batch_mnli['premise'],pad_to_multiple_of=4, add_special_tokens = True, return_tensors='pt', padding='longest')
    tokens_mnli_2 = anamod.tokenizer(batch_mnli['hypothesis'],pad_to_multiple_of=4, add_special_tokens = True, return_tensors='pt', padding='longest')

    # student embeddings
    out_student_mnli1 = anamod.forward(
            input_ids = tokens_mnli_1['input_ids'],
            attention_mask = tokens_mnli_1['attention_mask'],
            attention_mask_l2 = tokens_mnli_1['attention_mask_l2'],
            attention_mask_l3 = tokens_mnli_1['attention_mask_l3'],
            excess_cls_ids = tokens_mnli_1['excess_cls_ids'],
            excess_cls_ids_l2 = tokens_mnli_1['excess_cls_ids_l2'],
            excess_cls_ids_l3 = tokens_mnli_1 ['excess_cls_ids_l3']
    )
    out_student_mnli2 = anamod.forward(
            input_ids = tokens_mnli_2['input_ids'],
            attention_mask = tokens_mnli_2['attention_mask'],
            attention_mask_l2 = tokens_mnli_2['attention_mask_l2'],
            attention_mask_l3 = tokens_mnli_2['attention_mask_l3'],
            excess_cls_ids = tokens_mnli_2['excess_cls_ids'],
            excess_cls_ids_l2 = tokens_mnli_2['excess_cls_ids_l2'],
            excess_cls_ids_l3 = tokens_mnli_2 ['excess_cls_ids_l3']
    )
    # raw sentence-vectors from student
    feature_student_mnli1, feature_student_mnli2 = out_student_mnli1[1], out_student_mnli2[1]
    # labels
    pred_mnli3 = classifier_mnli3(feature_student_mnli1, feature_student_mnli2)
    # binary loss
    loss_cls_nmli3 = loss_fn_nmli3(pred_mnli3, batch_mnli['label'])
    #loss_cls_nmli3.backward()
    feature_teacher_nmli1 = teacher_emb(input_text=batch_mnli['premise'], prepend = 'passage: ')
    feature_teacher_nmli2 = teacher_emb(input_text=batch_mnli['hypothesis'], prepend = 'passage: ')
    # MNLI distillation loss
    loss_mnli_distil = (
        loss_fn_mnli3_distil(feature_student_mnli1, feature_teacher_nmli1) + loss_fn_mnli3_distil(feature_student_mnli2, feature_teacher_nmli2)
    )*weights_nmli_distil
    # backprop
    (loss_mnli_distil + loss_cls_nmli3).backward()
    # Backward pass: compute gradient of the loss with respect to model
    optimizer.step()

    if (i+1)%4 ==0:
        print(loss_cls_nmli3.detach().item())



1.3287630081176758
1.1084638833999634
1.1774473190307617
1.0645709037780762
1.091556429862976
1.1649658679962158
1.319928765296936
1.1654601097106934
0.9826673865318298
1.1563453674316406
1.0446501970291138
1.1165382862091064
1.1049705743789673
0.9217707514762878
1.14559006690979
1.1429061889648438
0.9149771928787231
1.207316279411316
1.1845396757125854
1.2629420757293701
0.9769338369369507
1.0895546674728394
1.0898280143737793
1.1648684740066528
0.9611557126045227
1.044935703277588
1.144046425819397
1.099448561668396
1.0884103775024414
1.142393946647644
1.0853071212768555
1.1239224672317505
1.0658488273620605
1.1993112564086914
0.9642707109451294
1.182077407836914
1.3221166133880615
1.1279082298278809
1.0723700523376465
1.1399314403533936
1.0013256072998047
1.1049387454986572
1.0147031545639038
1.2314361333847046
1.0651648044586182
1.1327135562896729
0.9887092709541321
1.0250582695007324
1.1199613809585571
1.094027042388916
1.091330885887146
1.098750114440918
1.1193275451660156
1.1657

In [None]:
class TrainerMultiTask:
    """Adapted from the uklab/sentence-transformers .fit() function"""
    def __init__(
            self,
            do_reload = True,
            epochs_total_lifetime = 5,
            scheduler: str = 'WarmupLinear',
            warmup_steps: int = 10000,
            optimizer_class: Type[Optimizer] = torch.optim.AdamW,
            optimizer_params : Dict[str, object]= {'lr': 2e-5},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            output_path: str = None,
            save_best_model: bool = True,
            max_grad_norm: float = 2.0,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            show_progress_bar: bool = False,
            checkpoint_path: str = 'checkpoint.pt',
            checkpoint_path_optimizer: str = 'checkpoint_optimizer.pt',
            checkpoint_path_scheduler: str = 'checkpoint_scheduler.pt',
            checkpoint_path_trainer_state: str = 'checkpoint_trainer_state.json',
            checkpoint_save_steps: int = 500,
            checkpoint_save_total_limit: int = 0,
            do_minimize_global_objective: Int = 1
        ):
            self.epochs_global = -1 # track the total number of epochs
            self.epochs_total_lifetime = epochs_total_lifetime # total number of epochs over lifetime
            self.global_step = 0 # track the toatl number of steps
            self.do_minimize = do_minimize_global_objective
            self.best_score = 9999999 if self.do_minimize else -9999999
            self.output_path = output_path
            self.checkpoint_path = checkpoint_path
            self.checkpoint_path_optimizer = checkpoint_path_optimizer
            self.checkpoint_path_scheduler = checkpoint_path_scheduler
            self.checkpoint_path_trainer_state = checkpoint_path_trainer_state
            self.scheduler_state_dict = None
            self.optimizer_state_dict = None
            self.trainer_state = None
            self.loss_models_states = None
            if do_reload:
                print('attempting to reload cached model, optimizer, scheduler, and saved trainer sate')
                model_state, loss_models_states = self.load_saved_model(self.checkpoint_path)
                self.model_state = model_state
                self.loss_models_states = loss_models_states
                self.scheduler_state_dicts = self.load_saved_scheduler(self.checkpoint_path_scheduler)
                self.optimizer_state_dicts = self.load_saved_optimizer(self.checkpoint_path_optimizer)
                self.trainer_state = self.load_saved_trainer_state(self.checkpoint_path_trainer_state)

    def fit(self,
            train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
            model=None,
            weights_train_objectives:List = None,
            teachers: List = None,
            evaluator: SentenceEvaluator = None,
            epochs: int = 1,
            epochs_total_lifetime = None,
            steps_per_epoch = None,
            scheduler: str = None, # 'WarmupLinear',
            warmup_steps: int = 10000,
            optimizer_class: Type[Optimizer] = torch.optim.AdamW,
            optimizer_params : Dict[str, object]= {'lr': 2e-5},
            weight_decay: float = 0.01,
            evaluation_steps: int = 0,
            save_best_model: bool = True,
            max_grad_norm: float = 2.0,
            use_amp: bool = False,
            callback: Callable[[float, int, int], None] = None,
            show_progress_bar: bool = True,
            checkpoint_path = None,
            checkpoint_path_optimizer= None,
            checkpoint_path_scheduler= None,
            checkpoint_path_trainer_config= None,
            checkpoint_save_steps: int = 500,
            checkpoint_save_total_limit: int = 2
            ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param train_objectives: Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param steps_per_epoch: Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        :param show_progress_bar: If True, output a tqdm progress bar
        :param checkpoint_path: Folder to save checkpoints during training
        :param checkpoint_save_steps: Will save a checkpoint after so many steps
        :param checkpoint_save_total_limit: Total number of checkpoints to store
        """
        if self.model_state is not None:
            print('reloading saved model state into model')
            model.load_state_dict(self.model_state)
            self.model = model

        # paths (optional update)
        self.checkpoint_path = checkpoint_path if checkpoint_path is not None else self.checkpoint_path
        self.checkpoint_path_optimizer = checkpoint_path_optimizer if checkpoint_path_optimizer is not None else self.checkpoint_path_optimizer
        self.checkpoint_path_scheduler = checkpoint_path_scheduler if checkpoint_path_scheduler is not None else self.checkpoint_path_scheduler
        self.checkpoint_path_trainer_state = checkpoint_path_trainer_state if checkpoint_path_trainer_state is not None else self.checkpoint_path_trainer_state
        self._target_device = model.device
        self.max_grad_norm = max_grad_norm
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps
        self.optimizer_params = optimizer_params
        self.evaluation_steps = evaluation_steps

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        #self.to(self._target_device)

        dataloaders = [dataloader for dataloader, _ in train_objectives]

        # Use smart batching
        if len(collators)==0 or collators is None:
            print('using default batch collators')
        for dli, dataloader in enumerate(dataloaders):
            if dataloader.collate_fn is None:
                print('using default batch collators for dataloader %d' % dli)
                dataloader.collate_fn = self.smart_batching_collate

        loss_models = [loss for _, loss in train_objectives]
        for midx, loss_model in enumerate(loss_models):
            if self.loss_models_states is not None:
                # reload each loss_model.classifier's saved states
                if hassattr(loss_model, 'classifier'):
                    loss_model.classifier.load_state_dict(self.loss_models_states[midx])
            loss_model.to(self._target_device)

        if steps_per_epoch is None or steps_per_epoch == 0:
            steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])

        if epochs_total_lifetime is None:
            epochs_total_lifetime = self.epochs_total_lifetime
        num_train_steps = int(steps_per_epoch * epochs_total_lifetime)

        # Prepare optimizers
        #optimizers = []
        #schedulers = []
        #for model_idx, loss_model in enumerate(loss_models):
        #    param_optimizer = list(loss_model.named_parameters())#
        #    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        #    optimizer_grouped_parameters = [
        #        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
        #        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        #    ]
        #    optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
        #    scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)
        #    if self.optimizer_state_dicts is not None:
        #        # reload optimizer states
        #        optimizer.load_state_dict(self.optimizer_state_dicts[model_idx])
        #    if self.scheduler_state_dicts is not None:
        #        # relead scheduler states
        #        scheduler_obj.load_state_dict(self.scheduler_state_dicts[model_idx])
        #    optimizers.append(optimizer)
        #    schedulers.append(scheduler_obj)

        # from: https://stackoverflow.com/questions/46377599/when-to-use-individual-optimizers-in-pytorch
        optimizer_parameters = set()
        for model_idx, loss_model in enumerate(loss_models):
            optimizer_parameters |= loss_model.named_parameters()

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in optimizer_parameters if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in optimizer_parameters if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
        scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)
        if self.optimizer_state_dicts is not None:
            # reload optimizer states
            #optimizer.load_state_dict(self.optimizer_state_dicts[model_idx])
            optimizer.load_state_dict(self.optimizer_state_dicts)
        if self.scheduler_state_dicts is not None:
            # relead scheduler states
            #scheduler_obj.load_state_dict(self.scheduler_state_dicts[model_idx])
            scheduler_obj.load_state_dict(self.scheduler_state_dicts)

        global_step = self.global_step
        data_iterators = [iter(dataloader) for dataloader in dataloaders]

        num_train_objectives = len(train_objectives)

        for epoch in trange(epochs, desc="Epoch", disable=not show_progress_bar):
            self.epochs_global += epoch
            training_steps = 0

            for loss_model in loss_models:
                loss_model.zero_grad()
                loss_model.train()

            for _ in trange(steps_per_epoch, desc="Iteration", smoothing=0.05, disable=not show_progress_bar):

                # loop through multiple tasks
                for train_idx in range(num_train_objectives):
                    loss_model = loss_models[train_idx]
                    loss_weight = weights_train_objectives[train_idx]
                    teacher = teachers[train_idx]
                    optimizer = optimizers[train_idx]
                    scheduler = schedulers[train_idx]
                    data_iterator = data_iterators[train_idx]

                    try:
                        data = next(data_iterator)
                    except StopIteration:
                        data_iterator = iter(dataloaders[train_idx])
                        data_iterators[train_idx] = data_iterator
                        data = next(data_iterator)

                    features, labels = data
                    features = list(map(lambda batch: batch_to_device(batch, self._target_device), features))
                    if labels is not None:
                        labels = labels.to(self._target_device)

                    loss_value = loss_model(features, labels, teacher=teacher)
                    loss_value *= loss_weight
                    loss_value.backward()

                torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
                optimizers.step()
                optimizers.zero_grad()
                schedulers.step()

                # TODO: integrate amp: https://discuss.pytorch.org/t/ddp-amp-gradient-accumulation-calling-optimizer-step-leads-to-nan-loss/162624
                training_steps += 1
                global_step += 1
                self.global_step = global_step

                if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path, save_best_model, epoch, training_steps, callback)

                    for loss_model in loss_models:
                        loss_model.zero_grad()
                        loss_model.train()

                if self.checkpoint_path is not None and checkpoint_save_steps is not None and checkpoint_save_steps > 0 and global_step % checkpoint_save_steps == 0:
                    self._save_checkpoint(
                        model, optimizers, schedulers, loss_models, checkpoint_save_total_limit, global_step
                    )

            self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1, callback)

        #if evaluator is None and output_path is not None:   #No evaluator, but output path: save final model version
        #    self.save(output_path)

        if checkpoint_path is not None:
            self._save_checkpoint(
                model, optimizers, schedulers, loss_models, checkpoint_save_total_limit, global_step
            )

    def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None):
        """
        Evaluate the model

        :param evaluator:
            the evaluator
        :param output_path:
            the evaluator can write the results to this path
        """
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
        return evaluator(self, output_path)

    def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps, callback):
        """Runs evaluation during the training"""
        eval_path = output_path
        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)
            eval_path = os.path.join(output_path, "eval")
            os.makedirs(eval_path, exist_ok=True)

        if evaluator is not None:
            score = evaluator(self, output_path=eval_path, epoch=epoch, steps=steps)
            if callback is not None:
                callback(score, epoch, steps)
            if score > self.best_score:
                self.best_score = score
                if save_best_model:
                    self.save(output_path)

    def _save_checkpoint(
        self,
        model,
        optimizers,
        schedulers,
        loss_models,
        checkpoint_save_total_limit,
        step,
        checkpoint_path = None,
        checkpoint_path_optimizer = None,
        checkpoint_path_scheduler = None,
        checkpoint_path_trainer_state =None
    ):
        # Store new checkpoint
        checkpoint_path = checkpoint_path if checkpoint_path is not None else self.checkpoint_path
        checkpoint_path_optimizer = checkpoint_path_optimizer if checkpoint_path_optimizer is not None else self.checkpoint_path_optimizer
        checkpoint_path_scheduler = checkpoint_path_scheduler if checkpoint_path_scheduler is not None else self.checkpoint_path_scheduler
        checkpoint_path_trainer_state = checkpoint_path_trainer_state if checkpoint_path_trainer_state is not None else self.checkpoint_path_trainer_state

        # model states
        self.model_state = model.state_dict()
        self.loss_models_states = [self._grab_loss_states(loss_model) for loss_models]
        torch.save({
            'epochs_global':self.epochs_global, 'global_step':self.global_step, 'step':step,
            'model_state_dict':self.model_state,
            'loss_models_state_dicts':self.loss_models_states,
        }, "%s-%08g" % (checkpoint_path, step))

        # optimizer
        self.optimizer_state_dicts = optimizers.state_dict() #[opt.state_dict() for opt in optimizers],
        torch.save({
            'epochs_global':self.epochs_global, 'global_step':self.global_step, 'step':step,
            'optimizer_state_dicts':self.optimizer_state_dicts,
        }, "%s-%08g" % (checkpoint_path_optimizer, step))

        # scheduler
        self.scheduler_state_dicts = schedulers.state_dict() #[scheduler.state_dict() for scheduler in schedulers]
        torch.save({
            'epochs_global':self.epochs_global, 'global_step':self.global_step, 'step':step,
            'scheduler_state_dicts':self.scheduler_state_dicts,
        }, "%s-%08g" % (checkpoint_path_scheduler, step))

        # trainer info
        with open(checkpoint_path_trainer_state, 'w') as jcon:
            trainer_objs_to_save = {
                'epochs_global':self.epochs_global, 'global_step':self.global_step, 'step':step,
                'max_grad_norm':self.max_grad_norm,
                'weight_decay':self.weight_decay,
                'warmup_steps':self.warmup_steps,
                'optimizer_params':self.optimizer_params,
                'evaluation_steps':self.evaluation_steps,
                'checkpoint_path_optimizer': "%s-%08g" % (checkpoint_path_optimizer, step),
                'checkpoint_path_scheduler': "%s-%08g" % (checkpoint_path_scheduler, step),
            }
            json.dump(trainer_objs_to_save, jcon)

        # Delete old checkpoints
        if checkpoint_save_total_limit is not None and checkpoint_save_total_limit > 0:
            old_checkpoints = []
            dir_to_checkpoints = "/".join(checkpoint_path.split('/')[:-1])
            for f in os.listdir(dir_to_checkpoints):
                if bool(re.search('(\-[0-9]+$',f)) & (checkpoint_path in f):
                    # get step of saved checkpoint
                    old_pt_step = int(re.search('(?<=\-)[0-9]+$',f).group())
                    old_checkpoints.append({
                        'step': old_pt_step, 'path': os.path.join(dir_to_checkpoints, f)
                    })

            if len(old_checkpoints) > checkpoint_save_total_limit:
                old_checkpoints = sorted(old_checkpoints, key=lambda x: x['step'])
                oldest_step = old_checkpoints[0]['step']
                for old_checkpoint in old_checkpoints:
                    if old_checkpoint['step']==oldest_step:
                        print('deleting old checkpoint: %s' % old_checkpoint['path'])
                        shutil.rmtree(old_checkpoint['path'])

    def _grab_loss_states(loss_model):
        """Gets the loss_model.state_dict() for a model embedded in a loss function"""
        return loss_model.classifier.state_dict()

    def load_saved_model(checkpoint_path=None):
        """reload saved model"""
        checkpoint_path = self.checkpoint_path if checkpoint_path is None else checkpoint_path
        saved_dict = torch.load(checkpoint_path)
        return saved_dict['model_state_dict'], saved_dict['loss_models_state_dicts']

    def load_saved_scheduler(checkpoint_path_scheduler=None):
        """reload saved model"""
        checkpoint_path_scheduler = self.checkpoint_path_scheduler if checkpoint_path_scheduler is None else checkpoint_path_scheduler
        saved_dict = torch.load(checkpoint_path_scheduler)
        return saved_dict['scheduler_state_dicts']

    def load_saved_optimizer(checkpoint_path_optimizer=None):
        """reload saved model"""
        checkpoint_path_optimizer = self.checkpoint_path_optimizer if checkpoint_path_optimizer is None else checkpoint_path_optimizer
        saved_dict = torch.load(checkpoint_path_optimizer)
        return saved_dict['optimizer_state_dicts']

    def load_saved_trainer_state(checkpoint_path_trainer_state):
        checkpoint_path_trainer_state = self.checkpoint_path_trainer_state if checkpoint_path_trainer_state is None else checkpoint_path_trainer_state
        with open(checkpoint_path_trainer_state, 'r') as jcon:
            trainer_state = json.load(jcon)
        self.epochs_global = trainer_state['epochs_global']
        self.global_step = trainer_state['global_step']
        self.step = trainer_state['step']
        self.max_grad_norm = trainer_state['max_grad_norm']
        self.weight_decay = trainer_state['weight_decay']
        self.warmup_steps = trainer_state['warmup_steps']
        self.optimizer_params = trainer_state['optimizer_params']
        self.evaluation_steps = trainer_state['evaluation_steps']

    def _load_auto_model(self, model_name_or_path):
        """
        Creates a simple Transformer + Mean Pooling model and returns the modules
        """
        logger.warning("No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path))
        transformer_model = Transformer(model_name_or_path)
        pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean')
        return [transformer_model, pooling_model]

    def _load_sbert_model(self, model_path):
        """
        Loads a full sentence-transformers model
        """
        # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
        config_sentence_transformers_json_path = os.path.join(model_path, 'config_sentence_transformers.json')
        if os.path.exists(config_sentence_transformers_json_path):
            with open(config_sentence_transformers_json_path) as fIn:
                self._model_config = json.load(fIn)

            if '__version__' in self._model_config and 'sentence_transformers' in self._model_config['__version__'] and self._model_config['__version__']['sentence_transformers'] > __version__:
                logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(self._model_config['__version__']['sentence_transformers'], __version__))

        # Check if a readme exists
        model_card_path = os.path.join(model_path, 'README.md')
        if os.path.exists(model_card_path):
            try:
                with open(model_card_path, encoding='utf8') as fIn:
                    self._model_card_text = fIn.read()
            except:
                pass

        # Load the modules of sentence transformer
        modules_json_path = os.path.join(model_path, 'modules.json')
        with open(modules_json_path) as fIn:
            modules_config = json.load(fIn)

        modules = OrderedDict()
        for module_config in modules_config:
            module_class = import_from_string(module_config['type'])
            module = module_class.load(os.path.join(model_path, module_config['path']))
            modules[module_config['name']] = module

        return modules

    @staticmethod
    def load(input_path):
        return SentenceTransformer(input_path)

    @staticmethod
    def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
        """
        Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        """
        scheduler = scheduler.lower()
        if scheduler == 'constantlr':
            return transformers.get_constant_schedule(optimizer)
        elif scheduler == 'warmupconstant':
            return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
        elif scheduler == 'warmuplinear':
            return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        elif scheduler == 'warmupcosine':
            return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        elif scheduler == 'warmupcosinewithhardrestarts':
            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        else:
            raise ValueError("Unknown scheduler {}".format(scheduler))

    @property
    def device(self) -> device:
        """
        Get torch.device from module, assuming that the whole module has one device.
        """
        try:
            return next(self.parameters()).device
        except StopIteration:
            # For nn.DataParallel compatibility in PyTorch 1.5

            def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
                tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
                return tuples

            gen = self._named_members(get_members_fn=find_tensor_attributes)
            first_tuple = next(gen)
            return first_tuple[1].device

    @property
    def tokenizer(self):
        """
        Property to get the tokenizer that is used by this model
        """
        return self.model.tokenizer

    #@tokenizer.setter
    #def tokenizer(self, value):
    #    self._first_module().tokenizer = value

    @property
    def max_seq_length(self):
        """
        Property to get the maximal input sequence length for the model. Longer inputs will be truncated.
        """
        return self.model._first_module().max_seq_length

    @max_seq_length.setter
    def max_seq_length(self, value):
        """
        Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
        """
        self.model._first_module().max_seq_length = value

SyntaxError: ignored

### Load a Standard Dataset for MLM task

Also need to grab datasets here: https://arxiv.org/pdf/1908.08962.pdf

```
    The Pile dataset looks good: https://pile.eleuther.ai/
    https://arxiv.org/abs/2101.00027
    PubMed Central, ArXiv, GitHub, the FreeLaw Project, Stack Exchange, the US
    Patent and Trademark Office, PubMed, Ubuntu IRC, HackerNews, YouTube, PhilPapers, and NIH ExPorter.
    We also introduce OpenWebText2 and
    BookCorpus2, which are extensions of the original
    OpenWebText (Gokaslan and Cohen, 2019) and
    BookCorpus (Zhu et al., 2015; Kobayashi, 2018)
    datasets, respectively.
    In addition, we incorporate several existing highquality datasets: Books3 (Presser, 2020), Project Gutenberg (PG-19) (Rae et al., 2019), OpenSubtitles (Tiedemann, 2016), English Wikipedia, DM Mathematics (Saxton et al., 2019), EuroParl
    (Koehn, 2005), and

    ABout the law:
    and other metadata, we focused specifically on
    court opinions due to an abundance of full-text
    entries. This data is entirely within the public domain.

```

Scientific Papers: You can use the scientific_papers dataset, which includes a large collection of scientific papers from various domains. It covers research articles from fields such as computer science, physics, biology, and more.

Patents: The patent_citations dataset contains patent text data along with citation information, making it suitable for training language models with a focus on technical and scientific domains.

ArXiv: The arxiv dataset includes research papers from the arXiv repository, covering a wide range of scientific disciplines. It can be used to enhance the exposure of your model to academic literature.

PubMed: The pubmed dataset consists of abstracts from biomedical research articles indexed in PubMed. It is a valuable resource if you want to incorporate biomedical and life sciences content into your MLM pretraining.

joelito/Multi_Legal_Pile - use subset `en_all` to access EU-courts, and other datasets


Looks like streaming data is available:
https://huggingface.co/learn/nlp-course/chapter5/4?fw=pt

In [None]:
### Load a standard dataset
%pip install transformers datasets zstandard rank_bm25 langdetect
# need the zstandard to use the streaming data function from huggingface datasets

Collecting transformers
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.14.3-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.1/519.1 kB[0m [31m41.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting zstandard
  Downloading zstandard-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Collecting langdetect
  Downloading langdetect-1.0.9.tar.gz (981 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m57.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting huggin

In [None]:
import lzma
from datasets import load_dataset
from itertools import islice
from datasets import interleave_datasets # for interweaving streaming datasets
#from transformers import BertTokenizer, LineByLineTextDataset, DataCollatorForLanguageModeling
from spacy.lang.en import English
import spacy
import re
import random
import numpy as np
import os
import pickle
from langdetect import detect
import copy

In [None]:
#import zstandard maybe not necessary
# Notes on Pile
# the largest ones are tarred and cannoted be loaded (like openwebtext2), but some are already available on huggingface anyway
"""dataset4 = load_dataset("the_pile_openwebtext2",split='train',streaming=True)""" # load like THIS!!
# consider using book2: RyokoExtra/books2-1.2-lite
# 'the_pile_books3',
# 'the_pile_stack_exchange'
# 'the_pile_openwebtext2'
# 'Cohere/wikipedia-22-12'
# 'tiiuae/falcon-refinedweb' # see also google's C4
# see more under conceptofmind/pile

# base_url = "https://the-eye.eu/public/AI/pile/"
data_files = [
     ("tiiuae/falcon-refinedweb",None, 18.11),# CC
     ('Cohere/wikipedia-22-12','en', 14.40), # see also: conceptofmind/pile_wikipedia_en
     #("the_pile_books3", None, 12.07), # alternative? bookcorpusopen (no); Multi-Domain-Expert-Layers/the_pile_books3_packed_128k
     ("Multi-Domain-Expert-Layers/the_pile_books3_packed_128k",None,12.07, 34500), # 34.5.k
     ("the_pile_openwebtext2",None, 10.01),
     ("macrocosm/arxiv_abstracts",None, 3.75), # just the abstracts k: abstract
     ("ccdv/pubmed-summarization",None, 3.75),# PMC # I should reduce this, use wikipedia instead
     ('https://the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst',None,  3.0), # freelaw THE EYE DELETED THE ORIGINAL DATA
     ('the_pile_stack_exchange',None,  5.13),
     ("conceptofmind/pile_uspto_backgrounds",None, 3.00),
     ("https://the-eye.eu/public/AI/pile_preliminary_components/PUBMED_title_abstracts_2019_baseline.jsonl.zst",None,  3.07),
     #"pg19", 0.1, # project gutenberg FAILS
     #("https://the-eye.eu/public/AI/pile_v2/data/EuroParliamentProceedings_1996_2011.jsonl.zst", None, 0.73), # NON ENGLISH
     #('https://the-eye.eu/public/AI/pile_preliminary_components/EuroParliamentProceedings_1996_2011.jsonl.zst',None, 0.73), # NON ENGLISH
     ("pile-of-law/pile-of-law",'euro_parl',1),
     ("conceptofmind/pile_hacker_news", None,2),
     #("https://the-eye.eu/public/AI/pile_preliminary_components/PhilArchive.jsonl.zst", None, 0.38), #(( philosophy papers -- its taking too long?
     ('https://the-eye.eu/public/AI/pile_v2/data/PhilArchive.jsonl.zst', None, 0.38), # does this work better?
     ("https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst",None, 0.30),
     ("https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", None, 5.0),# ledgar worked
     ("pile-of-law/pile-of-law",'r_legaladvice', 1.0),
     ("pile-of-law/pile-of-law",'exam_outlines',0.5),
     ("pile-of-law/pile-of-law",'cc_casebooks',0.5),
     ("eloukas/edgar-corpus",None, 4.0),
     #("orieg/elsevier-oa-cc-by",None,3.75) fails (takes too long)
     ("Rahmaa/ElsevieR_ClEaN",None,3.75),#
     ('ashraq/financial-news-articles', None, 1.0),
     ('pile-of-law/pile-of-law','courtlistener_opinions',  3.0), # freelaw THE EYE DELETED THE ORIGINAL DATA
     ('suolyer/pile_nih-exporter',['validation','test'], 0.30/2),
     ('EleutherAI/pile','all',10.01 + 5.13 + 3.07) # backup for openweb3 and stackexchange and  and pubmed abstracts
    #"https://huggingface.co/datasets/pile-of-law/pile-of-law/blob/main/data/train.edgar.jsonl.xz"
]

print(len(data_files))


data_streaming_config = {
    'files':data_files,
    'val_size':2000,
    'min_seq_length':48,
    'max_seq_length':512,
    'max_chunk_size':6,
    'train_chunk_size':6000,
    'max_chunk_start':1000000,
    "seed":42,
    "do_cc":True,
    "do_wikipedia":True,
    "do_book3":True, # delated from pile, no backup -- maybe book corpus and bookcorpus2?
    "do_openwebtext2":False, # delated from pile, see pilebackup
    "do_arxiv":True,
    "do_pmc-articles":True,
    "do_freelawopinions":False, # deleted from https://pile.eleuther.ai/
    "do_stackexchange":False, # delated from pile, see pilebackup
    "do_upto":True,
    "do_pubmed-abstracts":False, # deleted from https://pile.eleuther.ai/ see pilebackup
    "do_EuroParliamentProceedings_1996_2011":True,
    "do_hackernews":True,
    "do_philpapers":False, # this crashes my computer
    "do_NIH_ExPORTER_awarded_grant":True, # deleted from pile.eleuther.ai, opps, looks like it is restored
    "do_ledgar":True,
    "do_r_legaladvice":True,
    "do_legalexams":True,
    "do_casetexts":True,
    "do_edgar":True,
    "do_elseiver":True,
    'do_financialnews':True,
    'do_pilelawopinions_sub':True,
    'do_nih-backup':False,
    'do_pilebackupfiltered':True, # backup pile, filtered : nope, it depend on pile.eleuther.ai
}

data_streaming_config = {
    'files':data_files,
    'val_size':200,
    'min_seq_length':48,
    'max_seq_length':512,
    'max_chunk_size':6,
    'train_chunk_size':300,
    'max_chunk_start':6000,
    "seed":42,
    "do_cc":False,
    "do_wikipedia":False,
    "do_book3":True, # delated from pile, but alternative seems to wors
    "do_openwebtext2":False, # delated from pile, see pilebackup
    "do_arxiv":False,
    "do_pmc-articles":False,
    "do_freelawopinions":False, # deleted from https://pile.eleuther.ai/
    "do_stackexchange":False, # delated from pile, see pilebackup
    "do_upto":False,
    "do_pubmed-abstracts":False, # deleted from https://pile.eleuther.ai/ see pilebackup
    "do_EuroParliamentProceedings_1996_2011":False,
    "do_hackernews":False,
    "do_philpapers":False, # this crashes my computer
    "do_NIH_ExPORTER_awarded_grant":False, # deleted from pile.eleuther.ai, opps, looks like it is restored
    "do_ledgar":False,
    "do_r_legaladvice":False,
    "do_legalexams":False,
    "do_casetexts":True,
    "do_edgar":True,
    "do_elseiver":False,
    'do_financialnews':False,
    'do_pilelawopinions_sub':False,
    'do_nih-backup':False,
    'do_pilebackupfiltered':False, # backup pile, filtered : nope, it depend on pile.eleuther.ai
}


24


In [None]:

import re

def remove_first_http_url(text):
    pattern = r'http[s]*://[^ ]+'
    return re.sub(pattern, '', text, 1)

def parse_hacker_news(text):
    return remove_first_http_url(" ".join([" ".join(j.split('\n')[1:]) for j in text.replace("------\n","~~~\n").replace("======\n","~~~\n").split("~~~\n")]))


In [None]:
def make_streaming_datasets(data_streaming_config, streaming_datasets = None):
    """Makes the streaming dataset, like Pile but includes others"""

    print('consider adding: ashraq/financial-news-articles, for finacial news')

    def casetext_skip_first_k_char(example):
        example['text'] = example['text'][120000:].replace('\n'," ")
        return example

    def edgar_consolidate_sections(example):
        example['text'] = example['section_1'] + "\n" + example['section_2'] + "\n" + example['section_3'] + "\n" + example['section_7']
        return example

    def clean_elseiver_mlm(example):
        example['text'] = example['Clean_Title'] + " - " + example['Clean_Summary'] + "\n" + example['Clean_Text']
        return example

    def clean_financial_news(example):
        example['text'] = example['title'] + "\n" + example['text']
        return example

    if streaming_datasets is None:
        streaming_datasets = []

    # new probabilities
    probabilities = []

    data_files = data_streaming_config['files']

    if data_streaming_config['do_cc']:
        # CommonCraw
        j = 0
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['url', 'timestamp', 'dump', 'segment', 'image_urls']).rename_column('content','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_wikipedia']:
        # wikipedia
        j = 1
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[1][1], data_files[5][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['id', 'title', 'url', 'wiki_id', 'views', 'paragraph_id', 'langs'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_book3']:
        # the_pile_books3: need to figure out how to skip a certain amount of tokens
        j = 2
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[2][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['title'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_openwebtext2']:
        # the_pile_openwebtext2:
        j = 3
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[3][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['title','reddit_scores'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_arxiv']:
        # arxiv_abstracts:
        j = 4
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[4][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['embeddings', 'doi']).rename_column('abstract','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_pmc-articles']:
        # PMC articles
        j = 5
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[5][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['abstract']).rename_column('article','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_freelawopinions']:
        # Freelaw opinions
        j = 6
        streaming_datasets.append(load_dataset('json', data_files=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_stackexchange']:
        j = 7
        # stackexchange
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(path=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['domain'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_upto']:
        j = 8
        # upto
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(path=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_pubmed-abstracts']:
        j = 9
        # pubmed abstracts
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset('json', data_files=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_EuroParliamentProceedings_1996_2011']:
        j = 10
        #EuroParliamentProceedings_1996_2011
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['created_timestamp', 'downloaded_timestamp', 'url'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_hackernews']:
        j = 11
        # hackernews discusions
        print("Trying '%s" % data_files[j][0])
        print('Hacker news needs extra cleaning to remove ===== username and ----- username and ~~~ username')
        streaming_datasets.append(load_dataset(data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_philpapers']:
        j = 12
        # philosophy papers / philpapers
        streaming_datasets.append(load_dataset('json', data_files=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_NIH_ExPORTER_awarded_grant']:
        j = 13
        # NIH_ExPORTER_awarded_grant_text
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset('json', data_files=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_ledgar']:
        j = 14
        # LEDGAR_2016: ("https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", None, 5.0),# ledgar worked
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset('json', data_files=data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['label', 'source']).rename_column('provision','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_r_legaladvice']:
        j = 15
        # r_legaladvice
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['created_timestamp', 'downloaded_timestamp', 'url'])#.rename_column('provision','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_legalexams']:
        j = 16
        # legal exams
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['created_timestamp', 'downloaded_timestamp', 'url'])#.rename_column('provision','text')
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_casetexts']:
        j = 17
        # case text books
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['created_timestamp', 'downloaded_timestamp', 'url']).map(casetext_skip_first_k_char)
        probabilities.append(data_files[j][-1])


    if data_streaming_config['do_edgar']:
        j = 18
        # edgar corpus
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].map(edgar_consolidate_sections).remove_columns([
            'filename', 'cik', 'year', 'section_1A', 'section_1B', 'section_4', 'section_1', 'section_2', 'section_3', 'section_7',
            'section_5', 'section_6', 'section_8', 'section_9', 'section_10', 'section_7A', 'section_9A', 'section_9B',
            'section_11', 'section_12', 'section_13', 'section_14', 'section_15' #
        ])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_elseiver']:
        j = 19
        # elseiver
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], None, split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].map(clean_elseiver_mlm).remove_columns(['Unnamed: 0', 'Clean_Title', 'Clean_Text', 'Clean_Summary'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_financialnews']:
        j = 20
        # financial_news
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], None, split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].map(clean_financial_news).remove_columns(['title','url'])
        probabilities.append(data_files[j][-1])

    if data_streaming_config['do_pilelawopinions_sub']:
        j = 21
        # SUBSTITUTE: pile-of-law opinions
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split="train", streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['created_timestamp', 'downloaded_timestamp', 'url'])
        probabilities.append(data_files[j][-1])

    # do_nih-backup
    if data_streaming_config['do_nih-backup']:
        j = 22
        # SUBSTITUTE: pile-of-law opinions
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], split=data_files[j][1][0], streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])
        streaming_datasets.append(load_dataset(data_files[j][0], split=data_files[j][1][1], streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    # backup using the pile's 'all' filtered {'ArXiv','FreeLaw', 'Github','NIH ExPorter','OpenWebText2','Pile-CC','PubMed Abstracts','PubMed Central','StackExchange',
    #'USPTO Backgrounds', 'Wikipedia (en)'}
    if data_streaming_config['do_pilebackupfiltered']:
        j = 23
        print("Trying '%s" % data_files[j][0])
        streaming_datasets.append(load_dataset(data_files[j][0], data_files[j][1], split='train',streaming=True))
        streaming_datasets[-1] = streaming_datasets[-1].filter(
            lambda x: x['meta']['pile_set_name'] in ['NIH ExPorter','OpenWebText2','PubMed Abstracts','StackExchange','Wikipedia (en)']
        ).remove_columns(['meta'])
        probabilities.append(data_files[j][-1])

    assert len(streaming_datasets)==len(probabilities)
    return streaming_datasets, probabilities


def fetch_and_combine_streaming_mlm_data(
    data_streaming_config,
    stopping_strategy ='all_exhausted',
):
    """Creates dev-set and a random chunk for training set from a massive streaming dataset (pile)"""

    # make all the streaming datsets
    datasets_to_stream, dataset_probabilities = make_streaming_datasets(
        data_streaming_config, streaming_datasets = None
    )
    # normalize the probabilities
    dataset_probabilities = [
        p/sum(dataset_probabilities) for p in dataset_probabilities
    ]

    print('DONE initializing streaming datasets')
    #return datasets_to_stream

    # combine the datasets to stream together
    datasets_combined = interleave_datasets(
        datasets_to_stream,
        stopping_strategy = stopping_strategy,
        probabilities = dataset_probabilities,
        seed = data_streaming_config['seed']
    )
    return datasets_combined


In [None]:

CHAR_PER_WORD = 6.36
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("sentencizer")
config = {
    'max_seq_length':512,
    'min_seq_length':48,
    'max_chunk_size':6,
    'min_sentence_len':20,
    'seed':42
}

class ExampleProcessor:
    def __init__(
        self,
        config=config,
        char_per_word = CHAR_PER_WORD,
        nlp =nlp,
    ):
        self.nlp = nlp
        self.char_per_word = char_per_word
        self.max_seq_length = config.get('max_seq_length', 512) # maximum word-length for chunks for mlm objective (else split)
        self.min_seq_length = config.get('min_seq_length', 128) # min sequence length for chunks (else discard
        self.max_chunk_size = config.get('max_chunk_size', 5) # maximum number of chunks of text to take (each ~512 in length)
        self.min_sentence_len = config.get('min_sentence_len', 20) # for next-sentence, min sentence size to merge together
        self.seed = config.get('seed', 42)
        self.max_chunk_length = self.max_chunk_size * self.max_seq_length
        self.max_chunk_length_char = int(self.max_chunk_length*self.char_per_word)
        self.min_seq_length_char = int(self.min_seq_length*self.char_per_word)
        self.min_sentence_length_char = int(self.min_sentence_len*self.char_per_word)

    @staticmethod
    def split_into_chunks(text, chunk_char_size, overlapping_size = 50):
        chunks = []
        start = 0
        end = chunk_char_size + overlapping_size
        while start < len(text):
            chunk = text[start:end]
            period_index = chunk.find(". ")
            if period_index != -1:
                chunk = chunk[period_index + 1:]
            else:
                first_space_index = chunk.find(" ")
                if first_space_index != -1:
                    chunk = chunk[first_space_index + 1:]
            # Check if the chunk has been split and contains more than one word
            #if start > 0 and " " in chunk:
            if end < len(text) and " " in chunk and chunk[-1]!=" ":
                last_space_index = chunk.rfind(" ")
                chunk = chunk[:last_space_index]
            chunks.append(chunk)
            start += chunk_char_size
            end += chunk_char_size
        return chunks

    def split_chunk_into_sentences(self, chunk, discard_first_sentence=True, discard_last_sentence=True ):
        doc = self.nlp(chunk)
        MAX_CHAR_LEN = int(self.max_seq_length*self.char_per_word)
        sentences = [sent.text for sent in doc.sents]
        if discard_first_sentence:
            sentences = sentences[1:]
        if discard_last_sentence:
            sentences = sentences[:-1]

        super_list_concatenated = [] # accumulates concatenated sentences
        super_list_raw_sentences = [] # accumulates raw sentences (for next-sentence prediction)
        buffer = []
        buffer_len = 0

        for sentence in sentences:
            sentence_len = len(sentence)

            if buffer_len + sentence_len > MAX_CHAR_LEN:
                super_list_concatenated.append(" ".join(buffer))
                super_list_raw_sentences.extend(buffer)
                buffer = []
                buffer_len = 0

            buffer.append(sentence)
            buffer_len += sentence_len

        if buffer:  # If there are any remaining sentences in the buffer
            super_list_concatenated.append(" ".join(buffer))
            super_list_raw_sentences.extend(buffer)

        return super_list_concatenated, super_list_raw_sentences

    def _sample_chunk_span(self, text, max_chunk_length_char):
        chunks = self.split_into_chunks(text, max_chunk_length_char)
        # randomly sample from the chunks
        #FOOBAR SAMPLE FROM CHUNKS
        return random.choice(chunks)

    def is_too_small_quickcheck(self, text, textlen=None):
        if textlen is None: textlen = len(text.strip())
        return textlen < self.min_seq_length_char*0.9

    def is_too_small(self, nwords):
        return nwords < self.min_seq_length

    def is_larger_than_max_chunk_quickcheck(self, text, textlen):
        """if it is larger than a chunksize, then we need to sample chunks"""
        if textlen is None: textlen = len(text.strip())
        return textlen > self.max_chunk_length_char

    def is_short_than_a_chunk(self, text, textlen):
        """if it is shorter than a chunk, then we'll take all text, in chunks"""
        if textlen is None: textlen = len(text.strip())
        return textlen < self.max_chunk_length_char

    def is_smaller_than_two_paragraphs(self, text):
        charlen = len(text)
        if charlen < (1.5*self.max_seq_length*self.char_per_word):
            return True, re.split(r"[\s\n\r]+",text.strip())
        if charlen > (2.5*self.max_seq_length*self.char_per_word):
            return False, None
        # inbetween cases, split and calculate the number of words
        textsplit = re.split(r"[\s\n\r]+",text.strip())
        nwords = len(textsplit)
        if nwords < 1.2*self.max_seq_length:
            return True, textsplit
        return False, textsplit

    @staticmethod
    def preprocess_sentences(list_of_sentences, min_sentence_char_length):
        """Merges small sentences in a sequence of sentence, until the strings are greater than `min_sentence_char_length`"""
        processed_sentences = []
        buffer = ""

        for sentence in list_of_sentences:
            if len(sentence) < min_sentence_char_length:
                buffer = buffer + " " + sentence
                if (len(buffer)>=min_sentence_char_length):
                    processed_sentences.append(buffer.strip())
                    buffer = ""
            else:
                if (len(buffer)<min_sentence_char_length):
                    to_add = buffer + " " + sentence
                    processed_sentences.append(to_add.strip())
                    buffer = ""
                else:
                    processed_sentences.extend([buffer.strip(), sentence.strip()])

        if buffer:  # If there are any remaining sentences in the buffer
            processed_sentences.append(buffer)

        return processed_sentences

    def process(self, text):
        """Chunks and samples large portions of text"""

        charlen = len(text.strip())

        # DISCARD if it is too small for copus
        if self.is_too_small_quickcheck(text, charlen):

            return {'text':[], 'do_accept':False, 'sentences':[]}

        # sample span of chunks: if it larger than our max chunk size
        if self.is_larger_than_max_chunk_quickcheck(text, charlen):
            text_span_chunks = self._sample_chunk_span(text, self.max_chunk_length_char)
        else:
            text_span_chunks = text

        # check if it smaller, than 1.5 seqlen, then we just accept it all as one unit to truncate later in tokenizer
        is_smaller_than_2_paras, textsplit = self.is_smaller_than_two_paragraphs(text_span_chunks)

        if is_smaller_than_2_paras:

            # check if less than minsize
            if self.is_too_small(len(textsplit)):
                # if too small, return nothing
                return {'text':[], 'do_accept':False, 'sentences':[]}

            # return text to be truncated
            return {'text':[text_span_chunks], 'do_accept':True, 'sentences':[]}

        # leftover cases: text that needs to be chunked into ~512 / max_seq_len
        text_to_return, sentences_to_return = self.split_chunk_into_sentences(text_span_chunks)

        # return text strings as list of chunks, flag
        return {
            'text':text_to_return,
            'do_accept':True,
            'sentences':self.preprocess_sentences(sentences_to_return, self.min_sentence_length_char),
        }

    def __call__(self, text):
        return self.process(text)

example_processor = ExampleProcessor(config=data_streaming_config, char_per_word = CHAR_PER_WORD, nlp =nlp)
text = """As the aircraft approached Pearl Harbor, the weather cleared, as if on cue. This enabled the strike formations to use the battery of searchlights at Kahuku Point as a navigation aid to guide them toward their targets. Dawn was now breaking. As sunlight streamed over the horizon, the airborne strike force pressed home its attack over Pearl Harbor, achieving complete surprise. Dive-bombers and torpedo planes went to work on the ships lying at anchor along Battleship Row, where the U.S. Navy's capital ships were berthed. Fighter aircraft peeled off and strafed the airfield, hitting parked planes, fuel storage tanks, and hangars. Army Air Corps pilots rushed to take off after the attacking force, but by the time they were aloft, the attackers had completed their strikes and vanished. Failing to locate the attackers, the Army aircraft returned to base, whereupon a second wave of carrier strike aircraft hit them. A _New York Times_ reporter on the scene reported that the attacks were "unopposed by the defense, which was caught virtually napping. Surveying the results, the American defenders were filled with anger—and relief. The attack, executed on the morning of Sunday, _February 7, 1932_ , occurred at the outset of a U.S. Army-Navy war game called Grand Joint Exercise 4. Rear Admiral Harry Yarnell, commander of the newly commissioned American aircraft carriers _Saratoga_ and _Lexington_ , had launched the attacking planes. The "bombs" dropped were flour bags, which could be found splattered on the Navy's ships still sitting at anchor. Surveying the results, the American defenders were filled with anger—and relief. The attack, executed on the morning of Sunday, _February 7, 1932_ , occurred at the outset of a U.S. Army-Navy war game called Grand Joint Exercise 4. Rear Admiral Harry Yarnell, commander of the newly commissioned American aircraft carriers _Saratoga_ and _Lexington_ , had launched the attacking planes. The "bombs" dropped were flour bags, which could be found splattered on the Navy's ships still sitting at anchor.Red-faced, the Army Air Corps commanders sought to minimize the attack's results. They argued that the damage incurred to Hickam Field was minimal, and asserted that they had found and attacked Yarnell's carriers. Finally, they protested the attack on legal grounds—it was improper to begin a war on Sunday! The war game's umpires sided with the Army. Their report made no mention of Yarnell's attack but concluded that "it is doubtful if air attacks can be launched against Oahu in the face of strong defensive aviation without subjecting the attacking carriers to the danger of material damage and consequent great loss in the attacking] air force. Nearly ten years later carriers of the Imperial Japanese Navy, attacking Pearl Harbor on Sunday, December 7, 1941, proved that Admiral Yarnell, not the umpires or the Army, had gauged the future correctly. The admiral had been willing to confront uncomfortable possibilities, whereas others had not. Although America was shocked by the Japanese attack, many in the Navy were not. As Admiral Chester W. Nimitz, the architect of the Navy's victorious campaign against Japan, ruefully admitted, "Nothing that happened in the Pacific was strange or unexpected. ## **THE DAWN OF BLITZKRIEG**"""
text += text
text += text
text += text
text += text
foo = example_processor(text = text)
foo,is_good, foo_sentences = foo.values()
print(is_good)
print('mlm_sentences')
print(foo)

print('next sentences:') # this seems to be working okay
print(foo_sentences)
print(len(foo_sentences))


# works: test the process_sentences
print(example_processor.preprocess_sentences(["This is fine.","foo",'sh',"This is fine and long.","This is also find and long.",'No', "This is long and good."], 10))

# works, this returns split sentences
example_processor.split_chunk_into_sentences(
    chunk="This is the first sentence. This is the 2nd sentence and another. I'm the third sentence. Hello, this is me. 5th sentence here. And finally its me.",
    discard_first_sentence=True, discard_last_sentence=True
)

NameError: ignored

#### A Sample of 1000 will have...
... approximately 1523 samples of 512-long examples

In [None]:
# FUNCTIONS TO MAKE THE TRAINING AND VAL SETs
import numpy as np
import pickle
import os
import pickle

## convert the streaming dataset in a static dataset
def convert_streaming_dataset_to_static_corpus(streaming_dataset, skip=0, take=1000):
    """Takes a streaming_dataset and converts it into a list of examples"""
    if skip !=0:
        dataset_to_make_static = streaming_dataset.skip(skip).take(take)
    else:
        dataset_to_make_static = streaming_dataset.take(take)

    examples_static_mlm = [] # data for MLM objective
    examples_static_nextsentence = [] # data for next sentence task
    for i, example in enumerate(dataset_to_make_static):
        # chunk text into ~512 text-strings, and sentences
        examples_processed = example_processor(text = example['text'])
        # chunk, accept/reject, sentences
        example_parsed, do_accept, parsed_sentences = examples_processed.values()
        if is_do_acceptgood:
            # mlm gets the chunks of text-strings
            examples_static_mlm.extend(example_parsed)
            if len(parsed_sentences)>4:
                # sentences for next sentence prediction: make triplet of s1,s2,opposite, where opposites get label=1
                examples_static_nextsentence.extend(
                    convert_sequence_into_nextsentence_pairs(parsed_sentences)
                )
                #FOOFU - STOPPED HERE TO FIGURE OUT WHY MY NEXT-SENTENCE STUFF IS SO LONG
        if (i+1)%100==0:
            print("...streaming size: " % len(examples_static_mlm))

    return examples_static_mlm, examples_static_nextsentence

def convert_sequence_into_nextsentence_pairs(list_of_sentences):
    """Converts a list of sentences into a list of dicts, with next-sentence pairs"""
    n = len(list_of_sentences)

    def opposite(i,n):
        return (i + round(n/2+1)) % n

    list_of_nextsentence_pairs = []
    # loop through sequence, make triplet of anchor, next and an opposite
    for o1,o2 in zip(range(0,n-1), range(1,n)):
        s_anchor = list_of_sentences[o1]
        s_next = list_of_sentences[o2]
        s_opposite = list_of_sentences[opposite(o1,n)]
        list_of_nextsentence_pairs.append(
            {
                "anchor":s_anchor,
                "next":s_next,
                "opposite":s_opposite
            }
        )
    return list_of_nextsentence_pairs

print(convert_sequence_into_nextsentence_pairs(['a','b','c','d','e','f']))


def train_test_splits_from_stream(
    streaming_dataset,
    val_size = 100,#2000,
    epoch = 0,
    chunk_size = 500,#6000,
    max_chunk_start = 1000000,
    path_to_val_cache = 'val_mlm_cache.pkl'
):
    """
    val_size = 2000, number of streaming-iter to skip, reserved for the val-sze
    epoch = 0, epoch will change the seed when sampling the chunk idx for making the training set
    chunk_size = 5000, # number of streaming-iter to select the training data chunk
    max_chunk_start = 2000000, # randomly sample within this interval for streaming chunks
    """
    if os.path.isfile(path_to_val_cache):
        print('RELOADING VAL SET: iter=%s' % path_to_val_cache)
        with open(path_to_val_cache,'rb') as pcon:
            val_corpus_list = pickle.load(pcon)
            val_corpus_nextsentence = pickle.load(pcon)
        print('VAL SET SIZE: %d' % len(val_corpus_list))
    else:
        # stream validation set
        print('STREAMING VAL DATA: %d' % val_size)
        val_corpus_list, val_corpus_nextsentence = convert_streaming_dataset_to_static_corpus(
            streaming_dataset, skip=0, take=val_size
        )
        # save the validation corpus
        print('SAVING VAL SET: %s' % path_to_val_cache)
        with open(path_to_val_cache,'wb') as pcon:
            pickle.dump(val_corpus_list, pcon)
            pickle.dump(val_corpus_nextsentence, pcon)

    # take a random interger to start the streaming of training data
    skip_to_start_streaming_training_data = np.random.RandomState(
        42 + epoch
    ).randint(val_size, max_chunk_start)

    # stream training data
    print('STREAMING TRAIN DATA: %d STARTING AT: %d' % (chunk_size,skip_to_start_streaming_training_data))
    train_corpus_mlm, train_corpus_nextsentence = convert_streaming_dataset_to_static_corpus(
        streaming_dataset,
        skip=skip_to_start_streaming_training_data,
        take=chunk_size
    )
    print('TRAIN SET SIZE: %d' % len(train_corpus_mlm))
    train_data_mlm = {
        'train':train_corpus_mlm,
        'val':val_corpus_list,
        'epoch':0,
        'index_stream':skip_to_start_streaming_training_data
    }
    # next sentence prediction
    return train_data_mlm

In [None]:

# stream and combine the MLM datasets
datasets_mlm_streaming_combined = fetch_and_combine_streaming_mlm_data(
    data_streaming_config, #stopping_strategy ='all_exhausted'
)

# create the training set and validation set (save and reload later)
datasets_static = train_test_splits_from_stream(
    datasets_mlm_streaming_combined,
    val_size = data_streaming_config['val_size'],#2000,
    epoch = 0,
    chunk_size =  data_streaming_config['train_chunk_size'],#6000,
    max_chunk_start = data_streaming_config['max_chunk_start'],#1000000,
    path_to_val_cache = 'val_mlm_cache.pkl'
)

In [None]:
for s in datasets_static['train']:
    print(s.replace('\n',' ')[:150] + "\n----")

### Re-DO the MLM Datasets to loop through each Stream individually (like QA below)

In [None]:
# https://the-eye.eu/public/AI/pile/train/00.jsonl.zst



RANDOM:https://the-eye.eu/public/AI/pile/train/00.jsonl.zst;https://the-eye.eu/public/AI/pile/train/01.jsonl.zst;https://the-eye.eu/public/AI/pile/train/02.jsonl.zst;https://the-eye.eu/public/AI/pile/train/03.jsonl.zst;https://the-eye.eu/public/AI/pile/train/04.jsonl.zst;https://the-eye.eu/public/AI/pile/train/05.jsonl.zst;https://the-eye.eu/public/AI/pile/train/06.jsonl.zst;https://the-eye.eu/public/AI/pile/train/07.jsonl.zst;https://the-eye.eu/public/AI/pile/train/08.jsonl.zst;https://the-eye.eu/public/AI/pile/train/09.jsonl.zst;https://the-eye.eu/public/AI/pile/train/10.jsonl.zst;https://the-eye.eu/public/AI/pile/train/11.jsonl.zst;https://the-eye.eu/public/AI/pile/train/12.jsonl.zst;https://the-eye.eu/public/AI/pile/train/13.jsonl.zst;https://the-eye.eu/public/AI/pile/train/14.jsonl.zst;https://the-eye.eu/public/AI/pile/train/15.jsonl.zst;https://the-eye.eu/public/AI/pile/train/16.jsonl.zst;https://the-eye.eu/public/AI/pile/train/17.jsonl.zst;https://the-eye.eu/public/AI/pile/train

In [None]:
def make_url_for_random_pile():
    random_pile_urls = "RANDOM:"
    for i in range(30):
        if i>0: random_pile_urls+=";"
        random_pile_urls+= "https://the-eye.eu/public/AI/pile/train/%02g.jsonl.zst" % i
    return random_pile_urls

def clean_stream_refinedweb(x):
    x['text'] = x['content']
    return x

def clean_stream_arxiv(x):
    x['text'] = x['abstract']
    return x

def clean_stream_pubmedsum(x):
    x['text'] = x['article']
    return x

def remove_first_http_url(text):
    """Removes http strings from hackersnews"""
    pattern = r'http[s]*://[^ ]+'
    return re.sub(pattern, '', text, 1)

def parse_hacker_news(text):
    """removes hackernews' thread separators ----- ===== ~~~ and removes urls"""
    return remove_first_http_url(" ".join([" ".join(j.split('\n')[1:]) for j in text.replace("------\n","~~~\n").replace("======\n","~~~\n").split("~~~\n")]))

def clean_hackernews(x):
    x['text'] = parse_hacker_news(x['text'])
    return x

def clean_ledgarmlm(x):
    x['text'] = x['provision']
    return x

def clean_casetextbook(example):
    # discards the first 8 percent
    discard = int(0.08*len(example['text']))
    example['text'] = example['text'][discard:].replace('\n'," ")
    return example

def clean_edgarcorpus(example):
    example['text'] = example['section_1'] + "\n" + example['section_2'] + "\n" + example['section_3'] + "\n" + example['section_7']
    return example

def clean_elseiver_mlm(example):
    example['text'] = example['Clean_Title'] + " - " + example['Clean_Summary'] + "\n" + example['Clean_Text']
    return example

def clean_financial_news_mlm(example):
    example['text'] = example['title'] + "\n" + example['text']
    return example

def filter_pileall_mlm(x):
    return x['meta']['pile_set_name'] in ['NIH ExPorter','OpenWebText2','PubMed Abstracts','StackExchange','Wikipedia (en)','ArXiv']

def filter_europarl_mlm(x):
    return len(x['text'])>60*7 # at least a small paragraphs

def clean_irs_advice_mlm(x):
    text = x['text']
    text = re.sub(r'^[\d,.%$+\-\s\=]+\n?$',"",text,flags=re.MULTILINE | re.DOTALL)
    text = re.sub(r'\-{10,}',"",text)
    text = re.sub(r'^(.*)?[Pp]age\s\d+\n?$',"",text,flags=re.MULTILINE)
    x['text'] = text.replace("\n"," ").strip()
    return x

def clean_secproceedings_mlm(x):
    text = x['text']
    if 'I.\n' in text:
        text = "".join(re.split(r"^I.\n", text, flags=re.MULTILINE)[1:])
    else:
        text = '\n'.join(text.split('\n')[10:])

    text = re.sub(r'^(\()*\d+[\.\)]?\n?$', '', text,flags=re.MULTILINE)

    x['text'] = text
    return x

mlm_streaming_cleaning_functions = {
    'EleutherAI/pile/all':(lambda x: x, filter_pileall_mlm, ['meta']),
    "tiiuae/falcon-refinedweb":(clean_stream_refinedweb, None, ['url', 'timestamp', 'dump', 'segment', 'image_urls','content']),
    "Cohere/wikipedia-22-12":(lambda x : x, None, ['id', 'title', 'url', 'wiki_id', 'views', 'paragraph_id', 'langs']),
    "Multi-Domain-Expert-Layers/the_pile_books3_packed_128k":(lambda x: x, None, ['meta']),
    "macrocosm/arxiv_abstracts":(clean_stream_arxiv, None, ['embeddings', 'doi','abstract']),
    "ccdv/pubmed-summarization":(clean_stream_pubmedsum, None, ['abstract','article']),
    "conceptofmind/pile_uspto_backgrounds":(lambda x : x ,None, ['meta']),
    "pile-of-law/pile-of-law/euro_parl":(lambda x : x, filter_europarl_mlm, ['created_timestamp', 'downloaded_timestamp', 'url']),
    #"philArchive": fails, but available as subset in eloukas/edgar-corpus as domain=='PhilPapers'
    "conceptofmind/pile_hacker_news":(clean_hackernews, None,['meta']),
    "https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst":(lambda x:x, None,['meta']),
    "https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip":(clean_ledgarmlm,None,['provision','source']),
    "pile-of-law/pile-of-law/r_legaladvice":(lambda x : x, None, ['created_timestamp', 'downloaded_timestamp', 'url']),
    "pile-of-law/pile-of-law/exam_outlines":(lambda x : x, None, ['created_timestamp', 'downloaded_timestamp', 'url']),
    "pile-of-law/pile-of-law/cc_casebooks":(clean_casetextbook, None, ['created_timestamp', 'downloaded_timestamp', 'url']), # clean_casetextbook
    "eloukas/edgar-corpus":(
        clean_edgarcorpus, None, [
            'filename', 'cik', 'year', 'section_1A', 'section_1B', 'section_4', 'section_1', 'section_2', 'section_3', 'section_7',
            'section_5', 'section_6', 'section_8', 'section_9', 'section_10', 'section_7A', 'section_9A', 'section_9B',
            'section_11', 'section_12', 'section_13', 'section_14', 'section_15'
        ]),
    "Rahmaa/ElsevieR_ClEaN":(clean_elseiver_mlm, None, ['Unnamed: 0', 'Clean_Title', 'Clean_Text', 'Clean_Summary']),
    'ashraq/financial-news-articles':(clean_financial_news_mlm, None, ['title','url']),
    'pile-of-law/pile-of-law/courtlistener_opinions':(lambda x : x, None, ['created_timestamp', 'downloaded_timestamp', 'url']),
    "pile-of-law/pile-of-law/sec_administrative_proceedings":(clean_secproceedings_mlm, None, ['created_timestamp', 'downloaded_timestamp', 'url']),
    "pile-of-law/pile-of-law/irs_legal_advice_memos":(clean_irs_advice_mlm, None, ['created_timestamp', 'downloaded_timestamp', 'url'])
}

# entries: url, subset, probability, size, option(name of postprocess subsetting), shuffle?
mlm_files = [
    ('EleutherAI/pile','all', 18.21, 700000, "mlm", (30, 1000000)), # 10.01 + 5.13 + 3.07 has 30 files (each with? millions of examples)
    ("tiiuae/falcon-refinedweb", None, 18.11, 968000000, "mlm", (5534, 174000)), # CC; has 5534 files as parquet (each with ~174919)
    ("Cohere/wikipedia-22-12", 'en', 14.40, 8590000, "mlm",(351, 100000)), # wikipedia has 351 files (each with 100000 examples)
    ("Multi-Domain-Expert-Layers/the_pile_books3_packed_128k", None, 12.07, 34500, "mlm", (15, 9900)), # has 15 files (each with with ~9978/9983)
    ("macrocosm/arxiv_abstracts", None, 0, 2250000, "mlm", False), # set to zero because in PILE (has 23 parquet files)
    ("ccdv/pubmed-summarization", None, 0, 120000, "mlm", False), # 3.75 set to zero because elsiever and pubmed in Pile below
    ("conceptofmind/pile_uspto_backgrounds", None, 3.00, 11000000, "mlm",(139, 80000)), # has 139 Files (each with 80024)
    ("pile-of-law/pile-of-law",'euro_parl', 2.0, 7254, "mlm", False),
    ("conceptofmind/pile_hacker_news", None, 2.0, 1570000, "mlm", (20, 78500)), # has 20 files (each wit ~78599 or 78598)
    ("https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst", None, 1.0, 985651, "mlm", False), # change to 1 # fails?
    ("https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", None, 5.0, 1000000, "mlm", False),
    ("pile-of-law/pile-of-law",'r_legaladvice', 1.0, 109740, "mlm", False),
    ("pile-of-law/pile-of-law",'exam_outlines', 0.1, 12, "mlm",False), # useless (but interesting)
    ("pile-of-law/pile-of-law",'cc_casebooks',0.5, 59 ,"mlm",False),
    ("eloukas/edgar-corpus", "full", 5.0, 47000, "mlm",(28, 4000)), # has 28 files each with 1k-5k (variable amount of data: 1styear 1060 vs 5508 in 2018
    ("Rahmaa/ElsevieR_ClEaN", None, 2.0, 31600, "mlm", False),
    ('ashraq/financial-news-articles', None, 1.0, 306000, "mlm", (2, 153100)), # has 2 files (each with 153121)
    ('pile-of-law/pile-of-law','courtlistener_opinions',  3.0, 1000000 , "mlm", (16, 229000)), # has 16 files (each with 229678 to 526543)
    ('pile-of-law/pile-of-law',"sec_administrative_proceedings", 1.0, 10805, "mlm", False), # 118.4 MiB
    ('pile-of-law/pile-of-law',"irs_legal_advice_memos", 0.5, 442, "mlm", False), # 35.8 MiB
] # need to estimate the size of the def ['NIH ExPorter','OpenWebText2','PubMed Abstracts','StackExchange','Wikipedia (en)'] subset


# entries: url, subset, probability, size, option(name of postprocess subsetting), shuffle?
foo = [
    ('EleutherAI/pile','all', 0, 700000, "mlm", (30, 1000000)), # 10.01 + 5.13 + 3.07 has 30 files (each with? millions of examples)
    ("tiiuae/falcon-refinedweb", None, 0, 968000000, "mlm", (5534, 174000)), # CC; has 5534 files as parquet (each with ~174919)
    ("Cohere/wikipedia-22-12", 'en', 0, 8590000, "mlm",(351, 100000)), # wikipedia has 351 files (each with 100000 examples)
    ("Multi-Domain-Expert-Layers/the_pile_books3_packed_128k", None, 0, 34500, "mlm", (15, 9900)), # has 15 files (each with with ~9978/9983)
    ("macrocosm/arxiv_abstracts", None, 0, 2250000, "mlm", False), # set to zero because in PILE (has 23 parquet files)
    ("ccdv/pubmed-summarization", None, 0, 120000, "mlm", False), # 3.75 set to zero because elsiever and pubmed in Pile below
    ("conceptofmind/pile_uspto_backgrounds", None, 0, 11000000, "mlm",(139, 80000)), # has 139 Files (each with 80024)
    ("pile-of-law/pile-of-law",'euro_parl', 0, 7254, "mlm", False),
    ("conceptofmind/pile_hacker_news", None, 0, 1570000, "mlm", (20, 78500)), # has 20 files (each wit ~78599 or 78598)
    ("https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst", None, 0, 985651, "mlm", False), # change to 1
    ("https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", None, 0, 1000000, "mlm", False),
    ("pile-of-law/pile-of-law",'r_legaladvice', 0, 109740, "mlm", False),
    ("pile-of-law/pile-of-law",'exam_outlines', 0, 12, "mlm",False), # useless (but interesting)
    ("pile-of-law/pile-of-law",'cc_casebooks',0, 59 ,"mlm",False),
    ("eloukas/edgar-corpus", "full", 1, 47000, "mlm",(28, 4000)), # has 28 files each with 1k-5k (variable amount of data: 1styear 1060 vs 5508 in 2018
    ("Rahmaa/ElsevieR_ClEaN", None, 0, 31600, "mlm", False),
    ('ashraq/financial-news-articles', None, 1.0, 306000, "mlm", (2, 153100)), # has 2 files (each with 153121)
    ('pile-of-law/pile-of-law','courtlistener_opinions',  3.0, 1000000 , "mlm", (16, 229000)), # has 16 files (each with 229678 to 526543)
    ('pile-of-law/pile-of-law',"sec_administrative_proceedings", 1.0, 10805, "mlm", False), # 118.4 MiB
    ('pile-of-law/pile-of-law',"irs_legal_advice_memos", 0, 442, "mlm", False), # 35.8 MiB
] # need to estimate the size of the def ['NIH ExPorter','OpenWebText2','PubMed Abstracts','StackExchange','Wikipedia (en)'] subset


print([k[2] for k in mlm_files])
total_prob = sum([k[2] for k in mlm_files])
for url, f in zip(mlm_files,mlm_streaming_cleaning_functions.keys()):
    print("%0.3f" % (url[2]/total_prob) + "  "+ url[0] + " ||| " + f + '\n')

data_streaming_config = {
    'files':mlm_files,
    'val_size':2000,
    'min_seq_length':48,
    'max_seq_length':512,
    'max_chunk_size':6,
    'train_chunk_size':6000,
    'max_chunk_start':1000000,
    "seed":42,
}

def chunk_docs_into_chunks_and_sentences(
    list_of_strings,
    nlp=None,
    config_chunking=None,
    seed = 42,
    fieldname='text'
):
    """Splits long docs into chunks that do next exceet max_seq_len, as well as sentences for next-sentence-prediction """
    if nlp is None:
        nlp = spacy.load("en_core_web_sm")
        nlp.add_pipe("sentencizer")

    if config_chunking is None:
        config_chunking = {
            'max_seq_length':512,
            'min_seq_length':48,
            'max_chunk_size':6,
            'min_sentence_len':20,
            'seed':seed
        }
    else:
        config_chunking.update({'seed':seed})

    # initialize the example processor
    example_processor = ExampleProcessor(
        config=config_chunking, char_per_word = CHAR_PER_WORD, nlp =nlp
    )

    examples_static_chunks = [] # data for MLM objective
    examples_static_nextsentence = [] # data for next sentence task
    for i, example in enumerate(list_of_strings):
        # chunk text into ~512 text-strings, and sentences
        if isinstance(example,str):
            examples_processed = example_processor(text = example)
        elif isinstance(example,dict):
            examples_processed = example_processor(text = example[fieldname])
        # chunk, accept/reject, sentences
        example_parsed, do_accept, parsed_sentences = examples_processed.values()
        if do_accept:
            # mlm gets the text-strings chunked to size 512
            examples_static_chunks.extend(example_parsed)
            if len(parsed_sentences)>4:
                # sentences for next sentence prediction: make triplet of s1,s2,opposite, where opposites get label=1
                examples_static_nextsentence.extend(
                    convert_sequence_into_nextsentence_pairs(parsed_sentences)
                )

    return examples_static_chunks, examples_static_nextsentence

def nwords_quick(text):
    return len([w for w in text.split(" ") if len(w)>0])

def initialize_and_get_mlm_streaming_datasets(
    data_streaming_config,
    streaming_cleaning_functions,
    start_proportion = None,
    epoch=0,
    seed=42,
    path_to_val_cache = 'cache_val_mlm.pkl',
    path_to_train_cache_epoch = 'cache_train_mlm_%03g.pkl',
    do_check_english = True
):
    """Converts stream of unlabelled text data into static datasets for: MLM task and next-sentence-prediction task"""
    # list of files to stream
    files = data_streaming_config['files']
    # number of examples to take from stream for validation set
    val_size = data_streaming_config['val_size']
    # number of examples to take from stream for training set
    train_chunk_size = data_streaming_config['train_chunk_size']
    min_seq_len = data_streaming_config['min_seq_length']
    # normalization constant for normalizing the weights into probabilities
    probability_normalization_const = sum([x[2] for x in files])

    # where to initialize start-stream for training data
    if start_proportion is None:
        start_proportion = np.random.RandomState(seed+epoch).uniform()*0.95

    # reload cached files
    path_to_train_cache = None if not '%03g' in path_to_train_cache_epoch else path_to_train_cache_epoch % epoch
    do_make_valset = not os.path.isfile(path_to_val_cache)
    do_make_trainset = not os.path.isfile(path_to_train_cache)
    if not do_make_valset:
        print('RELOADING VAL-MLM SET: iter=%s' % path_to_val_cache)
        with open(path_to_val_cache,'rb') as pcon:
            datalist_val_mlm_static = pickle.load(pcon)
            datalist_val_sentences_static = pickle.load(pcon)
        print('VAL-MLM SET SIZE: %d' % len(datalist_val_mlm_static))
    else:
        datalist_val_mlm_static, datalist_val_sentences_static = [],[]
    if not do_make_trainset:
        print('RELOADING VAL-QA SET: iter=%s' % path_to_val_cache)
        with open(path_to_train_cache,'rb') as pcon:
            datalist_train_mlm_static = pickle.load(pcon)
            datalist_train_sentences_static = pickle.load(pcon)
        print('TRAIN-MLM EPOCH-%d SET SIZE: %d' % (epoch, len(datalist_train_mlm_static)))
    else:
        datalist_train_mlm_static, datalist_train_sentences_static = [],[]

    if (do_make_trainset or do_make_valset):

        # initialize the nlp-sentencizer for chunking
        nlp = spacy.load("en_core_web_sm")
        nlp.add_pipe("sentencizer")

        # loop through datasets
        for (mlm_nm, set_nm, prob, dataset_size, special_handling, partition_shuffle), dataset_key in zip(
            files, streaming_cleaning_functions.keys()
        ):
            if prob ==0:
                continue
            prob /= probability_normalization_const

            # get cleaning & filter functions for streaming data functionality
            clean_func, filter_func, removefeature_names = streaming_cleaning_functions[dataset_key]

            # set arguments for the load_dataset (huggingface repos)
            load_dataset_args = {
                'path':mlm_nm, 'name':set_nm, 'split':'train', 'streaming':True
            }
            # for other non-huggingface repos, path needs to be a "builder"
            if mlm_nm.endswith('.jsonl') or mlm_nm.endswith('.jsonl.zip') or mlm_nm.endswith('.jsonl.zst'):
                load_dataset_args.update({'path':'json','data_files':mlm_nm})

            # special proecssing of datasets with multiple partitions
            if bool(partition_shuffle): # or str(epoch)=='val':

                n_files, n_per_file = partition_shuffle
                dataset_size = n_per_file
                print('trying %s initialization (shuffling through %d files)' % (mlm_nm, n_files))

                # whether there is a filter
                if filter_func is None:
                    dset_stream = load_dataset(**load_dataset_args)
                else:
                    dset_stream = load_dataset(**load_dataset_args).filter(filter_func)

                # validation set
                if do_make_valset:
                    # take from stream
                    n_valset_take = max(int(prob*val_size), 1)
                    print('take %d from %s validation'% (n_valset_take, mlm_nm))
                    dset_stream_val = dset_stream.take(n_valset_take).map(clean_func).remove_columns(removefeature_names)
                    # convert stream to a static set (and check english language)
                    dset_static_val_thisset =[
                        e['text'] for e in dset_stream_val
                        if bool(re.search(r"\w+",e['text'][:200])) and (nwords_quick(e['text'][:1000])>min_seq_len)
                    ]
                # training set
                if do_make_trainset:
                    # randomly skip a bunch from this set
                    skip_to_start = int(start_proportion*n_per_file)
                    take_from_this_set = max(int(round(train_chunk_size*prob)),1)
                    print('take %d from %s training'% (take_from_this_set, mlm_nm))
                    # shuffle: take a random data partition (from the dataset's list of files)
                    dset_stream_train = dset_stream_val.shuffle(
                        seed = seed+epoch, buffer_size = skip_to_start+take_from_this_set,
                    )
                    dset_stream_train = dset_stream_train.skip(
                        skip_to_start # random skip through dataset to new start position
                    ).take(
                        take_from_this_set # take this amount for the training ste
                    ).map(clean_func).remove_columns(removefeature_names)
                    # convert training to static dataset
                    dset_static_train_thisset =[
                        e['text'] for e in dset_stream_train
                        if bool(re.search(r"\w+",e['text'][:200])) and (nwords_quick(e['text'][:1000])>min_seq_len)
                    ]
            else:
                # regular streaming
                print('trying %s initialization' % mlm_nm)
                # whether there is a filter
                if filter_func is None:
                    dset_stream = load_dataset(**load_dataset_args).map(clean_func).remove_columns(removefeature_names)
                else:
                    dset_stream = load_dataset(**load_dataset_args).filter(filter_func).map(clean_func).remove_columns(removefeature_names)
                # take from stream
                n_valset_take = max(int(prob*val_size), 1) # size of valset
                print('take %d from %s validation'% (n_valset_take, mlm_nm))
                skip_to_start = int(start_proportion*(dataset_size-n_valset_take)) # random point to skip to
                n_train_take = max(int(round(train_chunk_size*prob)),1) # size of train set
                print('take %d from %s train'% (n_train_take, mlm_nm))
                if do_make_valset:
                    dset_stream_val = dset_stream.take(n_valset_take)
                    dset_static_val_thisset = [
                        e['text'] for e in dset_stream_val
                        if bool(re.search(r"\w+",e['text'][:200])) and (nwords_quick(e['text'][:1100])>min_seq_len)
                    ]
                if do_make_trainset:
                    dset_stream_train = dset_stream.skip(n_valset_take+skip_to_start).take(n_train_take)
                    dset_static_train_thisset = [
                        e['text'] for e in dset_stream_train
                        if bool(re.search(r"\w+",e['text'][:200])) and (nwords_quick(e['text'][:1100])>min_seq_len)
                    ]
            print('Done getting streams/reloading from %s' % mlm_nm)
            # check language, chunk sentences
            if do_make_valset:
                # discard non-english
                dset_static_val_thisset =[
                    e for e in dset_static_val_thisset
                    if detect(e[:200]+" hello")=='en'
                ]
                print('done val language check')
                # chunk the docs (512-tokens and next-sentence prediction sentences)
                dset_val_chunked_for_mlm, dset_val_nextsentence = chunk_docs_into_chunks_and_sentences(
                    list_of_strings=dset_static_val_thisset,
                    config_chunking=copy.deepcopy(data_streaming_config),
                    seed=seed+epoch,
                    nlp=nlp
                )
                print('done val longtext chunking')
                # add to val set
                datalist_val_mlm_static.extend(dset_val_chunked_for_mlm)
                datalist_val_sentences_static.extend(dset_val_nextsentence)

            # check language, chunk sentences
            if do_make_trainset:
                # discard non-english
                dset_static_train_thisset =[
                    e for e in dset_static_train_thisset
                    if detect(e[:200] +" hello")=='en'
                ]
                print('done train language check')
                # chunk the docs (512-tokens and next-sentence prediction sentences)
                dset_train_chunked_for_mlm, dset_train_nextsentence = chunk_docs_into_chunks_and_sentences(
                    list_of_strings=dset_static_train_thisset,
                    config_chunking=copy.deepcopy(data_streaming_config),
                    seed=seed+epoch,
                    nlp=nlp
                )
                print('done trains longtext chunking')

                # ensure that none of the examples in the traning set are in the validation set
                if do_make_valset:
                    dset_train_chunked_for_mlm = [
                        s for s in dset_train_chunked_for_mlm
                        if s not in dset_val_chunked_for_mlm
                    ]
                    dset_train_nextsentence = [
                        tlt for tlt in dset_train_nextsentence
                        if (
                            tlt['anchor'] not in [
                                vtlt['anchor'] for vtlt in dset_val_nextsentence
                            ]
                        )
                    ]

                # add to training set
                datalist_train_mlm_static.extend(dset_train_chunked_for_mlm)
                datalist_train_sentences_static.extend(dset_train_nextsentence)

        print('Done collecting streaming data')

    if do_make_valset:
        print('saving streamed validation data: %s' % path_to_val_cache)
        with open(path_to_val_cache,'wb') as pcon:
            pickle.dump(datalist_val_mlm_static, pcon)
            pickle.dump(datalist_val_sentences_static, pcon)
    if do_make_trainset:
        print('saving streamed training for epoch %d: %s' % (epoch, path_to_train_cache))
        with open(path_to_train_cache,'wb') as pcon:
            pickle.dump(datalist_train_mlm_static, pcon)
            pickle.dump(datalist_train_sentences_static, pcon)

    return {
        'train':{
            'mlm':datalist_train_mlm_static,
            'nextsentence':datalist_train_sentences_static
        },
        'val':{
            'mlm':datalist_val_mlm_static,
            'nextsentence':datalist_val_sentences_static
        },
        'epoch':epoch,
        'index_stream':start_proportion
    }



In [None]:
data_streaming_config_mlm = {
    'files':mlm_files,
    'val_size':200, #2000,
    'min_seq_length':48,
    'max_seq_length':512,
    'max_chunk_size':6,
    'train_chunk_size':400, #6000,
    'max_chunk_start':1000000,
    "seed":42,
}

dataset_static_mlm = initialize_and_get_mlm_streaming_datasets(
    data_streaming_config=data_streaming_config_mlm,
    streaming_cleaning_functions=mlm_streaming_cleaning_functions,
    start_proportion = None,
    epoch=0,
    seed=42,
    path_to_val_cache = 'cache_val_mlm.pkl',
    path_to_train_cache_epoch = 'cache_train_mlm_%03g.pkl',
    do_check_english = True
)

trying EleutherAI/pile initialization (shuffling through 30 files)


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/EleutherAI--pile/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/EleutherAI--pile/ebea56d358e91cf4d37b0fde361d563bed1472fbd8221a21b38fc8bb4ba554fb


take 40 from EleutherAI/pile validation
take 81 from EleutherAI/pile training
Done getting streams/reloading from EleutherAI/pile
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying tiiuae/falcon-refinedweb initialization (shuffling through 5534 files)


Resolving data files:   0%|          | 0/5534 [00:00<?, ?it/s]

Using custom data configuration default-33459d9d641983a0
INFO:datasets.builder:Using custom data configuration default-33459d9d641983a0


take 40 from tiiuae/falcon-refinedweb validation
take 81 from tiiuae/falcon-refinedweb training
Done getting streams/reloading from tiiuae/falcon-refinedweb
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying Cohere/wikipedia-22-12 initialization (shuffling through 351 files)


https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/wikipedia-22-12.py not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d98f3fcb27851f1eaa458912bb48ed30fdd.py.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/wikipedia-22-12.py not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d98f3fcb27851f1eaa458912bb48ed30fdd.py.incomplete


Downloading builder script:   0%|          | 0.00/3.39k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/wikipedia-22-12.py in cache at /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d98f3fcb27851f1eaa458912bb48ed30fdd.py
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/wikipedia-22-12.py in cache at /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d98f3fcb27851f1eaa458912bb48ed30fdd.py
creating metadata file for /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d98f3fcb27851f1eaa458912bb48ed30fdd.py
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/180c6137cd9094d23825845c08f6900920a553ac7e54194a637614a5560de97a.68c8a9f3e2f910fbcfa2c448b2952d

Downloading readme:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/README.md in cache at /root/.cache/huggingface/datasets/downloads/be2fa19ec21cc54472c197e371cb0104c1c77b4ceb374848b3728232bbf250b2.bbc2711fab962231ef931afb17ad75612451b0ed6bc8a0c4eb49ef9597bec483
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/Cohere/wikipedia-22-12/resolve/main/README.md in cache at /root/.cache/huggingface/datasets/downloads/be2fa19ec21cc54472c197e371cb0104c1c77b4ceb374848b3728232bbf250b2.bbc2711fab962231ef931afb17ad75612451b0ed6bc8a0c4eb49ef9597bec483
creating metadata file for /root/.cache/huggingface/datasets/downloads/be2fa19ec21cc54472c197e371cb0104c1c77b4ceb374848b3728232bbf250b2.bbc2711fab962231ef931afb17ad75612451b0ed6bc8a0c4eb49ef9597bec483
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/be2fa19ec21cc54472c197e371cb0104c1c77b4ceb374848b3728232bbf250b2.bbc2711fab962231ef931afb17ad75612451b0ed6bc8a0c4eb49ef959

take 32 from Cohere/wikipedia-22-12 validation
take 64 from Cohere/wikipedia-22-12 training
Done getting streams/reloading from Cohere/wikipedia-22-12
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying Multi-Domain-Expert-Layers/the_pile_books3_packed_128k initialization (shuffling through 15 files)


https://huggingface.co/datasets/Multi-Domain-Expert-Layers/the_pile_books3_packed_128k/resolve/8383ca80b6c70bf4f2c4f1d1aade7c3547d5515d/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/6b1f219e0080bb3df6e4a2e774919dace5e7b739a3ab0a14ad08caeefe55648d.f87f0d6639b6b775e27dd780337041f798bcec78baac4a2465f8dba7e6f26a37.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/Multi-Domain-Expert-Layers/the_pile_books3_packed_128k/resolve/8383ca80b6c70bf4f2c4f1d1aade7c3547d5515d/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/6b1f219e0080bb3df6e4a2e774919dace5e7b739a3ab0a14ad08caeefe55648d.f87f0d6639b6b775e27dd780337041f798bcec78baac4a2465f8dba7e6f26a37.incomplete


Downloading readme:   0%|          | 0.00/795 [00:00<?, ?B/s]

storing https://huggingface.co/datasets/Multi-Domain-Expert-Layers/the_pile_books3_packed_128k/resolve/8383ca80b6c70bf4f2c4f1d1aade7c3547d5515d/README.md in cache at /root/.cache/huggingface/datasets/downloads/6b1f219e0080bb3df6e4a2e774919dace5e7b739a3ab0a14ad08caeefe55648d.f87f0d6639b6b775e27dd780337041f798bcec78baac4a2465f8dba7e6f26a37
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/Multi-Domain-Expert-Layers/the_pile_books3_packed_128k/resolve/8383ca80b6c70bf4f2c4f1d1aade7c3547d5515d/README.md in cache at /root/.cache/huggingface/datasets/downloads/6b1f219e0080bb3df6e4a2e774919dace5e7b739a3ab0a14ad08caeefe55648d.f87f0d6639b6b775e27dd780337041f798bcec78baac4a2465f8dba7e6f26a37
creating metadata file for /root/.cache/huggingface/datasets/downloads/6b1f219e0080bb3df6e4a2e774919dace5e7b739a3ab0a14ad08caeefe55648d.f87f0d6639b6b775e27dd780337041f798bcec78baac4a2465f8dba7e6f26a37
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datas

take 26 from Multi-Domain-Expert-Layers/the_pile_books3_packed_128k validation
take 54 from Multi-Domain-Expert-Layers/the_pile_books3_packed_128k training
Done getting streams/reloading from Multi-Domain-Expert-Layers/the_pile_books3_packed_128k
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying conceptofmind/pile_uspto_backgrounds initialization (shuffling through 139 files)


https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/e0f63b46cd575a4a979ee781d2fdc18b71e942de/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/5465db393b952a7915729478c4d895cc516c99b62f5012dbf5bba5b27bafba92.b15a39e499c2f40936f1c9f678cef5392d0f3a081478242bf2c46dbba34381fe.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/e0f63b46cd575a4a979ee781d2fdc18b71e942de/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/5465db393b952a7915729478c4d895cc516c99b62f5012dbf5bba5b27bafba92.b15a39e499c2f40936f1c9f678cef5392d0f3a081478242bf2c46dbba34381fe.incomplete


Downloading readme:   0%|          | 0.00/908 [00:00<?, ?B/s]

storing https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/e0f63b46cd575a4a979ee781d2fdc18b71e942de/README.md in cache at /root/.cache/huggingface/datasets/downloads/5465db393b952a7915729478c4d895cc516c99b62f5012dbf5bba5b27bafba92.b15a39e499c2f40936f1c9f678cef5392d0f3a081478242bf2c46dbba34381fe
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/e0f63b46cd575a4a979ee781d2fdc18b71e942de/README.md in cache at /root/.cache/huggingface/datasets/downloads/5465db393b952a7915729478c4d895cc516c99b62f5012dbf5bba5b27bafba92.b15a39e499c2f40936f1c9f678cef5392d0f3a081478242bf2c46dbba34381fe
creating metadata file for /root/.cache/huggingface/datasets/downloads/5465db393b952a7915729478c4d895cc516c99b62f5012dbf5bba5b27bafba92.b15a39e499c2f40936f1c9f678cef5392d0f3a081478242bf2c46dbba34381fe
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/5465db393b952a79157294

Resolving data files:   0%|          | 0/139 [00:00<?, ?it/s]

https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/main/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e7ac126879d91e75fb7f6e448df6371f1cadb7a21fd1667aedd198.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/main/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e7ac126879d91e75fb7f6e448df6371f1cadb7a21fd1667aedd198.incomplete


Downloading metadata:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/main/dataset_infos.json in cache at /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e7ac126879d91e75fb7f6e448df6371f1cadb7a21fd1667aedd198
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/conceptofmind/pile_uspto_backgrounds/resolve/main/dataset_infos.json in cache at /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e7ac126879d91e75fb7f6e448df6371f1cadb7a21fd1667aedd198
creating metadata file for /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e7ac126879d91e75fb7f6e448df6371f1cadb7a21fd1667aedd198
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/4ba1e4f13436ce6a22f3ba1c0270b392d2752ef2e7310085b4f9e5655aeb2087.11834c1546e

take 6 from conceptofmind/pile_uspto_backgrounds validation
take 13 from conceptofmind/pile_uspto_backgrounds training
Done getting streams/reloading from conceptofmind/pile_uspto_backgrounds
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 4 from pile-of-law/pile-of-law validation
take 9 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.euro_parl.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.euro_parl.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying conceptofmind/pile_hacker_news initialization (shuffling through 20 files)


https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/7051165c182ce2740056a6a446b8e035b1504173/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5a354dc29076c6639e2a6b8cdc87717.1804517b49993b0d00718832783d06bca3eab878fa82ea73775c686c050564ca.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/7051165c182ce2740056a6a446b8e035b1504173/README.md not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5a354dc29076c6639e2a6b8cdc87717.1804517b49993b0d00718832783d06bca3eab878fa82ea73775c686c050564ca.incomplete


Downloading readme:   0%|          | 0.00/828 [00:00<?, ?B/s]

storing https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/7051165c182ce2740056a6a446b8e035b1504173/README.md in cache at /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5a354dc29076c6639e2a6b8cdc87717.1804517b49993b0d00718832783d06bca3eab878fa82ea73775c686c050564ca
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/7051165c182ce2740056a6a446b8e035b1504173/README.md in cache at /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5a354dc29076c6639e2a6b8cdc87717.1804517b49993b0d00718832783d06bca3eab878fa82ea73775c686c050564ca
creating metadata file for /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5a354dc29076c6639e2a6b8cdc87717.1804517b49993b0d00718832783d06bca3eab878fa82ea73775c686c050564ca
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/d53d88f53a0562ce380fa3ae041306d7a5

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/main/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf91619dea8e585a4021a283d18aec1a02b243891d6bb86.incomplete
INFO:datasets.utils.file_utils:https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/main/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf91619dea8e585a4021a283d18aec1a02b243891d6bb86.incomplete


Downloading metadata:   0%|          | 0.00/1.73k [00:00<?, ?B/s]

storing https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/main/dataset_infos.json in cache at /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf91619dea8e585a4021a283d18aec1a02b243891d6bb86
INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/conceptofmind/pile_hacker_news/resolve/main/dataset_infos.json in cache at /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf91619dea8e585a4021a283d18aec1a02b243891d6bb86
creating metadata file for /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf91619dea8e585a4021a283d18aec1a02b243891d6bb86
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/84a14297df0a60fecfc2018c3071696af7efd080438a29bb86756649db5966c9.aef2f1447a1ac3c8ddf9161

take 4 from conceptofmind/pile_hacker_news validation
take 9 from conceptofmind/pile_hacker_news training
Done getting streams/reloading from conceptofmind/pile_hacker_news
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst initialization


Using custom data configuration default-7266146d08ae7164
INFO:datasets.builder:Using custom data configuration default-7266146d08ae7164
Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json


take 2 from https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst validation
take 4 from https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst train
Done getting streams/reloading from https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip initialization


Using custom data configuration default-62fbd92a495ab431
INFO:datasets.builder:Using custom data configuration default-62fbd92a495ab431
Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json


take 11 from https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip validation
take 22 from https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip train
Done getting streams/reloading from https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 2 from pile-of-law/pile-of-law validation
take 4 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.r_legaldvice.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.r_legaldvice.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 1 from pile-of-law/pile-of-law validation
take 1 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.examoutlines.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.examoutlines.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 1 from pile-of-law/pile-of-law validation
take 2 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.cc_casebooks.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.cc_casebooks.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying eloukas/edgar-corpus initialization (shuffling through 28 files)


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/eloukas--edgar-corpus/c2f9ada1db31915d6af4cc19f0ad9486cd0bab93c5c26bb32850e5a1f74f2bd7
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/eloukas--edgar-corpus/c2f9ada1db31915d6af4cc19f0ad9486cd0bab93c5c26bb32850e5a1f74f2bd7


take 11 from eloukas/edgar-corpus validation
take 22 from eloukas/edgar-corpus training
Done getting streams/reloading from eloukas/edgar-corpus
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying Rahmaa/ElsevieR_ClEaN initialization


Using custom data configuration default-20220b2cf5c24f3d
INFO:datasets.builder:Using custom data configuration default-20220b2cf5c24f3d
Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/csv
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/csv


take 4 from Rahmaa/ElsevieR_ClEaN validation
take 9 from Rahmaa/ElsevieR_ClEaN train
Done getting streams/reloading from Rahmaa/ElsevieR_ClEaN
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying ashraq/financial-news-articles initialization (shuffling through 2 files)


Using custom data configuration default-1b76e4a10823adc2
INFO:datasets.builder:Using custom data configuration default-1b76e4a10823adc2


take 2 from ashraq/financial-news-articles validation
take 4 from ashraq/financial-news-articles training
Done getting streams/reloading from ashraq/financial-news-articles
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization (shuffling through 16 files)


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 6 from pile-of-law/pile-of-law validation
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.courtlisteneropinions.0.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.courtlisteneropinions.1.jsonl.xz


Exception ignored in: <generator object PileOfLaw._generate_examples at 0x7fa601563990>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 207, in __iter__
    yield from self.generate_examples_fn(**self.kwargs)
RuntimeError: generator ignored GeneratorExit
Exception ignored in: <generator object ExamplesIterable.__iter__ at 0x7fa601563680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 207, in __iter__
    yield from self.generate_examples_fn(**self.kwargs)
RuntimeError: generator ignored GeneratorExit


take 13 from pile-of-law/pile-of-law training
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.courtlisteneropinions.0.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.courtlisteneropinions.1.jsonl.xz


Exception ignored in: <generator object PileOfLaw._generate_examples at 0x7fa601562dc0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 207, in __iter__
    yield from self.generate_examples_fn(**self.kwargs)
RuntimeError: generator ignored GeneratorExit
Exception ignored in: <generator object ExamplesIterable.__iter__ at 0x7fa601562e30>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 207, in __iter__
    yield from self.generate_examples_fn(**self.kwargs)
RuntimeError: generator ignored GeneratorExit


Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 2 from pile-of-law/pile-of-law validation
take 4 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.sec.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.sec.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
trying pile-of-law/pile-of-law initialization


Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


take 1 from pile-of-law/pile-of-law validation
take 2 from pile-of-law/pile-of-law train
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.irs_legal_advice_memos.jsonl.xz
Error reading file: https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.irs_legal_advice_memos.jsonl.xz
Done getting streams/reloading from pile-of-law/pile-of-law
done val language check
done val longtext chunking
done train language check
done trains longtext chunking
Done collecting streaming data
saving streamed validation data: cache_val_mlm.pkl
saving streamed training for epoch 0: cache_train_mlm_000.pkl


In [None]:
!rm *.pkl
np.random.choice(dataset_static_mlm['train']['nextsentence'])#[-2]

rm: cannot remove '*.pkl': No such file or directory


{'anchor': 'The next item is the report by Marit Paulsen, on behalf of the Committee on Agriculture and Rural Development, on the evaluation and assessment of the Animal Welfare Action Plan 2006-2010.',
 'next': 'Mr President, animal welfare is, in fact, something that most of the citizens of Europe care about. Animal welfare is not just about animals.',
 'opposite': "This obviously envisages both the allocation of appropriate funds to this area and the efficient use of the support opportunities offered by them so that Member States can invest in modern, innovative solutions intended for the benefit of the animals' welfare."}

In [None]:
%%time
# no longer works: the_pile_books3: maybe SaylorTwift/the_pile_books3_minus_gutenberg
# the_pile_stack_exchange : is down, maybe use: donfu/oa-stackexchange (but it is badly sorted) or # teven/stackexchange (but it has other languages)
# openwebtext2 : vietgpt/the_pile_openwebtext2
# 'EleutherAI/pile','pubmed_central' dead
# 'EleutherAI/pile','free_law' failes
# 'EleutherAI/pile','nih_exporter', dead
# 'EleutherAI/pile','hacker_news', dead
#foo = load_dataset('json',data_files = "https://the-eye.eu/public/AI/pile_neox/data/PhilPapersDataset_text_document.bin",split='train') # fails

# still works with all; # "EleutherAI/pile", split=""
#foo = load_dataset('json',data_files= "https://the-eye.eu/public/AI/pile_v2/data/NIH_ExPORTER_awarded_grant_text.jsonl.zst", split="train", streaming=True)
#foo = load_dataset('text',data_files = "https://the-eye.eu/public/AI/pile_neox/data/PhilPapersDataset_text_document.bin",split='train', encoding='latin-1',streaming=True)
#foo = load_dataset('the_pile_books3',split='train',streaming=True) # fails
#foo = load_dataset('bookcorpusopen',split='train',streaming=True) # fails: FileNotFoundError: https://the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz
#foo =  load_dataset('hieule/vie-book-v2',split='train',streaming=True) # fails
#foo =  load_dataset('Cohere/wikipedia-22-12',split='train',streaming=True) # not working recently, but it could be a temporary or HF thing
#foo =  load_dataset('Multi-Domain-Expert-Layers/the_pile_books3_packed_128k',split='train',streaming=True) # not working recently, but it could be a temporary or HF thing
#foo = load_dataset('pile-of-law/pile-of-law','courtlistener_opinions',split='train',streaming=False) # works MASSIVE
#foo = load_dataset('ArmelR/sharded-pile',split='train',streaming=True).filter(lambda x : x['domain']=='PhilPapers') # no, just keeps spinning :(
#foo = load_dataset('pile-of-law/pile-of-law','tax_rulings',split='train',streaming=False) ## This text is ugly. no paragraph breaks. Replace _ with -
#foo = load_dataset('pile-of-law/pile-of-law','sec_administrative_proceedings',split='train',streaming=False) # split .split("I.\n")[1:]
#foo = load_dataset('pile-of-law/pile-of-law','irs_legal_advice_memos',split='train',streaming=False) # works MASSIVE
#foo =load_dataset("pile-of-law/pile-of-law",'euro_parl', split='train',streaming=False)
#foo =load_dataset("json",data_files='https://the-eye.eu/public/AI/pile/train/00.jsonl.zst',split='train',streaming=True)
#print(len(foo))



# option 1: use shuffle, just skip 60k (i.e.)
# CONCLUSION, I think the best option is to: shuffle, then deliberately MARK DOWN the total number of samples (i.e,. only load within a partition)
skip=2000
take=20
#bar = load_dataset('EleutherAI/pile','all', split='train',streaming=True).shuffle(skip+take).skip(skip).take(take)
#bar = load_dataset("Cohere/wikipedia-22-12",'en', split='train',streaming=True).shuffle(skip+take).skip(skip).take(take)
#bar = load_dataset("Multi-Domain-Expert-Layers/the_pile_books3_packed_128k", split='train',streaming=True).shuffle(skip+take).skip(skip).take(take)
bar = load_dataset('eloukas/edgar-corpus', 'full', split='train',streaming=True).shuffle(skip+take).skip(skip).take(take)
for i,e in enumerate(bar):
    print((detect(e['section_1'][:200]), e['section_1'][:100].replace("\n","")))
    if (i+1)<take:
        continue
    break

#
#CPU times: user 6.26 s, sys: 392 ms, total: 6.66 s (shuffle) (pile)
#Wall time: 26.6 s (shuffle) (pile)
#CPU times: user 6.2 s, sys: 329 ms, total: 6.53 s (no shuffle) (pile)
#Wall time: 32.2 s (no shuffle) (pile)

#CPU times: user 1.4 s, sys: 38.6 ms, total: 1.43 s (shuffle cohere)
#Wall time: 7.87 s (shuffle cohere)
#CPU times: user 2.21 s, sys: 43 ms, total: 2.25 s (no shuffle cohere)
#Wall time: 11.1 s (no shuffle cohere)

#CPU times: user 16.5 s, sys: 1.71 s, total: 18.2 s (shuffle book3)
#Wall time: 55.3 s (shuffle book3)
#CPU times: user 13.9 s, sys: 1.94 s, total: 15.8 s (no shuffle book3)
#Wall time: 2min 58s (no shuffle book3)

#CPU times: user 4.59/5.01 s s, sys: 245 ms, total: 4.83 s (shuffle edgar)
#Wall time: 24.6/21.7 s (shuffle edgar)
#CPU times: user 3.7 s, sys: 171 ms, total: 3.87 s (no shuffle edgar)
#Wall time: 17.6 s (no shuffle edgar)


('en', 'ITEM 1. BUSINESS.IntroductionUAL Corporation ("UAL" or the "Company") was incorporated under the l')
('en', 'ITEM 1. BUSINESSGENERAL - ------- Health Care REIT, Inc. (the "Company"), founded in 1970, is a rea')
('en', 'Item 1. Business.Rockefeller Center Properties, Inc. (the "Company") was incorporated in Delaware o')
('en', 'ITEM 1. BUSINESSERLY Industries Inc., (the "Company" or "ERLY"), incorporated in California in 1964')
('en', 'Item 1. Business. - ------ --------General DescriptionThe St. Paul Companies, Inc. ("The St. Paul"')
('en', 'Item 1. Description of BusinessAmstar Corporation (the "Company" or "Amstar") is a privately held D')
('en', 'Item 1. Business.GENERALPacific Bell (the "Company") was incorporated in 1906 under the laws of th')
('en', 'ITEM 1. BUSINESSTHE PARTNERSHIP. Jones Cable Income Fund 1-A, Ltd. (the "Partnership") is a Colorad')
('en', 'Item 1. Business - ------- --------General - -------The Stride Rite Corporation is the leading mar')
('en', "ITEM 1

### Q&A Triplets!

Here I make a triplet dataset of query, positive answer, and negatives (if available)

B) QA Tasks
- squad_2
- WikiHow - used by S-BERT (questions and articles) - needs to be manually downloaded - https://github.com/mahnazkoupaee/WikiHow-Dataset/
- trivia_qa - 680 question, ans, evidence triplets. But, the context strings are very long (like wikipedia) and the questions are almost pop culture
- LLukas22/fiqa - financial QA, like conversations
- embedding-data/WikiAnswers - question-duplicates as paraphrases
- embedding-data/QQP_triplets - question-duplicates plus negatives (Quora)
- LLukas22/lfqa_preprocessed - question and answers 226k (from REDDIT)
- DONE gbharti/finance-alpaca (like FIQA - finance Q&A) on 14k?
- DONE embedding-data/PAQ_pairs - wikipedia question & answers
- GONE the_pile_stack_exchange - single texts, but can be split into question, answer
- DONE donfu/oa-stackexchange - 6.3 million (AND GROWING -- must monitor)
- cais/mmlu - multiple choice, but some of the answers are longers (need to filter)
- DONE sciq - science questions - see question and support
- DONE wiki_qa - wikipedia QA
- qasc - high-school questions - can combine the "facts" into a support
- pubmed_qa - science QA with answers
- DONE JoBeer/eclassTrainST - can easily convert into question-answer pairs
- dictinonary -

In [None]:
#JoBeer/eclassTrainST
#foo =  load_dataset('gart-labor/eclassTrainST',split='train',streaming=True).map(clean_eclassTrainST).remove_columns(['text', 'entailment', 'contradiction', 'label'])
#foo =  load_dataset('gbharti/finance-alpaca',split='train',streaming=True)  # good, financial questions
#foo =  load_dataset('gart-labor/eclassTrainST',split='train',streaming=True) # NAD; just for paraphrased questions, not for QA
# foo =  load_dataset('parquet',data_files = 'https://huggingface.co/datasets/gart-labor/eclassTrainST/resolve/main/data/eval-00001-of-00001-d8aa08935841e6a9.parquet',split='train',streaming=False) # NAD; just for paraphrased questions, not for QA
#foo =  load_dataset('wiki_qa',split='train',streaming=True) # excellent; with negatives and positives
#foo =  load_dataset('THUDM/webglm-qa',split='train',streaming=True) # excellent; with negatives and positives
foo = load_dataset("sciq",split='train',streaming=False) #
#add definitions
if True:
    # embedding-data/WikiAnswers
    for j,e in enumerate(foo):
        print(e)
        if j > 10:
          break
    print(e)
    print(e.keys())

Downloading builder script:   0%|          | 0.00/3.56k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11679 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?', 'distractor3': 'viruses', 'distractor1': 'protozoa', 'distractor2': 'gymnosperms', 'correct_answer': 'mesophilic organisms', 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.'}
{'question': 'What phenomenon makes global winds blow northeast to southwest or the reverse in the northern hemisphere and northwest to southeast or the reverse in the southern hemisphere?', 'distractor3': 'tropical effect', 'distractor1': 'muon effect', 'distractor2': 'centrifugal effect', 'correct_answer': 'coriolis effect', 'support': 'Without Coriolis Effe

In [None]:
e.keys()

dict_keys(['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'])

In [None]:
from torch.utils import data as torch_data
from rank_bm25 import BM25Okapi
import pandas as pd
import os

In [None]:
STACKEXCHANGE_NONQUANT_DOMAINS = [
    "stackexchange-"+k for k in [
        "academia",
        "aviation",
        "bicycles",
        "biology",
        "buddhism",
        "chemistry",
        "chess",
        "christianity",
        "coffee",
        "cogsci",
        "cooking",
        "crafts",
        "cseducators",
        "diy",
        "drones",
        "earthscience",
        "ebooks",
        "electronics",
        "english",
        "expatriates",
        "fitness",
        "freelancing",
        "gardening",
        "gaming",
        "genealogy",
        "ham",
        "hardwarerecs",
        "health",
        "hinduism",
        "history",
        "homebrew",
        "hsm",
        "interpersonal",
        "iot",
        "islam",
        "judaism",
        "law",
        "lifehacks",
        "linguistics",
        "literature",
        "martialarts",
        "materials",
        "mechanics",
        "moderators",
        "money",
        "music",
        "mythology",
        "outdoors",
        "parenting",
        "patents",
        "pets",
        "philosophy",
        "pm",
        "politics",
        "security",
        "skeptics",
        "softwarerecs",
        "sustainability",
        "travel",
        "vegetarianism",
        "woodworking",
        "workplace",
        "worldbuilding",
        "writers"
        ]
    ]

def clean_webglmqa(x):
    x['query']=x['question']
    x['positives'] = [x['answer']]
    x['negatives'] = []
    x['type'] = 'qa_triplet'
    return x

def clean_stream_PAQ_pairs(x):
    x['query'] = x['set'][0]
    x['positives'] = [x['set'][1]]
    x['negatives'] = []
    x['type'] = 'qa_triplet'
    return x

def clean_stream_finance_alpaca(x):
    x['query'] = x['instruction']
    x['positives'] = [x['output']]
    x['negatives'] = []
    x['type'] = 'qa_triplet'
    return x

def clean_stream_wiki_qa(x):
    x['query'] = x['question']
    is_pos = x['label']
    answer = x['answer']
    pos = [answer] if is_pos else []
    neg = [answer] if (not is_pos) else []
    x['positives'] = pos
    x['negatives'] = neg
    x['type'] = 'qa_triplet'
    return x

def clean_stream_oa_stackexchange(x):
    x['query'] = x['INSTRUCTION']
    x['positives'] = [x['RESPONSE']]
    x['negatives'] = []
    x['type'] = 'qa_triplet'
    return x

def clean_stream_sciqa(x):
    x['query'] = x['question']
    x['positives'] = [x['support']]
    x['negatives'] = []
    x['type'] = 'qa_triplet'
    return x

def filter_os_stackexchange(x):
    return x['SOURCE'] in STACKEXCHANGE_NONQUANT_DOMAINS

def get_name_and_description_eclassTrainST(text):
    description, name = text.split("; Name:")
    return description.replace("Description: ","").strip(), name.strip()

def clean_eclassTrainST(x):
    """This set isn't really about entailment/contradiction; it is really a dictionary"""
    description, name = get_name_and_description_eclassTrainST(x['text'])
    pos, _ = get_name_and_description_eclassTrainST(x['entailment'])
    extra, _ = get_name_and_description_eclassTrainST(x['contradiction'])
    x['query'] = 'What is a "%s"?' % name
    x['positives'] = [pos]
    # add the entailment as positive, contradiction as negatives
    if x['label'] == 'entailment':
        x['positives'].append(extra)
    else:
        x['negatives'] = [extra]
    x['type'] = 'qa_triplet'
    return x

#dict_keys(['question_id', 'question', 'document_title', 'answer', 'label'])
qa_streaming_cleaning_functions = {
    'embedding-data/PAQ_pairs':(clean_stream_PAQ_pairs, None, ['query','positives','negatives'],['set']),
    'gbharti/finance-alpaca':(clean_stream_finance_alpaca,None, ['query','positives','negatives'],['input', 'output', 'text', 'instruction']),
    'wiki_qa':(clean_stream_wiki_qa, None, ['query','positives','negatives'],['question_id', 'question', 'document_title', 'answer', 'label']),
    'donfu/oa-stackexchange':(clean_stream_oa_stackexchange, filter_os_stackexchange, ['query','positives','negatives'], ['INSTRUCTION', 'RESPONSE', 'SOURCE', 'METADATA']),
    'gart-labor/eclassTrainST':(clean_eclassTrainST, None, ['query','positives','negatives'], ['text', 'entailment', 'contradiction', 'label']),
    'THUDM/webglm-qa':( clean_webglmqa, None, ['query','positives','negatives'], ['question','answer','references']),
    'sciqa': (clean_stream_sciqa, None, ['query','positives','negatives'], ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support']),
    #'LLukas22/lfqa_preprocessed':() REDDIT QUESTION ANSWERS (ASK historians, ask me like I'M FIVE)
    }

qa_files = [
    ('embedding-data/PAQ_pairs',None, 0.1, 7.29*10**6, 'qa_triplet', False), # wikipedia pop culture pairs # get from 'set', 7.29*10**6
    ('gbharti/finance-alpaca',None, 0.1, 6.89*10**5, 'qa_triplet', False), # Stanford's Alpaca (https://github.com/tatsu-lab/stanford_alpaca) and FiQA (https://sites.google.com/view/fiqa/) with another 1.3k pairs custom generated using GPT3.5
    ('wiki_qa',None, 0.1, 20.4*10**3, 'qa_triplet', False), # Wiki Question Answering corpus from Microsoft. with multiple negatives that are similar!
    ('donfu/oa-stackexchange',None, 0.1, 6330000, 'qa_triplet', (14, int(6330000//14))), # stack-exchange question-answer pairs, across lots of domains; notice the original is 6.6 million, but there is a filter
    ('gart-labor/eclassTrainST', None, 0.02, 450912, 'qa_triplet', False), # questions about trade / business stuff
    ('THUDM/webglm-qa', None, 0.1, 43600, 'qa_triplet', False),
    ('sciq',None, 0.1, 11679, 'qa_triplet', False), # science questions from Allenai, with a question and support
    #'LLukas22/lfqa_preprocessed', None, 0.1, 226000,'qa_triplet',False) REDDIT QUESTION ANSWERS (ASK historians, ask me like I'M FIVE)

    # TODO: add the dictionary definitions
]

qadata_streaming_config = {
    'files':qa_files,
    'max_seq_length':512,
    'prepend_q': 'query: ',
    'prepend_a': 'passage: ',
    'val_size':100,
    'train_chunk_size':500,
    'seed':42,
}

def initialize_qa_streaming_datasets(data_streaming_config, streaming_cleaning_functions):
    files = data_streaming_config['files']
    qa_streaming_datsets, qa_probabilities, qa_datasizes = [],[],[]
    for (qa_nm, set_nm, prob, dataset_size, special_handling, partition_shuffle) in files:

        if prob ==0:
            continue
        # get cleaning & filter functions for streaming data / map & filters
        clean_func, filter_func, feature_names, removefeature_names = streaming_cleaning_functions[qa_nm]

        # arguments for the load_dataset (huggingface repos)
        load_dataset_args = {
            'path':qa_nm, 'name':set_nm, 'split':'train', 'streaming':True
        }
        # for other non-huggingface repos, path needs to be a "builder"
        if qa_nm.endswith('.jsonl') or qa_nm.endswith('.jsonl.zip') or qa_nm.endswith('.jsonl.zst'):
            load_dataset_args.update({'path':'json','data_files':qa_nm})

        print('trying %s' % qa_nm)
        if filter_func is None:
            dset_stream = load_dataset(**load_dataset_args).map(clean_func).remove_columns(removefeature_names)
        else:
            dset_stream = load_dataset(**load_dataset_args).filter(filter_func).map(clean_func).remove_columns(removefeature_names)

        qa_streaming_datsets.append(dset_stream)
        qa_probabilities.append(prob);
        qa_datasizes.append(dataset_size)

    print('done initializing the QA streaming datasets')
    return qa_streaming_datsets, qa_probabilities, qa_datasizes

def streaming_skip(skip, list_of_streaming_datasets, probabilities, datasizes, seed=42, convert_to_static = False):
    """Function loops through a list of streaming datasets, skips a first K values based on the probabilities, and returns them"""
    out = []
    normalized_p = [p/sum(probabilities) for p in probabilities]
    for dset, p, size in list_of_streaming_datasets, normalized_p, datasizes:
        skip_in_this_set = max(0,int(p)*skip)
        out.append(dset.skip(skip_in_this_set))
    return out

def streaming_take(skip, start_proportion, chunksize, list_of_streaming_datasets, probabilities, datasizes,  convert_to_static = False):
    """Takes some examples based on a starting point within the dataset, as a proportion of its total size"""
    out = []
    normalized_p = [p/sum(probabilities) for p in probabilities]
    for j, (dset, p, size) in enumerate(zip(list_of_streaming_datasets, normalized_p, datasizes)):
        #print(type(dset))
        #print(type(p))
        #print(type(size))
        # skip for valset
        skip_in_this_set = int(round(p*skip))
        # afterwards, where to start?
        skip_to_start = int(start_proportion*(size-skip_in_this_set))
        take_from_this_set = int(round(chunksize*p))
        if skip_to_start>0:
            dset_skipped = dset.skip(skip_in_this_set+skip_to_start).take(take_from_this_set)
        else:
            dset_skipped = dset.take(take_from_this_set)

        if not convert_to_static:
            # option to return the streaming dataset
            out.append(dset_skipped)
        else:
            # option just to convert the streaming dataset to static outputs
            for example in dset_skipped:
                example['source_id'] = j
                out.append(example)
        print('done %d' % j)
    return out

def train_test_splits_from_stream_qa(
    streaming_dataset,
    val_size = 100,#2000,
    epoch = 0,
    chunk_size = 500,#6000,
    path_to_val_cache = 'val_qa_cache.pkl',
    probabilities = None,
    datasizes = None,
    seed=42
):
    """
    val_size = 2000, number of streaming-iter to skip, reserved for the val-sze
    epoch = 0, epoch will change the seed when sampling the chunk idx for making the training set
    chunk_size = 5000, # number of streaming-iter to select the training data chunk
    max_chunk_start = 2000000, # randomly sample within this interval for streaming chunks
    """
    if os.path.isfile(path_to_val_cache):
        print('RELOADING VAL-QA SET: iter=%s' % path_to_val_cache)
        with open(path_to_val_cache,'rb') as pcon:
            val_corpus_list = pickle.load(pcon)
        print('VAL-QA SET SIZE: %d' % len(val_corpus_list))
    else:
        # stream validation set
        print('STREAMING VAL-QA DATA: %d' % val_size)
        val_corpus_list = streaming_take(
            skip=0,
            start_proportion=0,
            chunksize=val_size,
            list_of_streaming_datasets=streaming_dataset,
            probabilities=probabilities,
            datasizes=datasizes,
            convert_to_static = True
        )
        print('REALIZED VAL-QA DATA: %d' % len(val_corpus_list))
        # save the validation corpus
        print('SAVING VAL-QA SET: %s' % path_to_val_cache)
        with open(path_to_val_cache,'wb') as pcon:
            pickle.dump(val_corpus_list, pcon)

    # take a random interger to start the streaming of training data
    # starts at a random position
    train_start_proportion = np.random.RandomState(seed + epoch).random()*0.99
    print(train_start_proportion)

    # stream training data
    print('STREAMING TRAIN QA-DATA: %d STARTING AT: %0.3f' % (chunk_size,train_start_proportion))
    train_corpus_list = streaming_take(
            skip=val_size,
            start_proportion=train_start_proportion,
            chunksize=chunk_size,
            list_of_streaming_datasets=streaming_dataset,
            probabilities=probabilities,
            datasizes=datasizes,
            convert_to_static = True
        )

    print('REALISED TRAIN QA-DATA SIZE: %d' % len(train_corpus_list))
    return {
        'train':train_corpus_list,
        'val':val_corpus_list,
        'epoch':0,
        'index_stream':train_start_proportion
    }



def initialize_and_get_triplet_streaming_datasets(
    data_streaming_config,
    streaming_cleaning_functions,
    start_proportion = None,
    epoch=0,
    seed=42,
    path_to_val_cache = 'cache_val_qa.pkl',
    path_to_train_cache_epoch = 'cache_train_qa_%03g.pkl',
    do_check_english = True,
    name = 'QA' #
):
    """Converts stream of unlabelled text data into static datasets for: MLM task and next-sentence-prediction task"""
    # list of files to stream
    print('FOOFU: WIP converting this function from the MLM for triplets')
    files = data_streaming_config['files']
    # number of examples to take from stream for validation set
    val_size = data_streaming_config['val_size']
    # number of examples to take from stream for training set
    train_chunk_size = data_streaming_config['train_chunk_size']
    min_seq_len = data_streaming_config['min_seq_length']
    # normalization constant for normalizing the weights into probabilities
    probability_normalization_const = sum([x[2] for x in files])

    # where to initialize start-stream for training data
    if start_proportion is None:
        start_proportion = np.random.RandomState(seed+epoch).uniform()*0.95

    # reload cached files
    path_to_train_cache = None if not '%03g' in path_to_train_cache_epoch else path_to_train_cache_epoch % epoch
    do_make_valset = not os.path.isfile(path_to_val_cache)
    do_make_trainset = not os.path.isfile(path_to_train_cache)
    if not do_make_valset:
        print(f'RELOADING VAL-{name} SET: iter=%s' % path_to_val_cache)
        with open(path_to_val_cache,'rb') as pcon:
            datalist_val_triplet_static = pickle.load(pcon)
        print(f'VAL-{name} SET SIZE: %d' % len(datalist_val_triplet_static))
    else:
        datalist_val_triplet_static = []
    if not do_make_trainset:
        print(f'RELOADING VAL-{name} SET: iter=%s' % path_to_val_cache)
        with open(path_to_train_cache,'rb') as pcon:
            datalist_train_triplet_static = pickle.load(pcon)
        print(f'TRAIN-{name} EPOCH-%d SET SIZE: %d' % (epoch, len(datalist_train_triplet_static)))
    else:
        datalist_train_triplet_static = []

    if (do_make_trainset or do_make_valset):

        # loop through datasets
        for (data_nm, set_nm, prob, dataset_size, special_handling, partition_shuffle), dataset_key in zip(
            files, streaming_cleaning_functions.keys()
        ):
            if prob ==0:
                continue
            prob /= probability_normalization_const

            # get cleaning & filter functions for streaming data functionality
            clean_func, filter_func, removefeature_names = streaming_cleaning_functions[dataset_key]

            # set arguments for the load_dataset (huggingface repos)
            load_dataset_args = {
                'path':data_nm, 'name':set_nm, 'split':'train', 'streaming':True
            }
            # for other non-huggingface repos, path needs to be a "builder"
            if data_nm.endswith('.jsonl') or data_nm.endswith('.jsonl.zip') or data_nm.endswith('.jsonl.zst'):
                load_dataset_args.update({'path':'json','data_files':data_nm})

            # special proecssing of datasets with multiple partitions
            if bool(partition_shuffle): # or str(epoch)=='val':

                n_files, n_per_file = partition_shuffle
                dataset_size = n_per_file
                print('trying %s initialization (shuffling through %d files)' % (data_nm, n_files))

                # whether there is a filter
                if filter_func is None:
                    dset_stream = load_dataset(**load_dataset_args)
                else:
                    dset_stream = load_dataset(**load_dataset_args).filter(filter_func)

                # validation set
                if do_make_valset:
                    # take from stream
                    n_valset_take = max(int(prob*val_size), 1)
                    print('take %d from %s validation'% (n_valset_take, data_nm))
                    dset_stream_val = dset_stream.take(n_valset_take).map(clean_func).remove_columns(removefeature_names)
                    # convert stream to a static set and do check
                    dset_static_val_thisset = [
                        e for e in dset_stream_val if bool(re.search(r"\w+",e['query'][:200]))
                    ]
                # training set
                if do_make_trainset:
                    # randomly skip a bunch from this set
                    skip_to_start = int(start_proportion*n_per_file)
                    take_from_this_set = max(int(round(train_chunk_size*prob)),1)
                    print('take %d from %s training'% (take_from_this_set, data_nm))
                    # shuffle: take a random data partition (from the dataset's list of files)
                    dset_stream_train = dset_stream_val.shuffle(
                        seed = seed+epoch, buffer_size = skip_to_start+take_from_this_set,
                    )
                    dset_stream_train = dset_stream_train.skip(
                        skip_to_start # random skip through dataset to new start position
                    ).take(
                        take_from_this_set # take this amount for the training ste
                    ).map(clean_func).remove_columns(removefeature_names)
                    # convert training to static dataset
                    dset_static_train_thisset = [
                        e for e in dset_stream_train if bool(re.search(r"\w+",e['query'][:200]))
                    ]
            else:
                # regular streaming
                print('trying %s initialization' % data_nm)
                # whether there is a filter
                if filter_func is None:
                    dset_stream = load_dataset(**load_dataset_args).map(clean_func).remove_columns(removefeature_names)
                else:
                    dset_stream = load_dataset(**load_dataset_args).filter(filter_func).map(clean_func).remove_columns(removefeature_names)
                # take from stream
                n_valset_take = max(int(prob*val_size), 1) # size of valset
                print('take %d from %s validation'% (n_valset_take, data_nm))
                skip_to_start = int(start_proportion*(dataset_size-n_valset_take)) # random point to skip to
                n_train_take = max(int(round(train_chunk_size*prob)),1) # size of train set
                print('take %d from %s train'% (n_train_take, data_nm))
                if do_make_valset:
                    dset_stream_val = dset_stream.take(n_valset_take)
                    dset_static_val_thisset = [
                        e for e in dset_stream_val if bool(re.search(r"\w+",e['query'][:200]))
                    ]
                if do_make_trainset:
                    dset_stream_train = dset_stream.skip(n_valset_take+skip_to_start).take(n_train_take)
                    dset_static_train_thisset = [
                        e for e in dset_stream_train if bool(re.search(r"\w+",e['query'][:200]))
                    ]
            print('Done getting streams/reloading from %s' % data_nm)
            # check language, chunk sentences
            if do_make_valset:
                # discard non-english
                dset_static_val_thisset =[
                    e for e in dset_static_val_thisset if detect(e['query'][:200]+" hello")=='en'
                ]
                print('done val language check')
                # add to val set
                datalist_val_triplet_static.extend(dset_static_val_thisset)

            # check language, chunk sentences
            if do_make_trainset:
                # discard non-english
                dset_static_train_thisset =[
                    e for e in dset_static_train_thisset if detect(e['query'][:200] +" hello")=='en'
                ]
                print('done train language check')

                # ensure that none of the examples in the traning set are in the validation set
                if do_make_valset:
                    val_queries = set([q['query'] for q in dset_static_val_thisset])
                    dset_static_train_thisset = [
                        s for s in dset_static_train_thisset if s['query'] not in val_queries
                    ]

                # add to training set
                datalist_train_triplet_static.extend(dset_static_train_thisset)

        print(f'Done collecting {name} streaming data')

    if do_make_valset:
        print('saving streamed %s validation data: %s' % (name, path_to_val_cache))
        with open(path_to_val_cache,'wb') as pcon:
            pickle.dump(datalist_val_triplet_static, pcon)

    if do_make_trainset:
        print('saving streamed %s training for epoch %d: %s' % (name, epoch, path_to_train_cache))
        with open(path_to_train_cache,'wb') as pcon:
            pickle.dump(datalist_train_triplet_static, pcon)

    return {
        'train':datalist_train_triplet_static
        'val':datalist_val_triplet_static,
        'epoch':epoch,
        'index_stream':start_proportion
    }


class DatasetTriplets(torch_data.Dataset):
    def __init__(
        self,
        list_of_data=None,
        n_negatives= 3,
        topk_negatives_discard = 6, # get top kth most-similar results, discard first k, to use as negative
        focal_text_name ='query',
        positives_text_name ='positives',
        negativess_text_name ='negatives',
        seed = 32,
        label_processor_class = None # (optional) function to process negatives
    ):
        self.n_negatives = n_negatives
        self.topk_negatives_discard = topk_negatives_discard
        self.data = {}
        self.focal_text_name =focal_text_name
        self.positives_text_name = positives_text_name
        self.negativess_text_name = negativess_text_name
        self.seed = 42
        self.random = np.random.RandomState(self.seed)
        self.label_processor_class = label_processor_class

        if list_of_data is not None and len(list_of_data)>0:

            # loop through the data and add each triplets: export a panda df as final data
            self.df = self.process(list_of_data)

    def process(self, list_of_data):
        """Makes (query,pos,neg)-triplets, converts samples to dataframe for pytorch iteration"""

        # loop through the data and add each triplets
        self._loop_through_list_of_data_and_add_to_selfdata(
            list_of_data = list_of_data
        )

        # add positives to self.data
        self._find_positives_and_add_to_data()

        # add negatives to self.data
        self._find_negatives_and_add_to_data()

        # harden the dataset to pandas dataframe
        df = self.sample_data_and_make_static_dataframe(self.data)
        return df

    def _loop_through_list_of_data_and_add_to_selfdata(
        self,
        list_of_data
    ):
        """loops through and adds the positive/focal texts and negatives"""
        for raw_example in list_of_data:
            # add each element to the data
            self._add_triplet_to_data(
                focal_texts=raw_example[self.focal_text_name],
                positve_texts=raw_example[self.positives_text_name],
                negative_texts=raw_example[self.negativess_text_name],
            )
        self.focal_texts_as_keys = list(self.data.keys())

    def _add_triplet_to_data(
        self,
        focal_texts,
        positve_texts,
        negative_texts
    ):
        """add focal text to the data"""
        do_add_focals = False
        if isinstance(focal_texts,list):
            focal_text = sort(focal_texts)[0]
            do_add_focals = True
        elif isinstance(focal_texts, str):
            focal_text = focal_texts
        if focal_text not in self.data.keys():
            self.data[focal_text] = {'positives':[], 'negatives':[]}
        self.data[focal_text]['positives'] += [p for p in positve_texts if p not in self.data[focal_text]['positives']]
        #if negative_texts is None:
        #    print(focal_texts)
        #    print(positve_texts)
        #    print(negative_texts)
        self.data[focal_text]['negatives'] += negative_texts if (negative_texts is not None) else []
        if do_add_focals:
            self.data[focal_text]['positives'] += focal_texts[1:]

    def _build_corpus_of_potential_negatives(self):
        potential_corpus = [
            self.data[k]['positives'][:1] for k in self.focal_texts_as_keys
        ]
        potential_corpus = [
            'NEGATIVE' if (not bool(s)) else s[0] for s in potential_corpus
        ]
        tokenized_corpus = [s.lower().split(" ") for s in potential_corpus]
        bm25 = BM25Okapi(tokenized_corpus)
        return {'bm25':bm25, 'corpus':potential_corpus}

    def _find_negative(
        self,
        focal_text_as_query,
        positive_examples=None,
        use_focal_text = True,
        use_positives=True,
        bm25_corpus=None,
        corpus = None
    ):
        """Given a query, uses BM25 to find similar but wrong answers, to serve as triplet negatives; for a single query"""
        bmquery = (focal_text_as_query if use_focal_text else "") + " " + ("" if (not use_positives) else positive_examples[0])
        bmquery = bmquery.strip()
        bmquery_tokenized = bmquery.lower().split(" ")
        top_results = bm25_corpus.get_top_n(bmquery_tokenized, corpus, n=self.topk_negatives_discard + self.n_negatives)
        top_results = [
            s for s in top_results
            if (
                s not in positive_examples+[focal_text_as_query]
            )
        ]
        # remove any text that is equivalent to the query / focal texts
        potential_negatives = top_results[-1*self.n_negatives:]
        return potential_negatives

    def _find_positives_and_add_to_data(self):
        """For data that has a label, this can be used to artifically find and create synthetic positives"""
        pass

    def _find_negatives_and_add_to_data(self):
        """Uses BM25 to find similar but wrong answers, to serve as triplet negatives; loop over data"""

        # build bm25 corpus
        bm25_corpus = self._build_corpus_of_potential_negatives()

        # loop through data, find examples which don't have negatives
        for k,d in self.data.items():
            if not bool(d['negatives']):
                negatives = self._find_negative(
                    focal_text_as_query=k,
                    positive_examples=d['positives'],
                    use_focal_text = True,
                    use_positives=bool(d['positives']),
                    bm25_corpus=bm25_corpus['bm25'],
                    corpus = bm25_corpus['corpus']
                )
                d['negatives']+= negatives
        print('done finding negatives')

    def sample_data_and_make_static_dataframe(self, seed = 42):
        focals =[]
        pos =[]
        neg = []
        for query,d in self.data.items():
            for j in range(min(self.n_negatives, len(d['negatives']))):
                if len(d['positives'])==0:
                    continue
                elif len(d['positives'])==1:
                    pos+=d['positives']
                elif len(d['positives'])>1:
                    pos.append(self.random.choice(d['positives']))
                neg.append(d['negatives'][j])
                focals.append(query)
        df = pd.DataFrame({'query':focals, 'pos':pos, 'neg':neg})
        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self,idx):
        #key = self.focal_texts_as_keys[idx]
        #return {**{'query':key}, **self.data[key]}
        return self.df.iloc[idx].to_dict()



In [None]:

# intialize the qa streaming dataset (QA)
qa_streaming_datsets, qa_probabilities, qa_datasizes = initialize_qa_streaming_datasets(
    qadata_streaming_config,
    qa_streaming_cleaning_functions
)

qa_statics_datsets = train_test_splits_from_stream_qa(
    streaming_dataset=qa_streaming_datsets,
    val_size = 100,#2000,
    epoch = 0,
    chunk_size = 500,#6000,
    path_to_val_cache = 'val_qa_cache.pkl',
    probabilities = qa_probabilities,
    datasizes = qa_datasizes,
    seed=qadata_streaming_config['seed']
)


trying embedding-data/PAQ_pairs


Downloading readme:   0%|          | 0.00/5.41k [00:00<?, ?B/s]

trying gbharti/finance-alpaca


Downloading readme:   0%|          | 0.00/486 [00:00<?, ?B/s]

trying wiki_qa


Downloading builder script:   0%|          | 0.00/3.79k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.6k [00:00<?, ?B/s]

trying donfu/oa-stackexchange
trying gart-labor/eclassTrainST
done initializing the QA streaming datasets
STREAMING VAL-QA DATA: 100
done 0
done 1
done 2
done 3
done 4
REALIZED VAL-QA DATA: 101
SAVING VAL-QA SET: val_qa_cache.pkl
0.3707947176588889
STREAMING TRAIN QA-DATA: 500 STARTING AT: 0.371
done 0
done 1
done 2
done 3
done 4
REALISED TRAIN QA-DATA SIZE: 381


In [None]:
for i,e in enumerate(qa_statics_datsets['val']):
    if i<20:
        continue
    print("-------\nQ:%s\nA:%s" % (e['query'], e['positives'][0].replace("\n"," ") if bool(e['positives']) else e['negatives'][0].replace("\n"," ")))


In [None]:


qa_torchdataset_val = DatasetTriplets(
    list_of_data = qa_statics_datsets['val'],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
)
qa_torchdataset_train = DatasetTriplets(
    list_of_data = qa_statics_datsets['train'],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
)

done finding negatives
{'positives': [], 'negatives': ['In physics , circular motion is a movement of an object along the circumference of a circle or rotation along a circular path.', 'It can be uniform, with constant angular rate of rotation (and constant speed), or non-uniform with a changing rate of rotation.', 'The rotation around a fixed axis of a three-dimensional body involves circular motion of its parts.', 'The equations of motion describe the movement of the center of mass of a body.', 'Examples of circular motion include: an artificial satellite orbiting the Earth at constant height, a stone which is tied to a rope and is being swung in circles, a car turning through a curve in a race track , an electron moving perpendicular to a uniform magnetic field , and a gear turning inside a mechanism.', "Since the object's velocity vector is constantly changing direction, the moving object is undergoing acceleration by a centripetal force in the direction of the center of rotation."

In [None]:
print(len(qa_torchdataset_train))
qa_torchdataset_train[400]

768


{'query': 'Determine Current Controller in Use for Kohana\nWhat is the best way to determine which Controller class a Kohana application is presently using?\n\nExamples:\n\n  * ` \\- `_defaultControllerName_`\n  * ` \\- "frontpage"\n  * ` \\- "contact"',
 'pos': '**_The following applies to Kohana 2 instances..._**\n\nYou can do this by using the Router library. By default, this library is located in `/system/libraries/Router.php` \\- go ahead and copy it into `/application/libraries` as is the standard practice for all libraries being used.\n\nNow, from within your application you can get the controller value from the static Router class:\n    \n    \n    print Router::$controller; // outputs current Controller\n    \n\nDocumentation',
 'neg': "No, a `StringBuilder` is a purely managed resource. You should just get rid of all references to it. Everything else is taken care of by the garbage collector:\n    \n    \n    StringBuilder sb = ...;\n    // ... do work\n    sb = null; // or s

### A) Retrieval Tasks
In general, what loss would I use for the QA & retrieval tasks? Distillation is obvious, but what about
- SQUAD - has QA pairs - squad_v2
    - good for distillation
- ORCA - has GPT-like prompting QA pairs: https://huggingface.co/datasets/Open-Orca/OpenOrca/viewer/Open-Orca--OpenOrca/train?row=29
- DONE Simple-Wiki https://huggingface.co/datasets/embedding-data/simple-wiki - has paraphrases
- DONE embedding-data/coco_captions_quintets - multiple captions as paraphrases
- DONE embedding-data/simple-wiki - pairs of paraphrases from wikipedia
- DONE embedding-data/SPECTER - triplets of {anchor, pos, neg}, small headline-like snippets in technical /statistical /science fields
- https://huggingface.co/embedding-data - has a lot of retrieval tasks
- LLukas22/scidocs - titles and abstracts
- DONE allenai/scirepeval - cite_prediction - has query,pos, neg based on citations
- DONE - LEDGAR - can possible do triplets on same label
- Rahmaa/ElsevieR_ClEaN - possible relation between title and abstract
- embedding-data/WikiAnswers - 25 question paraphrases (maybe no answers)
- cnn_dailymail - summarization possiblility 287k (beware |||?)
- multi_news - another summarization 45k (beware |||?)
- DONE xsum - BBC extreme summarization 204k
- DONE lighteval/legal_summarization - legal summization of bills (BillSum 18.8k)
-


In [None]:
#foo =  load_dataset("embedding-data/simple-wiki",split='train',streaming=True)
#foo =  load_dataset("embedding-data/coco_captions_quintets",split='train',streaming=True).take(2000)
#foo =  load_dataset("embedding-data/SPECTER",split='train',streaming=True)
#foo = load_dataset(**{'path': 'embedding-data/SPECTER', 'name':None, 'split':'train', 'streaming':True})
#foo =  load_dataset("paws",'labeled_final',split='train',streaming=True)
#foo =  load_dataset("embedding-data/QQP_triplets",None,split='train',streaming=True)
#foo =  load_dataset("",None,split='train',streaming=True)
#foo =  load_dataset("",None,split='train',streaming=True)
#foo = load_dataset("allenai/scirepeval", 'cite_prediction',None, split='train',streaming=True)
# foo = load_dataset(**{'path': 'allenai/scirepeval', 'name':'cite_prediction', 'split':'train', 'streaming':True})
#foo = load_dataset('json', data_files="https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", split="train", streaming=False)
#foo = load_dataset(**{'path': 'json', 'name':None, 'data_files':'https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip', 'split':'train', 'streaming':True})
foo =  load_dataset("lighteval/legal_summarization","BillSum",split='train',streaming=True)

if True:
    # embedding-data/WikiAnswers
    for j,e in enumerate(foo):
        print(e)
        #print(len(e['set']))
        if j > 100:
            break
    print(e.keys())

In [None]:


def clean_legalsum(x):
    MAX_CHAR_LEN_BILLSUM = int(6.7*600)
    text = x['article'][:MAX_CHAR_LEN_BILLSUM]
    if 'SEC. 2.' in text:
        text = ".".join(text.split('SEC. 2.')[1].split('.')[1:])
    else:
        if 'SHORT TITLE' in text:
             text = text.split('SHORT TITLE')[1]
    x['query'] = x['summary']
    x['positives'] = [text.strip()]
    x['negatives'] = []
    x['type'] = 'sts_triplet'
    return x

def clean_xsum(x):
    x['query'] = x['summary']
    x['negatives'] = []
    x['positives'] = [x['document']]
    x['type'] = 'sts_triplet'
    return x

def clean_eurlex(x):
    x['query'] = x['text']
    x['negatives'] = []
    x['positives'] = []
    x['type'] = 'sts_by_textlabel'
    x['label'] = x['eurovoc_concepts']
    return x

def clean_allenai_citeprediction(x):
    x['query'] = x['query']['abstract']
    pos = x['pos']['abstract']
    x['positives'] = [pos] if pos is not None else []
    neg = x['neg']['abstract']
    x['negatives'] = [neg] if neg is not None else []
    x['type'] = 'sts_triplet'
    return x

def clean_simple_wiki(x):
    x['query'] = x['set'][0]
    x['positives'] = [x['set'][1]]
    x['negatives'] = []
    x['type'] = 'sts_triplet'
    return x

def clean_coco_captions_quintets(x):
    x['query'] = x['set'][0]
    x['positives'] = x['set'][1:]
    x['negatives'] = []
    x['type'] = 'sts_triplet'
    return x

def clean_specter(x):
    x['query'] = x['set'][0]
    x['positives'] = [x['set'][1]]
    x['negatives'] = [x['set'][2]]
    x['type'] = 'sts_triplet'
    return x

def clean_paws(x):
    x['query'] = x['sentence1']
    x['positives'] = [x['sentence2']]
    x['negatives'] = []
    x['type'] = 'sts_triplet'
    return x

def clean_qqp(x):
    x['query'] = x['set']['query']
    x['positives'] = x['set']['pos']
    x['negatives'] = x['set']['neg']
    x['type'] = 'sts_triplet'
    return x

def clean_ledgarlabelled(x):
    x['query'] = x['provision']
    x['negatives'] = []
    x['positives'] = []
    x['type'] = 'sts_by_textlabel'
    return x

#dict_keys(['question_id', 'question', 'document_title', 'answer', 'label'])
sts_streaming_cleaning_functions = {
    'xsum':(clean_xsum, None, ['query','positives','negatives'],['summary','id','document']),
    'embedding-data/simple-wiki':(clean_simple_wiki, None, ['query','positives','negatives'],['set']),
    'embedding-data/coco_captions_quintets':(clean_coco_captions_quintets,None, ['query','positives','negatives'],['set']),
    'embedding-data/SPECTER':(clean_specter,None, ['query','positives','negatives'],['set']),
    'paws':(clean_paws,None, ['query','positives','negatives'],['id', 'sentence1', 'sentence2', 'label']),
    'embedding-data/QQP_triplets':(clean_qqp,None, ['query','positives','negatives'],['set']),
    "allenai/scirepeval":(clean_allenai_citeprediction, None,  ['query','positives','negatives'], ['pos','neg']),
    "lighteval/legal_summarization":(clean_legalsum, None, ['query','positives','negatives'], ['article', 'summary']),
    "https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip":(
        clean_ledgarlabelled, None, ['query','label'], ['provision','source']
    ),
    "eurlex":(clean_eurlex, None,  ['query','positives','negatives'], ['celex_id', 'title', 'text', 'eurovoc_concepts']),
    #'':(,None, ['query','positives','negatives'],['']),
    #'':(,None, ['query','positives','negatives'],['']),
 }

DEFAULT_PROB = 1.0
sts_files = [
    # dataset name, subset, take_probability, dataset size
    ('xsum', None, DEFAULT_PROB, 204000, 'sts_by_triplet'),
    ('embedding-data/simple-wiki',None, DEFAULT_PROB, 102000, 'sts_by_triplet'), # wikipedia paraphrases
    ('embedding-data/coco_captions_quintets',None, DEFAULT_PROB,82800, 'sts_by_triplet'), # caption paraphrases
    ('embedding-data/SPECTER',None, DEFAULT_PROB,684000, 'sts_by_triplet'), # ?
    ('paws','labeled_final',DEFAULT_PROB, 49400, 'sts_by_triplet'), # paws paraphrases
    ('embedding-data/QQP_triplets',None,DEFAULT_PROB, 102000, 'sts_by_triplet'), # quora?
    ("allenai/scirepeval", 'cite_prediction',DEFAULT_PROB, 676000, 'sts_by_triplet'), # ?
    ("lighteval/legal_summarization","BillSum", DEFAULT_PROB, 18900, 'sts_by_triplet'),
    ('https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip', None, DEFAULT_PROB, 1000000, 'sts_by_label'),
    ('eurlex', None, DEFAULT_PROB, 45000, 'sts_by_label')
    #('',None, 0.1,?*10**5),
    #('',None, 0.1,?*10**5),
    #('',None, 0.1,?*10**5),
]

stsdata_streaming_config = {
    'files':sts_files,
    'max_seq_length':512,
    'prepend_q': 'passage: ',
    'prepend_a': 'passage: ',
    'val_size':100,
    'train_chunk_size':500,
    'seed':42,
}


In [None]:
# initialize streaming data for sts tasks
sts_streaming_datsets, sts_probabilities, sts_datasizes = initialize_qa_streaming_datasets(
    stsdata_streaming_config,
    sts_streaming_cleaning_functions
)

# split and make-static (train and val sets, non-streaming)
sts_statics_datsets = train_test_splits_from_stream_qa(
    streaming_dataset=sts_streaming_datsets,
    val_size = 100,#2000,
    epoch = 0,
    chunk_size = 2000,#6000,
    path_to_val_cache = 'val_sts_cache.pkl',
    probabilities = sts_probabilities,
    datasizes = sts_datasizes,
    seed=stsdata_streaming_config['seed']
)


trying xsum


Downloading builder script:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

trying embedding-data/simple-wiki


Downloading readme:   0%|          | 0.00/4.16k [00:00<?, ?B/s]

trying embedding-data/coco_captions_quintets


Downloading readme:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

trying embedding-data/SPECTER


Downloading readme:   0%|          | 0.00/4.28k [00:00<?, ?B/s]

trying paws


Downloading builder script:   0%|          | 0.00/8.43k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/7.52k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.34k [00:00<?, ?B/s]

trying embedding-data/QQP_triplets


Downloading readme:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

trying allenai/scirepeval


Downloading builder script:   0%|          | 0.00/8.74k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/38.6k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/26.3k [00:00<?, ?B/s]

{'train': 'https://ai2-s2-research-public.s3.us-west-2.amazonaws.com/scirepeval/train/cite_prediction/train.jsonl', 'val': 'https://ai2-s2-research-public.s3.us-west-2.amazonaws.com/scirepeval/train/cite_prediction/val.jsonl'}
trying lighteval/legal_summarization


Downloading builder script:   0%|          | 0.00/1.89k [00:00<?, ?B/s]

trying https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip
trying eurlex


Downloading builder script:   0%|          | 0.00/5.11k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.04k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.9k [00:00<?, ?B/s]

done initializing the QA streaming datasets
RELOADING VAL-QA SET: iter=val_sts_cache.pkl
VAL-QA SET SIZE: 0
0.3707947176588889
STREAMING TRAIN QA-DATA: 2000 STARTING AT: 0.371
done 0
done 1
done 2
done 3
done 4
done 5
done 6
done 7
done 8
done 9
REALISED TRAIN QA-DATA SIZE: 2000


In [None]:
for e in sts_statics_datsets['train']:
  print(e)

{'query': 'A ferry with 650 people aboard crashed into a dock during gale force winds earlier on Tuesday.', 'negatives': [], 'positives': ['All of the passengers had to be transferred from the ferry in Holyhead, Anglesey after the incident.\nThe Irish Ferries Jonathan Swift vessel was preparing to set off from Holyhead to Dublin before midday when heavy gusts pushed it into its berth.\nNo passengers were injured during the incident. Irish Ferries cancelled three Swift services following the incident.\nA spokesman for Irish Ferries said the aluminium hull ferry was being inspected to assess damage to the body and a replacement services was taking passengers to Dublin.\nHe said: "Just as she was leaving the berth at Holyhead she was caught by a gust of wind and blown back in. She was only yards off the berth and the ropes had been loosened.\n"We don\'t know if the hull has been punctured. We don\'t think that\'s likely, but if it has then I don\'t know if they can do the repair in Holyhe

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



{'query': '7.9.2005 EN Official Journal of the European Union L 230/7\nCOMMISSION REGULATION (EC) No 1450/2005\nof 5 September 2005\namending Annex V to Council Regulation (EC) No 1210/2003 concerning restrictions on economic and financial relations with Iraq\nTHE COMMISSION OF THE EUROPEAN COMMUNITIES\n,\nHaving regard to the Treaty establishing the European Community,\nHaving regard to Council Regulation (EC) No 1210/2003 of 7 July 2003 concerning certain specific restrictions on economic and financial relations with Iraq\xa0(1), and in particular Article 11(c) thereof,\nWhereas:\n(1) Annex V to Regulation (EC) No 1210/2003 lists the competent authorities to which specific functions related to the implementation of that Regulation are attributed.\n(2) Belgium, Germany, Lithuania and the Netherlands have requested that the address details concerning their competent authorities be amended.\n(3) Annex V to Regulation (EC) No 1210/2003 should therefore be amended accordingly,\nAnnex V to

In [None]:
import nltk
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
from nltk.tokenize import word_tokenize
import numpy as np
from multiprocessing import Pool
# Download stopwords and lemmatization resources
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
#lemmatizer = WordNetLemmatizer()
#stemmer = PorterStemmer()
#stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:

class LabelProcesser:

    def __init__(
        self,
        pos_thres = 0.97,
        neg_thres = 0.9,
        min_similarity_matrix_pos =0.34,
        max_similarity_matrix_pos = 0.30,
        examples=None, seed=42, textname='text',labelname='label'
    ):
        self.pos_thres = pos_thres # jaccard similarity index max
        self.neg_thres = neg_thres # jaccard similarity index max
        self.min_similarity_matrix = min_similarity_matrix_pos # threshold the similarity matrix by this, else 0
        self.max_similarity_matrix = max_similarity_matrix_neg # threshold the similarity matrix by this
        #self.lemmatizer = WordNetLemmatizer()
        #self.stemmer = PorterStemmer()
        #self.stop_words = set(stopwords.words('english'))
        #self.random = np.random.RandomState(seed)
        self.label_corpus =None
        self.label2stem =None
        self.textname=textname
        self.labelname=labelname

        if examples is not None and len(examples)>0:

            # build corpus from examples
            label_corpus, label2stem = self.build_corpus_by_labels(examples)
            self.label_corpus = label_corpus
            self.label2stem = label2stem

            # build label-similarity matrix
            self.SimMat = self.compute_similarity_matrix(list(self.label_corpus.keys()))

    def preprocess_label(self, text):
        pass

    @staticmethod
    def jaccard_similarity(tokens1, tokens2):
        set1 = set(tokens1)
        set2 = set(tokens2)
        intersection = set1.intersection(set2)
        union = set1.union(set2)
        similarity_score = len(intersection) / len(union)
        return similarity_score

    def build_corpus_by_labels(self, list_of_dict_with_labels_and_text):
        """Makes a dictionary of (tokenized/stemmed) labels:List[str] as the corpus by labels"""
        pass

    def _compute_similarity_for_processor_func(self, pair):
        """to be used internally with Pool map similarity functions"""
        idx, j, tokens1, tokens2 = pair
        return idx, j, self.jaccard_similarity(tokens1, tokens2)

    def compute_similarity_matrix(self, corpus):
        """Csompute similarity using calculate_similarity"""
        corpus_size = len(corpus)

        # Create an empty similarity matrix
        similarity_matrix = np.zeros((corpus_size, corpus_size))

        # Generate all pairwise combinations of indices and texts
        pairs = [(i, j, corpus[i], corpus[j]) for i in range(corpus_size) for j in range(i + 1, corpus_size)]

        # Use parallel processing to compute similarities efficiently
        with Pool() as pool:
            results = pool.map(self._compute_similarity_for_processor_func, pairs)

        # Fill in the similarity matrix
        for i,j, similarity in results:
            #i, j = divmod(idx, corpus_size)
            similarity_matrix[i, j] = similarity
            similarity_matrix[j, i] = similarity

        # threshold the similarity matrx -- no, because that will creat positives in the negatives
        return similarity_matrix

    @staticmethod
    def is_in(tuple1, tuple2):
        """is a in b or b in a"""
        s1=set(tuple1); s2 = set(tuple2)
        if not bool(s1.difference(s2)):
            return True
        return not bool(s2.difference(s1))

    @staticmethod
    def _quick_text_hash(text):
        return re.sub("\W+","",text.lower())

    def find_positive(
        self,
        query_text, # text of anchor/query (used to ensure not too similar, like an exact match)
        query_labelstem, # processed label (often a multi-label)
        corpus_keys, # corpus keys of other labels to find matches
        max_candidates=15
    ):
        """find positive match, based on best overlap of multi-label"""
        # first, check if there are other text with same label
        query_label_hash = self._quick_text_hash(query_text)

        # get all text with same label
        best_candidates_text = [
            s for s in self.label_corpus[query_labelstem] if self._quick_text_hash(s)!=query_label_hash
        ]
        if len(best_candidates_text)==0:
            # no similar text: need to find text with overlapping labelss
            kidx = corpus_keys.index(query_labelstem)
            # get similarities with other keys
            k_similarities = self.SimMat[kidx]
            if k_similarities.max()==0:
                #print("%s has no matches:" % '-'.join(query_labelstem))
                return []
            else:
                idx_bests = np.argsort(-1*k_similarities)[:max_candidates]
                # get most similar labels
                label_candidates = [
                    corpus_keys[j] for j in idx_bests if k_similarities[j]>= self.min_similarity_matrix
                ]
                # assert that the labels are AT LEAST inside of each other -- otherwise, no match
                label_candidates = [
                    lab for lab in label_candidates if self.is_in(lab, query_labelstem)
                ]
                if len(label_candidates)==0:
                    #print("%s has no matches:" % '-'.join(query_labelstem))
                    return []

                # get the text of the top candidate text
                best_candidates_text = [subs for s in [
                    self.label_corpus[lab] for lab in label_candidates
                ] for subs in s][:100]

                # ensure candidate texts are not the same
                best_candidates_text = [
                  s for s in self.label_corpus[query_labelstem] if self._quick_text_hash(s)!=query_label_hash
                ]
                if len(best_candidates_text)==0:
                    #print("%s has no matches:" % '-'.join(query_labelstem))
                    return []

        # grab first candidate text htat is NOT a high jaccard similarity
        best_candidates_text = best_candidates_text[::-1]
        top_match = None
        query_text_tokenized = [w for w in query_text.split(" ") if bool(re.search("\w+",w))]
        while top_match is None and len(best_candidates_text)>0:
            candidate_text = best_candidates_text.pop()
            # check that they aren't too similar in text
            candidate_text_tokenized = [w for w in candidate_text.split(" ") if bool(re.search("\w+",w))]
            candidate_sim_score = self.jaccard_similarity(query_text_tokenized, candidate_text_tokenized)
            if candidate_sim_score < self.pos_thres:
                top_match = candidate_text
                return [top_match]
        #print("%s has no matches:" % '-'.join(query_labelstem))
        #print('Its candidate pool was:')
        #print(best_candidates_text[:4])
        return []

    def find_positives(self, examples):
        if True:
            # find positives
            for idx, example in enumerate(examples):
                pos = self.find_positive(
                    query_text=example[self.textname],
                    query_labelstem=self.label2stem[tuple(example[self.labelname])],
                    corpus_keys = list(self.label_corpus.keys()),
                )
                example.update({'positives':pos})
                examples[idx] = example

        return examples

    def find_negative(self, query_text, query_labelstem, corpus_keys, max_candidates=15, n_negatives=1):
        # first, check if there are other text with same label
        query_label_hash = self._quick_text_hash(query_text)
        # get similarities with other keys
        kidx = corpus_keys.index(query_labelstem)
        k_similarities = self.SimMat[kidx]
        if k_similarities.max()==0:
            best_candidate_label = query_labelstem
            while best_candidate_label == query_labelstem:
                best_candidate_label = self.random.choice(corpus_keys)
        else:
            idx_bests = np.argsort(-1*k_similarities)[:max_candidates]
            # get most similar labels
            label_candidates = [
                corpus_keys[j] for j in idx_bests if (k_similarities[j]!=0 and k_similarities[j] <= self.max_similarity_matrix)
            ]
            # assert that the labels have some disjoint labels
            label_candidates = [
                lab for lab in label_candidates if not self.is_in(lab, query_labelstem)
            ] # disjoint entirely
            # sample randomly from candidate labels
            if len(label_candidates)>0:
                best_candidate_label_idx = self.random.choice(np.arange(len(label_candidates)))
                best_candidate_label = label_candidates[best_candidate_label_idx]
            # sample randomly from entire corpus
            elif len(label_candidates)==0:
                # pick random
                best_candidate_label = query_labelstem
                while best_candidate_label == query_labelstem:
                    best_candidate_label_idx = self.random.choice(np.arange(len(corpus_keys)))
                    best_candidate_label = corpus_keys[best_candidate_label_idx]

        # grab best text
        best_candidates_text = self.label_corpus[best_candidate_label]
        if len(best_candidates_text)==0:
            return []

        # ensure texts and query are not the same
        best_candidates_text = [
            s for s in best_candidates_text if self._quick_text_hash(s)!=query_label_hash
        ]
        if len(best_candidates_text)==0:
            return []

        # ensure texts are not very similar
        top_matches = []
        query_text_tokenized = [w for w in query_text.split(" ") if bool(re.search("\w+",w))]
        while len(top_matches) < n_negatives and len(best_candidates_text)>0:
            candidate_text = best_candidates_text.pop()
            # check that they aren't too similar in text
            candidate_text_tokenized = [w for w in candidate_text.split(" ") if bool(re.search("\w+",w))]
            candidate_sim_score = self.jaccard_similarity(query_text_tokenized, candidate_text_tokenized)
            if candidate_sim_score < self.neg_thres:
                top_matches.append(candidate_text)
                if len(top_matches)==n_negatives:
                    return top_matches
        # no matches
        return []

    def find_negatives(self, examples, n_negatives=1):
        if True:
            # find negatives
            for idx, example in enumerate(examples):
                neg = self.find_negative(
                    query_text=example[self.textname],
                    query_labelstem=self.label2stem[tuple(example[self.labelname])],
                    corpus_keys = list(self.label_corpus.keys()),
                    n_negatives=1
                )
                example.update({'negatives':neg})
                examples[idx] = example

        return examples


class LabelProcesserLedgar(LabelProcesser):
    """Preprocesses labels of LEDGAR for semantic similarity, as well as functionality for finding positive and negative pairs"""

    def __init__(self, pos_thres = 0.97, neg_thres = 0.9, min_similarity_matrix =0.33, max_similarity_matrix_neg=0.3, examples=None, seed=42, textname='text',labelname='label'):
        self.pos_thres = pos_thres # jaccard similarity index max
        self.neg_thres = neg_thres # jaccard similarity index max
        self.min_similarity_matrix = min_similarity_matrix_pos # threshold the similarity matrix by this, else 0
        self.max_similarity_matrix = max_similarity_matrix_neg # threshold the similarity matrix by this, else 0
        self.lemmatizer = WordNetLemmatizer()
        self.stemmer = PorterStemmer()
        self.stop_words = set(stopwords.words('english'))
        self.random = np.random.RandomState(seed)
        self.label_corpus =None
        self.label2stem =None
        self.textname=textname
        self.labelname=labelname
        #print(self.preprocess_label("The Borrowers’ obligation"))
        #print(self.preprocess_label("The Borrower's obligations"))

        if examples is not None and len(examples)>0:

            # build corpus from examples
            label_corpus, label2stem = self.build_corpus_by_labels(examples)
            self.label_corpus = label_corpus
            self.label2stem = label2stem

            # build label-similarity matrix
            self.SimMat = self.compute_similarity_matrix(list(self.label_corpus.keys()))

    def preprocess_label(self, text):
        if isinstance(text,str):
            tokens = word_tokenize(text.lower())
            # Remove stop words
            filtered_tokens = [token for token in tokens if token not in self.stop_words]
            # Perform lemmatization and stemming
            processed_tokens = [self.lemmatizer.lemmatize(self.stemmer.stem(token)) for token in filtered_tokens]
            processed_tokens = [w for w in processed_tokens if w not in ["'", "’", "’s", "'s", "(",")", ",", "."]]
            # Return the lemmatized and stop word-free tokens as a string
            return sorted(processed_tokens)

        elif isinstance(text,list):
            if len(text)==1:
                return self.preprocess_label(text[0])
            all_labels = [self.preprocess_label(l) for l in text]
            return sorted([subl for l in all_labels for subl in l])
        else:
            raise NotImplementedError(text)

    def build_corpus_by_labels(self, list_of_dict_with_labels_and_text):
        """Makes a dictionary of (tokenized/stemmed) labels:List[str] as the corpus by labels"""
        label_corpus = {}
        label2lem = {}
        for example in list_of_dict_with_labels_and_text:
            label = example[self.labelname]
            s = example[self.textname]
            if tuple(label) not in label2lem:
                labelstemmed = tuple(self.preprocess_label(label))
                label2lem[tuple(label)] = labelstemmed
            else:
                labelstemmed = label2lem[tuple(label)]
            if labelstemmed not in label_corpus.keys():
                label_corpus[labelstemmed] = []
            if s not in label_corpus[labelstemmed]:
                label_corpus[labelstemmed].append(s)

        # next, calculate the similarities between all pairs of keys
        return label_corpus, label2lem


class DatasetTripletsSimilarityByCoLabel(DatasetTriplets):

    def process(self, list_of_data):
        """Makes (query,pos,neg)-triplets, converts samples to dataframe for pytorch iteration"""

        # initialize the LabelProcessor
        label_processor = self.label_processor_class(
            examples = list_of_data,
            textname = self.focal_text_name
        )

        # find positives
        list_of_data = label_processor.find_positives(list_of_data)

        # only do ones with positives (otherwise no point)
        #list_of_data = [example for example in list_of_data if len(example['positives'])>0]
        #print(len(list_of_data))

        # find negatives
        list_of_data = label_processor.find_negatives(list_of_data, n_negatives=self.n_negatives)
        print(len(list_of_data))

        # loop through the data and add each triplets
        self._loop_through_list_of_data_and_add_to_selfdata(list_of_data = list_of_data)

        # harden the dataset to pandas dataframe
        df = self.sample_data_and_make_static_dataframe(self.data)
        return df #pd.DataFrame({})

    def _build_corpus_of_potential_negatives(self):
        pass

    def _find_negative(self):
        pass

    def _find_positives_and_add_to_data(self):
        """For data that has a label, this can be used to artifically find and create synthetic positives"""
        pass

    def _find_negatives_and_add_to_data(self):
       pass


In [None]:
sts_statics_datsets['train'][0]

{'query': "COMMISSION DECISION of 10 February 1999 amending the Decision on the Liaison Group on the Elderly (notified under document number C(1999) 211) (1999/141/EC)\nTHE COMMISSION OF THE EUROPEAN COMMUNITIES\n,\nHaving regard to the Treaty establishing the European Community,\nWhereas, in the light of developments at Community level, it is necessary to adjust the membership of the Group set up by Commission Decision 91/544/EEC (1), as amended by Decision 93/417/EEC (2); whereas at the same time, in the interests of administrative efficiency, the terms of office of the Chairman and of the Members of the Group should be reduced,\nDecision 91/544/EEC is amended as follows:\n1. in Article 3(2), '25 members` is replaced by '24 members`;\n2. Article 4(3) is amended as follows:\n(a) in each case, 'five seats` is replaced by 'four seats`;\n(b) the following indent is added:\n'- ESCU-European Senior Citizens Union: four seats`;\n3. in Article 5(1), '18 months` is replaced by '12 months`;\n4

In [None]:
class LabelProcesserEurlex(LabelProcesser):
    """Preprocesses labels of EURLEX for semantic similarity, as well as functionality for finding positive and negative pairs"""

    def __init__(self, pos_thres = 0.97, neg_thres = 0.9, min_similarity_matrix_pos =0.33, max_similarity_matrix_neg =0.30,  examples=None, seed=42, textname='text',labelname='label'):
        self.pos_thres = pos_thres # jaccard similarity index max
        self.neg_thres = neg_thres # jaccard similarity index max
        self.min_similarity_matrix = min_similarity_matrix_pos # threshold the similarity matrix by this, else 0
        self.max_similarity_matrix = max_similarity_matrix_neg # threshold the similarity matrix by this, else 0
        self.random = np.random.RandomState(seed)
        self.label_corpus =None
        self.label2stem =None
        self.textname=textname
        self.labelname=labelname
        #print(self.preprocess_label("The Borrowers’ obligation"))
        #print(self.preprocess_label("The Borrower's obligations"))

        if examples is not None and len(examples)>0:

            # build corpus from examples
            label_corpus, label2stem = self.build_corpus_by_labels(examples)
            self.label_corpus = label_corpus
            self.label2stem = label2stem

            # build label-similarity matrix
            self.SimMat = self.compute_similarity_matrix(list(self.label_corpus.keys()))

    def preprocess_label(self, text):
        # eurlex labels are already "tokenized" into integers of concepts
        if isinstance(text,str):
            return text
        elif isinstance(text,list):
            if len(text)==1:
                return text
            return sorted(list(set(text)))
        else:
            raise NotImplementedError(text)

    def build_corpus_by_labels(self, list_of_dict_with_labels_and_text):
        """Makes a dictionary of (tokenized/stemmed) labels:List[str] as the corpus by labels"""
        label_corpus = {}
        label2lem = {}
        for example in list_of_dict_with_labels_and_text:
            label = example[self.labelname]
            s = example[self.textname]
            if tuple(label) not in label2lem:
                labelstemmed = tuple(self.preprocess_label(label))
                label2lem[tuple(label)] = labelstemmed
            else:
                labelstemmed = label2lem[tuple(label)]
            if labelstemmed not in label_corpus.keys():
                label_corpus[labelstemmed] = []
            if s not in label_corpus[labelstemmed]:
                label_corpus[labelstemmed].append(s)

        # next, calculate the similarities between all pairs of keys
        return label_corpus, label2lem

In [None]:
sts_statics_datsets['train'][0]

label_processer_eurlex = LabelProcesserEurlex(
    pos_thres = 0.97,
    neg_thres = 0.9,
    min_similarity_matrix_pos =0.33,
    examples=sts_statics_datsets['train'],
    seed=42,
    textname='query',
    labelname='label'
)

In [None]:
sts_statics_datsets['train'] = label_processer_eurlex.find_positives(sts_statics_datsets['train'])

sts_statics_datsets['train'] = label_processer_eurlex.find_negatives(sts_statics_datsets['train'], n_negatives=3)
#print(len(list_of_data))

In [None]:
foo = [e for e in sts_statics_datsets['train'] if bool(e['positives'])]

In [None]:
sts_torchdataset_train_eurlex = DatasetTripletsSimilarityByCoLabel(
    list_of_data=[
        example for example in sts_statics_datsets['train'] if example['type']=='sts_by_textlabel'
    ],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
    seed = 42,
    label_processor_class = LabelProcesserEurlex
)

2000


In [None]:
sts_torchdataset_train_ledgar = DatasetTripletsSimilarityByCoLabel(
    list_of_data=[
        example for example in sts_statics_datsets['train'] if example['type']=='sts_by_textlabel'
    ],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
    seed = 42,
    label_processor_class = LabelProcesserLedgar
)

  best_candidate_label = self.random.choice(corpus_keys)


1250


In [None]:
sts_torchdataset_train_eurolex[270]

{'query': '15.6.2007 EN Official Journal of the European Union L 155/31\nCOMMISSION REGULATION (EC) No 662/2007\nof 14 June 2007\nfixing the export refunds on white and raw sugar exported without further processing\nTHE COMMISSION OF THE EUROPEAN COMMUNITIES\n,\nHaving regard to the Treaty establishing the European Community,\nHaving regard to Council Regulation (EC) No 318/2006 of 20\xa0February 2006 on the common organisation of the market in the sugar sector\xa0(1), and in particular the second subparagraph of Article 33(2) thereof,\nWhereas:\n(1) Article 32 of Regulation (EC) No 318/2006 provides that the difference between prices on the world market for the products listed in Article 1(1)(b) of that Regulation and prices for those products on the Community market may be covered by an export refund.\n(2) Given the present situation on the sugar market, export refunds should therefore be fixed in accordance with the rules and certain criteria provided for in Articles 32 and 33 of Re

In [None]:
for example in sts_statics_datsets['train']:
    if example['type']=='sts_by_textlabel':
        assert 'label' in example.keys()


In [None]:
labelprocessor = LabelProcesserLedgar(examples = [
  example for example in sts_statics_datsets['train'] if example['type']=='sts_by_textlabel'
])

foopos = labelprocessor.find_positives([
  example for example in sts_statics_datsets['train'] if example['type']=='sts_by_textlabel'
])

print(sum([bool(d['positives']) for d in foopos])/len(foopos))

fooneg = labelprocessor.find_negatives([
  example for example in sts_statics_datsets['train'] if example['type']=='sts_by_textlabel'
])

print(sum([bool(d['negatives']) for d in fooneg])/len(fooneg))

['borrow', 'oblig']
['borrow', 'oblig']
0.4376


  best_candidate_label = self.random.choice(corpus_keys)


1.0


In [None]:

# convert to torch dataset (val)
sts_torchdataset_val = DatasetTriplets(
    list_of_data = [
       x for x in sts_statics_datsets['val'] if x.get('type','na') == 'sts_triplet'
    ],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
)
# convert to torch dataset (train)
print('STS DatasetTriplet')
sts_torchdataset_train = DatasetTriplets(
    list_of_data = [
       x for x in sts_statics_datsets['train'] if x.get('type','na')== 'sts_triplet'
    ],
    n_negatives= 3,
    focal_text_name ='query',
    positives_text_name ='positives',
    negativess_text_name ='negatives',
)

done finding negatives
STS DatasetTriplet
done finding negatives
{'positives': [], 'negatives': ['We present a novel method for approximately equilibrating a matrix using only multiplication by the matrix and its transpose. Our method is based on convex optimization and projected stochastic gradient descent, using an unbiased estimate of a gradient obtained by a randomized method. Our method provably converges in expectation and empirically gets good results with a small number of iterations. We show how the method can be applied as a preconditioner for matrix-free iterative algorithms, substantially reducing the iterations required to reach a given level of precision. We also derive a novel connection between equilibration and condition number, showing that equilibrationminimizes an upper bound on the condition number over all choices of row and column scalings.']}
this is missing a positive example


In [None]:
sts_torchdataset_train[95]

{'query': 'After two summers , Scout and Jem find small presents in a tree outside the Radley place .',
 'pos': 'Following two summers of friendship with Dill , Scout and Jem find that someone is leaving them small gifts in a tree outside the Radley place .',
 'neg': 'It is of critical relevance that designers are able to comprehend the various kinds of design-level modifications that a system undergoes throughout its entire lifecycle. In this respect, an interesting and useful operation between subsequent system versions is the model difference calculation and representation. In this paper, a metamodel independent approach to the representation of model differences which is agnostic of the calculation method is presented. Given two models which conform to a metamodel, their difference is conforming to another metamodel derived from the former by an automated transformation. Difference models are first-class entities which induce transformations able to apply the modifications they spe

### Classifications Datasets

- SNLI - no, that will be its own task
- ag_news classification - (a couple of labels) https://huggingface.co/datasets/ag_news/viewer/default/train?row=100039
- dbpedia_14 - news classification or topic ? https://huggingface.co/datasets/dbpedia_14 (~14 labels corresponding to art or building types)
- sentiment analysis -- ?
- ccdv/patent-classification - 25k (abstract)
- fkdosilovic/docee-event-classification (21.9k) - 59 event types (fire, diaster)
- scholarly360/contracts-classification-instruction-llm-experiments - 6.05k (clauses) -- no, I think these are just the auto-labels from LEDGAR
- 'rcds/swiss_judgment_prediction','mt_en', (59703 examples) (NO, it is autotranslated)
- **'tum-nlp/cannot-dataset'** - like entailment, but contains paraphrases & negations
- samchain/BIS_Speeches_97_23 - next sentence prediction
- I could synthetically make another next-sentence-prediction using wikipedia?

I could combine all into a multilabel exercise

In [None]:

## Need a function to randomly
# ... function takes the first 5000 entries as the dev set
# ... then skips 5000 to make the starting position for the train set
# ... then randomly takes another start position to cycle trhough all the data
# ... then what? Hardens it and converts it into 512 chunks? filters out small segments (<200)




def nwords(sentence):
    return len([w for w in sentence.split(' ') if len(w)>0])

def process_streaming_mlm_data(data_config):
    """Creates dev-set and a random chunk for training set from a massive streaming dataset (pile)"""
    if data_config['dataset_probabilities'] is not None:
        dataset_probabilities = [a/sum(data_config['dataset_probabilities']) for a in data_config['dataset_probabilities']]
    else:
        dataset_probabilities = [1.0/len(data_config['files']) for _ in range(len(data_config['files']))]

    # concatenate list of streaming datasets
    datasets_to_stream = []
    for file_to_stream in data_config['files']:
        dataset_to_stream = load_dataset("json", data_files=file_to_stream, split="train", streaming=True)
        datasets_to_stream.append(dataset_to_stream.remove_columns("meta"))

    # combine the datasets to stream together
    datasets_combined = interleave_datasets(
        datasets_to_stream,
        stopping_strategy ='all_exhausted',
        probabilities = dataset_probabilities
    )
    return datasets_combined

# streaming datasets
datasets_combined = process_streaming_mlm_data(data_streaming_config)


import torch
from torch.utils.data import Dataset

class MLMDataset(Dataset):
    """Do I want to pre-tokenize? If so, then the Collator will call .pad"""
    def __init__(self, input_text, tokenizer, max_seq_length=512, min_seq_length=200):
        self.data = []
        self.max_seq_length = max_seq_length
        self.min_seq_length = min_seq_length
        for text in input_text:
            word_count = nwords(text)
            if word_count <= self.max_seq_length and word_count >= self.min_seq_length:
                self.data.append(text)
            elif word_count > self.max_seq_length:
                text_split = text.split(" ")
                chunks = [
                    text_split[i:i+self.max_seq_length] for i in range(0, word_count, 512)
                ]
                chunks = [" ".join(s) for s in chunks if len(s)>=self.min_seq_length]
                self.texts.extend(chunks)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        return text

KeyError: ignored

In [None]:

dataset1 = load_dataset("json", data_files=data_files[0], split="train", streaming=True)
print(next(iter(dataset1)))

dataset2 = load_dataset("json", data_files=data_files[1], split="train", streaming=True)
dataset3 = load_dataset("json", data_files=data_files[2], split="train",streaming=True)

# streaming datasets
streaming_datasets = [
    dataset1.remove_columns("meta"),
    dataset2.remove_columns("meta"),
    dataset3.remove_columns(["label","source"]).rename_column('provision','text') # ledgar
]

combined_dataset = interleave_datasets(streaming_datasets)
combined_dataset = combined_dataset.skip(10001)
next(iter(combined_dataset))

In [None]:
dataset4 = load_dataset("pile-of-law/pile-of-law",'euro_parl',split='train',streaming=True)
dataset4 = dataset4.skip(1000)

INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60


In [None]:
# can I load ledgar
dataset3 = load_dataset("json", data_files=data_files[2], split="train",streaming=True)


INFO:datasets.builder:Using custom data configuration default-de993bbf5aabe685
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json


In [None]:
# these pile datasets cannot be streamed, but then can be loaded individually
all_pile_datasets = ['r_legaladvice', 'courtlistener_docket_entry_documents', 'atticus_contracts', 'courtlistener_opinions', 'federal_register',
           'bva_opinions', 'us_bills', 'cc_casebooks', 'tos', 'euro_parl', 'nlrb_decisions', 'scotus_oral_arguments', 'cfr', 'state_codes',
           'scotus_filings', 'exam_outlines', 'edgar', 'cfpb_creditcard_contracts', 'constitutions', 'congressional_hearings', 'oig',
           'olc_memos', 'uscode', 'founding_docs', 'ftc_advisory_opinions', 'echr', 'eurlex', 'tax_rulings', 'un_debates', 'fre', 'frcp',
           'canadian_decisions', 'eoir', 'dol_ecab', 'icj-pcij', 'uspto_office_actions', 'ed_policy_guidance', 'acus_reports', 'hhs_alj_opinions',
           'sec_administrative_proceedings', 'fmshrc_bluebooks', 'resource_contracts', 'medicaid_policy_guidance', 'irs_legal_advice_memos', 'doj_guidance_documents'
    ]

dataset3 = load_dataset('pile-of-law/pile-of-law',all_pile_datasets[0],split='train')

print(next(iter(dataset3)))

INFO:datasets.info:Loading Dataset Infos from /root/.cache/huggingface/modules/datasets_modules/datasets/pile-of-law--pile-of-law/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60
INFO:datasets.builder:Generating dataset pile-of-law (/root/.cache/huggingface/datasets/pile-of-law___pile-of-law/r_legaladvice/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60)
INFO:datasets.builder:Dataset not on Hf google storage. Downloading and preparing it from source


Downloading and preparing dataset pile-of-law/r_legaladvice to /root/.cache/huggingface/datasets/pile-of-law___pile-of-law/r_legaladvice/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

INFO:datasets.utils.file_utils:https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.r_legaldvice.jsonl.xz not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/3401aa15961b3081a5a04646851c71451f98bc46a642f049a73b5bf2e7ce9876.incomplete


Downloading data:   0%|          | 0.00/61.5M [00:00<?, ?B/s]

INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/train.r_legaldvice.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/3401aa15961b3081a5a04646851c71451f98bc46a642f049a73b5bf2e7ce9876
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/3401aa15961b3081a5a04646851c71451f98bc46a642f049a73b5bf2e7ce9876
INFO:datasets.download.download_manager:Downloading took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

INFO:datasets.utils.file_utils:https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/validation.r_legaldvice.jsonl.xz not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/a1ef937f954b208b1e34406796793ca2d775f7d96ade8fbb7fef66979430b6a8.incomplete


Downloading data:   0%|          | 0.00/68.0 [00:00<?, ?B/s]

INFO:datasets.utils.file_utils:storing https://huggingface.co/datasets/pile-of-law/pile-of-law/resolve/main/data/validation.r_legaldvice.jsonl.xz in cache at /root/.cache/huggingface/datasets/downloads/a1ef937f954b208b1e34406796793ca2d775f7d96ade8fbb7fef66979430b6a8
INFO:datasets.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/a1ef937f954b208b1e34406796793ca2d775f7d96ade8fbb7fef66979430b6a8
INFO:datasets.download.download_manager:Downloading took 0.0 min
INFO:datasets.download.download_manager:Checksum Computation took 0.0 min
INFO:datasets.builder:Generating train split


Generating train split: 0 examples [00:00, ? examples/s]

INFO:datasets.builder:Generating validation split


Generating validation split: 0 examples [00:00, ? examples/s]

INFO:datasets.utils.info_utils:Unable to verify splits sizes.


Error reading file: /root/.cache/huggingface/datasets/downloads/a1ef937f954b208b1e34406796793ca2d775f7d96ade8fbb7fef66979430b6a8
Dataset pile-of-law downloaded and prepared to /root/.cache/huggingface/datasets/pile-of-law___pile-of-law/r_legaladvice/0.0.0/c1090502f95031ebfad49ede680394da5532909fa46b7a0452be8cddecc9fa60. Subsequent calls will reuse this data.
{'text': 'Title: Landlord broke lease agreement, what are my rights? (Chicago, IL)\nQuestion:Our landlord has been promising us a washer/dryer unit since we moved in (July 2015). When we resigned the lease August 2016, we wrote into the lease that an in-unit washer and dryer would be installed by September 30th 2016.\n\nSince September 30th, there have been continuous delays in getting the W/D installed. Since it has now been almost a month past the date the W/D was supposed to be installed, I am wondering what types of rights as a tenant I have? \n\nThanks ahead of time for any and all advice given.\nAnswer #1: You can let your la

In [None]:
next(iter(streaming_datasets[1])) # works

{'text': '543 U.S. 1079\nBARNESv.UNITED STATES.\nNo. 04-7550.\nSupreme Court of United States.\nJanuary 10, 2005.\n\n1\nC. A. 8th Cir. Certiorari denied. Reported below: 374 F. 3d 601.\n\n'}

In [None]:
#dataset_head = pubmed_dataset_streamed.skip(10000) # skipping


In [None]:
combined_dataset = interleave_datasets(streaming_datasets)
combined_dataset = combined_dataset.skip(10001)
next(iter(combined_dataset))

{'text': '\n517 U.S. 706 (1996)\nQUACKENBUSH, CALIFORNIA INSURANCE COMMISSIONER\nv.\nALLSTATE INSURANCE CO.\nNo. 95-244.\nUnited States Supreme Court.\nArgued February 20, 1996.\nDecided June 3, 1996.\nCERTIORARI TO THE UNITED STATES COURT OF APPEALS FOR THE NINTH CIRCUIT\n*708 *708 O\'Connor, J., delivered the opinion for a unanimous Court. Scalia, J., post, p. 731, and Kennedy, J., post, p. 733, filed concurring opinions.\nKarl L. Rubinstein argued the cause for petitioner. With him on the briefs were Dana Carli Brooks, Melissa S. Kooistra, William W. Palmer, and David L. Shapiro. \nDonald Francis Donovan argued the cause for respondent. With him on the brief were Carl Micarelli, Joseph D. Lee, and James G. Sporleder.[*]\n*709 Justice O\'Connor, delivered the opinion of the Court.\nIn this case, we consider whether an abstention-based remand order is appealable as a final order under 28 U. S. C. § 1291, and whether the abstention doctrine first recognized in Burford v. Sun Oil Co., 3

In [None]:
## Need a function to randomly
# ... function takes the first 5000 entries as the dev set
# ... then skips 5000 to make the starting position for the train set
# ... then randomly takes another start position to cycle trhough all the data
# ... then what? Hardens it and converts it into 512 chunks? filters out small segments (<200)
data_files_streaming = [
    "https://the-eye.eu/public/AI/pile_preliminary_components/PUBMED_title_abstracts_2019_baseline.jsonl.zst",
    "https://the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst",
    "https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", # ledgar worked

]
dataset_probabilities = [
    14.40,
    6.12,
    6
]

data_streaming_config = {
    'files':data_files_streaming,
    'val_size':10000,
    'min_seq_length':200,
    'max_seq_length':512,
    'dataset_probabilities':dataset_probabilities
}

def nwords(sentence):
    return len([w for w in sentence.split(' ') if len(w)>0])

def process_streaming_mlm_data(data_config):
    """Creates dev-set and a random chunk for training set from a massive streaming dataset (pile)"""
    if data_config['dataset_probabilities'] is not None:
        dataset_probabilities = [a/sum(data_config['dataset_probabilities']) for a in data_config['dataset_probabilities']]
    else:
        dataset_probabilities = [1.0/len(data_config['files']) for _ in range(len(data_config['files']))]

    # concatenate list of streaming datasets
    datasets_to_stream = []
    for file_to_stream in data_config['files']:
        dataset_to_stream = load_dataset("json", data_files=file_to_stream, split="train", streaming=True)
        datasets_to_stream.append(dataset_to_stream.remove_columns("meta"))

    # combine the datasets to stream together
    datasets_combined = interleave_datasets(
        datasets_to_stream,
        stopping_strategy ='all_exhausted',
        probabilities = dataset_probabilities
    )
    return datasets_combined

# streaming datasets
datasets_combined = process_streaming_mlm_data(data_streaming_config)


INFO:datasets.builder:Using custom data configuration default-6e3092816c4f845b
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json
INFO:datasets.builder:Using custom data configuration default-a1d9e8eaedd958cd
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json
INFO:datasets.builder:Using custom data configuration default-de993bbf5aabe685
INFO:datasets.info:Loading Dataset Infos from /usr/local/lib/python3.10/dist-packages/datasets/packaged_modules/json


KeyError: ignored

In [None]:
if True:

    # make the dev set
    datastream_for_dev = datasets_combined.take(data_config['val_size'])

    # now what? harden the set?
    dataslist_for_dev = list(datastream_for_dev)

    # reject any sentences less thatn data_config['min_sentence_size]
    dataslist_for_dev = [s['text'] for s in dataslist_for_dev if nwords(s['text']) > data_config['min_seq_length']]

    # maybe use line by line
    #dataset = LineByLineTextDataset(tokenizer=tokenizer, examples=openwebtext_dataset, block_size = 512)


In [None]:

## Need a function to randomly
# ... function takes the first 5000 entries as the dev set
# ... then skips 5000 to make the starting position for the train set
# ... then randomly takes another start position to cycle trhough all the data
# ... then what? Hardens it and converts it into 512 chunks? filters out small segments (<200)
data_files_streaming = [
    "https://the-eye.eu/public/AI/pile_preliminary_components/PUBMED_title_abstracts_2019_baseline.jsonl.zst",
    "https://the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst",
    "https://drive.switch.ch/index.php/s/j9S0GRMAbGZKa1A/download?path=%2F&files=LEDGAR_2016-2019.jsonl.zip", # ledgar worked

]
dataset_probabilities = [
    14.40,
    6.12,
    6
]

data_streaming_config = {
    'files':data_files_streaming,
    'val_size':10000,
    'min_seq_length':200,
    'max_seq_length':512,
    'dataset_probabilities':dataset_probabilities
}

def nwords(sentence):
    return len([w for w in sentence.split(' ') if len(w)>0])

def process_streaming_mlm_data(data_config):
    """Creates dev-set and a random chunk for training set from a massive streaming dataset (pile)"""
    if data_config['dataset_probabilities'] is not None:
        dataset_probabilities = [a/sum(data_config['dataset_probabilities']) for a in data_config['dataset_probabilities']]
    else:
        dataset_probabilities = [1.0/len(data_config['files']) for _ in range(len(data_config['files']))]

    # concatenate list of streaming datasets
    datasets_to_stream = []
    for file_to_stream in data_config['files']:
        dataset_to_stream = load_dataset("json", data_files=file_to_stream, split="train", streaming=True)
        datasets_to_stream.append(dataset_to_stream.remove_columns("meta"))

    # combine the datasets to stream together
    datasets_combined = interleave_datasets(
        datasets_to_stream,
        stopping_strategy ='all_exhausted',
        probabilities = dataset_probabilities
    )
    return datasets_combined

# streaming datasets
datasets_combined = process_streaming_mlm_data(data_streaming_config)


import torch
from torch.utils.data import Dataset

class MLMDataset(Dataset):
    """Do I want to pre-tokenize? If so, then the Collator will call .pad"""
    def __init__(self, input_text, tokenizer, max_seq_length=512, min_seq_length=200):
        self.data = []
        self.max_seq_length = max_seq_length
        self.min_seq_length = min_seq_length
        for text in input_text:
            word_count = nwords(text)
            if word_count <= self.max_seq_length and word_count >= self.min_seq_length:
                self.data.append(text)
            elif word_count > self.max_seq_length:
                text_split = text.split(" ")
                chunks = [
                    text_split[i:i+self.max_seq_length] for i in range(0, word_count, 512)
                ]
                chunks = [" ".join(s) for s in chunks if len(s)>=self.min_seq_length]
                self.texts.extend(chunks)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        return text

In [None]:
dataset_mlm_val = MLMDataset(texts = dataslist_for_dev, max_seq_length=data_config['max_seq_length'], min_seq_length=data_config['min_seq_length'])

In [None]:
from transformers.data.data_collator import DataCollatorForLanguageModeling, Mapping
collator_mlm = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm = True,
    pad_to_multiple_of = 4
)


In [None]:
len(dataslist_for_dev[7].split(" "))

215

In [None]:
# Example of loading multiple datasets
if False:
    from datasets import load_dataset

    # Download Wikipedia dataset
    wikipedia_dataset = load_dataset('wikipedia', '20200501.en', split='train')

    # Download OpenWebText dataset
    openwebtext_dataset = load_dataset('openwebtext', split='train')

    # Download BookCorpus dataset
    bookcorpus_dataset = load_dataset('bookcorpus', split='train')

    # Preprocess and tokenize the datasets
    from transformers import BertTokenizer

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def preprocess_function(examples):
        return tokenizer(examples['text'], truncation=True, padding='max_length')

    wikipedia_dataset = wikipedia_dataset.map(preprocess_function, batched=True)
    openwebtext_dataset = openwebtext_dataset.map(preprocess_function, batched=True)
    bookcorpus_dataset = bookcorpus_dataset.map(preprocess_function, batched=True)

    # Combine the datasets
    combined_dataset = wikipedia_dataset.concatenate(openwebtext_dataset)
    combined_dataset = combined_dataset.concatenate(bookcorpus_dataset)

    # Shuffle the dataset
    combined_dataset = combined_dataset.shuffle()

    # Split the dataset into training and validation sets
    train_dataset = combined_dataset.train_test_split(test_size=0.1)['train']
    val_dataset = combined_dataset.train_test_split(test_size=0.1)['test']

    # Convert the datasets to PyTorch tensors
    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
    val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

    # Print some examples from the dataset
    print(train_dataset[:5])
    print(val_dataset[:5])


In [None]:
from transformers import BertTokenizer, LineByLineTextDataset, DataCollatorForLanguageModeling
import datasets
# openwebtext_dataset = datasets.load_dataset('openwebtext') full dataset
#openwebtext_dataset = datasets.load_dataset('openwebtext', split=f'train[:{0.03}]') # doesn't work

pubmed_dataset_streamed = load_dataset(
    "json", data_files=data_files, split="train", streaming=True
)

Downloading and preparing dataset openwebtext/plain_text to /root/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521...


Downloading data files:   0%|          | 0/21 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]

Dataset openwebtext downloaded and prepared to /root/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/6f68e85c16ccc770c0dd489f4008852ea9633604995addd0cd76e293aed9e521. Subsequent calls will reuse this data.


ValueError: ignored

In [None]:
from torch.utils.data import Subset

dataset = LineByLineTextDataset(tokenizer=tokenizer, examples=openwebtext_dataset, block_size = 512)

# Create a subset of the dataset with the desired number of samples
subset_dataset = Subset(dataset, range(1000))

In [None]:
from transformers import AutoModelForMaskedLM
basemodelLM = AutoModelForMaskedLM.from_pretrained("google/bert_uncased_L-4_H-512_A-8")

Downloading pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
basemodelLM

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_aff

In [None]:
## try to grab a MLM classification head
## let's verify they have the same vocabulary
# models from: https://arxiv.org/pdf/1908.08962.pdf
from transformers import AutoModelForMaskedLM, AutoConfig
modelstring_base = "google/bert_uncased_L-12_H-512_A-8" #
modelstring_base = "google/bert_uncased_L-4_H-512_A-8"
#modelstring_base = 'google/bert_uncased_L-6_H-512_A-8'
basemod = AutoModelForMaskedLM.from_pretrained(modelstring_base)
basemod_tokenizer = AutoTokenizer.from_pretrained(modelstring_base)
# the minatoure googles have a vocab size of: 30522

modelstring_lg = 'bert-large-uncased' # I think the google-team used this for the miniature models
# bert-large uncased has a vocab size of: 30522
#modelstring_lg = "google/bert_uncased_L-12_H-768_A-12"
largmod = AutoModelForMaskedLM.from_pretrained(modelstring_lg) #
largmod_tokenizer = AutoTokenizer.from_pretrained(modelstring_lg)#"google/bert_uncased_L-12_H-768_A-12")


# note: which datasets used to train large
# wikipedia
# bookcorpus
# ... but seem more about datasets and models from: https://arxiv.org/pdf/1908.08962.pdf


text = "For Ex Works (EXW) terms, the Supplier will [MASK] all risk and liability for the Delivered [MASK] up until delivering the goods to the nominated Carrier."
with torch.no_grad():
    inputs1 = basemod_tokenizer(text, return_tensors='pt')
    outputs1 = basemod(**inputs1)
    preds1 = outputs1.logits
    inputs2 = largmod_tokenizer(text, return_tensors='pt')
    outputs2 = largmod(**inputs2)
    preds2 = outputs2.logits

    assert (inputs1['input_ids']-inputs2['input_ids']).sum()==0, 'ids are different'

    predicted_token_ids1 = preds1[0].argmax(dim=-1)
    predicted_token_ids2 = preds2[0].argmax(dim=-1)

    print(basemod_tokenizer.convert_ids_to_tokens(predicted_token_ids1))
    print(basemod_tokenizer.convert_ids_to_tokens(predicted_token_ids2))


# confirmation: the minature berts and the

Downloading (…)lve/main/config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/116M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-512_A-8 were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['delivery', 'for', 'ex', '##w', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'reduce', 'all', 'risk', 'and', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'delivery', 'carrier', '.', 'is']
['.', 'for', 'ex', 'works', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'assume', 'all', 'risk', ',', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'responsible', 'carrier', '.', '.']


['the', 'for', 'ex', '##w', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'cover', 'all', 'risk', 'and', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'delivered', 'carrier', '.', '.']
['.', 'for', 'ex', 'works', '(', 'ex', '##w', ')', 'terms', ',', 'the', 'supplier', 'will', 'assume', 'all', 'risk', ',', 'liability', 'for', 'the', 'delivered', 'goods', 'up', 'until', 'delivering', 'the', 'goods', 'to', 'the', 'responsible', 'carrier', '.', '.']


In [None]:
predicted_token_ids1


tensor([ 1996,  2005,  4654,  2860,  1006,  4654,  2860,  1007,  3408,  1010,
         1996, 17024,  2097,  3104,  2035,  3891,  1998, 14000,  2005,  1996,
         5359,  5350,  2039,  2127, 12771,  1996,  5350,  2000,  1996,  5359,
         6839,  1012,  1012])

In [None]:
inputs1['input_ids']-inputs2['input_ids']

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [None]:
basemod._modules['cls']

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=512, out_features=30522, bias=True)
  )
)

In [None]:
config = make_config('google/bert_uncased_L-12_H-512_A-8') #

# make the basemod and tokenizer
basemod = AutoModel.from_pretrained(config.model_string)
basemod.to(device)
tokenizer = AutoTokenizer.from_pretrained(config.model_string)


Some weights of the model checkpoint at google/bert_uncased_L-12_H-512_A-8 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
anathem_encoder1 = AnathemBaseModule(config, basemod, tokenizer)
anathem_encoder2 = AnathemMidModule(config, basemod)

In [None]:
time1 = time.time()
for iteration, batch in enumerate(tqdm(dl_train, disable=True)):
    if iteration>30:
        time2 = time.time()
        print(time2-time1)
        break
    with torch.no_grad():
        tokens = tokenize_anathem(batch['text'])
        (hidden_states, extended_attention_masks) = anathem_encoder1(
            input_ids = tokens['input_ids'],
            attention_mask = tokens['attention_mask'],
            token_type_ids = tokens['token_type_ids']
        )
        features,_ = anathem_encoder2(
            hidden_states_highres = hidden_states[0],
            hidden_states_midres = hidden_states[1],
            hidden_states_lowres = hidden_states[2],
            extended_attention_mask_highres = extended_attention_masks[0],
            extended_attention_mask_midres = extended_attention_masks[1],
            extended_attention_mask_lowres = extended_attention_masks[2]
        )

1.2566087245941162


In [None]:
# the new method takes: 3.198051929473877 / 200 iterations (I can't really te)