In [None]:
# Some useful modules for notebooks
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from diffusers import DiffusionPipeline
from scipy import signal
import numpy as np
import torch
import matplotlib.pyplot as plt
import h5py
from tqdne.conf import DATASETDIR
from pathlib import Path
from tqdne.dataset import H5Dataset, RandomDataset
from torch.utils.data import DataLoader

from diffusers import UNet1DModel
from diffusers import DDPMScheduler

from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pathlib import Path
from pytorch_lightning.loggers import WandbLogger

from tqdne.conf import OUTPUTDIR


In [None]:
# Create very simple synthetic dataset

t = (5501 // 16) * 16
batch_size = 16


# path_train = DATASETDIR / Path("data_train.h5")
# path_test = DATASETDIR / Path("data_test.h5")
# train_dataset = H5Dataset(path_train, cut=t)
# test_dataset = H5Dataset(path_test, cut=t)

train_dataset = RandomDataset(1024*8)
test_dataset = RandomDataset(512)

train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
low_res, high_res = train_dataset[0]
low_res.shape, high_res.shape

In [None]:
fs = 100
time = np.arange(0, t)/fs

plt.figure(figsize=(6, 3))
plt.plot(time ,low_res[0].numpy(), 'b')
plt.plot(time, high_res[0].numpy(), 'r')
plt.xlim(1, 5)
plt.xlabel("Time (s)")

In [None]:
batch_low, batch_high = next(iter(train_loader))
batch_low.shape, batch_high.shape

In [None]:
max_epochs = 50

# Unet parameters
unet_params = {
    "sample_size":t,
    "in_channels":1,
    "out_channels":1,
    "block_out_channels":  (32, 32, 64 ),
    "down_block_types": ('DownBlock1D', 'DownBlock1D', 'AttnDownBlock1D'),
    "up_block_types": ('AttnUpBlock1D', 'UpBlock1D', 'UpBlock1D'),
    "mid_block_type": 'UNetMidBlock1D',
    "extra_in_channels" : 1
}

scheduler_params = {
    "beta_schedule": "linear",
    "beta_start": 0.0001,
    "beta_end": 0.02,
    "num_train_timesteps": 100,
}

optimizer_params = {
    "learning_rate": 1e-4,
    "lr_warmup_steps": 500,
    "n_train": len(train_dataset) // batch_size,
    "seed": 0,
    "batch_size": batch_size,
    "max_epochs": max_epochs,
}

trainer_params = {
    # trainer parameters
    "accumulate_grad_batches": 1,
    "gradient_clip_val": 1,
    "precision": "32-true",  
    # Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
    # 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
    # Can be used on CPU, GPU, TPUs, HPUs or IPUs.
    "max_epochs": max_epochs,
    "accelerator": "auto",
    "devices": "auto",
    "num_nodes": 1}


In [None]:

net = UNet1DModel(**unet_params)
net.config

In [None]:
def to_inputs(low_res, high_res):
    """Build Unet inputs from low and high resolution data."""
    return torch.cat((low_res, high_res), dim=1)
high_resn = torch.rand(batch_size, 1,t)

inputs = to_inputs(batch_low, high_resn)
timesteps = torch.LongTensor([150]*batch_size)
print(inputs.shape)
assert net(inputs, timesteps).sample.shape == batch_high.shape


In [None]:
scheduler = DDPMScheduler(**scheduler_params)
scheduler.config

In [None]:
noise = torch.randn(batch_high.shape)
timesteps = torch.LongTensor([50]*batch_size)
noisy_sig = scheduler.add_noise(batch_high, noise, timesteps)
plt.figure(figsize=(6, 3))
plt.plot(time, noisy_sig[0,0].numpy(), 'b', label="noisy")
plt.plot(time, batch_high[0,0].numpy(), 'r',  label="original")
plt.xlim(1, 5)
plt.legend()

In [None]:
# # this is probably wrong because of the conditioning
# import tqdm

# def sample(noise):
#     sample = noise
#     for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
#         # 1. predict noise residual
#         with torch.no_grad():
#             residual = net(sample, t).sample
#         # 2. compute less noisy image and set x_t -> x_t-1
#         sample = scheduler.step(residual, t, sample).prev_sample

#     return sample

In [None]:
# import torch.nn.functional as F

# sample = train_dataset[0]
# sig = sample.unsqueeze(0)
# print(sig.shape)
# noise = torch.randn(sig[:,:1].shape)
# timesteps = torch.LongTensor([150])
# noisy_sig = scheduler.add_noise(sig[:,:1], noise, timesteps)
# noisy_sig = torch.concat([noisy_sig, sig[:,1:]], dim=1)
# noise_pred = net(noisy_sig, timesteps).sample
# loss = F.mse_loss(noise_pred, noise)

In [None]:

from typing import List, Optional, Tuple, Union

import torch

from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, AudioPipelineOutput


class DDPMPipeline1DCond(DiffusionPipeline):
    r"""
    Pipeline for image generation.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Parameters:
        unet ([`UNet1DModel`]):
            A `UNet1DModel` to denoise the encoded audio latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
            [`DDPMScheduler`], or [`DDIMScheduler`].
    """
    model_cpu_offload_seq = "unet"

    def __init__(self, unet, scheduler):
        super().__init__()
        self.num_inference_steps = scheduler.num_train_timesteps
        self.register_modules(unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(
        self,
        low_res: torch.Tensor,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        num_inference_steps: Optional[int] = None,
        return_dict: bool = True,
    ) -> Union[AudioPipelineOutput, Tuple]:
        r"""
        The call function to the pipeline for generation.

        Args:
            batch_size (`int`, *optional*, defaults to 1):
                The number of images to generate.
            generator (`torch.Generator`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            num_inference_steps (`int`, *optional*, defaults to 1000):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images
        """
        batch_size, channels, t = low_res.shape
        assert self.unet.config.in_channels == channels
        # Sample gaussian noise to begin loop
        sig_shape = low_res.shape
        assert self.unet.config.extra_in_channels == channels


        if self.device.type == "mps":
            # randn does not work reproducibly on mps
            sig = randn_tensor(sig_shape, generator=generator)
            sig = sig.to(self.device)
        else:
            sig = randn_tensor(sig_shape, generator=generator, device=self.device)

        # set step values
        if num_inference_steps is None:
            num_inference_steps = self.num_inference_steps
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.progress_bar(self.scheduler.timesteps):

            inputs = to_inputs(low_res, sig)

            # 1. predict noise model_output
            model_output = self.unet(inputs, t).sample

            # 2. compute previous image: x_t -> x_t-1
            sig = self.scheduler.step(model_output, t, sig, generator=generator).prev_sample



        if not return_dict:
            return (sig,)

        return AudioPipelineOutput(audios=sig)
    
pipeline = DDPMPipeline1DCond(net, scheduler)

In [None]:
def evaluate(low_res, pipeline):
    # Sample some signaol from random noise (this is the backward diffusion process).
    sig = pipeline(
        low_res = low_res,
        generator=torch.manual_seed(optimizer_params["seed"]),
    ).audios

    return sig

batch_low, batch_high = next(iter(train_loader))
gen_high = evaluate(batch_low, pipeline)

plt.plot(time, batch_high[0,0].numpy(), 'b', label="high res")
plt.plot(time, batch_low[0,0].numpy(), 'r', label="low res")
plt.plot(time, gen_high[0,0].numpy(), 'g', alpha=0.5, label="generated")
plt.legend()
plt.xlim(1, 5)


In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup
from torch.nn import functional as F
import pytorch_lightning as pl

def to_inputs(low_res, high_res):
    """Build Unet inputs from low and high resolution data."""
    return torch.cat((low_res, high_res), dim=1)

class LightningDDMP(pl.LightningModule):
    """A PyTorch Lightning module for training a diffusion model.

    Parameters
    ----------
    net : torch.nn.Module
        A PyTorch neural network.
    noise_scheduler : DDPMScheduler
        A scheduler for adding noise to the clean images.
    config : TrainingConfig
        A dataclass containing the training configuration.

    """

    def __init__(self, net: torch.nn.Module, noise_scheduler: DDPMScheduler, optimizer_params:dict):
        super().__init__()

        self.net = net
        self.optimizer_params = optimizer_params
        self.noise_scheduler = noise_scheduler
        self.pipeline = DDPMPipeline1DCond(self.net, self.noise_scheduler)
        self.save_hyperparameters()


    # def forward(self, x: torch.Tensor):
    #     return self.net(x)
    def evaluate(self, low_res):
        # Sample some signaol from random noise (this is the backward diffusion process).
        sig = self.pipeline(
            low_res = low_res,
            generator=torch.manual_seed(optimizer_params["seed"]),
        ).audios

        return sig

    def cross_entropy_loss(self, logits: torch.Tensor, labels: torch.Tensor):
        return F.nll_loss(logits, labels)

    def global_step(self, batch: List, batch_idx: int, train: bool = False):
        low_res, high_res = batch

        # Sample noise to add to the high_res
        noise = torch.randn(high_res.shape).to(high_res.device)
        batch_size = high_res.shape[0]

        # Sample a random timestep for each signal
        timesteps = torch.randint(
            0, self.noise_scheduler.num_train_timesteps, (batch_size,), device=high_res.device
        ).long()

        # Add noise to the clean high_res according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_hig_res = self.noise_scheduler.add_noise(high_res, noise, timesteps)

        # Predict the noise residual
        inputs = to_inputs(low_res, noisy_hig_res)
        noise_pred = self.net(inputs, timesteps, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise)
        if train:
            self.log("train_loss", loss, prog_bar=True)
        else:
            self.log("val_loss", loss, prog_bar=True)
        return loss

    def training_step(self, train_batch: List, batch_idx: int):
        return self.global_step(train_batch, batch_idx, train=True)

    def validation_step(self, val_batch: List, batch_idx: int):
        return self.global_step(val_batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(net.parameters(), lr=self.optimizer_params["learning_rate"])
        lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=self.optimizer_params["lr_warmup_steps"],
            num_training_steps=(self.optimizer_params["n_train"] * self.optimizer_params["max_epochs"]),
        )
        return [optimizer], [lr_scheduler]

model = LightningDDMP(net, scheduler, optimizer_params)

In [None]:
name = '1D-UNET'

# 1. Wandb Logger
wandb_logger = WandbLogger() # add project='projectname' to log to a specific project

# 2. Learning Rate Logger
lr_logger = LearningRateMonitor()
# 3. Set Early Stopping
early_stopping = EarlyStopping('val_loss', mode='min', patience=5)
# 4. saves checkpoints to 'model_path' whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(dirpath=OUTPUTDIR / Path(name), filename='{name}_{epoch}-{val_loss:.2f}',
                                      monitor='val_loss', mode='min', save_top_k=5)

(OUTPUTDIR/Path(name)).mkdir(parents=True, exist_ok=True)
# Define Trainer
trainer = pl.Trainer(**trainer_params, logger=wandb_logger, callbacks=[lr_logger, early_stopping, checkpoint_callback], 
                     default_root_dir=OUTPUTDIR/Path(name)) 

In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=test_loader)