New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
shared_intermediates and contraction trees #104
Comments
This is great, @fritzo may have some interesting use cases to try out. A few notes: -
Can you provide the |
I think the The
It's only used for ease of review, Is there any obvious explanation for why there are no cache hits? The cache does get accessed, but only for writing. |
If i make the example workable (I needed to call from itertools import chain
from opt_einsum import contract_path, shared_intermediates
import networkx as nx
import opt_einsum as oe
import quimb.tensor as qtn
class CustomTensor(qtn.Tensor):
@property
def tensor(self):
return self.data
@property
def axis_names(self):
return self.inds
class HitDict(dict):
def __init__(self, *args, **kwargs):
self.gets = 0
self.hits = 0
super().__init__(*args, **kwargs)
def __contains__(self, k):
self.gets += 1
does_contain = super().__contains__(k)
if does_contain:
self.hits += 1
return does_contain
@property
def misses(self):
return self.gets - self.hits
class ContractionTree:
"""Interface for finding contraction paths for each possible environment tensor with high cache reuse.
Based on "Improving the efficiency of variational tensor network algorithms" by Evenbly and Pfeifer. Returned
environment paths stay in same contraction tree family.
"""
def __init__(self, tensors, optimize='greedy'):
operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in tensors])
self.partition_path, _ = contract_path(*operands, (), optimize=optimize)
self.contraction_tree = nx.DiGraph()
for index, tensor in enumerate(tensors):
self.contraction_tree.add_node(tensor, index=index)
self.tensors = tensors
self.cache = HitDict()
active_nodes = [tensor for tensor in tensors]
for path_step in self.partition_path:
children = []
for index in sorted(path_step, reverse=True):
children.append(active_nodes.pop(index))
active_nodes.append(frozenset(children))
for child in children:
self.contraction_tree.add_edge(child, active_nodes[-1])
def get_environment_path(self, marginal_tensor):
"""Returns opt_einsum path for computing the environment tensor of `marginal_tensor`"""
environment_tree = self.contraction_tree.copy()
def redirect_edges(node, child, visited):
"""recursive function that redirects all edges towards initial node passed"""
visited = visited + [node]
out_edges = list(environment_tree.out_edges(node))
for start, end in out_edges:
if child is None or end != child:
environment_tree.remove_edge(start, end)
environment_tree.add_edge(end, start)
if end not in visited:
redirect_edges(end, node, visited)
in_edges = list(environment_tree.in_edges(node))
for start, end in in_edges:
if start not in visited:
redirect_edges(start, node, visited)
# redirect and remove redundant nodes
redirect_edges(marginal_tensor, None, [marginal_tensor])
nodes = list(environment_tree.nodes)
for node in nodes:
in_edges = environment_tree.in_edges(node)
out_edges = environment_tree.out_edges(node)
if len(out_edges) and len(in_edges) == 1:
origin = list(in_edges)[0][0]
for out_edge in out_edges:
environment_tree.add_edge(origin, out_edge[1])
environment_tree.remove_node(node)
# maintain dict mapping tensor -> index position in list of tensors
positions = {tensor: index for index, tensor in enumerate(self.tensors)}
positions = {tensor: position - (positions[marginal_tensor] < position)
for tensor, position in positions.items()}
del positions[marginal_tensor] # assume the query tensor is removed from input string
# rebuild path from reoriented contraction tree
path = []
active = len(self.tensors) - 1
for node in nx.topological_sort(environment_tree):
children = [start for start, end in environment_tree.in_edges(node)]
if len(children) > 1:
path.append(tuple(positions[tensor] for tensor in children))
positions = {tensor: position - sum(positions[removed_tensor] < position for removed_tensor in children)
for tensor, position in positions.items() if tensor not in children}
active = active - len(children) + 1
positions[node] = active - 1
return path
def get_environment(self, tensors, environment_of, optimize='auto', **kwargs):
environment_tensors = [tensor for tensor in tensors if tensor != environment_of]
operands = chain.from_iterable([(tensor.tensor, tensor.axis_names) for tensor in environment_tensors])
return oe.contract(*operands, environment_of.axis_names, optimize=optimize, **kwargs)
def marginals(all_tensors, marginal_tensors):
tree = ContractionTree(all_tensors)
marginals = []
for marginal_tensor in marginal_tensors:
path = tree.get_environment_path(marginal_tensor)
with shared_intermediates(tree.cache):
marginals.append(tree.get_environment(all_tensors, marginal_tensor, optimize=path))
return marginals, tree.cache
psi = qtn.MPS_rand_state(10, 4)
tn = psi.H & psi
tensors = [CustomTensor(x.data, x.inds) for x in tn]
mgs, cache = marginals(tensors, [tensors[i] for i in range(10, 20)]) I get: print(cache.hits)
# 144
print(cache.misses)
# 44 So seems to be working at least with this example? |
cc @eb8680 In Pyro we didn't see how implement a guaranteed the 3x overhead for computing all marginals, but we found empirically that using the greedy usually memoized enough intermediate expressions. Since then we've reimplemented our marginal computations to not rely on shared_intermediates, instead using a simple tape-based adjoint algorithm similar to backprop. This new implementation avoids the n^2 python overhead that can't be eliminated with shared_intermediates. For details see:
I think it's probably better to use such a tape-based mechanism than to make shared_intermediates smarter, because of the python overhead of shared_intermediates. @jcmgray use the |
@jcmgray thanks for putting it together. My own implementation is a bit different from what I presented here, with @fritzo why would it not guarantee 3x overhead if you can calculate all marginals with 3 computations per node in the contraction tree? |
If I understand correctly, ContractionTree guarantees linear growth as measured in tensor ops, but quadratic growth as measured in python code to find those tensor ops (e.g. an |
It is cached on |
@fritzo ah, I see what you mean now, but that scaling comes from converting the contraction tree to a contraction path and not from p.s. is Pyro suitable for discrete model inference? |
@Bonnevie You might find Jason Eisner's paper insightful Yes Pyro is suitable for discrete model inference, here is a tutorial. Also examples/einsum.py computes marginals using Pyro's taped wrapper around opt_einsum. See accompanying paper, especially the related work section. |
@fritzo thanks for taking the time to reply. In a wild coincidence, somebody made a passing mention of that very same paper on twitter yesterday evening, so I had it open in a tab while reading your reply :) I need to read it, but I guess the key is that environment tensors are the same as the tensor gradients since the einsum/tensor network is multilinear, so you can piggyback on the efficient autograd backends? Thanks for taking the time to give some details, will look through your linked material. |
Calculating marginals seems to be as simple as
Based on this, I don't think there is any reason for a contraction tree implementation - it's basically just a more cumbersome implementation of auto-diff. Of course, you need to actually have autodiff in your backend to exploit it. @fritzo why does your implementation of autodiff marginals span so many files? Is it only because you modify einsum to be log-sum-exp-like? Also, out of interest, do you use autodiff/tapes for MAP? I can see you have a Ring class lurking in the background. |
Ha ha, that's a fair criticism. I guess part of the reason is that it grew organically. Part of the factoring is intentional, however:
|
The
shared_intermediates
example in the docs is to calculate multiple marginals, which is a use case of interest to me. Calculating marginals is equivalent to calculating environment tensors, so I'm interested in using the approach from https://arxiv.org/abs/1310.8023 which basically guarantees that calculating all marginals has the same cost as 3x the cost of calculating 1 marginal by using extensive memoization. This is ensured by making each contraction path a member of the same contraction tree family, with each tree node at worst corresponding to 3 different contractions that can be appropriately memoized. This function is calledmultienv()
in various tensor network packages.Here is a rough but working prototype of a
ContractionTree
class that can generate the various contraction paths by calls toget_environment_path
.Contraction tree prototype
My issue is
jax
, but tried withnumpy
as well.shared_intermediates
capable of reusing high-order intermediates as found in the contraction tree?I am calling it using interleaved input (I am using a wrapper of tensor with an
axis_names
attribute):I realize this might be outside the scope of
opt_einsum
, but it also seems like this would be a valuable extension ofshared_intermediates
- or it might already be supported via the current caching mechanism and there's just a bug in my implementation :)The text was updated successfully, but these errors were encountered: