In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from scipy import interpolate
import numpy as np

def plot_shared_x_data(data_sets, axis_label_fontsize=30, line_thickness=0, tick_label_fontsize=26, tick_width=4,
                    x_axis_limits=None, y_axis_limits_list=None, x_tick_labels=None, y_tick_labels_list=None,
                    x_axis_label='2$\it{θ}$ (degrees)', y_axis_label_list=None, plot_title=None, point_size=0, file_name=None,
                    eye_guide = True):
    """
    Create subplots with shared x-axes where each set of data is plotted on a separate y-axis.
    The top plot will have no y-axis labels but retains the ticks and tick labels.
    
    Args:
    x_data (list): The data points for the x-axis, typically 2-theta or angle.
    y_data_sets (list of lists): A list containing multiple y-data sets for each subplot.
    y_axis_label_list (list of str): Labels for each subplot's y-axis.
    axis_label_fontsize (int): Font size for the x and y axis labels.
    line_thickness (float): Thickness of the plot line.
    tick_label_fontsize (int): Font size for the tick labels.
    tick_width (float): Thickness of the ticks and the box around the plot.
    point_size (int): Size of the markers at each data point, if set > 0.
    """

    #input data format: [  [[x,y],[x,y],...], [[x,y],[x,y],...]  ]

    colors = [(0/255, 190/255, 150/255), (110/255, 136/255, 194/255), (186/255, 60/255, 145/255), (186/255, 60/255, 145/255)]

    s_values = [0.001, 0.04, 0.001, 0.001]

    # Set font and weight globally
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.weight'] = 'bold'
    plt.rcParams['font.size'] = tick_label_fontsize

    # Ensure the length of y_data_sets and y_axis_label_list match
    if y_axis_label_list is None:
        y_axis_label_list = [''] * 3
    if y_axis_limits_list is None:
        y_axis_limits_list = [None] * 3
    if y_tick_labels_list is None:
        y_tick_labels_list = [None] * 3

    # Set up the subplots with shared x-axes
    fig, axes = plt.subplots(3, 1, figsize=(8, 10), sharex=True, dpi=300)

    if len(data_sets) == 1:
        axes = [axes]  # Ensure axes is iterable when there's only one subplot


    for i, (ax, xy_data, y_axis_label, y_axis_limits, y_tick_labels) in enumerate(zip(axes, data_sets[:3], y_axis_label_list, y_axis_limits_list, y_tick_labels_list)):
        # Plot each y_data set on its own y-axis
        xy_data = sorted(xy_data, key=lambda x:x[0])
        x_data = [x for x, _ in xy_data]
        y_data = [y for _, y in xy_data]


        if eye_guide:
            guiding_line = interpolate.UnivariateSpline(x_data, y_data, k = 3, s = s_values[i])
            guiding_line_domain =  np.linspace(x_data[0], x_data[-1], 300)
            ax.plot(guiding_line_domain, guiding_line(guiding_line_domain), linestyle = '--', color = 'black', linewidth = line_thickness, dashes = (3,4))

        if point_size > 0:
            ax.plot(x_data, y_data, linewidth = 0, marker='o', markersize=point_size, color = colors[i])
        else:
            ax.plot(x_data, y_data, linewidth = 0)

        if i == 2:
            data_sets[3] = sorted(data_sets[3], key=lambda x:x[0])
            x_data = [x for x, _ in data_sets[3]]
            y_data = [y for _, y in data_sets[3]]
            guiding_line = interpolate.UnivariateSpline(x_data, y_data, k = 3, s = s_values[i])
            guiding_line_domain =  np.linspace(x_data[0], x_data[-1], 300)
            ax.plot(guiding_line_domain, guiding_line(guiding_line_domain), linestyle = '--', color = 'black', linewidth = line_thickness, dashes = (3,4))
            ax.plot(x_data, y_data, linewidth = 0, marker='o', markersize=point_size, color = colors[i + 1])
        

        # Set axis limits
        if x_axis_limits:
            ax.set_xlim(x_axis_limits)

        if y_axis_limits:
            ax.set_ylim(y_axis_limits)

        if y_tick_labels:
            ax.set_yticks(y_tick_labels)

        # Customize ticks and spines thickness
        for spine in ax.spines.values():
            spine.set_linewidth(tick_width)
        ax.tick_params(which='major', width=tick_width, length=10)
        ax.tick_params(which='minor', width=tick_width, length=5, direction='out')

        # Set y-axis label for each subplot
        ax.set_ylabel(y_axis_label, fontsize=axis_label_fontsize, weight='bold')

    # Set the x-axis label for the bottom subplot
    axes[-1].set_xlabel(x_axis_label, fontsize=axis_label_fontsize, weight='bold')

    # Set plot title if given
    if plot_title:
        fig.suptitle(plot_title, fontsize=axis_label_fontsize + 2, weight='bold')

    # Adjust layout to remove space between subplots
    plt.subplots_adjust(hspace=0)  # Remove space between subplots

    # Save or show the plot
    if file_name:
        plt.savefig(file_name, dpi=300, transparent=True)
    else:
        plt.show()

















superconductivity = [[0.6249999999999982, 1.4590163934426226], [2.8124999999999982, 2.6229508196721314], [5.520833333333332, 3.40983606557377],
                     [7.1875, 3.4590163934426226], [8.645833333333332, 3.2950819672131146], [10.520833333333332, 3.524590163934426], [12.916666666666664, 3.6721311475409832],
                     [13.229166666666664, 3.557377049180328], [16.25, 3.40983606557377],[20.104166666666664, 3], [25.104166666666664, 2.6557377049180326],
                     [31.145833333333336, 2.3442622950819665]]

