Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shared_intermediates and contraction trees #104

Closed
Bonnevie opened this issue Sep 4, 2019 · 12 comments
Closed

shared_intermediates and contraction trees #104

Bonnevie opened this issue Sep 4, 2019 · 12 comments

Comments

@Bonnevie
Copy link

Bonnevie commented Sep 4, 2019

The shared_intermediates example in the docs is to calculate multiple marginals, which is a use case of interest to me. Calculating marginals is equivalent to calculating environment tensors, so I'm interested in using the approach from https://arxiv.org/abs/1310.8023 which basically guarantees that calculating all marginals has the same cost as 3x the cost of calculating 1 marginal by using extensive memoization. This is ensured by making each contraction path a member of the same contraction tree family, with each tree node at worst corresponding to 3 different contractions that can be appropriately memoized. This function is called multienv() in various tensor network packages.

Here is a rough but working prototype of a ContractionTree class that can generate the various contraction paths by calls to get_environment_path.

Contraction tree prototype
class ContractionTree:
    """Interface for finding contraction paths for each possible environment tensor with high cache reuse.

    Based on "Improving the efficiency of variational tensor network algorithms" by Evenbly and Pfeifer. Returned
    environment paths stay in same contraction tree family.
    """

    def __init__(self, tensors, optimize='greedy'):
        operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in tensors])
        self.partition_path, _ = contract_path(*operands, (), optimize=optimize)
        self.contraction_tree = nx.DiGraph()
        for index, tensor in enumerate(tensors):
            self.contraction_tree.add_node(tensor, index=index)
        self.tensors = tensors
        self.cache = HitDict()

        active_nodes = [tensor for tensor in tensors]
        for path_step in self.partition_path:
            children = []
            for index in sorted(path_step, reverse=True):
                children.append(active_nodes.pop(index))
            active_nodes.append(frozenset(children))
            for child in children:
                self.contraction_tree.add_edge(child, active_nodes[-1])

    def get_environment_path(self, marginal_tensor):
        """Returns opt_einsum path for computing the environment tensor of `marginal_tensor`"""
        environment_tree = self.contraction_tree.copy()

        def redirect_edges(node, child, visited):
            """recursive function that redirects all edges towards initial node passed"""
            visited = visited + [node]
            out_edges = list(environment_tree.out_edges(node))
            for start, end in out_edges:
                if child is None or end != child:
                    environment_tree.remove_edge(start, end)
                    environment_tree.add_edge(end, start)
                if end not in visited:
                    redirect_edges(end, node, visited)
            in_edges = list(environment_tree.in_edges(node))
            for start, end in in_edges:
                if start not in visited:
                    redirect_edges(start, node, visited)

        #  redirect and remove redundant nodes
        redirect_edges(marginal_tensor, None, [marginal_tensor])
        nodes = list(environment_tree.nodes)
        for node in nodes:
            in_edges = environment_tree.in_edges(node)
            out_edges = environment_tree.out_edges(node)
            if len(out_edges) and len(in_edges) == 1:
                origin = list(in_edges)[0][0]
                for out_edge in out_edges:
                    environment_tree.add_edge(origin, out_edge[1])
                environment_tree.remove_node(node)

        #  maintain dict mapping tensor -> index position in list of tensors
        positions = {tensor: index for index, tensor in enumerate(self.tensors)}
        positions = {tensor: position - (positions[marginal_tensor] < position)
                     for tensor, position in positions.items()}
        del positions[marginal_tensor] #  assume the query tensor is removed from input string

        #  rebuild path from reoriented contraction tree
        path = []
        active = len(self.tensors) - 1
        for node in nx.topological_sort(environment_tree):
            children = [start for start, end in environment_tree.in_edges(node)]
            if len(children) > 1:
                path.append(tuple(positions[tensor] for tensor in children))
                positions = {tensor: position - sum(positions[removed_tensor] < position for removed_tensor in children)
                             for tensor, position in positions.items() if tensor not in children}
                active = active - len(children) + 1
                positions[node] = active - 1
        return path

