In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import sys
import os
sys.path.append(os.path.abspath(".."))  # add parent folder to path
from utils import extract_token_hidden_states, setup_tokenizer, setup_model, get_device

def prepare_dataset(lang1, lang2, tokenizer):
    dataset = load_dataset("tatoeba", lang1=lang1, lang2=lang2)
    dataset = dataset['train'].to_pandas()
    dataset[lang1] = dataset['translation'].apply(lambda x: x[lang1])
    dataset[lang2] = dataset['translation'].apply(lambda x: x[lang2])
    dataset = dataset[[lang1, lang2]]
    dataset = dataset[dataset[lang1].str.len() >= 3]
    dataset = dataset[dataset[lang2].str.len() >= 3]
    dataset = dataset.sample(n=1000, random_state=2025).reset_index(drop=True)
    dataset[f'{lang1}_tokens'] = dataset[lang1].apply(tokenizer.tokenize)
    dataset[f'{lang2}_tokens'] = dataset[lang2].apply(tokenizer.tokenize)
    return dataset

def compute_crosslingual_cosine(hidden_en, hidden_ko, top_k=3):
    results = {}

    for layer in hidden_en:
        en_vecs = hidden_en[layer]  # shape: (N, D)
        ko_vecs = hidden_ko[layer]  # shape: (N, D)

        # Normalize to unit vectors for cosine similarity
        en_norm = F.normalize(en_vecs, p=2, dim=1)  # (N, D)
        ko_norm = F.normalize(ko_vecs, p=2, dim=1)  # (N, D)

        # Compute cosine similarity: (N x D) @ (D x N) = (N x N)
        sim_matrix = en_norm @ ko_norm.T  # (N, N)

        # For each English vector, get top-k most similar Korean vectors
        topk_values, topk_indices = torch.topk(sim_matrix, k=top_k, dim=1)  # (N, top_k)

        # Check if correct alignment exists in top-k (optional accuracy check)
        correct = torch.arange(sim_matrix.size(0)).to(topk_indices.device)
        hits = (topk_indices == correct.unsqueeze(1)).any(dim=1).float()  # 1 if correct in top-k

        results[layer] = {
            "similarity_matrix": sim_matrix,
            "topk_indices": topk_indices,
            "topk_values": topk_values,
            "topk_accuracy": hits.mean().item(),  # overall top-k accuracy
        }

    return results

In [None]:
import matplotlib.pyplot as plt

# Define language pairs
language_pairs = [("en", "ko"), ("de", "en"), ("de", "ko")]
models = ["Tower-Babel/Babel-9B-Chat", "google/gemma-3-12b-it", "meta-llama/Llama-2-7b-chat-hf"]

results_dict = {}

device = get_device()

for model_name in models:
    tokenizer = setup_tokenizer(model_name)
    model = setup_model(model_name, device=device)
    results_dict[model_name] = {}

    for lang1, lang2 in language_pairs:
        dataset = prepare_dataset(lang1, lang2, tokenizer)

        hidden_1 = extract_token_hidden_states(
            model=model,
            tokenizer=tokenizer,
            inputs=dataset[f'{lang1}_tokens'].tolist(),
            tokenizer_name=model_name,
            device=device
        )

        hidden_2 = extract_token_hidden_states(
            model=model,
            tokenizer=tokenizer,
            inputs=dataset[f'{lang2}_tokens'].tolist(),
            tokenizer_name=model_name,
            device=device
        )

        top_k = 3
        results = compute_crosslingual_cosine(hidden_1, hidden_2, top_k=top_k)

        layers = sorted(results.keys())
        accuracies = [results[l]['topk_accuracy'] for l in layers]

        results_dict[model_name][f"{lang1}-{lang2}"] = (layers, accuracies)

In [None]:
# pair_colors = {
#     "en-ko": "tab:blue",
#     "de-en": "tab:orange",
#     "de-ko": "tab:green",
# }
fig, axes = plt.subplots(1, len(models), figsize=(8, 10), sharex=False, sharey=True)
# fig, axes = plt.subplots(len(models), len(language_pairs), figsize=(25, 10), sharex=True, sharey=True)

# Ensure axes is iterable even if len(models)==1
if len(models) == 1:
    axes = [axes]

for col_idx, model_name in enumerate(models):
    ax = axes[col_idx]
    model_short = model_name.split("/")[-1]

    pairs_for_model = results_dict.get(model_name, {})
    if not pairs_for_model:
        ax.set_title(f"{model_short}\n(no results)")
        ax.axis('off')
        continue

    for pair, (layers, accs) in pairs_for_model.items():
        ax.plot(
            layers,
            accs,
            marker='o',
            label=pair,
            alpha=0.7
        )

    # Fit axes per column
    ax.relim()
    ax.autoscale(enable=True, axis='both', tight=True)
    ax.margins(x=0.02, y=0.05)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))  # integer layer ticks

    ax.set_title(model_short, fontsize=20)
    if col_idx == 0:
        ax.set_ylabel("Top-K Accuracy")
    ax.set_xlabel("Layer")
    ax.grid(True)
    ax.legend(title="Lang pair", fontsize=8)

