In [1]:
using PCT

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling PCT [ef708a43-c8a3-43f4-8f65-1f04ee4c5bb0]


Define the function in a functional DSL.

In [2]:
f, _ = @pct (x::RV) -> ∑(i, x(i))
f

(x) -> 
    ∑((i), x(i))

Decompose into primitives through *fibration*.

In [3]:
cf = decompose(f)

x: ∑i ◀ {i ⇥ x: ℳ x ◀ ->i}


Get the (code for the) pullback through *partial chain*.

In [4]:
df = pp(cf)

(_z_0, _k_0) -> 
    (_i_0) -> 
        ∑((i), δ(i, _i_0, _k_0))

Get the gradient by setting $k_0 = 1$.

In [5]:
df = df |> eval_all |> propagate_k

(_i_0) -> 
    ∑((i), δ(i, _i_0, 1))

Simplify using compiler optimization (equivalence graph).

In [6]:
simplify(df) |> first

  0.013353 seconds (25.59 k allocations: 1.682 MiB, 99.66% compilation time)
1 contract_delta
∑((i), δ(i, _i_0, 1))
1
-->

  0.000000 seconds


(_i_0) -> 
    1

## Example: Hartree Fock

Variables have types that encodes the *dimension* and the *symmetries* of the tensor.

Constants are introduced as variables in the outer scope.

In [7]:
f, ctx = @pct begin    
    @space T begin
        type = (I, I, I, I) -> C
        symmetries = (((2, 1, 4, 3), :conj), ((3, 4, 1, 2), :id))
    end
    (J::T) -> _
end
f

(J) -> 
    _

The ERI contraction.

In [8]:
g = fc(@pct f ctx (C::CM) -> sum((i, j, p, q, r, s),
    C(p, i)' * C(q, i) * C(r, j)' * C(s, j) * J(p, q, r, s))) 

(C) -> 
    ∑((i, j, p, q, r, s), (J(p, q, r, s)⋅C(s, j)⋅C(q, i)⋅C(p, i)'⋅C(r, j)'))

Decompose into the fibers.  

In [9]:
cg = decompose(g)

C: ∑ijpqrs ◀ {i, j, p, q, r, s ⇥ C: *(J(p, q, r, s)⋅C(q, i)⋅C(p, i)'⋅C(r, j)') ◀ ℳ C ◀ C: ->s|C: ->j}


In [10]:
dg = pp(cg)

(_z_0, _k_0) -> 
    (_i_2, _i_7) -> 
        ∑((i, j, p, q, r, s), (δ(j, _i_7, δ(s, _i_2, (C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅C(q, i)'⋅_k_0)))+∑((_i_3), (δ(i, _i_7, δ(q, _i_2, (δ(C(s, j), _i_3, _k_0)⋅C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅_i_3')))+∑((_i_0), (δ(i, _i_7, δ(p, _i_2, (δ(_i_0, C(q, i), δ(_i_3, C(s, j), _k_0'))⋅J(p, q, r, s)⋅C(r, j)'⋅_i_0⋅_i_3)))+∑((_i_4), δ(j, _i_7, δ(r, _i_2, (δ(_i_4, C(p, i)', δ(_i_0, C(q, i), δ(_i_3, C(s, j), _k_0')))⋅J(p, q, r, s)⋅_i_0⋅_i_3⋅_i_4))))))))))

Take the pullback. This really needs simplification

Get the gradient by setting $k = 1$.

In [11]:
dg = dg |> eval_all |> propagate_k

(_i_2, _i_7) -> 
    ∑((i, j, p, q, r, s), (δ(j, _i_7, δ(s, _i_2, (C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅C(q, i)')))+∑((_i_3), (δ(i, _i_7, δ(q, _i_2, (δ(C(s, j), _i_3, 1)⋅C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅_i_3')))+∑((_i_0), (δ(i, _i_7, δ(p, _i_2, (δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1))⋅J(p, q, r, s)⋅C(r, j)'⋅_i_0⋅_i_3)))+∑((_i_4), δ(j, _i_7, δ(r, _i_2, (δ(_i_4, C(p, i)', δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1)))⋅J(p, q, r, s)⋅_i_0⋅_i_3⋅_i_4))))))))))

In [12]:
dg = simplify(dg) |> first
dg

  0.334155 seconds (736.09 k allocations: 39.812 MiB, 77.02% compilation time)