My issue is

  1. using the HitDict class proposed for counting cache hits it seems that using these paths gives me 0 cache hits?? I am using jax, but tried with numpy as well.
  2. Is shared_intermediates capable of reusing high-order intermediates as found in the contraction tree?

I am calling it using interleaved input (I am using a wrapper of tensor with an axis_names attribute):

    def get_environment(self, tensors, environment_of, optimize='auto', **kwargs):
        environment_tensors = [tensor for tensor in tensors if tensor != environment_of]
        operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in environment_tensors])
        return oe.contract(*operands, environment_of.axis_names, optimize=optimize, **kwargs)
    
    def marginals(all_tensors, marginal_tensors):
        tree = ContractionTree(all_tensors)
        marginals = []
        for marginal_tensor in marginal_tensors:
            path = tree.get_environment_path(marginal_tensor)
            with shared_intermediates(tree.cache):
                marginals.append(get_environment(all_tensors, marginal_tensor, optimize=path))
    return marginals, tree.cache
            

I realize this might be outside the scope of opt_einsum, but it also seems like this would be a valuable extension of shared_intermediates - or it might already be supported via the current caching mechanism and there's just a bug in my implementation :)

@dgasmith
Copy link
Owner

dgasmith commented Sep 4, 2019

This is great, @fritzo may have some interesting use cases to try out. A few notes:

-shared_intermediates should be able to use high-order intermediates

  • I think this is likely in-scope of opt_einsum and a functionality that would be great to have. We have talked about algorithms that would maximally optimize overlap before, but did not have a good way of exploring those.
  • A dependance on networkx might be heavy for us, we could consider it an optional dependancy.

Can you provide the HitDict as well? It also might be useful to try a few simple examples that have ~2 straightforward intermediate to try and get a handle on what is going on.

@Bonnevie
Copy link
Author

Bonnevie commented Sep 4, 2019

I think the networkx dependency is avoidable, you just need some kind of data structure for the tree. Was considering anytree as well, but since the tree changes frequently it wasn't ideal.

The HitDict was proposed by @jcmgray in #85. I repeat it here for convenience:

class HitDict(dict): 
     
    def __init__(self, *args, **kwargs): 
        self.gets = 0 
        self.hits = 0 
        super().__init__(*args, **kwargs)

    def __contains__(self, k):
        self.gets += 1
        does_contain = super().__contains__(k)
        if does_contain:
            self.hits += 1
        return does_contain

    @property
    def misses(self):
        return self.gets - self.hits

It's only used for ease of review, ContractionTree does not otherwise depend on it and it could be replaced.

Is there any obvious explanation for why there are no cache hits? The cache does get accessed, but only for writing.

@jcmgray
Copy link
Collaborator

jcmgray commented Sep 4, 2019

If i make the example workable (I needed to call tree.get_environment as method not function) (also how does one make code collapsible on github??):

from itertools import chain
from opt_einsum import contract_path, shared_intermediates
import networkx as nx
import opt_einsum as oe
import quimb.tensor as qtn


class CustomTensor(qtn.Tensor):
    
    @property
    def tensor(self):
        return self.data
    
    @property
    def axis_names(self):
        return self.inds

    
class HitDict(dict): 
     
    def __init__(self, *args, **kwargs): 
        self.gets = 0 
        self.hits = 0 
        super().__init__(*args, **kwargs)

    def __contains__(self, k):
        self.gets += 1
        does_contain = super().__contains__(k)
        if does_contain:
            self.hits += 1
        return does_contain

    @property
    def misses(self):
        return self.gets - self.hits

    
