In [1]:
import pandas as pd
import langchain
import numpy as np
import os
import json

from langchain.document_loaders import DataFrameLoader
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/home/alexis_cunin/12_taxobservatory')

import country_by_country
from country_by_country.rag_engine.llm import get_llm
from country_by_country.rag_engine.rag import Extraction

# 1. Build vectorstore
---

## Parse and create vectorstore

##### Create vectorstore

In [2]:
file_path = '/home/alexis_cunin/12_taxobservatory/data/AngloAmerican_2021_CbCR/tables_img/md_tables.xlsx'
md_tables = pd.read_excel(file_path)
md_tables.head()

Unnamed: 0,page_num,md_table
0,24,| | income and expense items: ...
1,23,| | Tax Jurisdiction | 8 Name U ...
2,3,| | Revenues | Tangibl...
3,21,| | Tax Jurisdiction | Name ...
4,10,| | Tax Jurisdiction | 8 Name H ...


In [3]:
df_tables = md_tables[['md_table', 'page_num']].rename(columns={
    "md_table": "text",
    "page_num": "page"
})
df_tables.head()

Unnamed: 0,text,page
0,| | income and expense items: ...,24
1,| | Tax Jurisdiction | 8 Name U ...,23
2,| | Revenues | Tangibl...,3
3,| | Tax Jurisdiction | Name ...,21
4,| | Tax Jurisdiction | 8 Name H ...,10


## Table summarization

In [4]:
table_loader = DataFrameLoader(df_tables, page_content_column="text")
tables = table_loader.load()

In [5]:
table_prompt_text = """
You are an assistant responsible for extracting information from markdown tables.
Extract in a string all the key words and themes from the markdown table below, and add as many words from their lexical fields.

Guidelines:
- If country names are present in the table, list them all in the summary.
- If financial KPIs are present in the table, list them in the summary (e.g. income, tax income, number of employees).

Table: {element}
Summary:"""
prompt = ChatPromptTemplate.from_template(table_prompt_text)

model = get_llm()
table_summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

  from .autonotebook import tqdm as notebook_tqdm


Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/alexis_cunin/.cache/huggingface/token
Login successful


In [6]:
table_summaries = []
validated_tables = []
remaining_tables = [x for x in tables if x not in validated_tables]
table_errors = []

In [7]:
from tqdm.auto import tqdm
import time

for table in tqdm(tables):
    try:
        results = table_summarize_chain.invoke(table.page_content)
        table_summaries.append(results)
        validated_tables.append(table)
    except Exception as e:
        table_errors.append((table, str(e)))
        print(f"Error for text {table}: {e}")
        time.sleep(10)  # Wait 10 seconds before retrying

  0%|          | 0/26 [00:00<?, ?it/s]

 35%|███▍      | 9/26 [00:36<01:02,  3.70s/it]

Error for text page_content='|    | Currency USD         | Revenues        | Profit/(Loss   | Income Tax                                 | Income Tax                                  | Income Tax                      | Tangible Assets other than Cash and Cash    | CBCR Effective                               | Statutory Corporate   | Explanation of sianificant                  |                                             |                                                                                       |                                                          |                         |                                                                                                                                                              |\n|---:|:---------------------|:----------------|:---------------|:-------------------------------------------|:--------------------------------------------|:--------------------------------|:--------------------------------------------|:---

100%|██████████| 26/26 [04:21<00:00, 10.04s/it]


In [8]:
print("Summarized tables:", len(validated_tables))
print("Errors:", len(table_errors))

Summarized tables: 25
Errors: 1


In [9]:
for table, error in table_errors:
    print(error)

(ReadTimeoutError("HTTPSConnectionPool(host='api-inference.huggingface.co', port=443): Read timed out. (read timeout=120)"), '(Request ID: 8e5a0661-5932-4bcb-ae09-6df3de2ab686)')


In [10]:
table_summaries[4]

