In [2]:
from mmdew.bucket_stream2 import BucketStream
from mmdew.bucket_stream_old import BucketStream as OldBucketStream
import numpy as np
import pandas as pd
from mmdew.mmd import MMD
from sklearn import metrics
import numpy.linalg as la
from mmdew.bucket_stream2 import Bucket
import  math

def mmd(split, buckets, k):
    start = buckets[:split]
    end = buckets[split:]
    #breakpoint()
    start_elements = start[0].elements
    start_weights = start[0].weights * start[0].uncompressed_capacity
    end_elements = end[0].elements
    end_weights = end[0].weights * end[0].uncompressed_capacity
    start_uncompressed_capacity = start[0].uncompressed_capacity
    end_uncompressed_capacity = end[0].uncompressed_capacity
    #breakpoint()
    for bucket in start[1:]:
        #breakpoint()
        start_elements = np.concatenate((start_elements, bucket.elements))
        start_weights = np.concatenate((start_weights, bucket.weights * bucket.uncompressed_capacity))
        start_uncompressed_capacity += bucket.uncompressed_capacity
    for bucket in end[1:]:
        #breakpoint()
        end_elements = np.concatenate((end_elements, bucket.elements))
        #breakpoint()
        end_weights = np.concatenate((end_weights, bucket.weights * bucket.uncompressed_capacity))
        end_uncompressed_capacity += bucket.uncompressed_capacity
    #

    start_capacity = len(start_elements)
    end_capacity = len(end_elements)
    start_weights = start_weights * (1/start_uncompressed_capacity)
    end_weights = end_weights * (1/end_uncompressed_capacity)
    #breakpoint()
    addend_1 = start_weights.T @ k(start_elements, start_elements) @ start_weights
    addend_2 = end_weights.T @ k(end_elements, end_elements) @ end_weights
    addend_3 = start_weights.T @ k(start_elements, end_elements) @ end_weights
    return (addend_1 + addend_2 - (2 * addend_3))[0][0], start_uncompressed_capacity, end_uncompressed_capacity

def merge_buckets_with_subsampling(bucket_list,k, apply_subsampling):
    """Merges the buckets in `bucket_list` such that one bucket remains with XX, and XY such that their values correspond to the case that all data would have been in this bucket."""
    if len(bucket_list) == 1:
        return bucket_list[0]
    current = bucket_list[-1]
    previous = bucket_list[-2]

    current_elements = current.elements
    previous_elements = previous.elements
    current_weights = current.weights
    previous_weights = previous.weights

    joined_elements = np.concatenate((current_elements, previous_elements))
    joined_uncompressed_capacity = current.uncompressed_capacity + previous.uncompressed_capacity
    #subsampling seems to be too extreme. Maybe select less aggressively
    #maybe choose combined uncompressed capacity as n which would probably not contradict the chatalic paper
    #breakpoint()
    if joined_uncompressed_capacity > 16:

        m = round(math.sqrt(joined_uncompressed_capacity))  # size of the subsample
        #ToDo: uncomment this
        #m_idx = np.random.default_rng().integers(len(joined_elements), size=m)
        m_idx = range(0,m)
        subsample = joined_elements[m_idx]
    else:
        m = joined_uncompressed_capacity
        subsample = joined_elements
   # assuming current_elements and previous_elements have the same length


    joined_weights = np.concatenate((current_weights, previous_weights))
    #breakpoint()
    K_z = k(subsample, joined_elements)



    #K_m = np.zeros((m, m))  # initialize the kernel matrix with zeros
    #for i in range(m):
        #for j in range(m):
            # reshape to 2D array as rbf_kernel expects 2D array

    K_m = k(subsample, subsample)
    K_m_inv = la.pinv(K_m)
    #breakpoint()
    new_weights = .5 * K_m_inv @ K_z @ joined_weights
    return merge_buckets_with_subsampling(
        bucket_list[:-2]
        + [
            Bucket(
                elements=subsample,
                weights=new_weights,
                capacity=m,
                uncompressed_capacity=joined_uncompressed_capacity
            )
        ]
    )

def merge(buckets,k):
    if len(buckets) < 2:
        return buckets
    current = buckets[-1]
    previous = buckets[-2]
    if previous.uncompressed_capacity == current.uncompressed_capacity:
        buckets = buckets[:-2] + [merge_buckets_with_subsampling(buckets[-2:],k)]
        buckets = merge(buckets)
    return buckets


def insert_no_cut(buckets, element,k):
    #breaxkpoint()
    buckets += [
        Bucket(
            elements=np.array(element).reshape(1,-1),
            weights=np.array(1).reshape(1,-1),
            capacity=1,
            uncompressed_capacity=1
        )
    ]
    return merge(buckets,k)

    #breakpoint()

def run_new_test(rbf_kernel=False, random_data = False, limit=64, apply_subsampling=True, sample_random=False, even_sized=False):
    if random_data:
        X = np.random.normal(0, 1, limit)
        Y = np.random.normal(1, 2, limit)
    else:
        X = np.repeat([[1,2,3]], limit, axis=0)
        Y = np.repeat([[4,5,6]], limit, axis=0)
    bs = []
    if rbf_kernel:
        k = lambda X, Y: metrics.pairwise.rbf_kernel(X,Y,1)
    else:
        k = lambda X, Y: metrics.pairwise.distance_metrics(X,Y)
    for i in range(0,128):
        bs = insert_no_cut(bs, X[i],k)
    for i in range(0,127):
        bs = insert_no_cut(bs, Y[i],k)
