In [1]:
import logging
import time

from munch import Munch
import numpy as np
import torch
import yaml

from lcn.embedding_alignment import (
    convex_init,
    align,
    align_original,
)
from lcn.alignment_utils import (
    compute_accuracy,
    compute_csls,
    compute_nn,
    load_lexicon,
    load_vectors,
    pickle_cache,
    refine,
)

In [2]:
# Set up logging
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

# Configuration

In [3]:
with open('configs/embedding_alignment.yaml', 'r') as c:
    config_seml = yaml.safe_load(c)

In [4]:
config = Munch(config_seml['fixed'])

language_src = 'en'
language_tgt = 'es'
data_dir = "./data"  # Download the data first, as described in the README

method = 'lcn'  # Change this for other methods: original, full, nystrom, multiscale, sparse, lcn
config.update(config_seml["from_en"][method]["fixed"])
config.nystrom = None if config.nystrom == "None" else config.nystrom
config.sparse = None if config.sparse == "None" else config.sparse
original = method == "original"

seed = 1111
test = False  # Change to run on the test set
device = "cuda"  # Change to cpu if necessary

# Load data

In [5]:
model_src = f"{data_dir}/wiki.{language_src}.vec"
model_tgt = f"{data_dir}/wiki.{language_tgt}.vec"
if test:
    val_numbers = "5000-6500"
else:
    val_numbers = "0-5000"
lexicon = f"{data_dir}/{language_src}-{language_tgt}.{val_numbers}.txt"

sinkhorn_reg = torch.tensor(config.sinkhorn_reg, device=device)

logging.info("*** Wasserstein Procrustes ***")

np.random.seed(seed)
torch.manual_seed(seed)

maxload = 200_000
w_src, x_src = pickle_cache(
    f"{model_src}.pkl",
    load_vectors,
    [model_src, maxload],
    dict(norm=True, center=True),
)
w_tgt, x_tgt = pickle_cache(
    f"{model_tgt}.pkl",
    load_vectors,
    [model_tgt, maxload],
    dict(norm=True, center=True),
)
src2tgt, _ = load_lexicon(lexicon, w_src, w_tgt)

x_src = torch.tensor(x_src, dtype=torch.float).to(device)
x_tgt = torch.tensor(x_tgt, dtype=torch.float).to(device)

2021-07-12 01:24:51 (INFO): *** Wasserstein Procrustes ***


# Convex initialization

In [6]:
torch.cuda.synchronize()
t0 = time.time()
R0 = convex_init(
    x_src[:config.ninit], x_tgt[:config.ninit], sinkhorn_reg=sinkhorn_reg, apply_sqrt=True
)
torch.cuda.synchronize()
logging.info(f"Done [{time.time() - t0:.1f} sec]")

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

# Main part: Wasserstein Procrustes

In [None]:
config.nystrom

In [None]:
torch.cuda.synchronize()
t0 = time.time()
if original:
    R = align_original(x_src, x_tgt, R0.clone(), sinkhorn_reg)
else:
    R = align(
        x_src,
        x_tgt,
        R0.clone(),
        sinkhorn_reg=sinkhorn_reg,
        nystrom=config.nystrom,
        sparse=config.sparse,
        lr=config.lr,
        niter=config.niter,
        lr_half_niter=config.lr_half_niter,
        ntrain=config.ntrain,
        print_niter=config.print_niter,
    )
torch.cuda.synchronize()
runtime = time.time() - t0
logging.info(f"Done [{runtime:.1f} sec]")

HBox(children=(FloatProgress(value=0.0, max=300.0), HTML(value='')))

2021-07-12 01:23:52 (INFO): Done [51.3 sec]





# Compute accuracies

In [None]:
if test:
    logging.info("Evaluation on test set")
else:
    logging.info("Evaluation on validation set")

x_tgt_rot = x_tgt @ R.T

acc_nn = compute_accuracy(x_src, x_tgt_rot, src2tgt, compute_nn)
logging.info(f"NN precision@1: {100 * acc_nn:.2f}%")

acc_csls = compute_accuracy(x_src, x_tgt_rot, src2tgt, compute_csls)
logging.info(f"CSLS precision@1: {100 * acc_csls:.2f}%")

2021-07-12 01:23:53 (INFO): NN precision@1: 78.96%
2021-07-12 01:24:01 (INFO): CSLS precision@1: 82.24%


# Refine embeddings

In [None]:
x_tgt_rot = refine(x_src, x_tgt_rot, src2tgt)

acc_refined = compute_accuracy(x_src, x_tgt_rot, src2tgt, compute_csls)
logging.info(f"Refined CSLS precision@1: {100 * acc_refined:.2f}%")

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

KeyboardInterrupt: 