In [None]:
%store -r WEAVIATE_IP
from boto3 import Session

session = Session()
credentials = session.get_credentials()
current_credentials = credentials.get_frozen_credentials()

AWS_ACCESS_KEY = current_credentials.access_key
AWS_SECRET_KEY = current_credentials.secret_key
AWS_SECRET_TOKEN = current_credentials.token
print(f"AWS_ACCESS_KEY:\t{AWS_ACCESS_KEY}")
print(f"AWS_SECRET_KEY:\t{AWS_SECRET_KEY}")
print(f"AWS_SECRET_TOKEN:\t{AWS_SECRET_TOKEN}")
print(f"WEAVIATE_IP:\t{WEAVIATE_IP}")

## Connect

In [None]:
import weaviate

client = weaviate.connect_to_custom(
    http_host=WEAVIATE_IP, http_port="8080",  http_secure=False,
    grpc_host=WEAVIATE_IP, grpc_port="50051", grpc_secure=False,

    headers={
        "X-AWS-Access-Key": AWS_ACCESS_KEY,
        "X-AWS-Secret-Key": AWS_SECRET_KEY,
        "X-AWS-Session-Token": AWS_SECRET_TOKEN,
    }
)

client.is_ready()

## Create Collection

In [None]:
from weaviate.classes.config import Configure, Property, DataType

# client.collections.delete("Wikipedia")

# Create a collection here - with AWS as a vectorizer, with "embed-multilingual-v2.0" model
client.collections.create(
    name="Wikipedia",
    
    # TODO: add AWS vectorizer
    ## use embed-multilingual-v2.0 model
    
    # TODO: (optional) pick a generative model (i.e. gpt-4)
    # generative_config=Configure.Generative.,

    properties=[
        Property(name="text", data_type=DataType.TEXT),
        Property(name="title", data_type=DataType.TEXT, skip_vectorization=True),
        Property(name="wiki_id", data_type=DataType.INT, skip_vectorization=True),
        Property(name="url", data_type=DataType.TEXT, skip_vectorization=True),
        Property(name="lang", data_type=DataType.TEXT, skip_vectorization=True),
        Property(name="lang_id", data_type=DataType.INT, skip_vectorization=True),
        Property(name="views", data_type=DataType.NUMBER, skip_vectorization=True),
    ]
)

In [None]:
from datasets import load_dataset
from tqdm import tqdm

def import_wiki_data(lang, lang_id, max_rows, skip_rows=0):
    print(f"Importing {max_rows} data items for {lang}")

    dataset = load_dataset(f"Cohere/wikipedia-22-12-{lang}-embeddings", split="train", streaming=True)
    dataset = dataset.skip(skip_rows)

    counter = 0
    counter = skip_rows

    wikipedia = client.collections.get("Wikipedia")

    # TODO: prepare batch with fixed size set to 1000 and (optional) 4 concurrent requests
    with wikipedia.batch.
        for item in tqdm(dataset, initial=skip_rows, total=max_rows):
            vector = item["emb"]
            data_to_insert = {   
                "text": item["text"],
                "wiki_id": item["wiki_id"],
                "title": item["title"],
                "url": item["url"],
                "views": item["views"],
                "lang": lang,
                "lang_id": lang_id,
            }

            # TODO: add batch insert code, insert both data_to_insert and vector
            
            # stop after the request number reaches = max_rows
            counter += 1
            if counter >= max_rows:
                break
    
    # check for errors at the end
    if (len(wikipedia.batch.failed_objects)>0):
        print("Final error check")
        print(f"Some errors {len(wikipedia.batch.failed_objects)}")
        print(wikipedia.batch.failed_objects[-1])
    
    print(f"Imported {counter} items for {lang}")
    print("-----------------------------------")

In [None]:
import_per_country = 10_000

import_wiki_data("en", 0, import_per_country, 0)
import_wiki_data("de", 1, import_per_country, 0)
import_wiki_data("fr", 2, import_per_country, 0)
# import_wiki_data("es", 3, import_per_country, 0)
# import_wiki_data("it", 4, import_per_country, 0)

## Close the client when done

In [None]:
client.close()