In [1]:
# https://github.com/hkproj/mistral-llm-notes/blob/main/sliding_window_attention.ipynb

In [2]:
print_order = ['the', 'cat', 'is', 'on', 'a', 'chair']
sequence = [{print_order[i]} for i in range(len(print_order))]
sequence

[{'the'}, {'cat'}, {'is'}, {'on'}, {'a'}, {'chair'}]

In [5]:
sliding_window_size = 3

def sliding_window_attention(seq: list[set[str]], w: int):
    seq_len = len(seq)
    attention_scores: list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)]
    for i, q_tokens_set in enumerate(seq):
        for j, k_tokens_set in enumerate(seq):
            # The upper triangle is all None
            if j > i:
                continue
            # Each token can only attend to the previous W tokens
            if i - j >= w:
                continue

            attention = set()
            # Add all tokens from q_tokens_set to attention_result
            attention.update(q_tokens_set)
            # Add all tokens from k_tokens_set to attention_resul
            attention.update(k_tokens_set)

            attention_scores[i][j] = attention
    return attention_scores

def multiple_by_v(attention_scores: list[list[set]], v_sequence: list[set[str]]) -> list[set[str]]:
    seq_len = len(v_sequence)
    result = [set() for _ in range(seq_len)]
    for i in range(seq_len):
        for j in range(seq_len):
            attention = attention_scores[i][j]
            v = v_sequence[j]
            r = result[i]
            # Add all the tokens in the attention (if not None) to r
            if attention is not None:
                # Add all the tokens in v to r
                r.update(v)
                r.update(attention)
    return result

def print_attention(attention_scores: list[list[set[str]]]):
    for i, row in enumerate(attention_scores):
        for j, attention in enumerate(row):
            if attention is None:
                print('None', end='\t')
            else:
                print(f'{sorted(attention, key=lambda x: print_order.index(x))}', end='\t')
        print()

def print_sequence(seq: list[set[str]]):
    for i, tokens_set in enumerate(seq):
        print(f'{i}: {sorted(tokens_set, key=lambda x: print_order.index(x))}')

def print_layer(input: list[set[str]], layer_num: int) -> list[set[str]]:
    print(f'Layer {layer_num} input:')
    print_sequence(input)
    attention_scores = sliding_window_attention(input, sliding_window_size)
    print()
    print(f'Layer {layer_num} attention scores:')
    print_attention(attention_scores)
    output = multiple_by_v(attention_scores, input)
    print()
    print(f'Layer {layer_num} output:')
    print_sequence(output)
    return output

In [6]:
# layer 1
output_layer_1 = print_layer(sequence, 1)

Layer 1 input:
0: ['the']
1: ['cat']
2: ['is']
3: ['on']
4: ['a']
5: ['chair']

Layer 1 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['cat']	None	None	None	None	
['the', 'is']	['cat', 'is']	['is']	None	None	None	
None	['cat', 'on']	['is', 'on']	['on']	None	None	
None	None	['is', 'a']	['on', 'a']	['a']	None	
None	None	None	['on', 'chair']	['a', 'chair']	['chair']	

Layer 1 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']


In [7]:
# Layer 2
output_layer_2 = print_layer(output_layer_1, 2)

Layer 2 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['cat', 'is', 'on']
4: ['is', 'on', 'a']
5: ['on', 'a', 'chair']

Layer 2 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['cat', 'is', 'on', 'a']	['is', 'on', 'a']	None	
None	None	None	['cat', 'is', 'on', 'a', 'chair']	['is', 'on', 'a', 'chair']	['on', 'a', 'chair']	

Layer 2 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']


In [8]:
# Layer 3
output_layer_3 = print_layer(output_layer_2, 3)

Layer 3 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']

Layer 3 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['cat', 'is', 'on', 'a', 'chair']	

Layer 3 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']


In [9]:
# Layer 3
output_layer_3 = print_layer(output_layer_2, 3)

Layer 3 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['cat', 'is', 'on', 'a', 'chair']

Layer 3 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['cat', 'is', 'on', 'a', 'chair']	

Layer 3 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']


In [11]:

# Layer 4
output_layer_4 = print_layer(output_layer_3, 4)

Layer 4 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']

Layer 4 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	

Layer 4 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']


In [12]:

# Layer 5
output_layer_5 = print_layer(output_layer_4, 5)

Layer 5 input:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']

Layer 5 attention scores:
['the']	None	None	None	None	None	
['the', 'cat']	['the', 'cat']	None	None	None	None	
['the', 'cat', 'is']	['the', 'cat', 'is']	['the', 'cat', 'is']	None	None	None	
None	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	['the', 'cat', 'is', 'on']	None	None	
None	None	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	['the', 'cat', 'is', 'on', 'a']	None	
None	None	None	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	['the', 'cat', 'is', 'on', 'a', 'chair']	

Layer 5 output:
0: ['the']
1: ['the', 'cat']
2: ['the', 'cat', 'is']
3: ['the', 'cat', 'is', 'on']
4: ['the', 'cat', 'is', 'on', 'a']
5: ['the', 'cat', 'is', 'on', 'a', 'chair']
