# Fast tensor contraction

## Block-sparse tensor construction

In [1]:
using ITensors: BlockSparseTensor, Block, contract

In [2]:
# N = 3
# 1st dimension => 3 blocks of sizes 2, 2, 3
# 2nd dimension => 2 blocks of sizes 4, 3
# 3rd dimension => 2 blocks of sizes 3, 4
bst_dims = ([2, 2, 3], [4, 3], [3, 4])
# Multi-indices of two non-vanishing blocks
bst_blocks = [(1, 1, 1), (3, 2, 2)]

# Construct a block-sparse tensor with zero-initialized memory
bst = BlockSparseTensor{ComplexF64}(bst_blocks, bst_dims...)

Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 3] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 2] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 3] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im

[:, :, 4] =
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.0im
 0.0 + 0.0im  0.0 + 0.0im  0.0 + 0.

## Access to individual blocks

In [3]:
bst[Block(1, 1, 1)] = 2.0;
bst[Block(3, 2, 2)] = 3.0;

@show bst

bst = Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 2] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 3] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 2] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 3] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 4] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.

Dim 1: [2, 2, 3]
Dim 2: [4, 3]
Dim 3: [3, 4]
NDTensors.BlockSparse{ComplexF64, Vector{ComplexF64}, 3}
 7×7×7
Block(1, 1, 1)
 [1:2, 1:4, 1:3]
[:, :, 1] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 2] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

[:, :, 3] =
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im
 2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im  2.0 + 0.0im

Block(3, 2, 2)
 [5:7, 5:7, 4:7]
[:, :, 1] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 2] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 3] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im

[:, :, 4] =
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.0im
 3.0 + 0.0im  3.0 + 0.0im  3.0 + 0.

## Contraction of a chain tensor network ($n = 3$)

In [4]:
# Propagators: 2 dimension, 3 square blocks per each dimension
P_block_dims = ([2, 3, 4], [2, 3, 4])
P_dims = sum.(P_block_dims)
# The only non-zero blocks are diagonal
P_blocks = [(i, i) for i in 1:length(first(P_block_dims))]

P = [BlockSparseTensor{ComplexF64}(P_blocks, P_block_dims...) for i in 1:5];

In [5]:
# Test contraction of propagators alone

using Random: MersenneTwister, rand
rng = MersenneTwister(12345678)

for i in 1:5
    for b in 1:length(first(P_block_dims))
        P[i][Block(b, b)] = rand(rng, P_block_dims[1][b], P_block_dims[2][b])
    end
end

# Labels are attached to propagators as follows:
# P5_{1,2} P4_{2,3} P3_{3,4} P2_{4,5} P1_{5,6}

# Test contraction order: (P5 * (P4 * P3)) * (P2 * P1)
R43 = contract(P[4], (2, 3), P[3], (3, 4), (2, 4));
R543 = contract(P[5], (1, 2), R43, (2, 4), (1, 4));
R21 = contract(P[2], (4, 5), P[1], (5, 6), (4, 6));
R = contract(R543, (1, 4), R21, (4, 6), (1, 6))

for b in 1:length(first(P_block_dims))
    bl = Block(b, b)
    R_mat = convert(Matrix{ComplexF64}, R[bl])
    R_mat_ref = convert(Matrix{ComplexF64}, P[5][bl]) *
                convert(Matrix{ComplexF64}, P[4][bl]) *
                convert(Matrix{ComplexF64}, P[3][bl]) *
                convert(Matrix{ComplexF64}, P[2][bl]) *
                convert(Matrix{ComplexF64}, P[1][bl])
    @assert isapprox(R_mat, R_mat_ref)
end

In [63]:
# Interaction lines
N_int = 6 # Number of pair interactions
Δ = BlockSparseTensor{ComplexF64}([(1,)], [N_int])

# Interaction vertices: 3 dimensions
# There is always only one block along the 1st dimenstion (interaction index)
O_block_dims = ([N_int], P_block_dims...)
O_blocks = sort([(1, 1, 2),
                 (1, 2, 1),
                 (1, 2, 3),
                 (1, 3, 2),
                 (1, 1, 3),
                 (1, 3, 1)])

O = BlockSparseTensor{ComplexF64}(O_blocks, O_block_dims...);

In [54]:
"""
Cost of a single edge contraction.
"""
function edge_cost(T1::BlockSparseTensor, axis1::Int, T2::BlockSparseTensor, axis2)
    cost = 0
    # TODO: Optimize assuming that T1 and T2 are constructed from sorted lists of Blocks
    for bl1 in nzblocks(T1)
        for bl2 in nzblocks(T2)
            if bl1[axis1] == bl2[axis2]
                @assert size(T1[bl1], axis1) == size(T2[bl2], axis2)
                cost += length(T1[bl1]) * 
                        prod(d -> (d == axis2) ? 1 : size(T2[bl2], d), 1:length(bl2))
            end
        end
    end
    return cost
end

edge_cost(O, 3, P[5], 1)

972

In [55]:
# TODO: List all edges of the chain diagram
# TODO: Generate all contraction paths
# TODO: define cost() for a contraction path