# Ensemble Generation with STARLING

This notebook will cover how to generate structural ensembles in STARLING.

**Before starting**, install STARLING locally. See https://github.com/idptools/starling/ for more information

***Note***: STARLING can only generate ensembles of sequences up to 380 residues!

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (ensures 3D projection is registered)
from matplotlib import animation
from IPython.display import HTML
import starling
from matplotlib.colors import LinearSegmentedColormap

In [None]:
# first we are going to set the sequence we want to analyze.
# we are going to analyze the N-terminal IDR of Homo sapiens p53, Uniprot ID P04637.
p53='MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'
# we used metapredict V3 to define the region that is the IDR of p53
p53_idr = p53[:103]

# now we are going to generate distance maps for the p53 IDR
# Note: the generate function returns a dictionary, so we are going to first get the ensemble object
p53_ensemble = starling.generate(p53_idr)['sequence_1']
# finally we are going to get out the actual trajectory.
p53_trajectory = p53_ensemble.trajectory

## Visualizing the p53 IDR structural ensemble

Next we are going to visualize the ensemble. There are better ways to visualize this, but to keep from having additional dependencies, we are just going to use Matplotlib

In [None]:
# limit num frames
num_frames = 100

# get the coordinates
raw_data = p53_trajectory.traj.xyz[:num_frames]
data = raw_data
n_frames, n_points, _ = data.shape

# Precompute bounds (use all frames for stable box)
mins = data.min(axis=(0, 1))
maxs = data.max(axis=(0, 1))
center = (mins + maxs) / 2.0
half_range = ((maxs - mins).max() / 2.0) * 1.1

# Color mapping based on global Z
global_z_min = data[:, :, 2].min()
global_z_max = data[:, :, 2].max() +5
z_norm = plt.Normalize(global_z_min, global_z_max)
cmap = plt.cm.Blues_r

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')

def build_segments(coords):
    # If only one point, fabricate a zero-length segment to avoid errors
    if coords.shape[0] < 2:
        c = coords[0]
        return np.array([[[c[0], c[1], c[2]], [c[0], c[1], c[2]]]])
    return np.stack([coords[:-1], coords[1:]], axis=1)

# Seed with first frame segments so collection is NOT empty
initial_segments = build_segments(data[0])
from mpl_toolkits.mplot3d.art3d import Line3DCollection
lc = Line3DCollection(initial_segments, linewidth=2, cmap=cmap)
ax.add_collection3d(lc)

# Remove interior gridlines
ax.grid(False)
# add other lines
for axis in (ax.xaxis, ax.yaxis, ax.zaxis):
    # Hide grid (tick-aligned lattice)
    axis._axinfo['grid']['linewidth'] = 0
    # Keep box edge lines: make pane transparent but leave edgecolor
    axis.pane.set_facecolor((1, 1, 1, 0))   # transparent face
    axis.pane.set_edgecolor('black')        # ensure edge is visible


def set_static_limits():
    ax.set_xlim(center[0] - half_range, center[0] + half_range)
    ax.set_ylim(center[1] - half_range, center[1] + half_range)
    ax.set_zlim(center[2] - half_range, center[2] + half_range)
    ax.set_xlabel("X (Å)")
    ax.set_ylabel("Y (Å)")
    ax.set_zlabel("Z (Å)")

set_static_limits()
ax.set_autoscale_on(False)  # prevent Matplotlib from trying to rescale

# Apply initial coloring
seg_z0 = initial_segments.mean(axis=1)[:, 2]
lc.set_color(cmap(z_norm(seg_z0)))

# Colorbar
sm = plt.cm.ScalarMappable(norm=z_norm, cmap=cmap)
sm.set_array([])
#cbar = fig.colorbar(sm, ax=ax, pad=0.1)
#cbar.set_label("Z (Å)")

def update(frame_index):
    coords = data[frame_index]
    segments = build_segments(coords)
    seg_z = segments.mean(axis=1)[:, 2]
    lc.set_segments(segments)
    lc.set_color(cmap(z_norm(seg_z)))
    ax.set_title(f"Frame {frame_index + 1}/{n_frames}")
    return (lc,)

# Simple init just returns artists (no empty reset to avoid the error)
def init():
    return (lc,)

ani = animation.FuncAnimation(
    fig,
    update,
    init_func=init,
    frames=n_frames,
    interval=100,
    blit=False,
    repeat=True,
)

x0, y0, z0 = mins
x1, y1, z1 = maxs
# adjust for half_range
x0, y0, z0 = center - half_range
x1, y1, z1 = center + half_range
# plot the left most z axis line
ax.plot([x0, x0], [y0, y0], [z0, z1], color='black', linewidth=1)

HTML(ani.to_jshtml())

# OPTIONAL saves:
# ani.save("p53_idr_depth_gradient.gif", writer="pillow", fps=10)
# ani.save("p53_idr_depth_gradient.mp4", writer="ffmpeg", fps=10)