Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@

package com.mongodb.spark.sql.connector;

import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG;
import static com.mongodb.spark.sql.connector.schema.ConverterHelper.TIMESTAMP_NTZ_TYPE;
import static java.time.ZoneOffset.UTC;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import com.mongodb.spark.sql.connector.beans.BoxedBean;
import com.mongodb.spark.sql.connector.beans.ComplexBean;
import com.mongodb.spark.sql.connector.beans.DateTimeBean;
import com.mongodb.spark.sql.connector.beans.PrimitiveBean;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import com.mongodb.spark.sql.connector.config.WriteConfig.TruncateMode;
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
import java.sql.Date;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneOffset;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -39,8 +46,12 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.ValueSource;

public class RoundTripTest extends MongoSparkConnectorTestCase {

Expand Down Expand Up @@ -68,8 +79,9 @@ void testPrimitiveBean() {
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testBoxedBean() {
@ParameterizedTest
@EnumSource(TruncateMode.class)
void testBoxedBean(final TruncateMode mode) {
// Given
List<BoxedBean> dataSetOriginal =
singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true));
Expand All @@ -79,7 +91,12 @@ void testBoxedBean() {
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);

Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
dataset.write().format("mongodb").mode("Overwrite").save();
dataset
.write()
.format("mongodb")
.mode("Overwrite")
.option(WriteConfig.TRUNCATE_MODE_CONFIG, mode.name())
.save();

// Then
List<BoxedBean> dataSetMongo = spark
Expand All @@ -92,24 +109,29 @@ void testBoxedBean() {
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testDateTimeBean() {
@ParameterizedTest()
@ValueSource(strings = {"true", "false"})
void testDateTimeBean(final String java8DateTimeAPI) {
assumeTrue(TIMESTAMP_NTZ_TYPE != null);
TimeZone original = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone(ZoneOffset.UTC));
TimeZone.setDefault(TimeZone.getTimeZone(UTC));

// Given
long oneHour = TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS);
long oneDay = oneHour * 24;

Instant epoch = Instant.EPOCH;
List<DateTimeBean> dataSetOriginal = singletonList(new DateTimeBean(
new Date(oneDay * 365),
new Timestamp(oneDay + oneHour),
LocalDate.of(2000, 1, 1),
Instant.EPOCH));
epoch,
LocalDateTime.ofInstant(epoch, UTC)));

// when
SparkSession spark = getOrCreateSparkSession();
SparkSession spark = getOrCreateSparkSession(
getSparkConf().set("spark.sql.datetime.java8API.enabled", java8DateTimeAPI));
Encoder<DateTimeBean> encoder = Encoders.bean(DateTimeBean.class);

Dataset<DateTimeBean> dataset = spark.createDataset(dataSetOriginal, encoder);
Expand Down Expand Up @@ -162,4 +184,32 @@ void testComplexBean() {
.collectAsList();
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testCatalogAccessAndDelete() {
List<BoxedBean> dataSetOriginal = asList(
new BoxedBean((byte) 1, (short) 2, 0, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 1, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 2, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, false),
new BoxedBean((byte) 1, (short) 2, 4, 4L, 5.0f, 6.0, false),
new BoxedBean((byte) 1, (short) 2, 5, 4L, 5.0f, 6.0, false));

SparkSession spark = getOrCreateSparkSession();
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
spark
.createDataset(dataSetOriginal, encoder)
.write()
.format("mongodb")
.mode("Overwrite")
.save();

String tableName = CATALOG + "." + HELPER.getDatabaseName() + "." + HELPER.getCollectionName();
List<Row> rows = spark.sql("select * from " + tableName).collectAsList();
assertEquals(6, rows.size());

spark.sql("delete from " + tableName + " where booleanField = false and intField > 3");
rows = spark.sql("select * from " + tableName).collectAsList();
assertEquals(4, rows.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.Objects;

public class DateTimeBean implements Serializable {
private java.sql.Date sqlDate;
private java.sql.Timestamp sqlTimestamp;
private java.time.LocalDate localDate;
private java.time.Instant instant;
private java.time.LocalDateTime localDateTime;

public DateTimeBean() {}

public DateTimeBean(
final Date sqlDate,
final Timestamp sqlTimestamp,
final LocalDate localDate,
final Instant instant) {
final Instant instant,
final LocalDateTime localDateTime) {
this.sqlDate = sqlDate;
this.sqlTimestamp = sqlTimestamp;
this.localDate = localDate;
this.localDateTime = localDateTime;
this.instant = instant;
}

Expand Down Expand Up @@ -66,6 +70,14 @@ public void setLocalDate(final LocalDate localDate) {
this.localDate = localDate;
}

public LocalDateTime getLocalDateTime() {
return localDateTime;
}

public void setLocalDateTime(final LocalDateTime localDateTime) {
this.localDateTime = localDateTime;
}

public Instant getInstant() {
return instant;
}
Expand All @@ -86,20 +98,28 @@ public boolean equals(final Object o) {
return Objects.equals(sqlDate, that.sqlDate)
&& Objects.equals(sqlTimestamp, that.sqlTimestamp)
&& Objects.equals(localDate, that.localDate)
&& Objects.equals(localDateTime, that.localDateTime)
&& Objects.equals(instant, that.instant);
}

@Override
public int hashCode() {
return Objects.hash(sqlDate, sqlTimestamp, localDate, instant);
return Objects.hash(sqlDate, sqlTimestamp, localDate, localDateTime, instant);
}

@Override
public String toString() {
return "DateTimeBean{" + "sqlDate="
+ sqlDate + ", sqlTimestamp="
+ sqlTimestamp + ", localDate="
+ localDate + ", instant="
+ instant + '}';
return "DateTimeBean{"
+ "sqlDate="
+ sqlDate
+ ", sqlTimestamp="
+ sqlTimestamp
+ ", localDate="
+ localDate
+ ", localDateTime="
+ localDateTime
+ ", instant="
+ instant
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.model.Updates;
import com.mongodb.connection.ClusterType;
import com.mongodb.spark.sql.connector.MongoCatalog;
import com.mongodb.spark.sql.connector.config.MongoConfig;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -62,6 +63,7 @@ public class MongoSparkConnectorHelper
"{_id: '%s', pk: '%s', dups: '%s', i: %d, s: '%s'}";
private static final String COMPLEX_SAMPLE_DATA_TEMPLATE =
"{_id: '%s', nested: {pk: '%s', dups: '%s', i: %d}, s: '%s'}";
public static final String CATALOG = "mongo_catalog";

private static final Logger LOGGER = LoggerFactory.getLogger(MongoSparkConnectorHelper.class);

Expand Down Expand Up @@ -146,6 +148,7 @@ public SparkConf getSparkConf() {
.set("spark.sql.streaming.checkpointLocation", getTempDirectory())
.set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true")
.set("spark.app.id", "MongoSparkConnector")
.set("spark.sql.catalog." + CATALOG, MongoCatalog.class.getCanonicalName())
.set(
MongoConfig.PREFIX + MongoConfig.CONNECTION_STRING_CONFIG,
getConnectionString().getConnectionString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ public boolean isAtLeastFiveDotZero() {
return getMaxWireVersion() >= 12;
}

public boolean isAtLeastSixDotZero() {
return getMaxWireVersion() >= 17;
}

public boolean isAtLeastSevenDotZero() {
return getMaxWireVersion() >= 21;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.ChangeStreamPreAndPostImagesOptions;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.InsertManyOptions;
import com.mongodb.client.model.Updates;
Expand All @@ -54,6 +56,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeoutException;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -304,6 +307,60 @@ void testStreamWithPublishFullDocumentOnly(final String collectionsConfigModeStr
msg)));
}

@Test
void testStreamFullDocumentBeforeChange() {
assumeTrue(supportsChangeStreams());
assumeTrue(isAtLeastSixDotZero());

CollectionsConfig.Type collectionsConfigType = CollectionsConfig.Type.SINGLE;
testIdentifier = computeTestIdentifier("FullDocBeforeChange", collectionsConfigType);

testStreamingQuery(
createMongoConfig(collectionsConfigType)
.withOption(
ReadConfig.READ_PREFIX
+ ReadConfig.STREAM_LOOKUP_FULL_DOCUMENT_BEFORE_CHANGE_CONFIG,
"required"),
DOCUMENT_BEFORE_CHANGE_SCHEMA,
withSourceDb(
"Create the collection",
(msg, db) -> db.createCollection(
collectionName(),
new CreateCollectionOptions()
.changeStreamPreAndPostImagesOptions(
new ChangeStreamPreAndPostImagesOptions(true)))),
withSource("inserting 0-25", (msg, coll) -> coll.insertMany(createDocuments(0, 25))),
withMemorySink("Expected to see 25 documents", (msg, ds) -> {
List<Row> rows = ds.collectAsList();
assertEquals(25, rows.size(), msg);
assertTrue(
rows.stream()
.map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange")))
.allMatch(Objects::isNull),
msg);
}),
withSource(
"Updating all",
(msg, coll) ->
coll.updateMany(new BsonDocument(), Updates.set("a", new BsonString("a")))),
withMemorySink(
"Expecting to see 50 documents and the last 25 have fullDocumentBeforeChange",
(msg, ds) -> {
List<Row> rows = ds.collectAsList();
assertEquals(50, rows.size());
assertTrue(
rows.subList(0, 24).stream()
.map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange")))
.allMatch(Objects::isNull),
msg);
assertTrue(
rows.subList(25, 50).stream()
.map(r -> r.getString(r.fieldIndex("fullDocumentBeforeChange")))
.noneMatch(Objects::isNull),
msg);
}));
}

@ParameterizedTest
@ValueSource(strings = {"SINGLE", "MULTIPLE", "ALL"})
void testStreamPublishFullDocumentOnlyHandlesCollectionDrop(
Expand Down Expand Up @@ -707,6 +764,11 @@ void testReadsWithParseMode() {
createStructField("clusterTime", DataTypes.StringType, false),
createStructField("fullDocument", DataTypes.StringType, true)));

private static final StructType DOCUMENT_BEFORE_CHANGE_SCHEMA = createStructType(asList(
createStructField("operationType", DataTypes.StringType, false),
createStructField("clusterTime", DataTypes.StringType, false),
createStructField("fullDocumentBeforeChange", DataTypes.StringType, true)));

@SafeVarargs
private final void testStreamingQuery(
final MongoConfig mongoConfig,
Expand Down
Loading