' | Tax Jurisdiction | 8 Name H                                                                 | 8 H                           | 6 H+                       | 6 U   | C     | 0 U   | 0 H   | 5 U   | 7   | H   | 1 1 3   | 1 1 3   | J   | J   | Ubl   | 8 1 6   | L]   | Iu   | 8 1 L   | 82 HL   | 1   | 8   | 4 UL   |\n| ------ | ------------------- | ------------------------------------------------------------------------- | ------------------------------ | --------------------------- | ------ | ------ | ------ | ------ | ------ | ---- | ---- | -------- | -------- | ---- | ---- | ------ | -------- | ---- | ---- | -------- | -------- | ---- | ---- | -------- |\n| 1 | Canada             | De Beers Canada Inc.                                                     | Yes                           | Yes                        |       |       |       |       |       |     |     |         |         |     |     |       |         |      |      |         |         |     |     |        |\n| 2 | Canada 

In [11]:
df_tables['summary'] = table_summaries
df_tables.head()

ValueError: Length of values (25) does not match length of index (26)

In [None]:
df_tables.to_excel("table_summaries.xlsx", index=False)

## Get summaries

In [None]:
table_img_folder = os.path.join("../data", "AngloAmerican_2021_CbCR/tables_img")
df_tables.to_excel(os.path.join(table_img_folder, "md_tables_with_summary.xlsx"), index=False)
df_tables = pd.read_excel(os.path.join(table_img_folder, "md_tables_with_summary.xlsx"))
table_summaries = list(df_tables['summary'])

In [None]:
loader = DataFrameLoader(df_tables, page_content_column="text")
tables = loader.load()

In [None]:
tables[:5]

[Document(page_content='|    | income and expense items:                                                                                                                                        | Main Business Activity(ies   |\n|---:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------|\n|  1 | The nature ofthe main business activity(ies) carried by the constituent entity in the relevant taxjurisdiction, by ticking one or more of the appropriate boxes: |                              |', metadata={'page': 24, 'summary': '\n\n|    | income and expense items:                                                                                                                                        | Main Business Activity(ies   |\n|---:|:--------------------------------------------------------------------------------------------------------------------------

## Add tables to vectorstore

In [None]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings

# The vectorstore to use to index the child chunks
embedding_function = HuggingFaceEmbeddings(
    model_name='BAAI/bge-base-en-v1.5',  # BAAI/bge-base-en-v1.5 // thenlper/gte-small
    encode_kwargs={
        "show_progress_bar": True,
        "batch_size": 1
    },
)

In [None]:
k = 3

In [None]:
vectorstore = Chroma(
    collection_name="summaries",
    embedding_function=embedding_function,
)

In [None]:
# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_kwargs={"k": k},
)

In [None]:
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=s, metadata={id_key: table_ids[i]})
    for i, s in enumerate(table_summaries)
]
summary_tables[:5]

