In [None]:
! pip3 install code-lens@git+https://github.com/cisnlp/code-lens

### Import 

In [None]:
import os 

# Change to your own gpu ids
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm

# fix random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from code_lens import LlamaHelper
from code_lens import generate_heatmap
from code_lens import visualize_heatmap

In [None]:
# Change to your own token, model, and cache path

hf_token = 'hf_XxxxXXXxxXxxXXxxxxXXxXXxxXxXXxxxxX'
custom_model = "codellama/CodeLlama-7b-hf"
cache_directory = './transformers_cache/'
load_in_8bit = True # False

if custom_model is not None:
    model = LlamaHelper(dir=custom_model, load_in_8bit=load_in_8bit, hf_token=hf_token,cache_directory=cache_directory)
    tokenizer = model.tokenizer

### Example

In [None]:
java_snippets = [
    'String message = "Hello";',
    'public class MyClass {}',
    'public int value = 5;',
    'public void doSomething() {}',
    'int result = add(3, 5);',
    'for (int i = 0; i < 10; i++)',
    'if (x > 5) { /* ... */ }',
    'try { /* ... */ } catch (Exception e) { /* ... */ }',
    'System.out.println("Hello");'
]

rust_snippets = [
    'let message = "Hello";',
    'struct MyClass {}',
    'let value: i32 = 5;',
    'fn do_something() {}',
    'let result = add(3, 5);',
    'for i in 0..10',
    'if x > 5 { /* ... */ }',
    'match result { Ok(value) => { /* ... */ }, Err(e) => { /* ... */ } }',
    'println!("Hello");'
]

# Initialize an empty prompt string
prompt = ""

# Loop through both lists and add each Java-Rust pair to the prompt
for java, rust in zip(java_snippets, rust_snippets):
    prompt += f'Java: {java} - Rust: {rust}\n'

prompt = prompt.strip()

# Print the merged result
print(prompt)

In [None]:
num_beams = 3
max_length = 2

layers = [10, 20, 21, 25, 31]
min_position = 123
max_position = 134

heatmap_data = generate_heatmap(model=model, tokenizer=tokenizer, device=device, text=prompt, layers = layers, num_beams=num_beams, max_length=max_length, min_position = min_position, max_position = max_position)

In [None]:
layers_to_show = heatmap_data['layers']
token_indices_to_show = range(min(min_position, len(heatmap_data['tokens'])), max(max_position, 0))
visualize_heatmap(heatmap_data, layers_to_show, token_indices_to_show, trunc_size = 6)