In [1]:
!pip install einops

Collecting einops
  Obtaining dependency information for einops from https://files.pythonhosted.org/packages/29/0b/2d1c0ebfd092e25935b86509a9a817159212d82aa43d7fb07eca4eeff2c2/einops-0.7.0-py3-none-any.whl.metadata
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [2]:
import random
import torch
import copy

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

# Import Models

In [3]:
device = "cpu"
torch.set_default_device(device)

In [4]:
phi = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
tokenizer_phi = AutoTokenizer.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)
config_phi = AutoConfig.from_pretrained("microsoft/phi-1_5", trust_remote_code=True)

phi = phi.eval()
phi.zero_grad()

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

configuration_phi.py:   0%|          | 0.00/2.03k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-1_5:
- configuration_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi.py:   0%|          | 0.00/33.5k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-1_5:
- modeling_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


pytorch_model.bin:   0%|          | 0.00/2.84G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/69.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

In [5]:
gpt = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer_gpt = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
config_gpt = AutoConfig.from_pretrained("gpt2", trust_remote_code=True)

gpt.eval()
gpt.zero_grad()

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
phi_emb = phi.get_input_embeddings()
gpt_emb = gpt.get_input_embeddings()

# Define Helper Functions

In [7]:
def multiencode(tok, words, return_tensors="pt"):
    if (isinstance(words, list) or isinstance(words, tuple)) and not isinstance(words, str):
        return torch.cat([tok.encode(word, return_tensors="pt") for word in words], dim=-1)
    else:
        return tok.encode(words, return_tensors="pt")

In [8]:
def get_closest_emb(emb, word, k=1, decode=True, tok=None, avg=False):
    # If input is a string tokenize it
    if (isinstance(word, str) or isinstance(word[0], str)) and tok is not None:
        word = emb(multiencode(tok, word))
    # Calculate average if flag and needed
    if word.shape[1] != 1 and avg:
        word = torch.unsqueeze(torch.mean(word, dim=1), dim=1)
    # Compute normalized distances
    distances = torch.norm(emb.weight.data - word, dim=2)
    # Compute top k smalles indices
    topk = torch.squeeze(torch.topk(distances, k=k, largest=False).indices)
    # If one element, unsqueeze it
    if k == 1:
        topk = torch.unsqueeze(topk, dim=0)
    # Decode closest k
    if decode and tok is not None:
        topk = [tok.decode(c) for c in topk.tolist()]
    return topk

def emb_arithmetic(emb, tok, words, k=1, avg=False, norm=False):
    # Convert all words into tokens
    words_emb = [multiencode(tok, word) for word in words]
    # Convert all tokens into embeddings
    for i, word in enumerate(words_emb):
        if word.shape[1] == 1:
            words_emb[i] = emb(word)
        elif avg:
            words_emb[i] = torch.unsqueeze(torch.mean(emb(word), dim=1), dim=1)
        else:
            raise Exception(f"{words[i]} is not a single token: {word}")
    # Compute embeddings
    w1 = words_emb[0]
    w2 = words_emb[1]
    w3 = words_emb[2]
    # Do embedding arithmetic
    if norm:
        w = torch.nn.functional.normalize(w1 - w2 + w3, dim=0)
    else:
        w = w1 - w2 + w3
    # Get closest k
    closest = get_closest_emb(emb, w, k=k, decode=True, tok=tok)
    return (w, closest)

def print_results(res):
     for i, r in enumerate(res):
        print(f"{i+1}) {r}")

In [9]:
def batch_emb_arithmetic(emb, tok, queries, k=5, avg=False, norm=False, out=True):
    ret = []
    for q in queries:
        if out:
            print("##########################")
            # Print title
            title_q = q
            if not isinstance(q[0], str):
                title_q = [qq[0] for qq in q]
            print(f"{title_q[0]} - {title_q[1]} + {title_q[2]} =")
        # Compute and print results
        res = emb_arithmetic(emb, tok, q, k=k, avg=avg, norm=norm)
        if out:
            print_results(res[1])
        ret.append(res)
    return ret