[Document(page_content='\n\n|    | income and expense items:                                                                                                                                        | Main Business Activity(ies   |\n|---:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------|\n|  1 | The nature ofthe main business activity(ies) carried by the constituent entity in the relevant taxjurisdiction, by ticking one or more of the appropriate boxes: |                              |\nSummary:\n\n|    | income and expense items:                                                                                                                                        | Main Business Activity(ies   |\n|---:|:-------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
# Add tables
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

TypeError: sentence_transformers.SentenceTransformer.SentenceTransformer.encode() got multiple values for keyword argument 'show_progress_bar'

# 2. RAG
---

## Testing simple retrieval

In [None]:
retriever.get_relevant_documents("Ireland")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches: 100%|██████████| 1/1 [00:00<00:00, 19.25it/s]


[Document(page_content='|    | Tax Jurisdiction   | Name                                                    | 8: H                             | 6 Lu                   | 6 Hh     | 6 3 I   | 6 H   | 6 U   | 5  5 7 5   | 0 7 UHh 6   | 1 6 L   | U   | 8 Ih   | Ul   | 8 1L   | 1   | Iu   | Iu   | 1 1 1 1   | 1 1 V   | L   | 8   | 1 UL   |\n|---:|:-------------------|:--------------------------------------------------------|:---------------------------------|:-----------------------|:---------|:--------|:------|:------|:-----------|:------------|:--------|:----|:-------|:-----|:-------|:----|:-----|:-----|:----------|:--------|:----|:----|:-------|\n|  1 | Isle of Man        | Element Six (Legacy Pensions) Limited                   | Yes                              | 2S                     |          |         |       |       |            |             |         |     |        |      |        |     |      |      |           |         |     |     |        |\n|  2 | Israel             | De 

In [None]:
def get_top_k_docs(retriever, question, k):
    retriever.search_kwargs = {"k": k}
    docs = retriever.get_relevant_documents(question)
    df_docs = dict()
    for i, doc in enumerate(docs):
        df_docs["Doc {}".format(i+1)] = dict()
        df_docs["Doc {}".format(i+1)]['page_content'] = doc.page_content
        # df_docs["Doc {}".format(i+1)]['file_name'] = doc.metadata['source']
        df_docs["Doc {}".format(i+1)]['page_number'] = doc.metadata['page']
        # df_docs["Doc {}".format(i+1)]['content_type'] = ("tableau" if doc.metadata['type']=='MarkdownTable' else "texte")
    return df_docs

In [None]:
question = "How many employees in Ireland"
get_top_k_docs(retriever, question, k)

Batches: 100%|██████████| 1/1 [00:00<00:00, 24.27it/s]


{'Doc 1': {'page_content': '|    | nclude cash taxes paid by the     | permanent establishments, the stated capitalis reported by the                                                                     |\n|---:|:----------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------|\n|  1 | s reportedbv comnanv              | Number of Employees                                                                                                                |\n|  2 | negative amountsin Table 1_       | The total number of employees on a full-time equivalent (FTE) basis of allthe constituent entities resident for tax purposesin the |\n|  3 | Year)                             | relevant taxjurisdiction The number of employeeshas been renorted onthebasis of averaae emnlovmentlevels for thevear               |\n|  4 | xpense recorded on taxable        |                                             

## Testing simple RAG

In [None]:
retriever.search_kwargs = {"k": k}

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

# Prompt template
template = """
Answer the question directly, using only the following markdown table as the single element of the list:
{context}

Question: {question}
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)

# LLM
model = get_llm()

# RAG pipeline
chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

In [None]:
question = "How many employees in Ireland?"
result = chain.invoke(question)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches: 100%|██████████| 1/1 [00:00<00:00, 27.58it/s]


In [None]:
with open('rag_result.txt', 'w') as f:
    f.write(result)

## Final chain

In [None]:
%load_ext autoreload
%autoreload 2
from country_by_country.rag_engine.rag import Extraction

llm = get_llm()
chain = Extraction(retriever, llm)
questions = [
    "How many employees in Ireland?",
    "What revenue has been declared in France?",
]

df_answers = chain.run(questions)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Batches: 100%|██████████| 1/1 [00:00<00:00, 33.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 32.08it/s]


In [None]:
df_answers

Unnamed: 0,question,answer,Doc 1 page,Doc 1 relevant content,Doc 2 page,Doc 2 relevant content,Doc 3 page,Doc 3 relevant content
0,How many employees in Ireland?,"Human: \n You are an economical expert, wit...",24,| | nclude cash taxes paid by the | per...,8,| | Tax Jurisdiction | Name 2 8 222 ...,2,| | nglo American is a leading global minin...
1,What revenue has been declared in France?,"Human: \n You are an economical expert, wit...",24,| | nclude cash taxes paid by the | per...,2,| | nglo American is a leading global minin...,24,| | income and expense items: ...


In [None]:
df_answers.to_excel('rag_answers.xlsx', index=False)