1 sum_dist
∑((i, j, p, q, r, s), (δ(j, _i_7, δ(s, _i_2, (C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅C(q, i)')))+∑((_i_3), (δ(i, _i_7, δ(q, _i_2, (δ(C(s, j), _i_3, 1)⋅C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅_i_3')))+∑((_i_0), (δ(i, _i_7, δ(p, _i_2, (δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1))⋅J(p, q, r, s)⋅C(r, j)'⋅_i_0⋅_i_3)))+∑((_i_4), δ(j, _i_7, δ(r, _i_2, (δ(_i_4, C(p, i)', δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1)))⋅J(p, q, r, s)⋅_i_0⋅_i_3⋅_i_4))))))))))
(∑((i, j, p, q, r, s), δ(j, _i_7, δ(s, _i_2, (C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅C(q, i)'))))+∑((i, j, p, q, r, s, _i_3), (δ(i, _i_7, δ(q, _i_2, (δ(C(s, j), _i_3, 1)⋅C(p, i)⋅C(r, j)⋅J(p, q, r, s)'⋅_i_3')))+∑((_i_0), (δ(i, _i_7, δ(p, _i_2, (δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1))⋅J(p, q, r, s)⋅C(r, j)'⋅_i_0⋅_i_3)))+∑((_i_4), δ(j, _i_7, δ(r, _i_2, (δ(_i_4, C(p, i)', δ(_i_0, C(q, i), δ(_i_3, C(s, j), 1)))⋅J(p, q, r, s)⋅_i_0⋅_i_3⋅_i_4)))))))))
-->

  0.034677 seconds (195.48 k allocations: 8.229

(_i_2, _i_7) -> 
    (∑((i, p, q, s), (J(p, q, _i_2, s)⋅C(s, _i_7)⋅C(q, i)⋅C(p, i)'))+∑((i, p, q, r), (C(p, i)⋅C(r, _i_7)⋅J(p, q, r, _i_2)'⋅C(q, i)'))+∑((j, q, r, s), (J(_i_2, q, r, s)⋅C(s, j)⋅C(q, _i_7)⋅C(r, j)'))+∑((j, p, r, s), (C(p, _i_7)⋅C(r, j)⋅J(p, _i_2, r, s)'⋅C(s, j)')))

In [13]:
dg = simplify(dg; settings=symmetry_settings) |> first
dg

  0.000009 seconds (25 allocations: 976 bytes)
1 symmetry
(∑((i, p, q, s), (J(p, q, _i_2, s)⋅C(s, _i_7)⋅C(q, i)⋅C(p, i)'))+∑((i, p, q, r), (C(p, i)⋅C(r, _i_7)⋅J(p, q, r, _i_2)'⋅C(q, i)'))+∑((j, q, r, s), (J(_i_2, q, r, s)⋅C(s, j)⋅C(q, _i_7)⋅C(r, j)'))+∑((j, p, r, s), (C(p, _i_7)⋅C(r, j)⋅J(p, _i_2, r, s)'⋅C(s, j)')))
((2.0⋅∑((i, p, q, r), (C(p, i)⋅C(r, _i_7)⋅J(p, q, r, _i_2)'⋅C(q, i)')))+∑((j, q, r, s), (J(_i_2, q, r, s)⋅C(s, j)⋅C(q, _i_7)⋅C(r, j)'))+∑((j, p, r, s), (C(p, _i_7)⋅C(r, j)⋅J(p, _i_2, r, s)'⋅C(s, j)')))
<->

  0.011593 seconds (66.60 k allocations: 3.409 MiB, 76.81% compilation time)
2 sum_in
((2.0⋅∑((i, p, q, r), (C(p, i)⋅C(r, _i_7)⋅J(p, q, r, _i_2)'⋅C(q, i)')))+∑((j, q, r, s), (J(_i_2, q, r, s)⋅C(s, j)⋅C(q, _i_7)⋅C(r, j)'))+∑((j, p, r, s), (C(p, _i_7)⋅C(r, j)⋅J(p, _i_2, r, s)'⋅C(s, j)')))
(∑((i, p, q, r), (C(p, i)⋅C(r, _i_7)⋅J(p, q, r, _i_2)'⋅C(q, i)'⋅2.0))+∑((j, q, r, s), (J(_i_2, q, r, s)⋅C(s, j)⋅C(q, _i_7)⋅C(r, j)'))+∑((j, p, r, s), (C(p, _i_7)⋅C(r, j)⋅J(p, _i_2, r, s)'

(_i_2, _i_7) -> 
    ∑((i, p, q, r), (J(_i_2, r, q, p)⋅C(p, i)⋅C(r, _i_7)⋅C(q, i)'⋅4.0))

## Example: Maximally Localized Wannier Functions

A number of assumptions are needed for simplification.

- Periodic domain for the $k$ points.
- Symmetric domain for the $b$ vectors, which also should not be contracted.
- $M_{mn}$ matrices have symmetries.
- the weights $w$ are symmetric.

In [187]:
f, ctx = @pct begin

    @domain P begin
        base = I
        lower = -N
        upper = N - 1
        periodic = true
    end

    @domain Q begin
        base = I
        lower = -N
        upper = N
        contractable = false
    end

    @space Mmn begin
        type = (I, I, I, I) -> C
        symmetries = (((2, 1, 4, 3), :conj),)
    end

    @space Sym begin
        type = (I,) -> C
        symmetries = (((1,), :ineg),)
    end

    @space Gauge begin
        type = (I, I, I) -> C
    end

    @space Density begin
        type = (I, I) -> C
    end

    (S::Mmn, w::Sym) -> _
end
f

(S, w) -> 
    _

Maximize $\sum_{nb}|\hat{\rho}_n(\mathrm{b})|^2$ 

Known as $|X_n|^2 + |Y_n|^2 + |Z_n|^2$ in Gamma point calculation.

$U$ is the gauge and $S$ is Resta's polarization matrix.

In [188]:
g = fc(@pct f ctx (U::Gauge) -> ((ρ::Density) -> sum((n::I, b::Q), ρ(n, b)' * ρ(n, b)))(
    (n::I, b::Q) -> sum((k::P, p, q), U(p, n, k)' * S(p, q, k, k + b) * U(q, n, k + b))))

(U) -> 
    ((ρ) -> 
        ∑((n, b), (ρ(n, b)'⋅ρ(n, b))))((n, b) -> 
        ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))))

Decompose into the primitives.

In [189]:
cg = decompose(eval_all(g))

U: ∑nb ◀ {n, b ⇥ U: *∑((k, p, q), (U(p, n, k)⋅U(q, n, (b+k))'⋅S(p, q, k, (b+k))')) ◀ ∑kpq ◀ {k, p, q ⇥ U: *(U(q, n, (b+k))⋅S(p, q, k, (b+k))) ◀ † ◀ ℳ U ◀ U: ->p|U: ->n|U: ->k}}


Calculate the pullback.

In [190]:
dg = pp(cg)

(_z_0, _k_0) -> 
    (_i_1, _i_2, _i_7) -> 
        ∑((n, b), (∑((k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (_k_0'⋅∑((k, p, q), (U(q, n, (b+k))'⋅S(p, q, k, (b+k))'⋅U(p, n, k)))⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (δ(U(p, n, k)', _i_5, (_k_0⋅∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))⋅_i_5'⋅S(p, q, k, (b+k))')))))))+∑((_i_0, k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (δ(∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), _i_0, _k_0)⋅_i_0'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (_i_5⋅δ(_i_5, U(p, n, k), (_i_0⋅δ(_i_0, ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), _k_0')))⋅S(p, q, k, (b+k))')))))))))

Get the gradient.

In [191]:
dg = dg |> propagate_k

(_i_1, _i_2, _i_7) -> 
    ∑((n, b), (∑((k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (∑((k, p, q), (U(q, n, (b+k))'⋅S(p, q, k, (b+k))'⋅U(p, n, k)))⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (δ(U(p, n, k)', _i_5, ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))))⋅_i_5'⋅S(p, q, k, (b+k))')))))))+∑((_i_0, k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (δ(∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), _i_0, 1)⋅_i_0'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (_i_5⋅δ(_i_5, U(p, n, k), (_i_0⋅δ(_i_0, ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), 1)))⋅S(p, q, k, (b+k))')))))))))

In [192]:
dg = simplify(dg) |> first

  0.042001 seconds (225.52 k allocations: 8.873 MiB, 21.33% gc time)
1 sum_dist
∑((n, b), (∑((k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (∑((k, p, q), (U(q, n, (b+k))'⋅S(p, q, k, (b+k))'⋅U(p, n, k)))⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (δ(U(p, n, k)', _i_5, ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))))⋅_i_5'⋅S(p, q, k, (b+k))')))))))+∑((_i_0, k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (δ(∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), _i_0, 1)⋅_i_0'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (_i_5⋅δ(_i_5, U(p, n, k), (_i_0⋅δ(_i_0, ∑((k, p, q), (U(p, n, k)'⋅U(q, n, (b+k))⋅S(p, q, k, (b+k)))), 1)))⋅S(p, q, k, (b+k))')))))))))
(∑((n, b, k, p, q), (δ(k, _i_7, δ(n, _i_2, δ(p, _i_1, (∑((k, p, q), (U(q, n, (b+k))'⋅S(p, q, k, (b+k))'⋅U(p, n, k)))⋅U(q, n, (b+k))⋅S(p, q, k, (b+k))))))+∑((_i_5), δ((b+k), _i_7, δ(n, _i_2, δ(q, _i_1, (δ(U(p, n, k)', _i_5, ∑((k, p, q), (U(p, n, k)

(_i_1, _i_2, _i_7) -> 
    (∑((b, q, k, p, _i_0), (U(_i_0, _i_2, (b+k))'⋅S(p, _i_0, k, (b+k))'⋅2.0⋅U(q, _i_2, (_i_7+b))⋅U(p, _i_2, k)⋅S(_i_1, q, _i_7, (_i_7+b))))+∑((b, p, k, _i_0, q), (U(_i_0, _i_2, k)'⋅S(p, _i_1, ((b⋅-1.0)+_i_7), _i_7)'⋅2.0⋅U(q, _i_2, (b+k))⋅U(p, _i_2, ((b⋅-1.0)+_i_7))⋅S(_i_0, q, k, (b+k)))))

In [193]:
dg = simplify(dg; settings=symmetry_settings) |> first

  0.005332 seconds (83.50 k allocations: 2.405 MiB)
1 shift
(∑((b, q, k, p, _i_0), (U(_i_0, _i_2, (b+k))'⋅S(p, _i_0, k, (b+k))'⋅2.0⋅U(q, _i_2, (_i_7+b))⋅U(p, _i_2, k)⋅S(_i_1, q, _i_7, (_i_7+b))))+∑((b, p, k, _i_0, q), (U(_i_0, _i_2, k)'⋅S(p, _i_1, ((b⋅-1.0)+_i_7), _i_7)'⋅2.0⋅U(q, _i_2, (b+k))⋅U(p, _i_2, ((b⋅-1.0)+_i_7))⋅S(_i_0, q, k, (b+k)))))
(∑((b, q, k, p, _i_0), (U(_i_0, _i_2, k)'⋅S(p, _i_0, ((b⋅-1.0)+k), k)'⋅2.0⋅U(q, _i_2, (_i_7+b))⋅U(p, _i_2, ((b⋅-1.0)+k))⋅S(_i_1, q, _i_7, (_i_7+b))))+∑((b, p, k, _i_0, q), (U(_i_0, _i_2, k)'⋅S(p, _i_1, ((b⋅-1.0)+_i_7), _i_7)'⋅2.0⋅U(q, _i_2, (b+k))⋅U(p, _i_2, ((b⋅-1.0)+_i_7))⋅S(_i_0, q, k, (b+k)))))
<->

  0.012391 seconds (172.37 k allocations: 5.046 MiB)
2 sum_sym
(∑((b, q, k, p, _i_0), (U(_i_0, _i_2, k)'⋅S(p, _i_0, ((b⋅-1.0)+k), k)'⋅2.0⋅U(q, _i_2, (_i_7+b))⋅U(p, _i_2, ((b⋅-1.0)+k))⋅S(_i_1, q, _i_7, (_i_7+b))))+∑((b, p, k, _i_0, q), (U(_i_0, _i_2, k)'⋅S(p, _i_1, ((b⋅-1.0)+_i_7), _i_7)'⋅2.0⋅U(q, _i_2, (b+k))⋅U(p, _i_2, ((b⋅-1.0)+_i_7))⋅S(_i_0, q,

(_i_1, _i_2, _i_7) -> 
    ∑((b, p, k, _i_0, q), (U(_i_0, _i_2, (b+k))'⋅4.0⋅U(q, _i_2, k)⋅U(p, _i_2, (_i_7+b))⋅S(_i_0, q, (b+k), k)⋅S(_i_1, p, _i_7, (_i_7+b))))