Modèle de diffusion

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from tqdm import tqdm # Import added to ensure it can run standalone if needed


class SimpleMLP(nn.Module):
    # ... (MLP definition remains correct)
    def __init__(self, t_dim=50):
        super().__init__()
        self.t_dim = t_dim
        
        # 1. Time Embedding Layer
        self.time_embed = nn.Sequential(
            nn.Linear(t_dim, t_dim),
            nn.GELU(),
            nn.Linear(t_dim, t_dim)
        )
        
        # 2. Main MLP layers
        self.net = nn.Sequential(
            nn.Linear(2 + t_dim, 128),
            nn.GELU(),
            nn.Linear(128, 128),
            nn.GELU(),
            nn.Linear(128, 2)
        )

    def get_time_embedding(self, timesteps):
        """Generates positional time embedding."""
        half_dim = self.t_dim // 2
        
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=timesteps.device) * -emb)
        
        emb = timesteps[:, None].float() * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        
        return self.time_embed(emb)
        
    def forward(self, x, t):
        t_embed = self.get_time_embedding(t)
        h = torch.cat([x, t_embed], dim=-1)
        return self.net(h)


class GaussianDiffusion(nn.Module):
    def __init__(self, timesteps=100, beta_min=0.1, beta_max=5.0):
        super().__init__()
        self.timesteps = timesteps
        
        # 1. Beta schedule (CRITICAL FIX: Removed .double() for float32 compatibility)
        betas = torch.linspace(beta_min, beta_max, timesteps)
        self.register_buffer('betas', betas)
        
        # 2. Key pre-calculated terms
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))

        # 3. Reverse process variance
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        betas_tilde = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        self.register_buffer('betas_tilde', betas_tilde)

        # 4. Noise prediction model
        self.model = SimpleMLP()
        # CRITICAL FIX: Ensure the entire diffusion model and its buffers are float32
        self.to(torch.float) 
        
    def _extract(self, a, t, x_shape):
        """Helper function to extract a specific coefficient for batch t"""
        b = t.shape[0]
        out = a.gather(-1, t.cpu()).to(t.device)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

    def forward_diffusion(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
            
        sqrt_alpha_bar = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alpha_bar = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        x_t = sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise
        return x_t

    def forward(self, x_start, noise):
        b = x_start.shape[0]
        device = x_start.device
        
        t = torch.randint(0, self.timesteps, (b,), device=device).long()
        
        # x_t and pred_noise are expected to be float32 now
        x_t = self.forward_diffusion(x_start, t, noise)
        pred_noise = self.model(x_t, t)
        
        return pred_noise

    @torch.no_grad()
    def sampling(self, n_samples, device):
        shape = (n_samples, 2)
        x_t = torch.randn(shape, device=device)
        
        betas = self.betas.to(device)
        alphas = 1. - betas
        
        for t in tqdm(reversed(range(self.timesteps)), desc="Sampling"):
            t_tensor = torch.full((n_samples,), t, device=device, dtype=torch.long)
            
            pred_noise = self.model(x_t, t_tensor)
            
            sqrt_alpha_bar_t = self._extract(self.sqrt_alphas_cumprod, t_tensor, shape)
            sqrt_one_minus_alpha_bar_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t_tensor, shape)
            alpha_t = self._extract(alphas, t_tensor, shape)
            
            x_start_pred = (x_t - sqrt_one_minus_alpha_bar_t * pred_noise) / sqrt_alpha_bar_t
            x_start_pred.clamp_(-2.0, 2.0)
            
            mu_t = (
                (1. / torch.sqrt(alpha_t)) * (x_t - (betas[t] / sqrt_one_minus_alpha_bar_t) * pred_noise)
            )
            
            sigma_t = self._extract(self.betas_tilde, t_tensor, shape)
            
            if t > 0:
                z = torch.randn_like(x_t)
                x_t = mu_t + torch.sqrt(sigma_t) * z
            else:
                x_t = mu_t
                
        return x_t

1. Méthode calcul analytique

In [4]:
import torch
import torch.nn as nn

def count_flops_manual(model):
    total_flops = 0

    print(f"{'Layer':<30} | {'In':<8} | {'Out':<8} | {'FLOPs':<15}")
    print("-" * 70)

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            flops = 2 * module.in_features * module.out_features
            total_flops += flops
            print(f"{name:<30} | {module.in_features:<8} | {module.out_features:<8} | {flops:<15,}")

    print("-" * 70)
    print(f"TOTAL FLOPs  : {total_flops:,}")
    print(f"TOTAL GFLOPs : {total_flops / 1e9:.9f}")
    print("-" * 70)


In [3]:
model = GaussianDiffusion(timesteps=100)

print("FLOPs for noise prediction network (SimpleMLP):\n")
count_flops_manual(model.model)


FLOPs for noise prediction network (SimpleMLP):

Layer                          | In       | Out      | FLOPs          
----------------------------------------------------------------------
time_embed.0                   | 50       | 50       | 5,000          
time_embed.2                   | 50       | 50       | 5,000          
net.0                          | 52       | 128      | 13,312         
net.2                          | 128      | 128      | 32,768         
net.4                          | 128      | 2        | 512            
----------------------------------------------------------------------
TOTAL FLOPs  : 56,592
TOTAL GFLOPs : 0.000056592
----------------------------------------------------------------------


2. Methode automatique

In [7]:
from thop import profile
import torch

model = GaussianDiffusion().model
x = torch.randn(1, 2)
t = torch.randint(0, 100, (1,))

macs, params = profile(model, inputs=(x, t), verbose=False)
print("MACs:", macs)
print("FLOPs:", 2 * macs)


MACs: 28296.0
FLOPs: 56592.0