def evaluate_batch(results, solutions, out=True):
    
    def get_rank(r, s, out=0):
        try:
            return r.index(s)
        except ValueError:
            return out
    
    k = len(results[0][1])
    ev = []
    for res, sol in zip(list(map(lambda x: x[1], results)), solutions):
        # Get rank of each solution for each result outputs
        ranks = [get_rank(res, s, out=k) for s in sol]
        # Append best rank to final evaluation list
        ev.append(min(ranks))
    # Return score
    score = 1 - ( sum(ev) / (k * len(solutions)) )
    if out:
        print(f"{ev} -> {score}")
    return score

# Define Test Cases

### Capital Arithmetic

In [10]:
capital_arithmetic = [
    ["Rome Rome", "Italy Italy", "France France"],
    ["Rome Rome", "Italy Italy", "Australia Australia"],
    ["Paris Paris", "France France", "Italy Italy"],
    ["Paris Paris", "France France", "Australia Australia"],
    ["Canberra Canberra", "Australia Australia", "Italy Italy"],
    ["Canberra Canberra", "Australia Australia", "France France"],
]
capital_single = [
    [" Rome", " Italy", " France"],
    [" Rome", " Italy", " Australia"],
    [" Paris", " France", " Italy"],
    [" Paris", " France", " Australia"],
    [" Canberra", " Australia", " Italy"],
    [" Canberra", " Australia", " France"],
]
capital_arithmetic_lowercase = [
    ["rome rome", "italy italy", "france france"],
    ["rome rome", "italy italy", "australia australia"],
    ["paris paris", "france france", "italy italy"],
    ["paris paris", "france france", "australia australia"],
    ["canberra canberra", "australia australia", "italy italy"],
    ["canberra canberra", "australia australia", "france france"],
]
capital_sol = [
    ["Paris", " Paris", "paris", " paris"],
    ["Canberra", " Canberra", "canberra", " canberra"],
    ["Rome", " Rome", "rome", " rome"],
    ["Canberra", " Canberra", "canberra", " canberra"],
    ["Rome", " Rome", "rome", " rome"],
    ["Paris", " Paris", "paris", " paris"],
]

### Sex Arithmetic

In [11]:
sex_arithmetic = [
    ["King King", "Man Man", "Woman Woman"],
    ["Queen Queen", "Woman Woman", "Man Man"],
    ["Prince Prince", "Man Man", "Woman Woman"],
    ["Princess Princess", "Woman Woman", "Man Man"],
    ["Priest Priest", "Man Man", "Woman Woman"],
    ["Nun Nun", "Woman Woman", "Man Man"],
]
sex_arithmetic_lowercase_single = [
    [" king", " man", " woman"],
    [" queen", " woman", " man"],
    [" prince", " man", " woman"],
    [" princess", " woman", " man"],
    [" priest", " man", " woman"],
    [" nun", " woman", " man"],
]
sex_sol = [
    ["Queen", " Queen", "queen", " queen"],
    ["King", " King", "king", " king"],
    ["Princess", " Princess", "princess", " princess"],
    ["Prince", " Prince", "prince", " prince"],
    ["Nun", " Nun", "nun", " nun"],
    ["Priest", " Priest", "priest", " priest"],
]

# Calculate and Display Test Results

