Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to specify metadata columns in CSV loader #11576

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions libs/langchain/langchain/document_loaders/csv_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
from io import TextIOWrapper
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Sequence

from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
Expand Down Expand Up @@ -36,6 +36,7 @@ def __init__(
self,
file_path: str,
source_column: Optional[str] = None,
metadata_columns: Sequence[str] = (),
csv_args: Optional[Dict] = None,
encoding: Optional[str] = None,
autodetect_encoding: bool = False,
Expand All @@ -46,13 +47,15 @@ def __init__(
file_path: The path to the CSV file.
source_column: The name of the column in the CSV file to use as the source.
Optional. Defaults to None.
metadata_columns: A sequence of column names to use as metadata. Optional.
csv_args: A dictionary of arguments to pass to the csv.DictReader.
Optional. Defaults to None.
encoding: The encoding of the CSV file. Optional. Defaults to None.
autodetect_encoding: Whether to try to autodetect the file encoding.
"""
self.file_path = file_path
self.source_column = source_column
self.metadata_columns = metadata_columns
self.encoding = encoding
self.csv_args = csv_args or {}
self.autodetect_encoding = autodetect_encoding
Expand Down Expand Up @@ -85,9 +88,9 @@ def load(self) -> List[Document]:

def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
docs = []

csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
for i, row in enumerate(csv_reader):
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
try:
source = (
row[self.source_column]
Expand All @@ -98,7 +101,17 @@ def __read_file(self, csvfile: TextIOWrapper) -> List[Document]:
raise ValueError(
f"Source column '{self.source_column}' not found in CSV file."
)
content = "\n".join(
f"{k.strip()}: {v.strip()}"
for k, v in row.items()
if k not in self.metadata_columns
)
metadata = {"source": source, "row": i}
for col in self.metadata_columns:
try:
metadata[col] = row[col]
except KeyError:
raise ValueError(f"Metadata column '{col}' not found in CSV file.")
doc = Document(page_content=content, metadata=metadata)
docs.append(doc)

Expand Down
Loading