Skip to content

Commit

Permalink
Avoid code duplication in Merge DataFrame method (#5657)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Smirnov committed Jun 2, 2021
1 parent f7658b2 commit dc8daf0
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 193 deletions.
284 changes: 113 additions & 171 deletions src/Microsoft.Data.Analysis/DataFrame.Join.cs
Expand Up @@ -142,220 +142,162 @@ public DataFrame Join(DataFrame other, string leftSuffix = "_left", string right
}

// TODO: Merge API with an "On" parameter that merges on a column common to 2 dataframes

/// <summary>
/// Merge DataFrames with a database style join
/// </summary>
/// <param name="other"></param>
/// <param name="leftJoinColumn"></param>
/// <param name="rightJoinColumn"></param>
/// <param name="leftSuffix"></param>
/// <param name="rightSuffix"></param>
/// <param name="joinAlgorithm"></param>
/// <returns></returns>
public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string rightJoinColumn, string leftSuffix = "_left", string rightSuffix = "_right", JoinAlgorithm joinAlgorithm = JoinAlgorithm.Left)

private static Dictionary<TKey, long> Merge<TKey>(DataFrame retainedDataFrame, DataFrame supplementaryDataFame, string retainedJoinColumnName, string supplemetaryJoinColumnName, out PrimitiveDataFrameColumn<long> retainedRowIndices, out PrimitiveDataFrameColumn<long> supplementaryRowIndices, bool isInner = false, bool calculateIntersection = false)
{
// A simple hash join
DataFrame ret = new DataFrame();
DataFrame leftDataFrame = this;
DataFrame rightDataFrame = other;
Dictionary<TKey, long> intersection = calculateIntersection ? new Dictionary<TKey, long>(EqualityComparer<TKey>.Default) : null;

// The final table size is not known until runtime
long rowNumber = 0;
PrimitiveDataFrameColumn<long> leftRowIndices = new PrimitiveDataFrameColumn<long>("LeftIndices");
PrimitiveDataFrameColumn<long> rightRowIndices = new PrimitiveDataFrameColumn<long>("RightIndices");
if (joinAlgorithm == JoinAlgorithm.Left)
{
// First hash other dataframe on the rightJoinColumn
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
retainedRowIndices = new PrimitiveDataFrameColumn<long>("RetainedIndices");
supplementaryRowIndices = new PrimitiveDataFrameColumn<long>("SupplementaryIndices");

// First hash supplementary dataframe
DataFrameColumn supplementaryColumn = supplementaryDataFame.Columns[supplemetaryJoinColumnName];
Dictionary<TKey, ICollection<long>> multimap = supplementaryColumn.GroupColumnValues<TKey>(out HashSet<long> supplementaryColumnNullIndices);

// Go over the records in this dataframe and match with the dictionary
DataFrameColumn thisColumn = Columns[leftJoinColumn];
// Go over the records in this dataframe and match with the dictionary
DataFrameColumn retainedColumn = retainedDataFrame.Columns[retainedJoinColumnName];

for (long i = 0; i < thisColumn.Length; i++)
for (long i = 0; i < retainedColumn.Length; i++)
{
var retainedValue = retainedColumn[i];
if (retainedValue != null)
{
var thisColumnValue = thisColumn[i];
if (thisColumnValue != null)
//Get all rows from supplementary dataframe that sutisfy JOIN condition
if (multimap.TryGetValue((TKey)retainedValue, out ICollection<long> rowIndices))
{
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
foreach (long rowIndex in rowIndices)
{
foreach (long row in rowNumbers)
retainedRowIndices.Append(i);
supplementaryRowIndices.Append(rowIndex);

//store intersection if required
if (calculateIntersection)
{
leftRowIndices.Append(i);
rightRowIndices.Append(row);
if (!intersection.ContainsKey((TKey)retainedValue))
{
intersection.Add((TKey)retainedValue, rowIndex);
}
}
}
else
{
leftRowIndices.Append(i);
rightRowIndices.Append(null);
}
}
else
{
foreach (long row in otherColumnNullIndices)
{
leftRowIndices.Append(i);
rightRowIndices.Append(row);
}
if (isInner)
continue;

retainedRowIndices.Append(i);
supplementaryRowIndices.Append(null);
}
}
}
else if (joinAlgorithm == JoinAlgorithm.Right)
{
DataFrameColumn thisColumn = Columns[leftJoinColumn];
Dictionary<TKey, ICollection<long>> multimap = thisColumn.GroupColumnValues<TKey>(out HashSet<long> thisColumnNullIndices);

DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
for (long i = 0; i < otherColumn.Length; i++)
else
{
var otherColumnValue = otherColumn[i];
if (otherColumnValue != null)
{
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
{
foreach (long row in rowNumbers)
{
leftRowIndices.Append(row);
rightRowIndices.Append(i);
}
}
else
{
leftRowIndices.Append(null);
rightRowIndices.Append(i);
}
}
else
foreach (long row in supplementaryColumnNullIndices)
{
foreach (long thisColumnNullIndex in thisColumnNullIndices)
{
leftRowIndices.Append(thisColumnNullIndex);
rightRowIndices.Append(i);
}
retainedRowIndices.Append(i);
supplementaryRowIndices.Append(row);
}
}
}

return intersection;
}


/// <summary>
/// Merge DataFrames with a database style join
/// </summary>
/// <param name="other"></param>
/// <param name="leftJoinColumn"></param>
/// <param name="rightJoinColumn"></param>
/// <param name="leftSuffix"></param>
/// <param name="rightSuffix"></param>
/// <param name="joinAlgorithm"></param>
/// <returns></returns>
public DataFrame Merge<TKey>(DataFrame other, string leftJoinColumn, string rightJoinColumn, string leftSuffix = "_left", string rightSuffix = "_right", JoinAlgorithm joinAlgorithm = JoinAlgorithm.Left)
{
//In Outer join the joined dataframe retains each row — even if no other matching row exists in supplementary dataframe.
//Outer joins subdivide further into left outer joins (left dataframe is retained), right outer joins (rightdataframe is retained), in full outer both are retained

PrimitiveDataFrameColumn<long> retainedRowIndices;
PrimitiveDataFrameColumn<long> supplementaryRowIndices;
DataFrame supplementaryDataFrame;
DataFrame retainedDataFrame;
bool isLeftDataFrameRetained;

if (joinAlgorithm == JoinAlgorithm.Left || joinAlgorithm == JoinAlgorithm.Right)
{
isLeftDataFrameRetained = (joinAlgorithm == JoinAlgorithm.Left);

supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;

retainedDataFrame = isLeftDataFrameRetained ? this : other;
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;

Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices);

}
else if (joinAlgorithm == JoinAlgorithm.Inner)
{
// Hash the column with the smaller RowCount
long leftRowCount = Rows.Count;
long rightRowCount = other.Rows.Count;
// use as supplementary (for Hashing) the dataframe with the smaller RowCount
isLeftDataFrameRetained = (Rows.Count > other.Rows.Count);

bool leftColumnIsSmaller = leftRowCount <= rightRowCount;
DataFrameColumn hashColumn = leftColumnIsSmaller ? Columns[leftJoinColumn] : other.Columns[rightJoinColumn];
DataFrameColumn otherColumn = ReferenceEquals(hashColumn, Columns[leftJoinColumn]) ? other.Columns[rightJoinColumn] : Columns[leftJoinColumn];
Dictionary<TKey, ICollection<long>> multimap = hashColumn.GroupColumnValues<TKey>(out HashSet<long> smallerDataFrameColumnNullIndices);
supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;

for (long i = 0; i < otherColumn.Length; i++)
{
var otherColumnValue = otherColumn[i];
if (otherColumnValue != null)
{
if (multimap.TryGetValue((TKey)otherColumnValue, out ICollection<long> rowNumbers))
{
foreach (long row in rowNumbers)
{
leftRowIndices.Append(leftColumnIsSmaller ? row : i);
rightRowIndices.Append(leftColumnIsSmaller ? i : row);
}
}
}
else
{
foreach (long nullIndex in smallerDataFrameColumnNullIndices)
{
leftRowIndices.Append(leftColumnIsSmaller ? nullIndex : i);
rightRowIndices.Append(leftColumnIsSmaller ? i : nullIndex);
}
}
}
retainedDataFrame = isLeftDataFrameRetained ? this : other;
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;

Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices, true);
}
else if (joinAlgorithm == JoinAlgorithm.FullOuter)
{
DataFrameColumn otherColumn = other.Columns[rightJoinColumn];
Dictionary<TKey, ICollection<long>> multimap = otherColumn.GroupColumnValues<TKey>(out HashSet<long> otherColumnNullIndices);
Dictionary<TKey, long> intersection = new Dictionary<TKey, long>(EqualityComparer<TKey>.Default);
//In full outer join we would like to retain data from both side, so we do it into 2 steps: one first we do LEFT JOIN and then add lost data from the RIGHT side

//Step 1
//Do LEFT JOIN
isLeftDataFrameRetained = true;

// Go over the records in this dataframe and match with the dictionary
DataFrameColumn thisColumn = Columns[leftJoinColumn];
Int64DataFrameColumn thisColumnNullIndices = new Int64DataFrameColumn("ThisColumnNullIndices");
supplementaryDataFrame = isLeftDataFrameRetained ? other : this;
var supplementaryJoinColumn = isLeftDataFrameRetained ? rightJoinColumn : leftJoinColumn;

for (long i = 0; i < thisColumn.Length; i++)
{
var thisColumnValue = thisColumn[i];
if (thisColumnValue != null)
{
if (multimap.TryGetValue((TKey)thisColumnValue, out ICollection<long> rowNumbers))
{
foreach (long row in rowNumbers)
{
leftRowIndices.Append(i);
rightRowIndices.Append(row);
if (!intersection.ContainsKey((TKey)thisColumnValue))
{
intersection.Add((TKey)thisColumnValue, rowNumber);
}
}
}
else
{
leftRowIndices.Append(i);
rightRowIndices.Append(null);
}
}
else
{
thisColumnNullIndices.Append(i);
}
}
for (long i = 0; i < otherColumn.Length; i++)
retainedDataFrame = isLeftDataFrameRetained ? this : other;
var retainedJoinColumn = isLeftDataFrameRetained ? leftJoinColumn : rightJoinColumn;

var intersection = Merge<TKey>(retainedDataFrame, supplementaryDataFrame, retainedJoinColumn, supplementaryJoinColumn, out retainedRowIndices, out supplementaryRowIndices, calculateIntersection: true);

//Step 2
//Do RIGHT JOIN to retain all data from supplementary DataFrame too (take into account data intersection from the first step to avoid duplicates)
DataFrameColumn supplementaryColumn = supplementaryDataFrame.Columns[supplementaryJoinColumn];

for (long i = 0; i < supplementaryColumn.Length; i++)
{
var value = otherColumn[i];
var value = supplementaryColumn[i];
if (value != null)
{
if (!intersection.ContainsKey((TKey)value))
{
leftRowIndices.Append(null);
rightRowIndices.Append(i);
retainedRowIndices.Append(null);
supplementaryRowIndices.Append(i);
}
}
}

// Now handle the null rows
foreach (long? thisColumnNullIndex in thisColumnNullIndices)
{
foreach (long otherColumnNullIndex in otherColumnNullIndices)
{
leftRowIndices.Append(thisColumnNullIndex.Value);
rightRowIndices.Append(otherColumnNullIndex);
}
if (otherColumnNullIndices.Count == 0)
{
leftRowIndices.Append(thisColumnNullIndex.Value);
rightRowIndices.Append(null);
}
}
if (thisColumnNullIndices.Length == 0)
{
foreach (long otherColumnNullIndex in otherColumnNullIndices)
{
leftRowIndices.Append(null);
rightRowIndices.Append(otherColumnNullIndex);
}
}
}
else
throw new NotImplementedException(nameof(joinAlgorithm));

for (int i = 0; i < leftDataFrame.Columns.Count; i++)

DataFrame ret = new DataFrame();

//insert columns from left dataframe (this)
for (int i = 0; i < this.Columns.Count; i++)
{
ret.Columns.Insert(i, leftDataFrame.Columns[i].Clone(leftRowIndices));
ret.Columns.Insert(i, this.Columns[i].Clone(isLeftDataFrameRetained ? retainedRowIndices : supplementaryRowIndices));
}
for (int i = 0; i < rightDataFrame.Columns.Count; i++)

//insert columns from right dataframe (other)
for (int i = 0; i < other.Columns.Count; i++)
{
DataFrameColumn column = rightDataFrame.Columns[i].Clone(rightRowIndices);
DataFrameColumn column = other.Columns[i].Clone(isLeftDataFrameRetained ? supplementaryRowIndices : retainedRowIndices);
SetSuffixForDuplicatedColumnNames(ret, column, leftSuffix, rightSuffix);
ret.Columns.Insert(ret.Columns.Count, column);
}
Expand Down

0 comments on commit dc8daf0

Please sign in to comment.