In [12]:
test_capital = [
    ["Phi-1.5", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_arithmetic, k=100, avg=True, norm=False, out=False)],
    ["Phi-1.5 Normalized", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_arithmetic, k=100, avg=True, norm=True, out=False)],
    ["Phi-1.5 Lowercase", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_arithmetic_lowercase, k=100, avg=True, norm=False, out=False)],
    ["Phi-1.5 Normalized Lowercase", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_arithmetic_lowercase, k=100, avg=True, norm=True, out=False)],
    ["Phi-1.5 Single", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_single, k=100, avg=True, norm=False, out=False)],
    ["Phi-1.5 Normalized Single", batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_single, k=100, avg=True, norm=True, out=False)],
    ["GPT-2", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_arithmetic, k=100, avg=True, norm=False, out=False)],
    ["GPT-2 Normalized", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_arithmetic, k=100, avg=True, norm=True, out=False)],
    ["GPT-2 Lowercase", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_arithmetic_lowercase, k=100, avg=True, norm=False, out=False)],
    ["GPT-2 Normalized Lowercase", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_arithmetic_lowercase, k=100, avg=True, norm=True, out=False)],
    ["GPT-2 Single", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_single, k=100, avg=True, norm=False, out=False)],
    ["GPT-2 Normalized Single", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_single, k=100, avg=True, norm=True, out=False)],
]
test_sex = [
    ["Phi-1.5", batch_emb_arithmetic(phi_emb, tokenizer_phi, sex_arithmetic, k=100, avg=True, norm=False, out=False)],
    ["Phi-1.5 Normalized", batch_emb_arithmetic(phi_emb, tokenizer_phi, sex_arithmetic, k=100, avg=True, norm=True, out=False)],
    ["Phi-1.5 Lowercase Single", batch_emb_arithmetic(phi_emb, tokenizer_phi, sex_arithmetic_lowercase_single, k=100, avg=True, norm=False, out=False)],
    ["Phi-1.5 Normalized Lowercase Single", batch_emb_arithmetic(phi_emb, tokenizer_phi, sex_arithmetic_lowercase_single, k=100, avg=True, norm=True, out=False)],
    ["GPT-2", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, sex_arithmetic, k=100, avg=True, norm=False, out=False)],
    ["GPT-2 Normalized", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, sex_arithmetic, k=100, avg=True, norm=True, out=False)],
    ["GPT-2 Lowercase Single", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, sex_arithmetic_lowercase_single, k=100, avg=True, norm=False, out=False)],
    ["GPT-2 Normalized Lowercase Single", batch_emb_arithmetic(gpt_emb, tokenizer_gpt, sex_arithmetic_lowercase_single, k=100, avg=True, norm=True, out=False)],
]

In [13]:
print("Capital rankings")
for name, result in test_capital:
    print("---------------------------")
    print(name)
    evaluate_batch(result, capital_sol)

Capital rankings
---------------------------
Phi-1.5
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Normalized
[14, 100, 5, 21, 9, 6] -> 0.7416666666666667
---------------------------
Phi-1.5 Lowercase
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Normalized Lowercase
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Single
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Normalized Single
[3, 25, 13, 24, 6, 4] -> 0.875
---------------------------
GPT-2
[2, 67, 7, 11, 64, 4] -> 0.7416666666666667
---------------------------
GPT-2 Normalized
[4, 3, 15, 1, 12, 4] -> 0.935
---------------------------
GPT-2 Lowercase
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
GPT-2 Normalized Lowercase
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
GPT-2 Single
[2, 5, 3, 12, 4, 4] -> 0.95
---------------------------
GPT-2 Normalized Single
[3, 2, 11, 3, 9, 2] -> 

In [14]:
print("Sex rankings")
for name, result in test_sex:
    print("---------------------------")
    print(name)
    evaluate_batch(result, sex_sol)

Sex rankings
---------------------------
Phi-1.5
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Normalized
[6, 7, 5, 24, 100, 100] -> 0.5966666666666667
---------------------------
Phi-1.5 Lowercase Single
[100, 100, 100, 100, 100, 100] -> 0.0
---------------------------
Phi-1.5 Normalized Lowercase Single
[4, 3, 2, 4, 16, 7] -> 0.94
---------------------------
GPT-2
[3, 2, 3, 4, 100, 100] -> 0.6466666666666667
---------------------------
GPT-2 Normalized
[2, 2, 5, 3, 29, 100] -> 0.765
---------------------------
GPT-2 Lowercase Single
[1, 1, 1, 1, 7, 3] -> 0.9766666666666667
---------------------------
GPT-2 Normalized Lowercase Single
[0, 2, 1, 1, 3, 8] -> 0.975


# Visualizations 

In [15]:
_ = batch_emb_arithmetic(gpt_emb, tokenizer_gpt, capital_arithmetic, k=10, avg=True, norm=True)

