# Chunk Diagnosis

Some notes what's going on with [`long_terms_tokenizer`](https://github.com/IreneZihuiLi/HiPool/blob/main/Dataset_Split_Class.py).

`long_terms_tokenizer` should really be called `chunk` or `chunk_document` since it does the following: 

- Tokenizes a document using BERT
- Splits the document into chunks
- Returns a dict with a matrix of input ids, a matrix of attention masks, etc.

One of the quirks of this function is that it includes the `[CLS]` and `[SEP]` tokens in the length of the chunk, which makes calculating the overlap more complicated. In this notebook, `[CLS]` and `[SEP]` are also referred to as "start" and "end" for convenience. Here's an example that just includes the chunking logic from the function, ignoring the extra things like token type IDs, labels, etc.:

In [1]:
import math

import torch

In [2]:
# Toy list of tokens (all the ints from 1-100, inclusive)
tokens = list(range(1,101))
chunk_len = 20
overlap_len = 10
stride = overlap_len - 2
number_chunks = math.floor(100/stride)
chunks = []
for current in range(number_chunks - 1):
    chunk_toks = tokens[current*stride:current*stride+chunk_len-2]
    # These are easier-to-read stand-ins for [CLS] and [SEP]
    chunk_toks = ["start"] + chunk_toks + ["end"]  
    chunks.append(chunk_toks)

for chunk in chunks:
    print(chunk)

['start', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 'end']
['start', 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 'end']
['start', 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 'end']
['start', 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 'end']
['start', 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 'end']
['start', 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 'end']
['start', 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 'end']
['start', 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 'end']
['start', 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 'end']
['start', 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 'end']
['start', 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 'end']


In [3]:
print(f"Expected number of chunks: {number_chunks}")
print(f"Actual number of chunks: {len(chunks)}")

Expected number of chunks: 12
Actual number of chunks: 11


This implementation mostly works. The overlap is correct if you don't count special tokens but this implementation causes the last chunk to be truncated because `range(number_chunks - 1)` doesn't account for the fact that the high value is exclusive.

As we can see, because we are missing a chunk, we don't have full coverage.

I suspect that this may have been done intentionally to avoid dealing with differently sized chunks, but that can be overcome by padding the last chunk.

Below shows what happens when `number_chunks - 1` is replaced with `number_chunks`.

In [4]:
chunks = []
for current in range(number_chunks):
    chunk_toks = tokens[current*stride:current*stride+chunk_len-2]
    chunk_toks = ["start"] + chunk_toks + ["end"]
    chunks.append(chunk_toks)

for chunk in chunks:
    print(chunk)

['start', 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 'end']
['start', 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 'end']
['start', 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 'end']
['start', 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 'end']
['start', 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 'end']
['start', 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 'end']
['start', 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 'end']
['start', 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 'end']
['start', 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 'end']
['start', 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 'end']
['start', 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 'end']
['start', 89, 90, 91, 92, 93, 9

## An alternative chunking implementation

This implementation does not include the start and end tokens as part of the length of the chunk. This means we can chunk:

1. Chunk the document
2. Pad the last chunk
3. Create a tensor from the chunk.
4. Concatenate all the start and end tokens at once.

This streamlines the chunking logic.

In [22]:
# Toy list of tokens (all the ints from 1-100, inclusive)
tokens = list(range(1,111))
chunk_len = 20
overlap_len = 10
chunks = []
current_idx = 0
while True:
    chunks.append(tokens[current_idx: current_idx + chunk_len])
    if current_idx + chunk_len >= len(tokens):
        break
    else:
        current_idx += chunk_len - overlap_len

# Suppose -1 is pad token
last_chunk_padding = [-1] * (chunk_len - len(chunks[-1]))
chunks[-1] = chunks[-1] + last_chunk_padding
for chunk in chunks:
    print(chunk)

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
[31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]
[51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
[71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90]
[81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]
[91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110]


In [7]:
# Now create tensor and start/end tokens
x = torch.tensor(chunks)
# Now we'll use 101 and 102 to represent [CLS] and [SEP]
start = torch.tensor(101).repeat(x.shape[0]).unsqueeze(dim=1)
end = torch.tensor(102).repeat(x.shape[0]).unsqueeze(dim=1)
y = torch.cat((start, x, end), dim=1)

In [8]:
print(y.shape)
print(y)

torch.Size([10, 22])
tensor([[101,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20, 102],
        [101,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
          24,  25,  26,  27,  28,  29,  30, 102],
        [101,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,
          34,  35,  36,  37,  38,  39,  40, 102],
        [101,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,
          44,  45,  46,  47,  48,  49,  50, 102],
        [101,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
          54,  55,  56,  57,  58,  59,  60, 102],
        [101,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,
          64,  65,  66,  67,  68,  69,  70, 102],
        [101,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
          74,  75,  76,  77,  78,  79,  80, 102],
        [101,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  8

We only have 10 chunks instead of 11 or 12 because we're including a full 20 tokens per chunk (plus 2 special tokens) instead of 18 tokens (plus 2 special tokens).