## Imports

In [1]:
!pip install farm-haystack transformers ollama
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install sentence-transformers
!pip install "farm-haystack[inference]"

Looking in indexes: https://download.pytorch.org/whl/cpu


In [1]:
from haystack.document_stores import InMemoryDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.schema import Document
import ollama
import os
import pickle
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
CUR_LANG = "zh"
def EMBEDDINGS_PATH(lang: str):
    return f"embeddings_{CUR_LANG}.pkl"

## Prepare key terms embeddings

In [11]:
if (os.path.exists(EMBEDDINGS_PATH(CUR_LANG))):
    with open(EMBEDDINGS_PATH(CUR_LANG), 'rb') as f:
        document_store = pickle.load(f)
        retriever = EmbeddingRetriever(
            document_store=document_store,
            embedding_model="sentence-transformers/all-MiniLM-L6-v2",
            use_gpu=False 
        )
else:
    document_store = InMemoryDocumentStore(embedding_dim=384)
    key_terms = pd.read_csv("terms.csv")
    documents = [Document(content=term[1][CUR_LANG], id=term[1]["id"]) for term in key_terms.iterrows()]
    document_store.write_documents(documents)
    retriever = EmbeddingRetriever(
            document_store=document_store,
            embedding_model="sentence-transformers/all-MiniLM-L6-v2",
            use_gpu=False 
        )
    document_store.update_embeddings(retriever)
    with open(EMBEDDINGS_PATH(CUR_LANG), 'wb') as f:
        pickle.dump(document_store, f)

Batches: 100%|██████████| 229/229 [00:02<00:00, 88.34it/s] s/s]
Documents Processed: 10000 docs [00:02, 3611.43 docs/s]         


## Prepare question embedding and question for Llama3.2

In [28]:
def get_relevant_key_terms(question, top_k=10):
    results = retriever.retrieve(query=question, top_k=top_k)
    return {doc.content.replace('"', '') : doc.id for doc in results}

os.popen('ollama serve')
client = ollama.Client()

def ask_llama(question, relevant_key_terms):
    prompt = f"""
        Here is a search prompt by a user of the UN digital library:
        Question: {question}
        Here is a list of key terms. Each key term is in its own square brackets.
        Key Terms: {relevant_key_terms}
        Select the two key terms that are most relevant to the question.
        Output only the selected key terms as they are presented to you: each key term within square brackets.
        Separate each set of square brackets with a semi-colon and no spaces.
        """

    return client.chat(model="llama3.2", messages=[{"role": "user", "content": prompt}])['message']['content']


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Error: listen tcp 127.0.0.1:11434: bind: address already in use


## Ask question

In [37]:
#question = "I would like to know more about the impact of world war 1 on the economy of Germany."
question = "world war II"

In [38]:
# Get relevant key terms
key_terms = get_relevant_key_terms(question)

key_terms_array = ["[" + term.replace('"', '') + "]" for term in key_terms.keys()]
print("Key Terms:", key_terms_array)

# Join the key terms array into a single string separated by semi-colons
key_terms_string = ";".join(key_terms_array)
print("Relevant Key Terms:", key_terms_string)

Batches: 100%|██████████| 1/1 [00:00<00:00, 48.55it/s]

Key Terms: ['[世界大战（1939-1945]', '[泰国]', '[日本]', '[核武器国家]', '[国籍]', '[国徽]', '[国语]', '[无核武器国家]', '[国际日]', '[国际年]']
Relevant Key Terms: [世界大战（1939-1945];[泰国];[日本];[核武器国家];[国籍];[国徽];[国语];[无核武器国家];[国际日];[国际年]





In [42]:
max_tries = 5
is_answer_valid = False

i = 0
while not is_answer_valid and i < max_tries:
    answer = ask_llama(question, key_terms_string)
    print("Answer:", answer)

    answer_split = [term.strip() for term in answer.split(';')]
    is_answer_valid = isinstance(answer, str) and len(answer_split) <= 2
    is_answer_valid = is_answer_valid and all([term.strip() in key_terms_array for term in answer.split(';')])
    print(is_answer_valid)
    i += 1

if i == max_tries:
    # select two first key terms
    answer = ";".join(key_terms_array[:2])
    answer_split = [term.strip() for term in answer.split(';')]
    print("Answer:", answer)

Answer: [世界大战（1939-1945);][日本]
False
Answer: [世界大战（1939-1945);[日本];[核武器国家];[国籍]];[无核武器国家]; [国际日]; [国际年]
False
Answer: [世界大战（1939-1945）;[日本];[核武器国家]
False
Answer: [世界大战（1939-1945）; [日本]
False
Answer: [世界大战（1939-1945）; [日本]; [核武器国家]
False
Answer: [世界大战（1939-1945];[泰国]


In [40]:
print(answer_split)
output = ""
for term in answer_split:
    output += f"subjectheading:{term}"
print(output)

['[世界大战（1939-1945]', '[泰国]']
subjectheading:[世界大战（1939-1945]subjectheading:[泰国]


In [43]:
term1 = answer_split[0].replace('[', '').replace(']', '')
term2 = answer_split[1].replace('[', '').replace(']', '')
print(term1, term2)
print(key_terms[term1], key_terms[term2])


世界大战（1939-1945 泰国
1007089 1006486
