diff --git a/src/ragas/integrations/llama_index.py b/src/ragas/integrations/llama_index.py index ff9c5b84e..7ba31b732 100644 --- a/src/ragas/integrations/llama_index.py +++ b/src/ragas/integrations/llama_index.py @@ -43,7 +43,7 @@ def validate_dataset(dataset: dict, metrics: list[Metric]): def evaluate( query_engine, - dataset: dict, + dataset: Dataset, metrics: list[Metric], llm: t.Optional[LlamaindexLLM] = None, embeddings: t.Optional[LlamaIndexEmbeddings] = None, @@ -98,7 +98,7 @@ def evaluate( "answer": answers, } ) - if "ground_truth" in dataset: + if "ground_truth" in dataset.column_names: hf_dataset = hf_dataset.add_column( name="ground_truth", column=dataset["ground_truth"],