In [3]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

In [6]:
# Load the lightweight BERT model and its tokenizer
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-mini")
model = AutoModel.from_pretrained("prajjwal1/bert-mini")
model.eval()


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 256, padding_idx=0)
    (position_embeddings): Embedding(512, 256)
    (token_type_embeddings): Embedding(2, 256)
    (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-3): 4 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=256, out_features=256, bias=True)
            (key): Linear(in_features=256, out_features=256, bias=True)
            (value): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=256, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


# Task 1

In [7]:

# Dictionary to cache computed word embeddings
embedding_cache = {}

def get_embedding(word):
    """
    Given a word, computes its embedding using the BERT model.
    If the word tokenizes into multiple tokens, their embeddings are averaged.
    The result is cached for efficiency.
    """
    # Use lower-case since we are using an uncased model.
    word = word.lower()
    if word in embedding_cache:
        return embedding_cache[word]
    with torch.no_grad():
        # Tokenize without adding special tokens so that only the word’s sub–tokens are processed
        inputs = tokenizer(word, add_special_tokens=False, return_tensors="pt")
        outputs = model(**inputs)
        # outputs.last_hidden_state has shape (1, sequence_length, hidden_size)
        token_embeds = outputs.last_hidden_state.squeeze(0)  # shape: (sequence_length, hidden_size)
        # If the word is split into multiple tokens, average the token embeddings
        if token_embeds.dim() == 1:
            embed = token_embeds
        else:
            embed = token_embeds.mean(dim=0)
        embedding_cache[word] = embed
        return embed

def read_analogy_file(file_path):
    """
    Reads the analogy file and splits it into groups.
    Each group starts with a line beginning with a colon (":") specifying the group name.
    The following lines (until the next group header) are assumed to be analogies
    in the format: a b c d
    """
    groups = {}
    current_group = None
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # If the line starts with ":" it indicates a new group.
            if line.startswith(":"):
                group_name = line[1:].strip()
                current_group = group_name
                groups[current_group] = []
            else:
                # Each line should contain exactly 4 words
                tokens = line.split()
                if len(tokens) == 4 and current_group is not None:
                    groups[current_group].append(tokens)
    return groups

def evaluate_group(analogy_list):
    """
    For a given list of analogies (each a list of 4 words [a, b, c, d]),
    compute prediction accuracy for varying cutoff values k based on two measures:
      - Cosine similarity: higher is better.
      - L2 distance: lower is better.
    
    The candidate set is defined as all unique words that appear in the second (b)
    and fourth (d) positions among the analogies in the group.
    Returns the list of ks and two dictionaries mapping each k to the accuracy percentage.
    """
    # Build candidate set: unique words from the second and fourth positions in the group
    candidates = set()
    for tokens in analogy_list:
        candidates.add(tokens[1].lower())
        candidates.add(tokens[3].lower())
    candidates = list(candidates)

    ks = [1, 2, 5, 10, 20]
    cosine_correct = {k: 0 for k in ks}
    l2_correct = {k: 0 for k in ks}
    total = len(analogy_list)
    
    for tokens in analogy_list:
        # Unpack the analogy: a is to b as c is to d.
        a, b, c, d = [word.lower() for word in tokens]
        # Compute the reference difference vector from the known pair (a and b)
        ref_diff = get_embedding(a) - get_embedding(b)
        
        # For each candidate, compute the difference vector with respect to c
        cos_scores = {}
        l2_scores = {}
        for cand in candidates:
            candidate_diff = get_embedding(c) - get_embedding(cand)
            # For cosine similarity: higher is better
            cos_sim = F.cosine_similarity(ref_diff, candidate_diff, dim=0)
            # For L2 distance: lower is better
            l2_distance = torch.norm(ref_diff - candidate_diff, p=2)
            cos_scores[cand] = cos_sim.item()
            l2_scores[cand] = l2_distance.item()
        
        # Sort candidates by cosine similarity (descending order) and by L2 distance (ascending order)
        sorted_cos = sorted(cos_scores.items(), key=lambda x: x[1], reverse=True)
        sorted_l2 = sorted(l2_scores.items(), key=lambda x: x[1])
        
        # For each cutoff k, check if the true answer d is among the top k candidates
        for k in ks:
            top_k_cos = [item[0] for item in sorted_cos[:k]]
            top_k_l2 = [item[0] for item in sorted_l2[:k]]
            if d in top_k_cos:
                cosine_correct[k] += 1
            if d in top_k_l2:
                l2_correct[k] += 1
                
    # Calculate accuracy percentages
    cosine_acc = {k: (cosine_correct[k] / total) * 100 for k in ks}
    l2_acc = {k: (l2_correct[k] / total) * 100 for k in ks}
    
    return ks, cosine_acc, l2_acc

def print_table(group_name, ks, cosine_acc, l2_acc):
    """
    Prints a formatted results table for a given group.
    """
    print("Group:", group_name)
    print("{:<5} {:<30} {:<30}".format("k", "Accuracy (Cosine Similarity)", "Accuracy (L2 Distance)"))
    for k in ks:
        print("{:<5} {:<30.2f} {:<30.2f}".format(k, cosine_acc[k], l2_acc[k]))
    print("\n" + "-"*70 + "\n")
    
if __name__ == "__main__":
    # Read the analogy dataset from the local file
    file_path = "./data/task_1_data.txt"
    groups = read_analogy_file(file_path)
    
    # Choose three groups; one of them must be a 'capital' group.
    selected_groups = {}
    capital_group_key = None
    
    # Look for a group name containing 'capital' (case-insensitive)
    for key in groups.keys():
        if "capital" in key.lower():
            capital_group_key = key
            break
    if capital_group_key is None:
        raise ValueError("No capital-related group found in the dataset!")
    
    selected_groups[capital_group_key] = groups[capital_group_key]
    
    # Add any two other groups (excluding the already selected capital group)
    count = 1
    for key in groups.keys():
        if key == capital_group_key:
            continue
        if count >= 3:
            break
        selected_groups[key] = groups[key]
        count += 1
    
    # Evaluate each selected group and output the results table
    for group_name, analogy_list in selected_groups.items():
        ks, cosine_acc, l2_acc = evaluate_group(analogy_list)
        print_table(group_name, ks, cosine_acc, l2_acc)


Group: capital-common-countries
k     Accuracy (Cosine Similarity)   Accuracy (L2 Distance)        
1     64.82                          98.02                         
2     76.48                          99.01                         
5     86.17                          100.00                        
10    94.47                          100.00                        
20    99.21                          100.00                        

----------------------------------------------------------------------

Group: capital-world
k     Accuracy (Cosine Similarity)   Accuracy (L2 Distance)        
1     16.78                          51.99                         
2     21.68                          59.86                         
5     31.12                          71.24                         
10    40.52                          78.69                         
20    52.83                          85.08                         

---------------------------------------------------------

# Task 2