Skip to content

Commit

Permalink
Extract document loaders and parsers into separate modules (#354)
Browse files Browse the repository at this point in the history
- extract PDF, POI document parsers into separate modules
- extract and simplify S3 document loader into a separate module
  • Loading branch information
langchain4j committed Dec 18, 2023
1 parent 99faffe commit 3731f33
Show file tree
Hide file tree
Showing 48 changed files with 1,000 additions and 1,475 deletions.
56 changes: 56 additions & 0 deletions document-loaders/langchain4j-document-loader-amazon-s3/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.24.0</version>
<relativePath>../../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-document-loader-amazon-s3</artifactId>
<name>LangChain4j Amazon S3 document loader</name>
<packaging>jar</packaging>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>localstack</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
package dev.langchain4j.data.document.loader.amazon.s3;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentLoader;
import dev.langchain4j.data.document.DocumentParser;
import dev.langchain4j.data.document.source.amazon.s3.AmazonS3Source;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
import software.amazon.awssdk.core.ResponseInputStream;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.s3.model.*;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.stream.Collectors.toList;
import static software.amazon.awssdk.regions.Region.US_EAST_1;

public class AmazonS3DocumentLoader {

private static final Logger log = LoggerFactory.getLogger(AmazonS3DocumentLoader.class);

private final S3Client s3Client;

public AmazonS3DocumentLoader(S3Client s3Client) {
this.s3Client = ensureNotNull(s3Client, "s3Client");
}

/**
* Loads a single document from the specified S3 bucket based on the specified object key.
*
* @param bucket S3 bucket to load from.
* @param key The key of the S3 object which should be loaded.
* @param parser The parser to be used for parsing text from the object.
* @return A document containing the content of the S3 object.
* @throws RuntimeException If {@link S3Exception} occurs.
*/
public Document loadDocument(String bucket, String key, DocumentParser parser) {
try {
GetObjectRequest getObjectRequest = GetObjectRequest.builder()
.bucket(ensureNotBlank(bucket, "bucket"))
.key(ensureNotBlank(key, "key"))
.build();
ResponseInputStream<GetObjectResponse> inputStream = s3Client.getObject(getObjectRequest);
AmazonS3Source source = new AmazonS3Source(inputStream, bucket, key);
return DocumentLoader.load(source, parser);
} catch (S3Exception e) {
throw new RuntimeException(e);
}
}

/**
* Loads all documents from an S3 bucket.
* Skips any documents that fail to load.
*
* @param bucket S3 bucket to load from.
* @param parser The parser to be used for parsing text from the object.
* @return A list of documents.
* @throws RuntimeException If {@link S3Exception} occurs.
*/
public List<Document> loadDocuments(String bucket, DocumentParser parser) {
return loadDocuments(bucket, null, parser);
}

/**
* Loads all documents from an S3 bucket.
* Skips any documents that fail to load.
*
* @param bucket S3 bucket to load from.
* @param prefix Only keys with the specified prefix will be loaded.
* @param parser The parser to be used for parsing text from the object.
* @return A list of documents.
* @throws RuntimeException If {@link S3Exception} occurs.
*/
public List<Document> loadDocuments(String bucket, String prefix, DocumentParser parser) {
List<Document> documents = new ArrayList<>();

ListObjectsV2Request listObjectsV2Request = ListObjectsV2Request.builder()
.bucket(ensureNotBlank(bucket, "bucket"))
.prefix(prefix)
.build();

ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);

List<S3Object> filteredS3Objects = listObjectsV2Response.contents().stream()
.filter(s3Object -> !s3Object.key().endsWith("/") && s3Object.size() > 0)
.collect(toList());

for (S3Object s3Object : filteredS3Objects) {
String key = s3Object.key();
try {
Document document = loadDocument(bucket, key, parser);
documents.add(document);
} catch (Exception e) {
log.warn("Failed to load an object with key '{}' from bucket '{}', skipping it.", key, bucket, e);
}
}

return documents;
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private Region region = US_EAST_1;
private String endpointUrl;
private String profile;
private boolean forcePathStyle;
private AwsCredentials awsCredentials;

/**
* Set the AWS region. Defaults to US_EAST_1
*
* @param region The AWS region.
* @return The builder instance.
*/
public Builder region(String region) {
this.region = Region.of(region);
return this;
}

/**
* Set the AWS region. Defaults to US_EAST_1
*
* @param region The AWS region.
* @return The builder instance.
*/
public Builder region(Region region) {
this.region = region;
return this;
}

/**
* Specifies a custom endpoint URL to override the default service URL.
*
* @param endpointUrl The endpoint URL.
* @return The builder instance.
*/
public Builder endpointUrl(String endpointUrl) {
this.endpointUrl = endpointUrl;
return this;
}

/**
* Set the profile defined in AWS credentials. If not set, it will use the default profile.
*
* @param profile The profile defined in AWS credentials.
* @return The builder instance.
*/
public Builder profile(String profile) {
this.profile = profile;
return this;
}

/**
* Set the forcePathStyle. When enabled, it will use the path-style URL
*
* @param forcePathStyle The forcePathStyle.
* @return The builder instance.
*/
public Builder forcePathStyle(boolean forcePathStyle) {
this.forcePathStyle = forcePathStyle;
return this;
}

/**
* Set the AWS credentials. If not set, it will use the default credentials.
*
* @param awsCredentials The AWS credentials.
* @return The builder instance.
*/
public Builder awsCredentials(AwsCredentials awsCredentials) {
this.awsCredentials = awsCredentials;
return this;
}

public AmazonS3DocumentLoader build() {
AwsCredentialsProvider credentialsProvider = createCredentialsProvider();
S3Client s3Client = createS3Client(credentialsProvider);
return new AmazonS3DocumentLoader(s3Client);
}

private AwsCredentialsProvider createCredentialsProvider() {
if (!isNullOrBlank(profile)) {
return ProfileCredentialsProvider.create(profile);
}

if (awsCredentials != null) {
return awsCredentials.toCredentialsProvider();
}

return DefaultCredentialsProvider.create();
}

private S3Client createS3Client(AwsCredentialsProvider credentialsProvider) {

S3ClientBuilder s3ClientBuilder = S3Client.builder()
.region(region)
.forcePathStyle(forcePathStyle)
.credentialsProvider(credentialsProvider);

if (!isNullOrBlank(endpointUrl)) {
try {
s3ClientBuilder.endpointOverride(new URI(endpointUrl));
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

return s3ClientBuilder.build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.data.document.loader.amazon.s3;

import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

/**
* Represents an AWS credentials object, including access key ID, secret access key, and optional session token.
*/
public class AwsCredentials {

private final String accessKeyId;
private final String secretAccessKey;
private final String sessionToken;

public AwsCredentials(String accessKeyId, String secretAccessKey) {
this(accessKeyId, secretAccessKey, null);
}

public AwsCredentials(String accessKeyId, String secretAccessKey, String sessionToken) {
this.accessKeyId = ensureNotBlank(accessKeyId, "accessKeyId");
this.secretAccessKey = ensureNotBlank(secretAccessKey, "secretAccessKey");
this.sessionToken = sessionToken;
}

public AwsCredentialsProvider toCredentialsProvider() {
return StaticCredentialsProvider.create(toCredentials());
}

private software.amazon.awssdk.auth.credentials.AwsCredentials toCredentials() {
if (sessionToken != null) {
return AwsSessionCredentials.create(accessKeyId, secretAccessKey, sessionToken);
}
return AwsBasicCredentials.create(accessKeyId, secretAccessKey);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package dev.langchain4j.data.document.source.amazon.s3;

import dev.langchain4j.data.document.DocumentSource;
import dev.langchain4j.data.document.Metadata;

import java.io.InputStream;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.lang.String.format;

public class AmazonS3Source implements DocumentSource {

public static final String SOURCE = "source";

private final InputStream inputStream;
private final String bucket;
private final String key;

public AmazonS3Source(InputStream inputStream, String bucket, String key) {
this.inputStream = ensureNotNull(inputStream, "inputStream");
this.bucket = ensureNotBlank(bucket, "bucket");
this.key = ensureNotBlank(key, "key");
}

@Override
public InputStream inputStream() {
return inputStream;
}

@Override
public Metadata metadata() {
return Metadata.from(SOURCE, format("s3://%s/%s", bucket, key));
}
}
Loading

0 comments on commit 3731f33

Please sign in to comment.