##########################
Rome Rome - Italy Italy + France France =
1) France
2) ome
3)  France
4) French
5) Paris
6)  Paris
7)  French
8)  Frenchman
9) ô
10)  Lyon
##########################
Rome Rome - Italy Italy + Australia Australia =
1) Australia
2)  Australia
3)  Sydney
4)  Canberra
5) Australian
6)  Australian
7)  Australians
8)  Perth
9)  Aboriginal
10)  Melbourne
##########################
Paris Paris - France France + Italy Italy =
1) Paris
2) Italy
3)  Paris
4)  Milan
5) Italian
6)  Italy
7)  Lisbon
8)  Budapest
9)  Italians
10)  Venice
##########################
Paris Paris - France France + Australia Australia =
1) Australia
2)  Canberra
3)  Australians
4) Austral
5)  Australia
6) Paris
7)  Sydney
8) Australian
9)  Melbourne
10)  Australian
##########################
Canberra Canberra - Australia Australia + Italy Italy =
1) Italy
2)  Italy
3)  Italians
4) berra
5) Italian
6)  Sicily
7)  Naples
8)  Serie
9)  Italian
10)  Juventus
##########################
Canberra Canbe

In [16]:
_ = batch_emb_arithmetic(phi_emb, tokenizer_phi, capital_arithmetic, k=10, avg=True, norm=True)

##########################
Rome Rome - Italy Italy + France France =
1)  France
2) R
3) France
4) ome
5)  Rome
6)  R
7)  French
8) M
9) r
10) F
##########################
Rome Rome - Italy Italy + Australia Australia =
1)  Australia
2) Australia
3) R
4)  Australian
5) ome
6)  R
7)  Australians
8)  Rome
9) r
10) S
##########################
Paris Paris - France France + Italy Italy =
1)  Paris
2)  Italy
3) Paris
4) Italy
5)  Italian
6)  Rome
7) Italian
8)  Milan
9)  Lisbon
10)  Tokyo
##########################
Paris Paris - France France + Australia Australia =
1)  Australia
2)  Paris
3) Australia
4) Paris
5)  Australian
6)  Sydney
7)  Australians
8)  Melbourne
9) Australian
10)  Canada
##########################
Canberra Canberra - Australia Australia + Italy Italy =
1)  Italy
2) Can
3) Italy
4)  Italian
5) Italian
6)  Can
7)  Canberra
8)  Italians
9) can
10)  Rome
##########################
Canberra Canberra - Australia Australia + France France =
1)  France
2) France
3) Can
4)  Frenc

In [17]:
_ = batch_emb_arithmetic(gpt_emb, tokenizer_gpt, sex_arithmetic_lowercase_single, k=10, avg=True, norm=True)

##########################
 king -  man +  woman =
1)  queen
2)  king
3) Queen
4)  kings
5)  Queen
6)  princess
7)  queens
8)  pope
9)  King
10) King
##########################
 queen -  woman +  man =
1)  queen
2)  queens
3)  king
4)  kings
5)  Queen
6) Queen
7)  emperor
8)  rook
9) King
10)  monarch
##########################
 prince -  man +  woman =
1)  prince
2)  princess
3)  princes
4) Prince
5)  Princess
6)  queen
7)  Prince
8) Prin
9)  Duchess
10)  heroine
##########################
 princess -  woman +  man =
1)  princess
2)  prince
3)  princes
4)  Princess
5) Prin
6)  Prince
7) Prince
8)  king
9)  knight
10)  Cinderella
##########################
 priest -  man +  woman =
1)  priest
2)  priests
3)  Priest
4)  nun
5)  clergy
6)  prostitute
7)  nuns
8)  bishop
9) Catholic
10)  priesthood
##########################
 nun -  woman +  man =
1)  nun
2)  nuns
3)  monk
4)  convent
5)  pope
6)  monks
7)  Jesuit
8)  monastery
9)  priest
10)  Archbishop


In [18]:
res = get_closest_emb(phi_emb, " nun", k=10, tok=tokenizer_phi, avg=True)
print_results(res)

1)  nun
2)  Skydragon
3) wcsstore
4)  "$:/
5) �
6)  guiName
7)  裏�
8)  externalTo
9) ertodd
10)  Smartstocks


In [19]:
res = get_closest_emb(gpt_emb, " nun", k=10, tok=tokenizer_gpt, avg=True)
print_results(res)

1)  nun
2)  nuns
3)  convent
4)  priest
5)  monk
6)  monastery
7) �
8) 
9) 
10) 
