In [None]:
from typing import Optional, Tuple, Union

from diffusers import DiffusionPipeline, DDIMScheduler, DDPMScheduler
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from ProteinDiffusion.models.UNet1DProtein import UNet1DProtein
import torch

class VaeProteinProcessor(ConfigMixin):
    # https://github.com/huggingface/diffusers/blob/main/src/diffusers/image_processor.py#L60
    def __init__(self):
        super().__init__()

        self.register_modules(unet=unet, scheduler=scheduler)

def latent_to_seq(latents: torch.Tensor) -> torch.Tensor:
    latents = latents.clamp(-1, 1)
    latents = (latents + 1) / 2
    return latents

class ProteinDiffusionPipeline(DiffusionPipeline):
    def __init__(self, 
                 unet: UNet1DProtein,
                 scheduler: Union[DDPMScheduler],
                 aa_encoder: PreTrainedTokenizerFast,
                 aa_decoder: Tokenizer,
                 ):
        super().__init__()

        self.register_modules(unet=unet, scheduler=scheduler)

    def __call__(
        self,
        seq_length: Optional[int] = 256,
        num_inference_steps: Optional[int] = 50,
        generator: Optional[torch.Generator] = None,
        batch_size: Optional[int] = 1,
        output_type: Optional[str] = "aa",
        return_dict: bool = True,
        **kwargs,
    ) -> Union[Tuple]:
        latents = torch.randn(
            (batch_size, 
             self.unet.config.in_channels, 
             seq_length
             ),
            generator=generator,
        )

        latents = latents.to(self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):
            # predict the noise residual
            noise_pred = self.unet(latents, t).sample

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        latents = latent_to_seq(latents)

        if output_type == "aa":
            seq = self.aa_decoder.decode(latents)

        if not return_dict:
            return (seq,)

        return ImagePipelineOutput(images=image)