In [1]:
%env RANK=0
%env WORLD_SIZE=1
%env MASTER_ADDR=127.0.0.1
%env MASTER_PORT=2020

env: RANK=0
env: WORLD_SIZE=1
env: MASTER_ADDR=127.0.0.1
env: MASTER_PORT=2020


In [2]:
llama_checkpoint_dir = "modified_llama/llama-2-7b"
tokenizer_path = "modified_llama/tokenizer.model"
compression_checkpoint_dir = "cv_library/attention_model.pt"
max_seq_len = 128
max_batch_size = 4

In [3]:
content_string = "Neptune is the eighth and farthest known planet from the Sun."
query_string = "Neptune is the Roman god of freshwater and the sea in Roman religion."

In [4]:
from modified_llama.llama import Llama
import torch

# Create the Llama generator
print("Building generator...")
generator = Llama.build(
    ckpt_dir=llama_checkpoint_dir,
    tokenizer_path=tokenizer_path,
    max_seq_len=max_seq_len,
    max_batch_size=max_batch_size,
)
print("Built generator!")

# Tokenize the content and query, and generate context vectors for them
content_tokens = generator.tokenize(max_seq_len, [("", content_string)])
content_tokens = [l for _, l in content_tokens]
query_tokens = generator.tokenize(max_seq_len, [("", query_string)])
query_tokens = [l for _, l in query_tokens]

content_cvs = generator.generate(content_tokens, max_gen_len=len(content_tokens[0]))
query_cvs = generator.generate(query_tokens, max_gen_len=len(query_tokens[0]))

Building generator...
> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 11.72 seconds
Built generator!


In [5]:
del generator

In [6]:
from cv_library.hierarchical_compression import HierarchicalAttention
from cv_library.loss_functions import sequence_similarity
from pathlib import Path
import torch.nn.functional as F

# Load the compression network
torch.set_default_dtype(torch.float32)
with torch.device("cuda"):
    compression_network = HierarchicalAttention(content_cvs.shape)
    checkpoint_path = Path(compression_checkpoint_dir)
    if checkpoint_path.is_file():
        checkpoint = torch.load(checkpoint_path)
        compression_network.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("ERROR: Path provided for hierarchical compression model not a file")
        exit(1)
    compressed_content = compression_network.forward(content_cvs.clone().to(torch.float32))

    compression_network = HierarchicalAttention(query_cvs.shape)
    checkpoint_path = Path(compression_checkpoint_dir)
    if checkpoint_path.is_file():
        checkpoint = torch.load(checkpoint_path)
        compression_network.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("ERROR: Path provided for hierarchical compression model not a file")
        exit(1)
    compressed_query = compression_network.forward(query_cvs.clone().to(torch.float32))

    print(content_string)
    print(query_string)
    for cc, cq in zip(compressed_content, compressed_query):
        vector_size = cc.shape[-1]
        sim_score = sequence_similarity(cc, cq)
        #sim_score = torch.dot(F.normalize(cc.squeeze(), dim=-1), F.normalize(cq.squeeze(), dim=-1))
        print(f"{vector_size}: {sim_score.item()}")

Layer #1 output size: 512
Layer #2 output size: 64
Layer #3 output size: 8
Layer #1 output size: 512
Layer #2 output size: 64
Layer #3 output size: 8
Neptune is the eighth and farthest known planet from the Sun.
Neptune is the Roman god of freshwater and the sea in Roman religion.
512: 0.8026992082595825
64: 0.34545978903770447
8: 0.5567335486412048
