# Toy Replica Exchange

This gives a very simple overview of a replica exchange code using `dask`. This is much simpler than a real replica exchange code, but is enough to capture the overall parallelization challenge, which is that swaps can lead to complicated dependency graphs.

Overall, we'll think of several "ensembles" that we are sampling simultaneously. We'll also refer to these as "slots," and they are represented by characters (`'A'`, `'B'`, etc.) The value associated with each slot will be an integer. The goal of this code is to swap which integer is associated with which "ensemble" letter.

In real replica exchange methods, the ensembles are typically temperatures (for REMD) or path ensembles (for RETIS). The values are either points in phase space (for REMD) or trajectories (for RETIS).

In [None]:
import itertools
import random
import time
import dask

In [None]:
class Swapper(object):
    """Perform swap between specific ensembles.
    
    Parameters
    ----------
    elem_1 : character
        first ensemble of swap pair
    elem_2 : character
        second ensemble of swap pair
    """
    def __init__(self, elem_1, elem_2):
        self.elem_1 = elem_1
        self.elem_2 = elem_2
    
    @property
    def pair(self):
        """convenience to return tuple of swap pair"""
        return (self.elem_1, self.elem_2)
    
    def subdict(self, slots):
        """Mapping of ensemble to value for ensembles in this move"""
        return {slot: slots[slot] for slot in self.pair}
        
    def __call__(self, subdict):
        """Perform the swap"""
        time.sleep(1)
        return {self.elem_1: subdict[self.elem_2], self.elem_2: subdict[self.elem_1]}
        
    def __repr__(self):
        return "{}('{}', '{}')".format(self.__class__.__name__, self.elem_1, self.elem_2)

In [None]:
# every character in ensemble_string represents an ensemble
# swappers are created for all pairs
ensemble_string = "ABCDEF"
swappers = [Swapper(*pair) for pair in itertools.combinations(ensemble_string, 2)]

In [None]:
# preselect 20 random swap moves
swaps = [random.choice(swappers) for i in range(20)]

### Run without dask

In [None]:
# initial conditions give the index in the string as the value of that character
slots = {letter: idx for idx, letter in enumerate(ensemble_string)}
slots

In [None]:
%%time
for swap in swaps:
    subdict = swap.subdict(slots)
    pairs = swap.pair
    swapped_sub = swap(subdict)
    for slot in pairs:
        slots[slot] = swapped_sub[slot]

In [None]:
slots

### Run with dask

In [None]:
# reset initial conditions
slots = {letter: idx for idx, letter in enumerate(ensemble_string)}

Next you'll modify the code to be task-based, using `dask.delayed`. The cell below starts as exactly the same as the example without dask. 

Hints:

1. Does your delayed task return a single object or a tuple? If it returns a tuple of length `N`, use `nout=N`.
2. Is there randomness in your task? If so, use `pure=False`.

In [None]:
# YOUR TURN: modify this code to work with dask.delayed

for swap in swaps:
    subdict = swap.subdict(slots)
    pairs = swap.pair
    swapped_sub = swap(subdict)
    for slot in pairs:
        slots[slot] = swapped_sub[slot]

In [None]:
slots

In [None]:
dask.visualize(*slots.values(), rankdir='TB')

In [None]:
# using the distributed scheduler is optional, but its dashboard is nice to watch!
from dask.distributed import Client, LocalCluster

In [None]:
# trick specific in JURECA
host = !hostname
ip = host[0]+'i'
print(ip)

In [None]:
cluster = LocalCluster(ip=ip)
client = Client(cluster)
client

In [None]:
%%time
dask.compute(slots)[0]

Because we used the exact same steps, these should give us the same results as the version without dask.