In [1]:
# given a search query, return all results which are related that query

In [28]:
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.embeddings import OpenAIEmbeddings
from vectorize_dataset import load_descriptions_data, create_db

import requests
from helpers import clean_up_tags


In [6]:
class DatasetRecommender:
    def __init__(self, llm_backbone = OpenAI(), embeddings_backbone = OpenAIEmbeddings()):
        self.llm_backbone = llm_backbone
        self.embeddings_backbone = embeddings_backbone
        self.hf_df = load_descriptions_data()
        self.db = create_db(self.hf_df, self.embeddings_backbone)
        self.datasets_url_base = "https://huggingface.co/datasets/"
        # expose this index in a retriever interface
        self.retriever = self.db.as_retriever(search_type="similarity", search_kwargs={"k":2})
        # create a chain to answer questions 
        self.qa = RetrievalQA.from_chain_type(
            llm=self.llm_backbone, chain_type="stuff", retriever=self.retriever, return_source_documents=True)

    def recommend_based_on_text(self, query):
        result = self.qa({"query": query})
        response_text = result['result']
        source_documents = result['source_documents']
        linked_datasets = [f"{self.datasets_url_base}{x.metadata['id']}" for x in source_documents]
        return {'message': response_text, 'datasets': linked_datasets}

    def get_similar_datasets(self, query_url):
        retrieved_metadata = get_dataset_metadata(query_url)
        cleaned_description = retrieved_metadata['description'] + clean_up_tags(retrieved_metadata['tags'])
        similar_documents = database.similarity_search(cleaned_description)
        similar_datasets = [f"{self.datasets_url_base}{x.metadata['id']}" for x in similar_documents if x.metadata['id'] not in url]       
        return {'datasets': similar_datasets} 

In [7]:
db_lookup = DatasetRecommender()

Found cached dataset parquet (/Users/noahkasmanoff/.cache/huggingface/datasets/nkasmanoff___parquet/nkasmanoff--huggingface-datasets-60bbbd3d2e18598e/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 1/1 [00:00<00:00, 102.93it/s]
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1253, which is longer than the specified 1000
Created a chunk of size 1140, which is longer than the specified 1000
Created a chunk of size 1140, which is longer than the specified 1000
Created a chunk of size 1087, which is longer than the specified 1000
Created a chunk of size 1190, which is longer than the specified 1000
Created a chunk of size 1164, w

In [None]:
db_lookup.recommend_based_on_text("Show me datasets which are about natural disasters?")

{'message': ' The HumAID Twitter dataset consists of several thousands of manually annotated tweets that has been collected during 19 major natural disaster events including earthquakes, hurricanes, wildfires, and floods, which happened from 2016 to 2019 across different parts of the World. The annotations in the provided datasets consists of humanitarian categories such as caution and advice, displaced people and evacuations, infrastructure and utility damage, injured or dead people, missing or found people, requests or urgent needs, rescue volunteering or donation effort, and sympathy and support. Additionally, the dataset contains the dataset contains 30,000 messages drawn from events including an earthquake in Haiti in 2010, an earthquake in Chile in 2010, floods in Pakistan in 2010, super-storm Sandy in the U.S.A. in 2012, and news articles spanning a large number of years and 100s of different disasters.',
 'datasets': ['https://huggingface.co/datasets/disaster_response_messages'

In [45]:
def check_api_url(url):
    """
    This function checks to see if "api" is present in the URL between ".co" and "/datasets". If not, it inserts "api" in the correct position.
    
    Args:
    url (str): A URL string
    
    Returns:
    str: A URL string with "api" inserted if necessary
    """
    # Split the URL into three parts based on the location of ".co" and "/datasets"
    parts = url.split(".co")
    first_part = parts[0] + ".co"
    last_part = parts[1]
    last_parts = last_part.split("/datasets")
    middle_part = ""
    if len(last_parts) > 1 and "/api" not in last_parts[0]:
        middle_part = "/api"
    # Concatenate the three parts to form the final URL
    new_url = first_part + middle_part + last_parts[0] + "/datasets" + last_parts[1]
    return new_url



def get_dataset_metadata(dataset_url):
    retrieved_metadata = {}
    dataset_url = check_api_url(dataset_url)
    keys_to_retrieve = ['id','description', 'tags']
    response = requests.get(dataset_url)
    if response.status_code == 200:
        response_json = response.json()
        for key in keys_to_retrieve:
            if key in response_json:
                retrieved_metadata[key] = response_json[key]

    return retrieved_metadata

In [56]:
url = "https://huggingface.co/datasets/turkic_xwmt"
retrieved_metadata = get_dataset_metadata(url)
cleaned_description = retrieved_metadata['description'] + clean_up_tags(retrieved_metadata['tags'])
similar_documents = database.similarity_search(cleaned_description)
similar_datasets = [x.metadata['id'] for x in similar_documents if x.metadata['id'] not in url]