-
Notifications
You must be signed in to change notification settings - Fork 7
/
sampler.mojo
124 lines (113 loc) · 4.88 KB
/
sampler.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from helpers.utils import *
from math import round
struct DDPMSampler:
var seed_val: Int
var num_training_steps: Int
var betas: Tensor[float_dtype]
var alphas: Tensor[float_dtype]
var alphas_cumprod: Tensor[float_dtype]
var timesteps: Tensor[float_dtype]
var num_inference_steps: Int
var start_step: Int
fn __init__(
inout self,
seed_val: Int = 0,
# Setting this to 10 for illustrative purposes, since we are not interested in training. Typical values would be around 1000
num_training_steps: Int = 10,
beta_start: Float32 = 0.00085,
beta_end: Float32 = 0.0120,
):
# Setting this to 1 since I am intersted in demonstrating a single forward pass
self.num_inference_steps = 1
self.start_step = 0
self.seed_val = seed_val
self.num_training_steps = num_training_steps
self.betas = (
linspace(beta_start**0.5, beta_end**0.5, num_training_steps) ** 2
)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = cumprod(self.alphas)
self.timesteps = arange(0, num_training_steps, True)
fn set_inference_timesteps(
inout self,
num_inference_steps: Int = 1,
):
self.num_inference_steps = num_inference_steps
var step_ratio: Float32 = self.num_training_steps // self.num_inference_steps
var timesteps = round_tensor(
arange(0, self.num_inference_steps, True) * step_ratio
)
self.timesteps = timesteps
fn get_previous_timestep(
inout self,
timestep: Int,
) -> Int:
var prev_t = timestep - self.num_training_steps // self.num_inference_steps
return prev_t
fn get_variance(
inout self,
timestep: Int,
) -> Float32:
var prev_t = self.get_previous_timestep(timestep)
var alpha_prod_t = self.alphas_cumprod[timestep]
var alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else 1
var current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
var variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# Preventing zero values
variance = variance.max(1e-20)
return variance[0]
fn set_strength(inout self, strength: Float32):
var start_step = self.num_inference_steps - int(
self.num_inference_steps * strength
)
var timesteps_length = self.timesteps.num_elements()
self.timesteps = get_tensor_values(self.timesteps, start_step, start_step + timesteps_length)
self.start_step = start_step
fn step(
inout self,
timestep: Int,
latents: Matrix[float_dtype],
model_output: Matrix[float_dtype],
) -> Matrix[float_dtype]:
var prev_t = self.get_previous_timestep(timestep)
var alpha_prod = self.alphas_cumprod[timestep]
var alpha_prod_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else 1
var beta_prod = 1 - alpha_prod
var beta_prod_prev = 1 - alpha_prod_prev
var current_alpha = alpha_prod / alpha_prod_prev
var current_beta = 1 - current_alpha
var alpha_prod_final = alpha_prod[0]
var beta_prod_final = beta_prod[0]
var pred_original_sample = (
latents - (model_output) * (beta_prod_final ** (0.5))
) / (alpha_prod_final ** (0.5))
var pred_original_sample_coefficient: Float32 = (
alpha_prod_prev ** (0.5) * current_beta
) / beta_prod
var current_sample_coefficient = current_alpha ** (
0.5
) * beta_prod_prev / beta_prod
var pred_previous_sample = pred_original_sample * pred_original_sample_coefficient + latents * current_sample_coefficient
if timestep > 0:
var noise = Matrix[float_dtype](
model_output.dim0, model_output.dim1, model_output.dim2
)
noise.init_weights_seed(self.seed_val)
var multiplier = (self.get_variance(timestep) ** 0.5)
var variance = noise * multiplier
pred_previous_sample = pred_previous_sample + variance
return pred_previous_sample
fn add_noise(
inout self, original_samples: Matrix[float_dtype], timestep: Float32
) -> Matrix[float_dtype]:
var int_timestep = int(timestep)
var sqrt_alpha_prod = self.alphas_cumprod[int_timestep] ** 0.5
var sqrt_alpha_prod_scalar = sqrt_alpha_prod[0]
var sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[int_timestep]) ** 0.5
var sqrt_one_minus_alpha_prod_scalar = sqrt_one_minus_alpha_prod[0]
var noise = Matrix[float_dtype](
original_samples.dim0, original_samples.dim1, original_samples.dim2
)
noise.init_weights_seed(self.seed_val)
var noisy_samples = original_samples * sqrt_alpha_prod_scalar + noise * sqrt_one_minus_alpha_prod_scalar
return noisy_samples