In [1]:
import torch

import lightning as L
import numpy as np
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import sys
import os
base_path = os.getcwd() + "/.."

if base_path not in sys.path:
    sys.path.append(base_path)
    
from src.modules.trainer import IsingLightningModule
from src.modules.sampler import DiscreteJarzynskiIntegrator

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.is_available())
print(device)

def grab(x):
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()
    return x

def ess(At):
    return torch.mean(torch.exp(At))**2 / torch.mean(torch.exp(2*At))

True
cuda:0


## Load model

In [5]:
L        = 15
ckpt = f"{base_path}/ckpts/ising_final.ckpt"
name = "ising_critical"

model = IsingLightningModule.load_from_checkpoint(ckpt).to(device)

Js: tensor(1.) tensor(0.)
mus: tensor(0.) tensor(0.)
Bs: tensor(0.) tensor(0.)
betas: tensor(0.4407) tensor(0.4407)


## Initialize simulator

In [6]:
final_t = 1.0
n_step = 100
bs = 1000
n_batches = 5 #0 Increase the number of samples for exact ESS estimate, here we decreased it to run faster
ts = torch.linspace(0, final_t, n_step + 1)
eps = torch.tensor(n_step).to(device)
jit = DiscreteJarzynskiIntegrator(
    model.Energy, eps, ts,
    Qt_net=model.net, transport=True,
    n_mcmc_per_net=0, #We do not use MCMC here during inference
    n_save=n_step,
    resample=False, 
    resample_thres=0.7, 
    compute_is_weights=False,
)

## Evaluate effective sample size (ESS)

In [7]:
sigmas_list = []
As_list = []
for idx in tqdm(range(n_batches)):
   
    sigma_vec = 2 * torch.randint(0, 2, size=(bs, L, L)).float().to(device) - 1
    sigmas, As = jit.rollout(sigma_vec) 
    
    sigmas = sigmas.detach().cpu()[-1]
    As     = As.detach().cpu()[-1]

    sigmas_list.append(sigmas)
    As_list.append(As)

100%|██████████| 5/5 [01:28<00:00, 17.75s/it]


In [8]:
sigmas = torch.concat(sigmas_list)
As = torch.concat(As_list)

sigmas_numpy = grab(sigmas)
As_numpy = grab(As)

np.save(f'sigmas_{name}.npy', sigmas_numpy)
np.save(f'As_{name}.npy', As_numpy)


ess_val = ess(As - As.mean())
print(f"ESS of {name}: {ess_val:.4f} (NOTE: THIS IS AN ESS ESTIMATE WITH FEW SAMPLES, NEED TO RUN ON MORE SAMPLES TO GET BETTER ESTIMATE)")

ESS of ising_critical: 0.6577 (NOTE: THIS IS AN ESS ESTIMATE WITH FEW SAMPLES, NEED TO RUN ON MORE SAMPLES TO GET BETTER ESTIMATE)


## Simulate full trajectories to plot evolution of samples

In [None]:
sigma_vec = 2 * torch.randint(0, 2, size=(bs, L, L)).float().to(device) - 1
sigmas, As = jit.rollout(sigma_vec) 

In [9]:
sigmas.shape

torch.Size([101, 1000, 15, 15])

## Plot evolution of samples

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def animate_lattices(
    phi,
    n_rows=4,
    n_cols=8,
    interval=200,
    start_pause=5.0,
    end_pause=10.0,
    save_path=None
):
    """
    Animate the evolution of lattices over time in a grid and save as a GIF,
    with an extra hold at the start and end.

    Parameters:
    - phi: numpy array of shape (n_timesteps, batch_size, L, L)
    - n_rows, n_cols: grid layout
    - interval: ms between frames
    - start_pause, end_pause: seconds to hold the first/last frame
    - save_path: where to write the GIF (if provided)
    """
    n_timesteps, batch_size, L, _ = phi.shape
    n_plots = min(n_rows * n_cols, batch_size)
    
    # compute how many frames correspond to the pause durations
    hold_start = int(start_pause * 1000 / interval)
    hold_end   = int(end_pause   * 1000 / interval)
    
    # build the frame sequence: [0,...,0, 0,1,2,...,T-1, T-1,...,T-1]
    frame_seq = [0] * hold_start \
                + list(range(n_timesteps)) \
                + [n_timesteps - 1] * hold_end

    # which lattice indices to show
    indices = np.arange(n_plots)

    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=(n_cols * 1.5, n_rows * 1.5))
    axes = axes.flatten()

    v = np.max(np.abs(phi))  # fixed color scale
    ims = []
    for idx, ax in zip(indices, axes):
        im = ax.imshow(phi[0, idx],
                       cmap='viridis',
                       origin='lower',
                       vmin=-v, vmax=v)
        ax.axis('off')
        ims.append(im)

    def update(frame_idx):
        actual_t = frame_seq[frame_idx]
        for im, idx in zip(ims, indices):
            im.set_array(phi[actual_t, idx])
        return ims

    anim = animation.FuncAnimation(
        fig,
        update,
        frames=len(frame_seq),
        interval=interval,
        blit=True
    )

    if save_path:
        fps = 5000 / interval
        anim.save(save_path, writer='pillow', fps=fps)

    plt.close(fig)
    return anim

In [12]:
save_path = "../figures/ising.gif"
animate_lattices(sigmas.cpu().numpy(), save_path=save_path)

<matplotlib.animation.FuncAnimation at 0x7fdec5a69430>