In [None]:
"""
Plot AlphaFold3 PAE matrix in AF3 style.
Input: JSON file containing key 'predicted_aligned_error'.
"""

import json
import sys
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def plot_pae(json_file, out_file="pae_plot.png"):
    # load data
    with open(json_file, "r") as f:
        data = json.load(f)

    pae = np.array(data["pae"])
    n = pae.shape[0]

    # Create custom colormap from dark purple to light purple
    colors = ['#41005b', '#e2a1fc']
    custom_cmap = LinearSegmentedColormap.from_list('custom_purple', colors)

    # Custom style: dark purple = low error, light purple = high error
    fig, ax = plt.subplots(figsize=(6, 6))
    im = ax.imshow(
        pae,
        cmap=custom_cmap,   # Custom purple palette
        origin="upper",
        vmin=0,
        vmax=30             # AF convention: clip at 30 Å
    )

    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Predicted aligned error (Å)")

    # Set tick positions and labels to start from 1
    ax.set_xticks(range(0, n, max(1, n//10)))
    ax.set_xticklabels([str(i+1) for i in range(0, n, max(1, n//10))])
    ax.set_yticks(range(0, n, max(1, n//10)))
    ax.set_yticklabels([str(i+1) for i in range(0, n, max(1, n//10))])
    
    ax.set_xlabel("Expected position error (Angstrom)")
    ax.set_ylabel("Aligned residue")
    ax.set_title("Predicted aligned error (PAE)")

    plt.tight_layout()
    plt.savefig(out_file, dpi=300)
    plt.show()


if __name__ == "__main__":
    infile = "/home/markus/MPI_local/data/plotting/o00629_1-521_q14814_1-521/o00629_1-521_q14814_1-521/o00629_1-521_q14814_1-521_confidences.json"
    outfile = '/home/markus/Desktop/Thesis/o00629_1-521_q14814_1-521_pae_plot.png'
    plot_pae(infile, outfile)