In [None]:
import torch
from torch import Tensor
import math

class Scheduler:
    def __init__(
        self,
        num_steps : int,
        min_beta : float,
        max_beta : float,
        T : float | None = None
        ):
        self.num_steps = num_steps
        self.timesteps = torch.arange(num_steps + 1)
        self.delta_t = 1 / num_steps
        
        self.timesteps = torch.arange(num_steps + 1)
        betas = torch.zeros(num_steps)
        first_betas_len = math.ceil(num_steps / 2)
        betas[:first_betas_len] = torch.linspace(min_beta, max_beta, first_betas_len)
        betas[-first_betas_len:] = torch.flip(betas[:first_betas_len], [0])
        self.betas = torch.cat([torch.zeros(1), betas])
        self.sigmas_2 = torch.cumsum(self.betas, 0)
        self.sigmas_2_bar = torch.flip(self.sigmas_2, [0])
        
        if T is not None:
            self.betas = self.betas / self.betas.sum() * T
        else:
            self.T = self.betas.sum()
        
        sigmas_2 = torch.cumsum(self.betas, 0)
        self.sigmas_2 = torch.cat([torch.zeros(1), sigmas_2])
        self.sigmas_2_bar = torch.flip(self.sigmas_2, [0])
        
    def shape_for_constant(self, shape : tuple[int]) -> tuple[int]:
        return [-1] + [1] * (len(shape) - 1) 
    
    def to_tensor(self, value : int, batch_size : int) -> Tensor:
        return torch.full((batch_size,), value, dtype=torch.long)
    
    def reshape_constants(self, constants : list[Tensor], shape : tuple[int]) -> list[Tensor]:
        return [constant.view(self.shape_for_constant(shape)) for constant in constants]     
 
    def sample_timestep(self, batch_size : int) -> Tensor:
        return torch.randint(1, self.num_steps + 1, (batch_size,))
    
    def sample_posterior(self, x0 : Tensor, x1 : Tensor, n : int | None = None) -> tuple[Tensor, Tensor]:
        shape = x0.shape
        batch_size = shape[0]
        ns = self.to_tensor(n, batch_size) if n is not None else self.sample_timestep(batch_size)
        sigmas_2 = self.sigmas_2[ns]
        sigmas_2_bar = self.sigmas_2_bar[ns]
        mu_1, mu_2, sigma = self.gaussian_product(sigmas_2_bar, sigmas_2)
        mu_1, mu_2, sigma = self.reshape_constants([mu_1, mu_2, sigma], shape)
        mu = mu_1 * x0 + mu_2 * x1
        std = torch.sqrt(sigma)
        return mu + std * torch.randn_like(mu), ns
        
    def gaussian_product(self, v1 : Tensor, v2 : Tensor) -> tuple[Tensor, Tensor]:
        mu_1 = v1 / (v1 + v2)
        mu_2 = v2 / (v1 + v2)
        sigma = v1 * v2 / (v1 + v2)
        
        return mu_1, mu_2, sigma
    
scheduler = Scheduler(12, 1, 10)
print(scheduler.gaussian_product(scheduler.sigmas_2_bar[0], scheduler.sigmas_2[0]))
print(scheduler.gaussian_product(scheduler.sigmas_2_bar[-1], scheduler.sigmas_2[-1]))

(tensor(1.), tensor(0.), tensor(0.))
(tensor(0.), tensor(1.), tensor(0.))


In [4]:
from src.lightning_modules import I2SB
from src.networks import UNet2D
import torch

network = UNet2D(
    in_channels = 1,
    out_channels = 1,
    block_out_channels = [32, 32, 32],
    down_block_types = ['DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
    up_block_types = ['UpBlock2D', 'UpBlock2D', 'UpBlock2D']
)

model = I2SB(
    model = network,
    num_steps = 10,
    min_beta = 0.1,
    max_beta = 1,
    T = 1.0,
)

In [75]:
import torch
import math

num_steps = 5
timesteps = torch.arange(num_steps + 1)
betas = torch.zeros(num_steps)
first_betas_len = math.ceil(num_steps / 2)
betas[:first_betas_len] = torch.linspace(0.1, 1, first_betas_len)
betas[-first_betas_len:] = torch.flip(betas[:first_betas_len], [0])
betas = torch.cat([torch.zeros(1), betas])
sigmas_2 = torch.cumsum(betas, 0)
sigmas_2_bar = torch.flip(sigmas_2, [0])

print("timesteps:", len(timesteps))
print("betas:", len(betas))
print(betas)
print("sigmas_2:", len(sigmas_2))
print("sigmas_2_bar:", len(sigmas_2_bar))

timesteps: 6
betas: 6
tensor([0.0000, 0.1000, 0.5500, 1.0000, 0.5500, 0.1000])
sigmas_2: 6
sigmas_2_bar: 6


In [65]:
import math
math.ceil(6 / 2)

3