plt.tight_layout()
plt.show()


In [None]:
                ax.plot(
                    layers,
                    values,
                    marker='o',
                    label=f"{src} → {tgt}",
                    alpha=0.7
                )
                ax.set_title(f"{model_name.split('/')[-1]}: {pair_group[0][0]}↔{pair_group[0][1]}", fontsize=20)
                ax.grid(True)

        # Add labels
        if i == len(models) - 1:
            ax.set_xlabel("Layer", fontsize=18)
        if j == 0:
            ax.set_ylabel(metric, fontsize=18)

        # Add legend
        ax.grid(True)
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.legend(title="Language Pair", fontsize=14)

# Adjust layout and show plot
plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Adjust layout for title
plt.show()

In [None]:
import sys
import os
# Add the RQ1 directory to the path
sys.path.append(os.path.abspath("../"))
from classification import WordNonwordClassifier

# model_name = "Tower-Babel/Babel-9B-Chat"
model_name = "google/gemma-3-12b-it"
# model_name = "google/gemma-3-12b-pt"
# model_name = "meta-llama/Llama-2-7b-chat-hf"
word_nonword_cls = WordNonwordClassifier("English", model_name) # language is not used in the model name, but it is required by the class

In [None]:
lang1 = "en"
lang2 = "ko"

dataset_en_ko = prepare_dataset(lang1, lang2, word_nonword_cls.tokenizer)

hidden_1 = word_nonword_cls.extract_token_i_hidden_states(dataset_en_ko[f'{lang1}_tokens'].tolist())
hidden_2 = word_nonword_cls.extract_token_i_hidden_states(dataset_en_ko[f'{lang2}_tokens'].tolist())
torch.save(hidden_1, f"/home/hyujang/multilingual-inner-lexicon/data/RQ1/TatoebaHiddens/hidden_{model_name.split("/")[-1]}_{lang1}_1.pt")
torch.save(hidden_2, f"/home/hyujang/multilingual-inner-lexicon/data/RQ1/TatoebaHiddens/hidden_{model_name.split("/")[-1]}_{lang2}_1.pt")

top_k = 3
results = compute_crosslingual_cosine(hidden_1, hidden_2, top_k=top_k)

layers = sorted(results.keys())
accuracies = [results[l]['topk_accuracy'] for l in layers]
plt.plot(layers, accuracies)
plt.xlabel("Layer")
plt.ylabel(f"Top-{top_k} Accuracy")
plt.title(f"Cross-Lingual Alignment over Layers ({lang1}-{lang2})")
plt.grid(True)
plt.show()

In [None]:
lang1 = "de"
lang2 = "en"

dataset_de_en = prepare_dataset(lang1, lang2, word_nonword_cls.tokenizer)

hidden_1 = word_nonword_cls.extract_token_i_hidden_states(dataset_de_en[f'{lang1}_tokens'].tolist())
hidden_2 = word_nonword_cls.extract_token_i_hidden_states(dataset_de_en[f'{lang2}_tokens'].tolist())

top_k = 3
results = compute_crosslingual_cosine(hidden_1, hidden_2, top_k=top_k)

layers = sorted(results.keys())
accuracies = [results[l]['topk_accuracy'] for l in layers]
plt.plot(layers, accuracies)
plt.xlabel("Layer")
plt.ylabel(f"Top-{top_k} Accuracy")
plt.title(f"Cross-Lingual Alignment over Layers ({lang1}-{lang2})")
plt.grid(True)
plt.show()

In [None]:
lang1 = "de"
lang2 = "ko"

dataset_de_ko = prepare_dataset(lang1, lang2, word_nonword_cls.tokenizer)

hidden_1 = word_nonword_cls.extract_token_i_hidden_states(dataset_de_ko[f'{lang1}_tokens'].tolist())
hidden_2 = word_nonword_cls.extract_token_i_hidden_states(dataset_de_ko[f'{lang2}_tokens'].tolist())

top_k = 3
results = compute_crosslingual_cosine(hidden_1, hidden_2, top_k=top_k)

layers = sorted(results.keys())
accuracies = [results[l]['topk_accuracy'] for l in layers]
plt.plot(layers, accuracies)
plt.xlabel("Layer")
plt.ylabel(f"Top-{top_k} Accuracy")
plt.title(f"Cross-Lingual Alignment over Layers ({lang1}-{lang2})")
plt.grid(True)
plt.show()