## first pass at downloading chunks from WME

In [None]:
# !pip install torch==2.9.0 --extra-index-url https://download.pytorch.org/whl/cpu
# !conda install -c faiss-cpu
# !pip install -U sentence-transformers transformers safetensors huggingface-hub
# !pip install dataset[faiss]
# !pip install faiss-cpu

In [None]:
import requests
import gzip
import json
import mwparserfromhell
from datasets import Dataset
from sentence_transformers import SentenceTransformer
import os

from dotenv import load_dotenv
load_dotenv()

In [None]:
wme_username = os.environ['wme_username']
wme_password = os.environ['wme_password']

In [None]:
# cache all-MiniLM-L6-v2 embedding model and load it
model_path = "/home/htriedman/.cache/huggingface/hub/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/c9745ed1d9f207416be6d2e6f8de32d1f16199bf"
model = SentenceTransformer(model_path, device='cpu')

In [None]:
resp = requests.post(
    'https://auth.enterprise.wikimedia.com/v1/login',
    data={
        'username': wme_username,
        'password': wme_password
    }     
)
tokens = resp.json()

In [None]:
snapshot_identifier = "enwiki_namespace_0"

# download a single sample chunk, switch comment to download all ~400 chunks
# url = f"https://api.enterprise.wikimedia.com/v2/snapshots/{snapshot_identifier}/download"
url = f"https://api.enterprise.wikimedia.com/v2/snapshots/{snapshot_identifier}/chunks/0/download"
destination_directory = "./extracted_enwiki_chunk0"

headers = {
    "Accept": "application/json",
    "Authorization": f"Bearer {tokens['access_token']}"
}

chunk = []


try:
    response = requests.get(url, headers=headers)
    response.raise_for_status()  # raise an exception for bad status codes

    # check content type
    content_type = response.headers.get('Content-Type', '')
    if content_type not in ['application/zip', 'binary/octet-stream']:
        print(f"Warning: Expected a zip file, but received {content_type}")
    
    compressed_data = response.content
    
    # decompress, decode, and extract usable string from data
    decompressed_data = gzip.decompress(compressed_data)
    decoded_string = decompressed_data.decode('utf-8')
    decoded_string = decoded_string[decoded_string.find('{'):decoded_string.rfind('}')+1]
    for page in decoded_string.split('\n'):
        try:
            chunk += [json.loads(page)]
        except:
            print(f"{page[:1000]} load failed")
            continue
        
    print(f"{len(chunk)} total pages in this chunk")

except UnicodeDecodeError as e:
    print(f"Error decoding the decompressed data. {e}")
    print("The data might be binary or use a different character encoding.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

In [None]:
ds = Dataset.from_list(chunk)

In [None]:
ds.save_to_disk('data/chunk_0_hf')

In [None]:
ds = Dataset.load_from_disk('data/chunk_0_hf')

In [None]:
def get_to_embed(article):
    name = article['name']
    abstract = article['abstract']
    
    wikitext = mwparserfromhell.parse(article['article_body']['wikitext'])
    for t in wikitext.filter_templates():
        try:
            wikitext.remove(t)
        except:
            continue
            
    end = min(500, len(wikitext))
    first_chars = str(wikitext[:end])
    
    article['to_embed'] = f"{name}\n\n{abstract}\n\n{first_chars}"
    return article
    

ds = ds.map(get_to_embed)

In [None]:
def do_embedding(article):
    embedding = model.encode(article['to_embed'], convert_to_numpy=True)
    article['embedding'] = embedding
    return article

ds = ds.map(do_embedding)

In [None]:
ds.save_to_disk('data/chunk_0_hf_embed')

In [None]:
ds = Dataset.load_from_disk('data/chunk_0_hf_embed')

In [None]:
ds.add_faiss_index(column='embedding')

In [None]:
q = 'Cats are so interesting. I especially like tigers.'
emb_q = model.encode(q, convert_to_numpy=True)

scores, retrieved = ds.get_nearest_examples('embedding', emb_q, k=10)

for s, r in zip(scores, retrieved['name']):
    print(f"{r}, {s}")

In [None]:
ds.save_faiss_index('embedding', 'test_idx.faiss')

In [None]:
ds.load_faiss_index('embedding', 'my_index.faiss')

## TODOs
- parse `{{cite xxx` and `{{citation` templates from source data
- paralellize to download multiple chunks at a time
- parse grokipedia