In [16]:
import torch

import lightning as L
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import sys
import os
base_path = os.getcwd() + "/.."

if base_path not in sys.path:
    sys.path.append(base_path)
    
from src.modules.trainer import IsingLightningModule
from src.modules.sampler import DiscreteJarzynskiIntegrator

In [17]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.is_available())
print(device)

def grab(x):
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()
    return x

def ess(At):
    return torch.mean(torch.exp(At))**2 / torch.mean(torch.exp(2*At))

True
cuda:0


## Load model

In [18]:
L        = 15
target     = 'potts'
ckpt = f"{base_path}/ckpts/potts_final.ckpt"
name = "potts_critical"

model = IsingLightningModule.load_from_checkpoint(ckpt).to(device) #The IsingLightningModule includes the Potts case

Js: tensor(1.) tensor(0.)
Bs: tensor(0.) tensor(0.)
betas: tensor(1.0010) tensor(1.0010)
n_cat: 3
Added warm_up of:  1000


## Initialize simulator

In [19]:
n_step = 300
discretization = "exp_2"
if discretization == "uniform":
    ts = torch.linspace(0, final_t, n_step + 1)
elif discretization == "tanh":
    ts = torch.tanh(torch.pi*torch.linspace(0, final_t, n_step + 1))
elif discretization == "exp":
    ts = 1-torch.exp(-torch.linspace(0, 7, n_step + 1))
    ts = ts/ts[-1]
elif discretization == "exp_2":
    ts = 1-torch.exp(-torch.linspace(0, 5, n_step + 1))
    ts = ts**3
    ts = ts/ts[-1]
else:
    raise NotImplementedError()

final_t = 1.0
eps = torch.tensor(n_step).to(device)
L = 15
bs = 100
n_batches = 5 #00
jit = DiscreteJarzynskiIntegrator(
    model.Energy, eps, ts,
    Qt_net=model.net, transport=True,
    n_mcmc_per_net=0, # We don't use any MCMC here for inference (in a real use-case, you probably would)
    n_save=n_step,
    resample=False, 
    resample_thres=0.7, 
    compute_is_weights=True, q=3, 
    model_class='potts',
)

## Evaluate effective sample size (ESS)

In [20]:
sigmas_list = []
As_list = []
for idx in tqdm(range(n_batches)):
    sigma_vec = torch.randint(0, 3, size=(bs, L, L)).to(device)
    sigmas, As = jit.rollout(sigma_vec) 
    
    sigmas = sigmas.detach().cpu()[-1]
    As     = As.detach().cpu()[-1]

    sigmas_list.append(sigmas)
    As_list.append(As)

sigmas = torch.concat(sigmas_list)
As = torch.concat(As_list)
    
sigmas_numpy = grab(sigmas)
As_numpy = grab(As)

np.save(f'sigmas_{name}_disc={discretization}_nsteps={n_step}.npy', sigmas_numpy)
np.save(f'As_{name}_disc={discretization}_nsteps={n_step}.npy', As_numpy)
    
ess_val = ess(As - As.mean())
print(f"ESS of {name}: {ess_val:.4f} (NOTE: THIS IS AN ESS ESTIMATE WITH FEW SAMPLES, NEED TO RUN ON MORE SAMPLES TO GET BETTER ESTIMATE)")

100%|██████████| 5/5 [00:43<00:00,  8.66s/it]

ESS of potts_critical: 0.3075 (NOTE: THIS IS AN ESS ESTIMATE WITH FEW SAMPLES, NEED TO RUN ON MORE SAMPLES TO GET BETTER ESTIMATE)





## Simulate full trajectories to plot evolution of samples

In [5]:
sigma_vec = sigma_vec = torch.randint(0, 3, size=(bs, L, L)).to(device)
sigmas, As = jit.rollout(sigma_vec) 

Perform systematic resampling with IS weights.

In [12]:
from torch.distributions import Categorical
k = 100
is_weights = torch.softmax(As[-1],dim=0)
dist = Categorical(is_weights)
indices = dist.sample((k,))
indices = torch.unique(indices)
sigmas_select = sigmas[:,indices]

## Plot evolution of samples

In [21]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from matplotlib.colors import ListedColormap

def animate_lattices(
    phi,
    n_rows=4,
    n_cols=8,
    interval=200,
    start_pause=5.0,
    end_pause=10.0,
    save_path=None
):
    """
    Animate the evolution of 3‑state Potts lattices over time in a grid,
    with an extra hold at the start and end.

    Parameters:
    - phi: numpy array of shape (n_timesteps, batch_size, L, L),
           values should be in {0,1,2}
    - n_rows, n_cols: grid layout
    - interval: ms between frames
    - start_pause, end_pause: seconds to hold the first/last frame
    - save_path: where to write the GIF (if provided)
    """
    n_timesteps, batch_size, L, _ = phi.shape
    n_plots = min(n_rows * n_cols, batch_size)
    
    # frames for start/end hold
    hold_start = int(start_pause * 1000 / interval)
    hold_end   = int(end_pause   * 2000 / interval)
    frame_seq = [0]*hold_start + list(range(n_timesteps)) + [n_timesteps-1]*hold_end

    # pick the first n_plots lattices
    indices = np.arange(n_plots)

    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=(n_cols * 1.5, n_rows * 1.5))
    axes = axes.flatten()

    # discrete 3‑color map for Potts states 0,1,2
    #palette = sns.color_palette("Set1", n_colors=3)
    #cmap = ListedColormap(palette)
    palette = sns.color_palette("Dark2", n_colors=3)
    # Option B: your own custom hex colors (uncomment to use)
    # palette = ["#6a3d9a", "#ff7f00", "#1f78b4"]
    palette = ["#440154", "#FDE725", "#ff7f00"]
    cmap = ListedColormap(palette)    
    # initialize each subplot
    ims = []
    for idx, ax in zip(indices, axes):
        im = ax.imshow(
            phi[0, idx],
            cmap=cmap,
            origin='lower',
            vmin=0, vmax=2,
            interpolation='nearest'
        )
        ax.axis('off')
        ims.append(im)

    def update(frame_idx):
        t = frame_seq[frame_idx]
        for im, idx in zip(ims, indices):
            im.set_data(phi[t, idx])
        return ims

    anim = animation.FuncAnimation(
        fig,
        update,
        frames=len(frame_seq),
        interval=interval,
        blit=True
    )

    if save_path:
        fps = 6*(1000 / interval)
        anim.save(save_path, writer='pillow', fps=fps)

    plt.close(fig)
    return anim


In [None]:
save_path = "../figures/potts.gif"
animate_lattices(sigmas_select.cpu().numpy(), save_path=save_path)