# Country and Capitals

Characterizing Mechanisms for Factual Recall in Language Models <br>
https://arxiv.org/pdf/2310.15910.pdf

In [15]:
from measureLM import helpers

import pandas as pd
import random

In [46]:
df = pd.read_csv(helpers.ROOT_DIR / "data" / "CountryCapital" / "country-capital.csv")
countries, capitals = df["country"].to_list(), df["capital"].to_list()
df

Unnamed: 0,country,capital,type
0,Abkhazia,Sukhumi,countryCapital
1,Afghanistan,Kabul,countryCapital
2,Akrotiri and Dhekelia,Episkopi Cantonment,countryCapital
3,Albania,Tirana,countryCapital
4,Algeria,Algiers,countryCapital
...,...,...,...
243,Wallis and Futuna,Mata-Utu,countryCapital
244,Western Sahara,El Aaiún,countryCapital
245,Yemen,Sanaá,countryCapital
246,Zambia,Lusaka,countryCapital


In [49]:
def format_prompt(country, city, contextCity=None):
    
    in_context_prefix = "The capital of {country} is {contextCity}." 
    template = "Q: What is the capital of {country}? A:"
    prompt = template.format(country=country, city=city)
    
    if contextCity is not None:
        in_context_prefix = in_context_prefix.format(country=country, contextCity=contextCity)
        prompt = f"{in_context_prefix} {prompt}"
    return prompt 


def load_country_capitals(n_pairs=10, n_wrong_contexts=0, seed=0):
    
    df = pd.read_csv(helpers.ROOT_DIR / "data" / "CountryCapital" / "country-capital.csv")
    countries, capitals = df["country"].to_list(), df["capital"].to_list()
    
    random.seed(seed)
    idcs  = random.sample(range(len(countries)), n_pairs) 
    
    prompts = []
    for idx in idcs:
        country, capital = countries[idx], capitals[idx]
        memory_prompt = format_prompt(country, capital, contextCity=None)
        right_context_prompt = format_prompt(country, capital, contextCity=capital)
        country_capital_prompts = [capital, memory_prompt, right_context_prompt]
        
        wrong_context_idcs = random.sample(list(set(range(len(countries)))-set([idx])), n_wrong_contexts)
        for wrong_context_idx in wrong_context_idcs:
            wrong_context_prompt = format_prompt(country, capital,contextCity=capitals[wrong_context_idx])
            country_capital_prompts.append(wrong_context_prompt) 
        prompts.append(country_capital_prompts)
    return prompts
    
prompts = load_country_capitals(n_pairs=5, n_wrong_contexts=2, seed=0) 
prompts

[['São Tomé',
  'Q: What is the capital of São Tomé and Príncipe? A:',
  'The capital of São Tomé and Príncipe is São Tomé. Q: What is the capital of São Tomé and Príncipe? A:',
  'The capital of São Tomé and Príncipe is Buenos Aires. Q: What is the capital of São Tomé and Príncipe? A:',
  'The capital of São Tomé and Príncipe is Cairo. Q: What is the capital of São Tomé and Príncipe? A:'],
 ['Reykjavík',
  'Q: What is the capital of Iceland? A:',
  'The capital of Iceland is Reykjavík. Q: What is the capital of Iceland? A:',
  'The capital of Iceland is Bamako. Q: What is the capital of Iceland? A:',
  'The capital of Iceland is Luxembourg. Q: What is the capital of Iceland? A:'],
 ['Victoria',
  'Q: What is the capital of Seychelles? A:',
  'The capital of Seychelles is Victoria. Q: What is the capital of Seychelles? A:',
  'The capital of Seychelles is Dublin. Q: What is the capital of Seychelles? A:',
  'The capital of Seychelles is Montevideo. Q: What is the capital of Seychelles?