# Fast tensor contraction

## Block-sparse tensor construction

In [1]:
using ITensors: BlockSparseTensor, Block, contract

In [2]:
# N = 3
# 1st dimension => 3 blocks of sizes 2, 2, 3
# 2nd dimension => 2 blocks of sizes 4, 3
# 3rd dimension => 2 blocks of sizes 3, 4
bst_dims = ([2, 2, 3], [4, 3], [3, 4])
# Multi-indices of two non-vanishing blocks
bst_blocks = [(1, 1, 1), (3, 2, 2)]

# Construct a block-sparse tensor with zero-initialized memory
bst = BlockSparseTensor{ComplexF64}(bst_blocks, bst_dims...)

Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 3] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 3] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 4] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.

## Access to individual blocks

In [3]:
bst[Block(1, 1, 1)] = 2.0;
bst[Block(3, 2, 2)] = 3.0;

@show bst

bst = Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 2] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 3] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 2] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 3] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 4] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.

Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 2] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 3] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 2] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 3] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 4] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.

## Contraction of a chain tensor network ($n = 3$)

In [4]:
# Propagators: 2 dimension, 3 square blocks per each dimension
P_block_dims = ([2, 3, 4], [2, 3, 4])
P_dims = sum.(P_block_dims)
# The only non-zero blocks are diagonal
P_blocks = [(i, i) for i in 1:length(first(P_block_dims))]

P = [BlockSparseTensor{ComplexF64}(P_blocks, P_block_dims...) for i in 1:5];

In [5]:
# Test contraction of propagators alone

using Random: MersenneTwister, rand
rng = MersenneTwister(12345678)

for i in 1:5
    for b in 1:length(first(P_block_dims))
        P[i][Block(b, b)] = rand(rng, P_block_dims[1][b], P_block_dims[2][b])
    end
end

# Labels are attached to propagators as follows:
# P5_{1,2} P4_{2,3} P3_{3,4} P2_{4,5} P1_{5,6}

# Test contraction order: (P5 * (P4 * P3)) * (P2 * P1)
R43 = contract(P[4], (2, 3), P[3], (3, 4), (2, 4));
R543 = contract(P[5], (1, 2), R43, (2, 4), (1, 4));
R21 = contract(P[2], (4, 5), P[1], (5, 6), (4, 6));
R = contract(R543, (1, 4), R21, (4, 6), (1, 6))

for b in 1:length(first(P_block_dims))
    bl = Block(b, b)
    R_mat = convert(Matrix{ComplexF64}, R[bl])
    R_mat_ref = convert(Matrix{ComplexF64}, P[5][bl]) *
                convert(Matrix{ComplexF64}, P[4][bl]) *
                convert(Matrix{ComplexF64}, P[3][bl]) *
                convert(Matrix{ComplexF64}, P[2][bl]) *
                convert(Matrix{ComplexF64}, P[1][bl])
    @assert isapprox(R_mat, R_mat_ref)
end

In [6]:
# Interaction lines
N_int = 6 # Number of pair interactions
Δ = BlockSparseTensor{ComplexF64}([(b, b) for b in 1:N_int], ones(Int, N_int), ones(Int, N_int))

Dim 1: [1, 1, 1, 1, 1, 1]
Dim 2: [1, 1, 1, 1, 1, 1]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 2}
 6×6
Block(1, 1)
 [1:1, 1:1]
 0.0 + 0.0im

Block(2, 2)
 [2:2, 2:2]
 0.0 + 0.0im

Block(3, 3)
 [3:3, 3:3]
 0.0 + 0.0im

Block(4, 4)
 [4:4, 4:4]
 0.0 + 0.0im

Block(5, 5)
 [5:5, 5:5]
 0.0 + 0.0im

Block(6, 6)
 [6:6, 6:6]
 0.0 + 0.0im

In [7]:
# Interaction vertices: 3 dimensions
# There is always only one block along the 1st dimenstion (interaction index)
O_block_dims = ([N_int], P_block_dims...)
O_blocks = sort([(1, 1, 2),
                 (1, 2, 1),
                 (1, 2, 3),
                 (1, 3, 2),
                 (1, 1, 3),
                 (1, 3, 1)])

O = BlockSparseTensor{ComplexF64}(O_blocks, O_block_dims...);

In [8]:
using ITensors.NDTensors: inds,
                          nzblocks,
                          contract_inds,
                          contract_labels,
                          contract_blocks,
                          are_blocks_contracted,
                          ValLength

"""
Cost of a single tensor pair contraction.
"""
function contraction_cost(T1::BlockSparseTensor,
                          labels1,
                          T2::BlockSparseTensor,
                          labels2,
                          labelsR)
    labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(labels1, labels2, labelsR)
    blocksR = Block[]
    cost = 0
    for block1 in nzblocks(T1)
        for block2 in nzblocks(T2)
            if are_blocks_contracted(block1, block2, labels1_to_labels2)
                cost += length(T1[block1]) * prod(size(T2[block2], a2) for (a2, aR) in enumerate(labels2_to_labelsR) if aR != 0)
                push!(blocksR, contract_blocks(block1, labels1_to_labelsR, block2, labels2_to_labelsR, ValLength(labelsR)))
            end
        end
    end
    return cost, blocksR
end

# Test
T1 = BlockSparseTensor{ComplexF64}([(1, 1, 2, 3, 1),
                                    (2, 2, 3, 2, 1),
                                    (3, 2, 1, 2, 1)],
                                    [2, 3, 4], [3, 2, 3], [4, 5, 2], [5, 1, 3], [6, 2, 7]);
T2 = BlockSparseTensor{ComplexF64}([(1, 1, 1, 2, 2),
                                    (2, 3, 1, 2, 1),
                                    (1, 2, 3, 1, 1)],
                                    [10, 2, 3], [5, 1, 3], [2, 3, 4], [7, 4, 1], [8, 2, 2]);

labels1 = (1, 2, 3, 4, 5)
labels2 = (6, 4, 1, 9, 10)
labelsR = (2, 3, 5, 6, 10, 9)

R = contract(T1, labels1, T2, labels2, labelsR)
# Block structure of the contraction result
@show contract_inds(inds(T1), labels1, inds(T2), labels2, labelsR)

cost, blocksR = contraction_cost(T1, labels1, T2, labels2, labelsR)

@assert cost == (2*3*5*3*6)*(2*4*8) + # Block(1, 1, 2, 3, 1) ∘ Block(2, 3, 1, 2, 1)
                (4*2*4*1*6)*(10*7*8)  # Block(3, 2, 1, 2, 1) ∘ Block(1, 2, 3, 1, 1)
@assert blocksR == nzblocks(R)

contract_inds(inds(T1), labels1, inds(T2), labels2, labelsR) = ([3, 2, 3], [4, 5, 2], [6, 2, 7], [10, 2, 3], [8, 2, 2], [7, 4, 1])


In [10]:
contraction_cost(O, (1, 2, 3), P[5], (4, 3), (1, 2, 4))

(972, Block[Block(1, 1, 2), Block(1, 1, 3), Block(1, 2, 1), Block(1, 2, 3), Block(1, 3, 1), Block(1, 3, 2)])

In [11]:
# TODO: List all edges of the chain diagram
# TODO: Generate all contraction paths
# TODO: define cost() for a contraction path