# Prototype for masking k contiguous sentences

In [36]:
import numpy as np
rng = np.random.default_rng()


In [37]:

def merge_intervals(intervals):
    """Merge intervals to be non-overlapping.
    Source: https://github.com/facebookresearch/SpanBERT/blob/main/pretraining/fairseq/data/masking.py
    """
    intervals = sorted(intervals, key=lambda x : x[0])
    merged = []
    for interval in intervals:
        # if the list of merged intervals is empty or if the current
        # interval does not overlap with the previous, simply append it.
        if not merged or merged[-1][1] + 1 < interval[0]:
            merged.append(interval)
        else:
        # otherwise, there is overlap, so we merge the current and previous
        # intervals.
            merged[-1][1] = max(merged[-1][1], interval[1])

    return merged

def create_covered_indices(spans, max_index=None):
    """From a list of spans, create all indices covered by them."""
    spans = np.array(spans)
    if not max_index:
        max_index = spans[:, 1].max()

    covered = np.zeros(max_index, dtype=bool)
    for start, end in spans:
        covered[start:end] = True

    return np.where(covered)[0]

    



In [38]:
def span_masking(sentence_lengths, mask_frac=0.3, k=3):
    n_sentences = len(sentence_lengths)
    n_tokens = sum(sentence_lengths)
    cumsum_lengths = np.cumsum(sentence_lengths)
    spans = []

    span_starts = rng.choice(np.arange(n_sentences), size=n_sentences, replace=False)
    # span_lengths = rng.geometric(p=0.4, size = 10) # we'll probably need a higher p than 0.2 since we're masking at the sentence level
    span_lengths = np.full(n_sentences, k)
    print(span_lengths)

    n_masked_tokens = 0
    for span_start, span_length in zip(span_starts, span_lengths):
        span_end = span_start + span_length
        span_end = min(span_end, n_sentences-1)
        tokens_in_span = cumsum_lengths[span_end] - cumsum_lengths[span_start]
        if (n_masked_tokens + tokens_in_span) / n_tokens > mask_frac:
            break
        spans.append([span_start, span_end])
        spans = merge_intervals(spans)
        print(f"spans are: {spans}")
        n_masked_tokens += tokens_in_span
        print(f"while loop is at {n_masked_tokens / n_tokens} ")

    return create_covered_indices(spans)



In [41]:
sentence_lengths = rng.choice(np.arange(5), size=20)
sentences_to_mask = span_masking(sentence_lengths, mask_frac=0.6)
# the function should be called after `find_sentences_to_mask` (which will need to return also the sentence lengths)
sentences_to_mask
# this can be plugged in to the `mask_tokens_in_sentence` function

[3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3]
spans are: [[3, 6]]
while loop is at 0.12121212121212122 
spans are: [[3, 6], [16, 19]]
while loop is at 0.42424242424242425 
spans are: [[3, 6], [16, 19]]
while loop is at 0.42424242424242425 


array([ 3,  4,  5, 16, 17, 18])