In [None]:
source_lang = 'C#'
targets = ['PHP', 'Java', 'Javascript', 'Python', 'C++', 'C', 'C#']

### Import and Load Model

In [None]:
import os

# Change to your own gpu ids
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import numpy as np
from tqdm import tqdm
import json
import torch
import gc

# fix random seed
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from .llamawrapper import LlamaHelper
from .utils import generate_heatmap

In [None]:
# Change to your own token, model, and cache path

hf_token = "hf_xxXxXxXXXXxxxxxXXxxxxxXXXXXXXxXXxx"
# custom_model = "codellama/CodeLlama-7b-hf"
custom_model = "meta-llama/Llama-3.1-8B"

cache_directory = './transformers_cache/'
load_in_8bit = False

if custom_model is not None:
    model = LlamaHelper(dir=custom_model, device=device, load_in_8bit=load_in_8bit, hf_token=hf_token,cache_directory=cache_directory)
    tokenizer = model.tokenizer

### Compute Lens for Language

In [None]:
num_beams = 10
max_length = 1
layers = list(range(0, 32))

parallel_path = '../datasets/parallel/code_snippets'
prompt_path = '../datasets/parallel/prompts'

In [None]:
def get_path(lang, postfix):
    return f'{prompt_path}/{lang.lower()}_{postfix}.txt'

def find_last_index(lst, value):
    return len(lst) - 1 - lst[::-1].index(value)

def save(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=4, ensure_ascii=True)

In [None]:
def prompt(source_snippets, target_snippets, source_lang, target_lang):
    
    # Initialize an empty prompt string
    prompt = ""

    # Loop through both lists and add each Java-Rust pair to the prompt
    for s, t in zip(source_snippets, target_snippets):
        prompt += f'{source_lang}: {s} - {target_lang}: {t}\n'

    prompt = prompt.strip()

    # Return the merged result
    return prompt


def test(intext, soruce_text):
    
    min_position = len(tokenizer.tokenize(soruce_text))
    max_position = len(tokenizer.tokenize(intext))
    
        
    heatmap_data = generate_heatmap(model=model,
                                    tokenizer=tokenizer,
                                    device=device,
                                    text=intext,
                                    layers=layers,
                                    num_beams=num_beams,
                                    max_length=max_length,
                                    min_position=min_position,
                                    max_position=max_position,
                                    batch_size=1)

    return heatmap_data 

In [None]:
# Main

max_num_parallel_sent = 100

source_path = get_path(source_lang, 'fewshot')

# Load the file to verify
with open(source_path, 'r') as file:
    source_snippets = [l.strip() for l in file.readlines()]

for target_lang in targets:
    
    if target_lang == source_lang:
        continue
        
    target_path = get_path(target_lang, 'fewshot')

    with open(target_path, 'r') as file:
        target_snippets = [l.strip() for l in file.readlines()]
    
    
    output_path = f'./outputs-{custom_model.replace('/', '-')}/{source_lang.lower()}-{target_lang.lower()}'

    # Check if the directory exists, create it if it doesn't
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    
    p = prompt(source_snippets, target_snippets, source_lang, target_lang)
    
    
    for file_id in tqdm(os.listdir(f'{parallel_path}/{source_lang}')[:max_num_parallel_sent]):
        
        if os.path.exists(os.path.join(output_path, file_id)):
            continue
        
        source_test = [json.load(open(f'{parallel_path}/{source_lang}/{file_id}'))['snippet'].strip()]
        target_test = [json.load(open(f'{parallel_path}/{target_lang}/{file_id}'))['snippet'].strip()]
        p_test = prompt(source_test, target_test, source_lang, target_lang)

        
        intext = p + '\n' + p_test

        
        last_index = intext.rfind(target_lang) + len(target_lang + ':') + 1

        try:
            with torch.no_grad():
                heatmap_data = test(intext, intext[:last_index])

            save(heatmap_data, os.path.join(output_path, file_id))
            del heatmap_data

        except:
            torch.cuda.empty_cache()
            gc.collect()
            gc.collect()
        
        
        torch.cuda.empty_cache()
        gc.collect()
        gc.collect()