# Repaint: A Tutorial

This notebook introduces inpainting for a simple problem in 2 dimensions. We will use an "analytical model" that requires no
training. This analytical model is not appropriate for a general system: it is only useful for demonstration purposes. 
Since there is no training, no actual dataset is needed: the analytical model relies on the assumption that the dataset is normal-distributed. Thus, in what follows, we will speak of the "effective" dataset even if it is never instantiated.

In [None]:
# Define the directory where artifacts will be written. Delete it if it exists to start clean.
import shutil
from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR

output_path = TOP_DIR / "tutorials" / "output" / "tutorial_repaint"
shutil.rmtree(output_path)

output_path.mkdir(parents=True, exist_ok=True)

# The effective dataset

We will consider a regular 2D grid as the equilibrium positions. The effective dataset will be distributed like a simple isotropic Gaussian
centered on this 2D grid with an effective width sigma_d.

In [None]:
# Basic imports and defining global variables
import torch

from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import NoiseParameters
from utilities import get_2d_grid_equilibrium_relative_coordinates

# Define the regular grid that represents the equilibrium relative coordinates
n = 4

# Define the Gaussian width of the effective dataset.
sigma_d = 0.02

equilibrium_relative_coordinates = get_2d_grid_equilibrium_relative_coordinates(n=n)
mu = torch.tensor(equilibrium_relative_coordinates)

number_of_atoms = len(equilibrium_relative_coordinates)
spatial_dimension = 2
elements = ["X"] # Just a dummy name.


# Define the repaint constraint for the 4 atoms in the center of the grid. The constraint will
# be a 90 degrees rotation of these points about the center of the unit cell.
mask_grid = torch.zeros([n, n], dtype=torch.bool)
mid = n // 2
mask_grid[mid - 1: mid + 1, mid - 1: mid + 1] = True
constrained_indices = torch.arange(n**2)[mask_grid.flatten()]

theta = torch.tensor(torch.pi / 4)

center = torch.tensor([0.5, 0.5])
rotation_matrix = torch.tensor([[ torch.cos(theta), torch.sin(theta)],
                                [-torch.sin(theta), torch.cos(theta)]])

constrained_relative_coordinates = torch.matmul(mu[constrained_indices] - center, rotation_matrix.T) + center


# The Analytical Score Model

We will not train a learnable score model. Here, in the interest of time, we will use an analytical model that creates the score
exactly. This is not available in general: it can be computed for this idealized situation.

In [None]:
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import \
    AnalyticalScoreNetworkParameters, AnalyticalScoreNetwork

score_network_parameters = AnalyticalScoreNetworkParameters(number_of_atoms=number_of_atoms,
                                                            spatial_dimension=spatial_dimension,
                                                            num_atom_types=1,
                                                            kmax=4,
                                                            equilibrium_relative_coordinates=equilibrium_relative_coordinates,
                                                            sigma_d=sigma_d)

axl_network = AnalyticalScoreNetwork(score_network_parameters)

# Sampling
We can draw samples with the analytical score network. In order to do so, we create a "generator" which is responsible for
creating new samples by using the analytical axl_network to denoise random starting points.

In [None]:
from diffusion_for_multi_scale_molecular_dynamics.generators.sampling_constraint import SamplingConstraint
from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \
    ConstrainedLangevinGenerator
from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling import create_batch_of_samples
import torch
from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \
    PredictorCorrectorSamplingParameters
from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import LangevinGenerator


# We must define the parameters of this noising process.
noise_parameters = NoiseParameters(total_time_steps=25,
                                   schedule_type="exponential",
                                   sigma_min=0.001,
                                   sigma_max=0.2)


# Sampling parameters
cell_dimensions = [1.0, 1.0]

# Define the sampling parameters. We will draw a single sample, and we will record the corresponding trajectory during 
# diffusion to see what it looks like.
sampling_parameters = PredictorCorrectorSamplingParameters(number_of_samples=1,
                                                           spatial_dimension=spatial_dimension,
                                                           number_of_corrector_steps=1,
                                                           num_atom_types=1,
                                                           number_of_atoms=number_of_atoms,
                                                           use_fixed_lattice_parameters=True,
                                                           cell_dimensions=cell_dimensions,
                                                           record_samples=True)

# Define an unconstrained generator. This should generate samples from the effective dataset distribution.
generator = LangevinGenerator(noise_parameters=noise_parameters,
                              sampling_parameters=sampling_parameters,
                              axl_network=axl_network)

# Define a constrained generator. We specify the "constrained indices" because the model is not equivariant. 
# An equivariant model wouldn't "need to know" which index are constrained...
sampling_constraint = SamplingConstraint(elements=elements,
                                         constrained_relative_coordinates=constrained_relative_coordinates,
                                         constrained_atom_types=torch.zeros_like(constrained_indices),
                                         constrained_indices=constrained_indices)

constrained_generator = ConstrainedLangevinGenerator(noise_parameters=noise_parameters,
                                                     sampling_parameters=sampling_parameters,
                                                     axl_network=axl_network,
                                                     sampling_constraints=sampling_constraint)

# Draw samples, both free and constrained.
with torch.no_grad():
    device = torch.device('cpu')
    samples_batch = create_batch_of_samples(generator=generator,
                                            sampling_parameters=sampling_parameters,
                                            device=device)

    constrained_samples_batch = create_batch_of_samples(generator=constrained_generator,
                                                        sampling_parameters=sampling_parameters,
                                                        device=device)

# Visualizing Trajectories

We can can gauge the quality of the samples by looking at videos of the sampling trajectories.

The videos show a representation of the effective datasets in terms of isosurfaces (concentric blue circles) as well as the evolving
sample relative coordinates as time goes from 1 (fully noised configuration) to 0 (the "data space").

The "free sample" aim to look like the effective dataset, wheras the "constrained sample" has atoms 

In [None]:
from utilities import create_2d_trajectory_video
import einops
from diffusion_for_multi_scale_molecular_dynamics import TOP_DIR

list_generators = [generator, constrained_generator]
list_output_filenames = ["free_trajectory_video.mp4", "constrained_trajectory_video.mp4"]

for gen, output_filename in zip(list_generators, list_output_filenames):
    # The trajectory is held internally in the trajectory recorder object
    list_x = []
    for step_dictionary in gen.sample_trajectory_recorder._internal_data['predictor_step']:
        list_x.append(step_dictionary['composition_im1'].X)

    trajectories = einops.rearrange(list_x, "time batch natoms d -> batch time natoms d")

    trajectory = trajectories[0]
    video_output_path = output_path / output_filename
    create_2d_trajectory_video(trajectory, mu, constrained_relative_coordinates, sigma_d, video_output_path)

In [None]:
from IPython.display import Video

free_trajectory_video = str(output_path  / "free_trajectory_video.mp4")
constrained_trajectory_video = str(output_path  / "constrained_trajectory_video.mp4")

# We can now visualize what the diffusion trajectories look like.

In [None]:
# For free diffusion
Video(free_trajectory_video, embed=True)

In [None]:
# For constrained diffusion
Video(constrained_trajectory_video, embed=True)