In [1]:
import torch

from minicons import cwe

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def cosine(a: torch.Tensor, b: torch.Tensor, eps =1e-8) -> torch.Tensor:
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sims = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sims

In [3]:
model = cwe.CWE('bert-base-uncased', device="cpu")

In [14]:
queries = [
    ("There is a fair number of bright stars, both single and double, in Lepus.", "fair"),
    ("Using most or all of a work does not bar a finding of fair use.", "fair"),
    ("The rivalry has had its fair share of fights as well.", "fair")
]

In [15]:
layer_embs = model.extract_representation(queries, layer=[2,8])

In [16]:
layer_embs

[tensor([[ 1.2002,  1.0760,  0.6308,  ..., -0.6405, -1.1925, -1.4545],
         [ 1.2930,  1.4050,  0.1073,  ..., -0.0766, -1.3806, -1.1698],
         [ 1.2247,  1.1270,  0.4763,  ..., -0.0422, -1.3762, -1.2325]]),
 tensor([[ 0.8524,  0.1222,  0.1619,  ..., -0.7295,  0.1747,  0.0786],
         [ 0.8999, -0.2127,  0.1966,  ...,  0.2875,  0.1784, -0.4899],
         [ 0.6417, -0.6083, -0.1415,  ..., -0.2819, -0.4101, -0.2447]])]

In [17]:
layer_embs = model.extract_representation(queries, layer='all')

In [25]:
for layer in [0, 8, 11]:
    first_sim = cosine(layer_embs[layer], layer_embs[layer])[0]
    print(f"Layer {layer}: {first_sim}")

Layer 0: tensor([1.0000, 0.9374, 0.9580])
Layer 8: tensor([1.0000, 0.4051, 0.6927])
Layer 11: tensor([1.0000, 0.3901, 0.6451])