class ContractionTree:
    """Interface for finding contraction paths for each possible environment tensor with high cache reuse.

    Based on "Improving the efficiency of variational tensor network algorithms" by Evenbly and Pfeifer. Returned
    environment paths stay in same contraction tree family.
    """

    def __init__(self, tensors, optimize='greedy'):
        operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in tensors])
        self.partition_path, _ = contract_path(*operands, (), optimize=optimize)
        self.contraction_tree = nx.DiGraph()
        for index, tensor in enumerate(tensors):
            self.contraction_tree.add_node(tensor, index=index)
        self.tensors = tensors
        self.cache = HitDict()

        active_nodes = [tensor for tensor in tensors]
        for path_step in self.partition_path:
            children = []
            for index in sorted(path_step, reverse=True):
                children.append(active_nodes.pop(index))
            active_nodes.append(frozenset(children))
            for child in children:
                self.contraction_tree.add_edge(child, active_nodes[-1])

    def get_environment_path(self, marginal_tensor):
        """Returns opt_einsum path for computing the environment tensor of `marginal_tensor`"""
        environment_tree = self.contraction_tree.copy()

        def redirect_edges(node, child, visited):
            """recursive function that redirects all edges towards initial node passed"""
            visited = visited + [node]
            out_edges = list(environment_tree.out_edges(node))
            for start, end in out_edges:
                if child is None or end != child:
                    environment_tree.remove_edge(start, end)
                    environment_tree.add_edge(end, start)
                if end not in visited:
                    redirect_edges(end, node, visited)
            in_edges = list(environment_tree.in_edges(node))
            for start, end in in_edges:
                if start not in visited:
                    redirect_edges(start, node, visited)

        #  redirect and remove redundant nodes
        redirect_edges(marginal_tensor, None, [marginal_tensor])
        nodes = list(environment_tree.nodes)
        for node in nodes:
            in_edges = environment_tree.in_edges(node)
            out_edges = environment_tree.out_edges(node)
            if len(out_edges) and len(in_edges) == 1:
                origin = list(in_edges)[0][0]
                for out_edge in out_edges:
                    environment_tree.add_edge(origin, out_edge[1])
                environment_tree.remove_node(node)

        #  maintain dict mapping tensor -> index position in list of tensors
        positions = {tensor: index for index, tensor in enumerate(self.tensors)}
        positions = {tensor: position - (positions[marginal_tensor] < position)
                     for tensor, position in positions.items()}
        del positions[marginal_tensor] #  assume the query tensor is removed from input string

        #  rebuild path from reoriented contraction tree
        path = []
        active = len(self.tensors) - 1
        for node in nx.topological_sort(environment_tree):
            children = [start for start, end in environment_tree.in_edges(node)]
            if len(children) > 1:
                path.append(tuple(positions[tensor] for tensor in children))
                positions = {tensor: position - sum(positions[removed_tensor] < position for removed_tensor in children)
                             for tensor, position in positions.items() if tensor not in children}
                active = active - len(children) + 1
                positions[node] = active - 1
        return path

    def get_environment(self, tensors, environment_of, optimize='auto', **kwargs):
        environment_tensors = [tensor for tensor in tensors if tensor != environment_of]
        operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in environment_tensors])
        return oe.contract(*operands, environment_of.axis_names, optimize=optimize, **kwargs)

    
def marginals(all_tensors, marginal_tensors):
    tree = ContractionTree(all_tensors)
    marginals = []
    for marginal_tensor in marginal_tensors:
        path = tree.get_environment_path(marginal_tensor)
        with shared_intermediates(tree.cache):
            marginals.append(tree.get_environment(all_tensors, marginal_tensor, optimize=path))

    return marginals, tree.cache

psi = qtn.MPS_rand_state(10, 4)
tn = psi.H & psi
tensors = [CustomTensor(x.data, x.inds) for x in tn]
mgs, cache = marginals(tensors, [tensors[i] for i in range(10, 20)])

I get:

print(cache.hits)
# 144
print(cache.misses)
# 44

So seems to be working at least with this example?

@fritzo
Copy link
Contributor

fritzo commented Sep 4, 2019

cc @eb8680

In Pyro we didn't see how implement a guaranteed the 3x overhead for computing all marginals, but we found empirically that using the greedy usually memoized enough intermediate expressions.

Since then we've reimplemented our marginal computations to not rely on shared_intermediates, instead using a simple tape-based adjoint algorithm similar to backprop. This new implementation avoids the n^2 python overhead that can't be eliminated with shared_intermediates. For details see:

I think it's probably better to use such a tape-based mechanism than to make shared_intermediates smarter, because of the python overhead of shared_intermediates.

@jcmgray use the <details> ... </details> tag with plenty of newlines.