c_over_a_ratio = [[0.16, 4.779002/5.006336], [3.5, 4.50705/4.967145], [4.5, 4.463864/4.956719], [6.5, 4.38991/4.940804], [7.5, 4.339534/4.933373], [9, 4.303594/4.923028], [10.5, 4.254205/4.907403],
                  [12.4, 4.215364/4.892107], [13.3, 4.20367/4.893026], [14.5, 3.34415/5.40997],[15.5, 3.342671/5.395167], [17, 3.341669/5.373317], [18, 3.343599/5.333982], 
                  [21, 3.346762/5.29187], [24, 3.352346/5.25208], [26.5, 3.355109/5.214148], [32, 3.359539/5.134447], [34, 3.355769/5.100621], [39, 3.349792/5.059476], 
                  [35.5, 3.358743/5.07839], [31, 3.371147/5.118287], [22, 3.393147/5.196405], [17, 3.424088/5.312746], [13, 3.439279/5.422773], [10.5, 3.361292/5.495634],
                  [12.4, 3.350413/5.466113], [13.3, 3.346409/5.44492]]

c_over_a_ratio_1 = [[0.16, 4.779002/5.006336], [3.5, 4.50705/4.967145], [4.5, 4.463864/4.956719], [6.5, 4.38991/4.940804], [7.5, 4.339534/4.933373], [9, 4.303594/4.923028], [10.5, 4.254205/4.907403],
                  [12.4, 4.215364/4.892107], [13.3, 4.20367/4.893026]]

c_over_a_ratio_2 = [[14.5, 3.34415/5.40997],[15.5, 3.342671/5.395167], [17, 3.341669/5.373317], [18, 3.343599/5.333982], 
                  [21, 3.346762/5.29187], [24, 3.352346/5.25208], [26.5, 3.355109/5.214148], [32, 3.359539/5.134447], [34, 3.355769/5.100621], [39, 3.349792/5.059476], 
                  [10.5, 3.361292/5.495634],
                  [12.4, 3.350413/5.466113], [13.3, 3.346409/5.44492]]

"""site_order = [[0.16, 1], [3.5, 1], [4.5, 1], [6.5, 1], [7.5, 1], [9, 1], [10.5, 0.9998], [12.4, 0.95011], [13.3, 0.90684], [14.5, 0.85884], [15.5, 0.85721], [17, 0.85363], [18, 0.82975], [21, 0.8067],
              [24, 0.78419], [26.5, 0.77282], [32, 0.70988], [34, 0.67497], [39, 0.50039], [35.5, 0.5], [31, 0.5], [22, 0.5], [17, 0.50098], [13, 0.5]]"""
site_order = [[0.16, 1], [3.5, 1], [4.5, 1], [6.5, 1], [7.5, 1], [9, 1], [10.5, 0.9998], [12.4, 0.95011], [13.3, 0.90684], [14.5, 0.85884], [15.5, 0.85721], [17, 0.85363], [18, 0.82975], [21, 0.8067],
              [24, 0.78419], [26.5, 0.77282], [32, 0.70988], [34, 0.67497], [39, 0.50039]]


from_lit = [[-0.07832898172323866, 0.9542944785276073],
            [0.6527415143603115, 0.9432515337423312], [1.3315926892950376, 0.9322085889570552], [1.8015665796344642, 0.9233128834355828], [2.4804177545691886, 0.9153374233128834],
            [3.2637075718015627, 0.90920245398773], [5.456919060052217, 0.9049079754601227], [6.135770234986943, 0.9003067484662577],
            [6.343154246100559, 0.601497005988024], [7.642980935875274, 0.6044910179640719], [9.618717504332826, 0.6083832335329342], [9.930675909878753, 0.6119760479041916],
            [11.438474870017417, 0.6074850299401197], [12.738301559792125, 0.61556886227544915], [13.7261698440209, 0.6164670658682635], [14.818024263431651, 0.6173652694610779],
            [17.001733102253155, 0.62395209580838326], [16.63778162911624, 0.6452095808383234], [18.093587521663913, 0.6335329341317366], [20.3812824956674, 0.6392215568862275],
            [21.733102253033096, 0.6592814371257485], [22.772963604852855, 0.6494011976047905], [25.112651646447326, 0.6577844311377245], [27.19237435008685, 0.6613772455089821],
            [29.220103986135392, 0.6679640718562874], [30.519930675910096, 0.668562874251497], [30.46793760831912, 0.6634730538922156], [32.13171577123074, 0.6688622754491018],
            [34.62738301559817, 0.6694610778443113], [35.09532062391707, 0.6706586826347305], [38.16291161178538, 0.6721556886227545]]

#data_sets = [c_over_a_ratio, superconductivity, site_order]
data_sets = [site_order, superconductivity, c_over_a_ratio_1, c_over_a_ratio_2]
y_axis_label_list = ['Site Order', 'Tc (K)', 'c/a (Å)', 'c/a (Å)']

plot_shared_x_data(data_sets, y_axis_label_list=y_axis_label_list, point_size=15, line_thickness= 5, y_tick_labels_list= [[0.6, 0.8, 1.0],[1.5, 2.5, 3.5], [0.6, 0.7, 0.8, 0.9]], file_name = 'superconductivity.pdf')