# Install Cuda and Minicons

In [12]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install minicons

Looking in indexes: https://download.pytorch.org/whl/cu124



[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import random
import os
from minicons import cwe
import torch

torch.cuda.is_available()

True

# Load stimuli and model

In [3]:
stimuli_folder = "./stimuli"
stimuli = []
for file in os.listdir(stimuli_folder):
    with open(os.path.join(stimuli_folder, file), 'r') as stimulus:
        file_contents = stimulus.read().replace('\n', '')
        text_split_at_space = file_contents.split()
        stimuli.append(text_split_at_space)

# Load model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = cwe.CWE(model_name='bert-base-uncased', device=device)



# Produce Sliding window context representations

In [4]:
# Function to get sliding window context representations
def get_sliding_window_context_representations(data: list[str], window_size: int, layer: int) -> list[list[tuple[list[str], torch.Tensor]]]:
    def get_context_words(arr, pos):
        start = max(0, pos - window_size)
        words = " ".join(arr[start:pos + 1])
        return [words, arr[pos]]

    results = []
    for pos in range(len(data)):
        context_words = get_context_words(data, pos)
        representation = model.extract_representation(context_words, layer=layer)
        results.append([context_words, representation])
    return results

## Extract context representations for one story

In [5]:
layer = 12
story_representations = get_sliding_window_context_representations(data=stimuli[0],
                                                                   window_size=4,
                                                                   layer=layer)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


In [6]:
random_index = random.randint(0, len(story_representations) - 1)
context, context_representation = story_representations[random_index]

print(f"Random context (index {random_index}) with its layer 12 representation:\n{context}\n{context_representation}")


Random context (index 1595) with its layer 12 representation:
['to them. And {BR} I', 'I']
tensor([[-9.7179e-01,  5.6414e-02, -7.4843e-02, -3.6394e-01,  6.5555e-01,
          3.7119e-01,  2.6344e-01,  8.7313e-01, -3.3810e-01, -4.6067e-01,
         -4.3446e-01,  2.5819e-02,  1.5414e-01, -6.9084e-02, -1.0626e-01,
         -1.2962e-01,  4.3342e-01,  2.5932e-03,  6.1435e-02, -1.1143e-01,
         -3.5935e-02, -6.9124e-02, -1.0290e+00, -4.1050e-01,  4.9770e-01,
          8.0817e-02, -3.2828e-01, -2.4316e-01,  3.9930e-01,  2.3143e-01,
          4.0242e-01, -8.3865e-02, -1.2525e-02,  4.3984e-01, -1.3059e+00,
         -1.8592e-01, -2.5856e-01,  4.9956e-01, -6.0177e-01,  3.8538e-01,
         -3.1858e-01, -5.5120e-01,  2.2789e-01, -8.7787e-01,  1.0639e-01,
         -4.5276e-01,  9.7732e-01,  3.8869e-01,  2.6941e-01,  6.6239e-01,
         -1.0177e+00, -9.0851e-03, -3.4074e-02, -2.4615e-01,  1.0542e-02,
         -2.7447e-01,  6.7200e-01,  4.6330e-01, -8.6003e-01, -5.5495e-01,
          1.1377e+00,

## Extract context representations for all stories

In [7]:
all_representations = [get_sliding_window_context_representations(stimulus, window_size=5, layer=None) for stimulus in stimuli]



In [11]:
context, representation = all_representations[0][10]

print(context)
print(representation)

['that the universe that we inhabit', 'inhabit']
tensor([[ 5.4745e-01,  8.8868e-01,  2.0234e-01, -2.6698e-01,  4.1239e-01,
          1.0069e-01,  3.8865e-01,  7.0537e-01,  5.4214e-01, -1.2111e+00,
          4.8280e-01, -1.4419e-01, -1.0896e+00,  1.0598e+00, -4.1688e-01,
          4.1683e-01, -2.0451e-01,  5.6005e-03, -2.1767e-01,  4.8985e-01,
         -3.9682e-01,  2.0072e-01, -2.9084e-01,  2.3084e-01,  2.4638e-01,
         -1.9373e-02, -1.0311e-01,  8.1191e-01, -1.3286e-01, -4.1009e-01,
         -2.4934e-01,  1.1043e+00,  1.2458e-01,  2.3373e-01, -5.7065e-02,
          5.9015e-01, -1.8757e-02, -5.5312e-02, -1.3660e-01,  5.1150e-01,
          2.4910e-01, -3.0834e-01,  2.7299e-01, -1.6936e-01, -2.1597e-01,
         -2.5065e-01,  4.3360e-01,  2.6983e-01,  3.2872e-01,  6.1394e-01,
         -5.4159e-01,  1.0774e+00, -7.9726e-01,  4.7545e-02,  8.0608e-01,
          9.8120e-01,  6.5279e-01, -5.5325e-01,  6.5457e-02, -3.2394e-01,
         -1.7371e-01,  9.0868e-01,  1.0276e-01, -2.7416e-01, -8