In [2]:
import sys
sys.path.append('/Users/tunadorable/local-repos/next-concept-predictor/venv/lib/python3.11/site-packages')

In [3]:
import torch
import torch.nn as nn

In [4]:
class CosineNormalization(nn.Module):
    def __init__(self, dim):
        super(CosineNormalization, self).__init__()
        self.dim = dim

    def forward(self, x):
        norm = x.pow(2).sum(dim=self.dim, keepdim=True).sqrt()
        return x / norm

In [5]:
# Example usage
cosine_norm = CosineNormalization(dim=2)  # normalizing along dimension d in a (b,n,d)
input_tensor = torch.randn(2, 3, 2)  # Example input with b=2, n=3 and d=2
normalized_tensor = cosine_norm(input_tensor)

In [6]:
input_tensor

tensor([[[ 1.8320, -0.0399],
         [-1.9510,  0.6486],
         [-0.7931,  1.0168]],

        [[ 0.7166,  0.9704],
         [ 0.9085, -1.2551],
         [-0.3724, -0.1848]]])

In [7]:
normalized_tensor

tensor([[[ 0.9998, -0.0218],
         [-0.9489,  0.3155],
         [-0.6150,  0.7885]],

        [[ 0.5941,  0.8044],
         [ 0.5863, -0.8101],
         [-0.8957, -0.4446]]])

In [19]:
# now let's check C against E to find our token of interest
import torch.nn.functional as F

In [35]:
def cosine_similarity_batch(C, E):
    # Normalize the vectors in C and E to have unit length
    #C_norm = F.normalize(C, p=2, dim=1)
    #E_norm = F.normalize(E, p=2, dim=1)
    # actually they should be going in already normalized
    # however i do need to look into just using F.normalize instead of the function above

    # Compute cosine similarity
    similarity = torch.matmul(C, E.T)
    print("similarity shape: ", similarity.shape)
    print("similarity: ",similarity)
    
    # Find the indices of the most similar vectors
    most_similar_indices = torch.argmax(similarity, dim=1)
    print("most_similar_indices shape: ",most_similar_indices.shape)
    print("most_similar_indices: ",most_similar_indices)
    
    # Select the vectors from E based on these indices
    selected_vectors = E[most_similar_indices]

    return selected_vectors

In [None]:
# Example Usage
b, d, v = 2, 4, 10  # Example dimensions
C = torch.randn(b, d)  # Tensor of shape (b, d)
E = torch.randn(v, d)  # Embedding matrix of shape (v, d)

In [33]:
cos_norm_1 = CosineNormalization(dim=1)  # normalizing along dimension d in a (x,d)

C_norm = cos_norm_1(C)
E_norm = cos_norm_1(E)

In [39]:
result = cosine_similarity_batch(C_norm, E_norm)
print("result shape: ",result.shape)  # Should be (b, d)

similarity shape:  torch.Size([2, 10])
similarity:  tensor([[ 0.8218,  0.4293, -0.8940, -0.3390, -0.4694, -0.1859,  0.4158, -0.4712,
         -0.2731,  0.8189],
        [-0.8956, -0.4395,  0.6102, -0.2264,  0.0549,  0.0660,  0.0517,  0.1872,
          0.3604, -0.8554]])
most_similar_indices shape:  torch.Size([2])
most_similar_indices:  tensor([0, 2])
result shape:  torch.Size([2, 4])


In [25]:
print(C_norm)
print(E_norm)
print(result)

tensor([[ 0.7605, -0.0776,  0.6437, -0.0352],
        [-0.7925, -0.0571, -0.3291,  0.5102]])
tensor([[ 0.6790,  0.4184,  0.5071, -0.3268],
        [ 0.8471, -0.3412, -0.3651,  0.1809],
        [-0.6143,  0.4927, -0.6089, -0.0958],
        [-0.1163,  0.3717, -0.3900, -0.8344],
        [ 0.1721, -0.0535, -0.9525, -0.2454],
        [-0.6059, -0.3464,  0.3513, -0.6240],
        [ 0.1537,  0.3670,  0.5488,  0.7352],
        [-0.4630,  0.8167, -0.1044, -0.3282],
        [-0.3290,  0.8725,  0.0887,  0.3501],
        [ 0.9893, -0.0718,  0.0897, -0.0900]])
tensor([[ 0.6790,  0.4184,  0.5071, -0.3268],
        [-0.6143,  0.4927, -0.6089, -0.0958]])


In [42]:
# now let's define a "neighborhood" size around each token vector
# if C is not actually within one of these neighborhoods, meaning if it doesn't reach a 
#     critical level of cosine similarity with a vector in E, then it'll just return
#     the vector that was originally in C
def cosine_similarity_batch_with_threshold(C, E, similarity_threshold):

    # Compute cosine similarity
    similarity = torch.matmul(C_norm, E_norm.T)
    print("similarity shape: ", similarity.shape)
    print("similarity: ",similarity)

    # Find the indices and values of the most similar vectors
    most_similar_values, most_similar_indices = torch.max(similarity, dim=1)
    print("most_similar_indices shape: ",most_similar_indices.shape)
    print("most_similar_indices: ",most_similar_indices)
    print("most_similar_values shape: ",most_similar_values.shape)
    print("most_similar_values: ",most_similar_values)

    # Check if the most similar vector is above the threshold
    is_above_threshold = most_similar_values >= similarity_threshold

    # Select the vectors from E or keep the original from C based on the threshold check
    selected_vectors = torch.where(is_above_threshold.unsqueeze(1), E[most_similar_indices], C)

    return selected_vectors

In [43]:
# Example Usage (C, E, and C_norm and E_norm were defined earlier)
similarity_threshold = 0.8  # Define your similarity threshold

result = cosine_similarity_batch_with_threshold(C_norm, E_norm, similarity_threshold)
print("result shape: ",result.shape)  # Should be (b, d)

similarity shape:  torch.Size([2, 10])
similarity:  tensor([[ 0.8218,  0.4293, -0.8940, -0.3390, -0.4694, -0.1859,  0.4158, -0.4712,
         -0.2731,  0.8189],
        [-0.8956, -0.4395,  0.6102, -0.2264,  0.0549,  0.0660,  0.0517,  0.1872,
          0.3604, -0.8554]])
most_similar_indices shape:  torch.Size([2])
most_similar_indices:  tensor([0, 2])
most_similar_values shape:  torch.Size([2])
most_similar_values:  tensor([0.8218, 0.6102])
result shape:  torch.Size([2, 4])


In [41]:
print(C_norm)
print(E_norm)
print(result)
# we see here, at least in the example that's here while i'm writing this, that
#   one vector got replaced and another did not

tensor([[ 0.7605, -0.0776,  0.6437, -0.0352],
        [-0.7925, -0.0571, -0.3291,  0.5102]])
tensor([[ 0.6790,  0.4184,  0.5071, -0.3268],
        [ 0.8471, -0.3412, -0.3651,  0.1809],
        [-0.6143,  0.4927, -0.6089, -0.0958],
        [-0.1163,  0.3717, -0.3900, -0.8344],
        [ 0.1721, -0.0535, -0.9525, -0.2454],
        [-0.6059, -0.3464,  0.3513, -0.6240],
        [ 0.1537,  0.3670,  0.5488,  0.7352],
        [-0.4630,  0.8167, -0.1044, -0.3282],
        [-0.3290,  0.8725,  0.0887,  0.3501],
        [ 0.9893, -0.0718,  0.0897, -0.0900]])
tensor([[ 0.6790,  0.4184,  0.5071, -0.3268],
        [-0.7925, -0.0571, -0.3291,  0.5102]])
