In [1]:
import numpy as np
from scipy.spatial import distance
import random
import pandas as pd
pd.options.mode.chained_assignment = None  # default='warn'
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist
import time
import json
from scipy.cluster import hierarchy
from scipy.optimize import linear_sum_assignment


In [2]:
def calculate_OT_cost(p, q, reg, cost_matrix, num_iterations, stop_theshold):
    p = np.array([p]).T
    q = np.array([q]).T
    Xi = np.exp(-cost_matrix / reg)
    v_n = np.ones((Xi.shape[1], 1))
    v_old = v_n
    for _ in range(num_iterations):
        v_n = q / (Xi.T @ (p / (Xi @ v_n)))
        if np.linalg.norm(v_n  - v_old)<stop_theshold:
            break
        v_old = v_n
    diag_u = np.diagflat((p / (Xi @ v_n)))
    diag_v = np.diagflat(v_n)
    OT_plan = diag_u @ Xi @ diag_v
    OT_cost = np.multiply(OT_plan, cost_matrix).sum()
    return OT_plan


def fill_ot_distance(df, num_of_iterations, lambda_pen, stop_theshold):
    for i in range(len(df)):# Here we iterate among rows, and below we shall calculate the densities
        for j in range(i+1):

            cost_matrix = distance.cdist(df['data points'][i], df['data points'][j])
            min_time = time.time()
            OT_plan_test = calculate_OT_cost(df['p'][i], df['p'][j], lambda_pen, cost_matrix, num_of_iterations, stop_theshold)            
            OT_cost_test = np.multiply(OT_plan_test, cost_matrix).sum()  #yakhoda
            max_time = time.time()
            df[str(i)][j] = OT_cost_test
    
def wasserstein_distance(samples_P, samples_Q):
    pairwise_distances = np.abs(samples_P[:, None] - samples_Q[None, :])
    row_ind, col_ind = linear_sum_assignment(pairwise_distances)
    optimal_transport_cost = pairwise_distances[row_ind, col_ind].sum()
    return optimal_transport_cost / len(samples_P)

def wasserstein_distance(samples_P, samples_Q):
    pairwise_distances = np.abs(samples_P[:, None] - samples_Q[None, :])
    row_ind, col_ind = linear_sum_assignment(pairwise_distances)
    optimal_transport_cost = pairwise_distances[row_ind, col_ind].sum()
    return optimal_transport_cost / len(samples_P)

def calculate_OT_cost(a, b, epsilon, pairwise_distances, num_iter, tolerance):
    # Placeholder function for the actual OT cost calculation
    return np.sum(pairwise_distances)



In [3]:
num_samples_list = [10,100,200,300,600,900,1200,1600,2000,2400,3000]
num_micro_reps = 10  
seed = 42  

times_list_reg_total = []
times_list_unreg_total = []

for num_samples in num_samples_list:
    num_samples = int(num_samples)
    print(f"Number of samples: {num_samples}")
    
    reg_times = []
    unreg_times = []
    
    for _ in range(num_micro_reps):
        # Generate independent normal samples
        samples_P_no_crn = np.random.normal(0, 1, num_samples)
        samples_Q_no_crn = np.random.normal(5, 1, num_samples)

        # Measure unregularized Wasserstein distance time
        start_time = time.time()
        wasserstein_distance(samples_P_no_crn, samples_Q_no_crn)
        unreg_times.append(time.time() - start_time)

        # Measure regularized OT cost calculation time
        start_time = time.time()
        pairwise_distances = np.abs(samples_P_no_crn[:, None] - samples_Q_no_crn[None, :])
        calculate_OT_cost(np.ones(len(samples_P_no_crn)) / len(samples_P_no_crn), 
                          np.ones(len(samples_Q_no_crn)) / len(samples_Q_no_crn), 
                          0.5, pairwise_distances, 200, 10**-4)
        reg_times.append(time.time() - start_time)
    
    times_list_reg_total.append(reg_times)
    times_list_unreg_total.append(unreg_times)

    
    
    


Number of samples: 10
Number of samples: 100
Number of samples: 200
Number of samples: 300
Number of samples: 600
Number of samples: 900
Number of samples: 1200
Number of samples: 1600
Number of samples: 2000
Number of samples: 2400
Number of samples: 3000


In [6]:
def write_list_to_file(data_list, filename):
    """
    Writes the list of lists to a file exactly as it appears.

    Parameters:
    data_list (list of lists): List of lists to write to the file.
    filename (str): Name of the file to write to.
    """
    with open(filename, 'w') as file:
        file.write(str(data_list))
write_list_to_file(times_list_reg_total, "../../data/regularized.txt")
write_list_to_file(times_list_unreg_total, "../../data/unregularized.txt")