In [128]:
import random
from queue import Queue
import matplotlib.pyplot as plt
import numpy as np
import math

In [129]:
class point:

    def __init__(self, x, y, data=[]):
        #point data
        self.x = x
        self.y = y
        self.data = data

    def __repr__(self):
        return f'{{"x": {self.x}, "y": {self.y}}}'

In [130]:
class square:

    def __init__(self, x, y, l):
        # square cell initializing
        self.x = x
        self.y = y
        self.l = l
        self.points = []

    def __repr__(self):
        return f'({self.x}, {self.y}, {self.l})'

    def contains(self, point):
        # checks if point falls within a cell
        xcheck = self.x - (self.l / 2) <= point.x and self.x + (self.l / 2) > point.x
        ycheck = self.y - (self.l / 2) <= point.y and self.y + (self.l / 2) > point.y
        return xcheck and ycheck

In [145]:
class quadtree:

    def __init__(self, square, capacity, divided=False):
        #initialize quadtree object
        self.square = square
        self.capacity = capacity
        self.divided = divided
        self.topleft = None
        self.topright = None
        self.botleft = None
        self.botright = None

    def subdivide(self):
        #divide up the current cell
        x, y, l = self.square.x, self.square.y, self.square.l

        topleft = square(x-l/4, y+l/4, l/2)
        self.topleft = quadtree(topleft, 1)

        topright = square(x+l/4, y+l/4, l/2)
        self.topright = quadtree(topright, 1)

        botleft = square(x-l/4, y-l/4, l/2)
        self.botleft = quadtree(botleft, 1)

        botright = square(x+l/4, y-l/4, l/2)
        self.botright = quadtree(botright, 1)

        self.divided = True

        for point in self.square.points:
            self.topleft.insert(point)
            self.topright.insert(point)
            self.botleft.insert(point)
            self.botright.insert(point)

        self.square.points = []

    def insert(self, point):
        #insert a point into the quadtree
        if not self.square.contains(point):
            return
        elif self.divided:
            self.topleft.insert(point)
            self.topright.insert(point)
            self.botleft.insert(point)
            self.botright.insert(point)
        elif len(self.square.points) < self.capacity:
            self.square.points.append(point)
        else:
            self.subdivide()
            self.topleft.insert(point)
            self.topright.insert(point)
            self.botleft.insert(point)
            self.botright.insert(point)

    def killemptychildren(self):
        #get rid of any cells that do not have points inisde
        if not self.divided and len(self.square.points) != 0:
            return

        if not self.topleft.divided and len(self.topleft.square.points) == 0:
            self.topleft = None
        else:
            self.topleft.killemptychildren()
        
        if not self.topright.divided and len(self.topright.square.points) == 0:
            self.topright = None
        else:
            self.topright.killemptychildren()

        if not self.botleft.divided and len(self.botleft.square.points) == 0:
            self.botleft = None
        else:
            self.botleft.killemptychildren()

        if not self.botright.divided and len(self.botright.square.points) == 0:
            self.botright = None
        else:
            self.botright.killemptychildren()

        

    def printsub(self):
        if self.divided is False and len(self.square.points) > 0:
            print(self.square)
            print(self.square.points)
        else:
            if self.topleft is not None:
                self.topleft.printsub()
            if self.topright is not None:
                self.topright.printsub()
            if self.botleft is not None:
                self.botleft.printsub()
            if self.botright is not None:
                self.botright.printsub()

In [132]:
def is_leaf(qtree):
    if qtree.divided == False:
        return True
    else:
        return False


In [133]:
cost = 0
transport_plan = {}

In [156]:
def compute_ot(qtree, cost_func):
    global cost
    global transport_plan
    
    if qtree == None:
        return []

    if is_leaf(qtree): # base case: leaf node
        p = qtree.square.points[0] # there should only be one point 
        val = min(p.data)
        if val > 0: # if both distributions have mass at same point
            transport_plan[(p.x, p.y)] = [((p.x, p.y), val)]
            p.data = [m-val for m in p.data]
        if max(p.data) > 0: # if there is still mass at point push up
            return qtree.square.points
        else: return []
    #recursive call
    qtree.square.points = compute_ot(qtree.topleft, cost_func) + compute_ot(qtree.topright, cost_func) + compute_ot(qtree.botleft, cost_func) + compute_ot(qtree.botright, cost_func)
    mass = [0, 0] #only two distributions
    dist1_q = Queue()
    dist2_q = Queue()
    for p in qtree.square.points:
        mass[0] += p.data[0]
        mass[1] += p.data[1]
        if p.data[0] > 0:
            dist1_q.put(p)
        else:                    # each point now should only have mass from one distribution after           
            dist2_q.put(p)       # pairing at leaf nodes
    val = min(mass)

    while val > 0:
        p1 = dist1_q.get()
        transport_plan[(p1.x, p1.y)] = []
        if p1.data[0] >= val:   # map all points from dist2 
            val = 0
            while dist2_q.empty() == False:
                p2 = dist2_q.get()
                transport_plan[(p1.x, p1.y)].append(((p2.x, p2.y), p2.data[1]))
                cost += cost_func(p1, p2) * p2.data[1]
                p2.data[1] = 0
            p1.data[0] -= val

        else:   # p1.data[0] < val
            m = p1.data[0]
            while m > 0:
                p2 = dist2_q.get()
                if m >= p2.data[1]:
                    m -= p2.data[1]
                    val -= p2.data[1]
                    transport_plan[(p1.x, p1.y)].append(((p2.x, p2.y), p2.data[1]))
                    cost += cost_func(p1, p2) * p2.data[1]
                    p1.data[0] -= p2.data[1]
                    p2.data[1] = 0
                else:   # m < p2.data[1]
                    m = 0
                    val -= m
                    transport_plan[(p1.x, p1.y)].append(((p2.x, p2.y), m))
                    cost += cost_func(p1, p2) * p2.data[1]
                    p1.data[0] -= m
                    p2.data[1] -= m
    
    return [p for p in qtree.square.points if max(p.data) > 0]

In [135]:
def euclid_dist(a, b):
    return math.sqrt((b.x-a.x)**2 + (b.y-a.y)**2)

In [157]:
#test with 2 distributions with one point each

cost = 0
transport_plan = {}

sq = square(0, 0, 3)
qtree = quadtree(sq, 1)
qtree.insert(point(1, 0, [1, 0]))
qtree.insert(point(0, 0, [0, 1]))
qtree.killemptychildren()
qtree.printsub()

compute_ot(qtree, euclid_dist)

print(cost)
print(transport_plan)

(0.375, 0.375, 0.75)
[{"x": 0, "y": 0}]
(1.125, 0.375, 0.75)
[{"x": 1, "y": 0}]
1.0
{(1, 0): [((0, 0), 1)]}