@Bonnevie
Copy link
Author

Bonnevie commented Sep 4, 2019

@jcmgray thanks for putting it together. My own implementation is a bit different from what I presented here, with get_environment actually being a class method of a larger class, so maybe there's something deeper at play preventing the cache from being used. What is the cache indexed on exactly? Nice to see that it works for an MPS! Did not know about quimb either, that might come in handy.

@fritzo why would it not guarantee 3x overhead if you can calculate all marginals with 3 computations per node in the contraction tree?

@fritzo
Copy link
Contributor

fritzo commented Sep 4, 2019

why would [ContractionTree] not guarantee 3x overhead ... ?

If I understand correctly, ContractionTree guarantees linear growth as measured in tensor ops, but quadratic growth as measured in python code to find those tensor ops (e.g. an O(n) cost topological sort is called for each of O(n) marginals, resulting in O(n^2)). In Pyro we perform lots of small tensor operations and cannot compile away the that quadratic overhead. Hence we switched to a tape-based implementation with linear growth both for tensor op count and Python overhead.

@jcmgray
Copy link
Collaborator

jcmgray commented Sep 4, 2019

What is the cache indexed on exactly?

It is cached on id(array) I believe, so it has to be the same actual python object passed in, not just the same underlying values.

@Bonnevie
Copy link
Author

Bonnevie commented Sep 5, 2019

@fritzo ah, I see what you mean now, but that scaling comes from converting the contraction tree to a contraction path and not from shared_intermediates, correct? Is the overhead from the sort significant in the grand scope?
also, any chance you could explain how a tape-based method can replace shared_intermediates vis-a-vis memoization? Cannot quite tell where to look for the answer in the pyro code you referenced.

p.s. is Pyro suitable for discrete model inference?

@fritzo
Copy link
Contributor

fritzo commented Sep 5, 2019

@Bonnevie You might find Jason Eisner's paper insightful
Inside-Outside and Forward-Backward Algorithms Are Just Backprop

Yes Pyro is suitable for discrete model inference, here is a tutorial. Also examples/einsum.py computes marginals using Pyro's taped wrapper around opt_einsum. See accompanying paper, especially the related work section.

@Bonnevie
Copy link
Author

Bonnevie commented Sep 6, 2019

@fritzo thanks for taking the time to reply. In a wild coincidence, somebody made a passing mention of that very same paper on twitter yesterday evening, so I had it open in a tab while reading your reply :) I need to read it, but I guess the key is that environment tensors are the same as the tensor gradients since the einsum/tensor network is multilinear, so you can piggyback on the efficient autograd backends?

Thanks for taking the time to give some details, will look through your linked material.

@Bonnevie
Copy link
Author

Calculating marginals seems to be as simple as

import opt_einsum as oe
import numpy as onp
import jax.numpy as np
from jax import grad

equation, shapes = oe.helpers.rand_equation(5,2)
tensors = [np.array(onp.random.randn(*shape)) for shape in shapes]
expr = oe.contract_expression(equation, *shapes)
envs = grad(expr, argnums=list(range(5)))

Based on this, I don't think there is any reason for a contraction tree implementation - it's basically just a more cumbersome implementation of auto-diff. Of course, you need to actually have autodiff in your backend to exploit it.

@fritzo why does your implementation of autodiff marginals span so many files? Is it only because you modify einsum to be log-sum-exp-like? Also, out of interest, do you use autodiff/tapes for MAP? I can see you have a Ring class lurking in the background.

@fritzo
Copy link
Contributor

fritzo commented Sep 16, 2019

why does your implementation ... span so many files?

Ha ha, that's a fair criticism. I guess part of the reason is that it grew organically. Part of the factoring is intentional, however:

  • We wanted to separate the taping mechanism from the mathematical Ring operations, and we implement a few different rings, e.g. the (logaddexp,add) ring.
  • We wanted a jittable differentiable implementation of marginals, but due to limitations in PyTorch we couldn't grad(jit(grad(...))) as you can in JAX. Hence we built our own taping mechanism which is jittable.

@dgasmith dgasmith closed this as completed Nov 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants