In [None]:
from init_notebook import *

In [None]:
trainer = load_experiment_trainer("../experiments/diffusion/baseline.yml", device="cuda")
trainer

In [None]:
print(trainer.checkpoint_path)
trainer.load_checkpoint()

In [None]:
size = 4
with torch.no_grad():
    images = trainer.generate_images(size*size, (1, 128, 128), steps=2)
VF.to_pil_image(make_grid(images, nrow=size))

In [None]:
class DiffusionSampler:

    def __init__(
            self, 
            model: nn.Module,
            channels: int,
    ):
        self.model = model
        self.channels = channels
        self._device = None

    @property
    def device(self):
        if self._device is None:
            for p in self.model.parameters():
                self._device = p.device
                break
        return self._device

    def predict_noise(
            self,
            images: torch.Tensor,
            noise_amounts: Optional[torch.Tensor] = None,
    ):
        assert images.ndim == 4, f"Got {images.shape}"
        assert images.shape[1] == self.channels, f"Got {images.shape}"
        
        if noise_amounts is None:
            noise_amounts = self.create_noise_amount(images.shape[0], 1, 1)

        assert noise_amounts.ndim == 2, f"Got {noise_amounts.shape}"
        assert noise_amounts.shape[0] == images.shape[0], f"Got: {noise_amounts.shape}"
        
        #embedding = noise_amounts[:, None, None, :].repeat(*images.shape) 
        
        return self.model(
            images.clamp(-1, 1),
            noise_amounts,
        )
        
    def _to_generator(self, seed: Union[None, int, torch.Generator]) -> Optional[torch.Generator]:
        if seed is None:
            return None #torch.Generator()
        elif isinstance(seed, torch.Generator):
            return seed
        else:
            return torch.Generator().manual_seed(seed)

    def create_noise_amount(
            self, 
            batch_size: int, 
            minimum: float = 0.001,
            maximum: float = 1.,
            seed: Union[None, int, torch.Generator] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        gen = self._to_generator(seed)
        
        amounts = torch.rand((batch_size, 1), generator=gen).to(self.device)
        if minimum == maximum:
            return amounts * maximum
        return amounts * ((maximum - minimum) + minimum)
            
    def create_noise(
            self, 
            batch_size: int, 
            shape: Tuple[int, int], 
            seed: Union[None, int, torch.Generator] = None,
    ) -> torch.Tensor:
        gen = self._to_generator(seed)
        
        return torch.randn((batch_size, self.channels, *shape), generator=gen).to(self.device) 
        #* noise_amounts[:, None, None, :]

    def denoise_image(
            self,
            images: torch.Tensor,
            noise_amounts: Optional[torch.Tensor] = None,
            strength: float = 1.,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert images.ndim == 4, f"Got {images.shape}"
        
        if noise_amounts is None:
            noise_amounts = self.create_noise_amount(images.shape[0], 1, 1)

        estimated_noise = self.predict_noise(images, noise_amounts)

        images = images - estimated_noise * strength
        noise_amounts = noise_amounts - noise_amounts * strength

        return images, noise_amounts
        
    def generate_images(
            self,
            batch_size,
            shape: Tuple[int, int],
            steps: int = 10,
            seed: Optional[int] = None,
            method: int = 1,
    ):
        gen = self._to_generator(seed)

        noise = self.create_noise(batch_size, shape, seed=gen)
        noise_amounts = self.create_noise_amount(batch_size, 1., 1.)
        
        noisy_images = noise

        for step in range(steps):
            estimated_noise = self.predict_noise(noisy_images, noise_amounts)
            outputs = noisy_images - estimated_noise

            if step < steps - 1:
                noise = self.create_noise(batch_size, shape, seed=gen)
                noisy_images = outputs + noise * (1. - (step + 1) / steps)
                noise_amounts = noise_amounts - noise_amounts / steps

        return outputs

sampler = DiffusionSampler(
    model=trainer.model,
    channels=1,
)
VF.to_pil_image(make_grid(
    sampler.generate_images(4, (32, 32)) * .5 + .5
))