# Extracting Contextual Representations from LMs

In [7]:
import torch

from minicons import cwe
from minicons.utils import character_span # for demonstrating.

## Cosine similarity calculation setup

Below I've defined a function that computes the cosine similarity between every element in tensor A (usually a 2D tensor) to that of every element in tensor B (also usually a 2D tensor).

For instance, if A is a n x d matrix and B is a m x d matrix, then the resulting matrix will be n x m.

In [5]:
def cosine(a: torch.Tensor, b: torch.Tensor, eps =1e-8) -> torch.Tensor:
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sims = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sims

In [6]:
# Example with random matrices
# the d dimension should be same.

A = torch.randn(3, 32)
B = torch.randn(4, 32)

cosine(A, B)

tensor([[ 0.0463,  0.0113, -0.1011, -0.0974],
        [ 0.0550,  0.0675, -0.2210, -0.1196],
        [-0.1084, -0.3387,  0.2003, -0.0311]])

## Initializing the Model

Use the `cwe.CWE` module to initialize an LM. This can be a variety of different models: BERT, RoBERTa, GPT2, etc.

In [3]:
model = cwe.CWE('bert-base-uncased', device="cpu")

## Data formatting
The function primarily used for extracting representations from models is `model.extract_representation()`. It accepts batches of instances represented in either of the following formats:

```
data = [
  (sentence_1, word_1),
  (sentence_2, word_2),
  ....
  (sentence_n, word_n)
]
```
or

```
data = [
  (sentence_1, (start_1, end_1)),
  (sentence_2, (start_2, end_2)),
  ....
  (sentence_n, (start_n, end_n))
]
```
where `(start_i, end_i)` are the character span indices for the target word in the ith sentence, i.e., `start_i` is the start index, and `end_i` is the end index.

For example, the instance `["I like reading books.", (15, 20)]` corresponds to the word `"books"`.

Regardless of what is specified, `extract_representation()` reduces the input to the second format. For instance, to get the character span indices of *aircraft* in the first sentence:

In [9]:
queries = [
    ("There is a fair number of bright stars, both single and double, in Lepus.", "fair"),
    ("Using most or all of a work does not bar a finding of fair use.", "fair"),
    ("The rivalry has had its fair share of fights as well.", "fair")
]

In [11]:
# example of what a character span looks like. Here, "fair" corresponds 
# to the 11th to 15th character in the string that's in the first 
# element of the queries object
character_span(queries[0][0], 'fair')

(11, 15)

## Extracting representations

Below is some code to extract the representations of `fair` from each of the query sentences, from layers 2 and 8.

In [15]:
layer_embs = model.extract_representation(queries, layer=[2,8])

In [16]:
layer_embs

[tensor([[ 1.2002,  1.0760,  0.6308,  ..., -0.6405, -1.1925, -1.4545],
         [ 1.2930,  1.4050,  0.1073,  ..., -0.0766, -1.3806, -1.1698],
         [ 1.2247,  1.1270,  0.4763,  ..., -0.0422, -1.3762, -1.2325]]),
 tensor([[ 0.8524,  0.1222,  0.1619,  ..., -0.7295,  0.1747,  0.0786],
         [ 0.8999, -0.2127,  0.1966,  ...,  0.2875,  0.1784, -0.4899],
         [ 0.6417, -0.6083, -0.1415,  ..., -0.2819, -0.4101, -0.2447]])]

You can also just specify `layer='all'` to get representatiosn from all layers!

In [17]:
layer_embs = model.extract_representation(queries, layer='all')

## Cosine computation

Now let's compute the cosine similarities of the representations of "fair" across all elements to all other elements in the query. This will give us a 3x3 matrix, but for brevity, I will print the first row. This corresponds to the similarity of "fair" in the first sentence to that of all "fair" in all sentences (which means the first element will be 1.0, since it is the similarity between two things that are the same). For simplicity, I will compute these values for layers 0, 8, and 11.

In [25]:
for layer in [0, 8, 11]:
    first_sim = cosine(layer_embs[layer], layer_embs[layer])[0]
    print(f"Layer {layer}: {first_sim}")

Layer 0: tensor([1.0000, 0.9374, 0.9580])
Layer 8: tensor([1.0000, 0.4051, 0.6927])
Layer 11: tensor([1.0000, 0.3901, 0.6451])
