In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
from diffusers.training_utils import compute_loss_weighting_for_sd3
import torch
import matplotlib.pyplot as plt


for schema in [#'sigma_sqrt',
     'cosmap']:
    print(schema)
    print(compute_loss_weighting_for_sd3(weighting_scheme=schema, sigmas=torch.linspace(0.0001, 1, 20)))
    _ = plt.plot(compute_loss_weighting_for_sd3(weighting_scheme=schema, sigmas=torch.linspace(0.0001, 1, 1000)))



In [None]:
def compute_flow_matching_min_snr_weights(t, min_snr_gamma):
        """
        Compute min-SNR(gamma) weighting from timesteps.

        For flow-matching, SNR(t) ≈ t²/(1-t)² based on the interpolation schedule.
        We clip this to prevent extreme weights at t→0 or t→1.

        Args:
            t: Timesteps [B]

        Returns:
            Weights [B]
        """
        # Avoid division by zero at boundaries
        t_clamped = torch.clamp(t, min=1e-5, max=1 - 1e-5)

        # Compute SNR for flow-matching interpolation
        # SNR(t) = signal²/noise² = t²/(1-t)²
        snr = (t_clamped ** 2) / ((1 - t_clamped) ** 2)

        # Apply min-SNR clipping
        snr_capped = torch.minimum(snr, torch.tensor(min_snr_gamma))

        # Weight is inverse of capped SNR (higher weight for harder timesteps)
        # Add 1 to prevent division issues when snr=0
        weights = 1.0 / (snr_capped + 1.0)

        return weights

_ = plt.plot(compute_flow_matching_min_snr_weights(t=torch.linspace(0.01, 1, 1000), min_snr_gamma=5))


In [None]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained('/workspace/models/sdsf-97k-1e5qrt-12k-4e6sqrt-11k-h')
pipe.to('cuda')

In [None]:
from diffusers import UNet2DConditionModel
import torch
encoder_hidden_states = torch.randn((1, 77, 1024), device=pipe.device)
sample = torch.randn((1, 4, 64, 64), device=pipe.device)
t = 1
torch.set_grad_enabled(False)
unet: UNet2DConditionModel = pipe.unet
unet.forward(sample, t, encoder_hidden_states)

In [None]:
from diffusers import FlowMatchEulerDiscreteScheduler, DDPMScheduler

ddpm_scheduler = DDPMScheduler()
ddpm_scheduler.alphas_cumprod


In [None]:
flow_match_scheduler = FlowMatchEulerDiscreteScheduler()
flow_match_scheduler.sigmas


In [None]:
flow_match_scheduler.set_timesteps(num_inference_steps=999)

In [None]:
flow_match_scheduler.sigmas

In [None]:
flow_match_scheduler.config.num_train_timesteps

In [None]:
import torch
from flow_match_model import TrainFlowMatchScheduler

fm_test = TrainFlowMatchScheduler()
latents = torch.zeros((2, 2, 4, 4))
noise = torch.randn((2, 2, 4, 4))
timesteps = torch.tensor([25, 800]).float()
fm_test.add_noise(latents, noise=noise, timesteps=timesteps)

In [None]:
from diffusers import FlowMatchEulerDiscreteScheduler

scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config)

In [None]:
pipe.scheduler = scheduler

pipe.to('cuda')

In [None]:
prompt = "A cat in a forest"
pipe: StableDiffusionPipeline
pipe.scheduler.init_noise_sigma = 1
pipe.scheduler.scale_model_input = lambda x, t: x
images = pipe(prompt=prompt, negative_prompt="ugly, pointillism", height=768, width=768, num_inference_steps=30)    


In [None]:
images.images[0]

In [None]:
import math
def _get_polynomial_decay_schedule_with_warmup_adj(
    lr_init: float,
    num_warmup_steps: int,
    num_training_steps: int,
    num_cycles: int = 1,
    lr_end: float = 1e-7,
    power: float = 1.0,
    last_epoch: int = -1,
):
    """
    Adapted from diffusers get_polynomial_decay_schedule_with_warmup to remove the restrictive check on strictly decreasing LR

    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
    initial lr set in the optimizer.

    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        lr_end (`float`, *optional*, defaults to 1e-7):
            The end LR.
        power (`float`, *optional*, defaults to 1.0):
            Power factor.
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
        num_cycles (`int`, *optional*, defaults to 1):
            How many times to repeat the cycle of warmup/cooldown during training.

    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
    implementation at
    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37

    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.

    """

    num_warmup_steps_cycle = math.ceil(num_warmup_steps / num_cycles)
    num_training_steps_cycle = math.ceil(num_training_steps / num_cycles)

    def lr_lambda_cycleinternal(current_cycle_step: int):
        if current_cycle_step < num_warmup_steps_cycle:
            return float(current_cycle_step) / float(max(1, num_warmup_steps_cycle))
        elif current_cycle_step > num_training_steps_cycle:
            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
        else:
            lr_range = lr_init - lr_end
            decay_steps = num_training_steps_cycle - num_warmup_steps_cycle
            pct_remaining = 1 - (current_cycle_step - num_warmup_steps_cycle) / decay_steps
            print('pct_remaining', pct_remaining)
            decay = lr_range * pct_remaining**power + lr_end
            print('lr_range', lr_range, 'power', power, 'decay', decay)
            return decay / lr_init  # as LambdaLR multiplies by lr_init

    def lr_lambda(current_step: int):
        current_cycle_step = current_step % int(num_warmup_steps_cycle + num_training_steps_cycle)
        return lr_lambda_cycleinternal(current_cycle_step)

    return lr_lambda

ll = _get_polynomial_decay_schedule_with_warmup_adj(1e-3, lr_end=1e-1, num_warmup_steps=0, num_training_steps=200, power=2)

In [None]:
num_warmup_steps = 1000
max_lr = 1e-4
min_lr = 1e-8

def findlr_lambda(current_step: int):
    if current_step >= num_warmup_steps:
        return max_lr

    t = current_step / num_warmup_steps
    return min_lr * (max_lr / min_lr) ** t

def findlr_lambda_(current_step: int):
    if current_step < num_warmup_steps:
        pos = current_step / num_warmup_steps
        return min_lr + (pos ** power) * (max_lr - min_lr)
    else:
        return max_lr

from matplotlib import pyplot as plt
_ = plt.plot([findlr_lambda(x) for x in range(2000)])

In [None]:
ll(0)

In [None]:
ll(199)

In [None]:
%load_ext autoreload
%autoreload 2
%cd /workspace/EveryDream2trainer-remote

In [None]:
from flow_match_model import TrainFlowMatchScheduler
s = TrainFlowMatchScheduler()


In [None]:
from model.training_model import get_training_noise_scheduler

flow_match_shift = 3
noise_scheduler = get_training_noise_scheduler(s, "flow-matching",
                                                       trained_betas=[],
                                                       rescale_betas_zero_snr=False,
                                                       flow_match_shift=flow_match_shift
                                               )


In [None]:
from model.training_model import TrainFlowMatchScheduler
noise_scheduler = TrainFlowMatchScheduler(shift=3)
noise_scheduler.sigmas * 1000 - noise_scheduler.timesteps


In [None]:
import torch
ex = noise_scheduler.get_exact_timesteps(torch.tensor([0, 1, 998, 999]))
ex