# OmniField + SOLID/DDPM CIFAR-10 Training & Super-Resolution NotebookThis notebook wires the provided OmniField Perceiver and SOLID/DDPM diffusion U-Net for CIFAR-10 training. It keeps diffusion on the 32×32 training grid while enabling coordinate-conditioned super-resolution decoding to higher resolutions via OmniField coordinate interpolation (no text inputs).

## Setup- Uses PyTorch + torchvision (assumed preinstalled).- Training images are normalized to `[-1, 1]`.- Diffusion runs at the training grid (32×32).- OmniField decodes to arbitrary resolutions by rebuilding coordinate embeddings while keeping the latent token budget fixed.

In [None]:
import os, math, ssl, torch, torchvisionfrom functools import wrapsfrom torch import nnfrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom einops import rearrange, repeatfrom tqdm import tqdm# Fix for torchvision dataset download issuesssl._create_default_https_context = ssl._create_unverified_contextdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print('Using device:', device)# HyperparametersBATCH_SIZE = 128IMAGE_SIZE = 32LR = 1e-4EPOCHS = 2  # increase as neededDIFFUSION_STEPS = 1000DDIM_STEPS = 100SUPERRES_SIZE = 64  # target size for super-res training/sampling

## Model components (OmniField Perceiver + SOLID/DDPM U-Net)

In [None]:
# === Helper utilities ===def exists(val):    return val is not Nonedef default(val, d):    return val if exists(val) else ddef cache_fn(f):    cache = None    @wraps(f)    def cached_fn(*args, _cache=True, **kwargs):        if not _cache:            return f(*args, **kwargs)        nonlocal cache        if cache is not None:            return cache        cache = f(*args, **kwargs)        return cache    return cached_fn# Positional encoding used by SOLID/DDPM U-Net time embeddingclass PositionalEncoding(nn.Module):    def __init__(self, dim):        super().__init__()        self.dim = dim    def forward(self, t):        device = t.device        half_dim = self.dim // 2        emb = math.log(10000) / (half_dim - 1)        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)        emb = t[:, None] * emb[None, :]        emb = torch.cat((emb.sin(), emb.cos()), dim=1)        return emb# === Perceiver core ===class PreNorm(nn.Module):    def __init__(self, dim, fn, context_dim=None):        super().__init__()        self.fn = fn        self.norm = nn.LayerNorm(dim)        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None    def forward(self, x, **kwargs):        x = self.norm(x)        if exists(self.norm_context):            context = kwargs['context']            normed_context = self.norm_context(context)            kwargs.update(context=normed_context)        return self.fn(x, **kwargs)class GEGLU(nn.Module):    def forward(self, x):        x, gates = x.chunk(2, dim=-1)        return x * nn.functional.gelu(gates)class FeedForward(nn.Module):    def __init__(self, dim, mult=4):        super().__init__()        self.net = nn.Sequential(            nn.Linear(dim, dim * mult * 2),            GEGLU(),            nn.Linear(dim * mult, dim)        )    def forward(self, x):        return self.net(x)class Attention(nn.Module):    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):        super().__init__()        inner_dim = dim_head * heads        context_dim = default(context_dim, query_dim)        self.scale = dim_head ** -0.5        self.heads = heads        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)        self.to_out = nn.Linear(inner_dim, query_dim)        self.latest_attn = None    def forward(self, x, context=None, mask=None):        h = self.heads        q = self.to_q(x)        context = default(context, x)        k, v = self.to_kv(context).chunk(2, dim=-1)        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))        sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale        if exists(mask):            mask = rearrange(mask, 'b ... -> b (...)')            max_neg_value = -torch.finfo(sim.dtype).max            mask = repeat(mask, 'b j -> (b h) () j', h=h)            sim.masked_fill_(~mask, max_neg_value)        attn = sim.softmax(dim=-1)        self.latest_attn = attn.detach()        out = torch.einsum('b i j, b j d -> b i d', attn, v)        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)        return self.to_out(out)# Sinusoidal embeddings for Perceiver latent initdef get_sinusoidal_embeddings(n, d):    assert d % 2 == 0, 'latent_dim must be even'    position = torch.arange(n, dtype=torch.float).unsqueeze(1)    div_term = torch.exp(torch.arange(0, d, 2).float() * -(math.log(10000.0) / d))    pe = torch.zeros(n, d)    pe[:, 0::2] = torch.sin(position * div_term)    pe[:, 1::2] = torch.cos(position * div_term)    return pedef add_white_noise(coords, scale=0.01):    return coords + torch.randn_like(coords) * scaleclass CascadedBlock(nn.Module):    def __init__(self, dim, n_latents, input_dim, cross_heads, cross_dim_head, self_heads, self_dim_head, residual_dim=None):        super().__init__()        self.latents = nn.Parameter(get_sinusoidal_embeddings(n_latents, dim), requires_grad=False)        self.cross_attn = PreNorm(dim, Attention(dim, input_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=input_dim)        self.self_attn = PreNorm(dim, Attention(dim, heads=self_heads, dim_head=self_dim_head))        self.residual_proj = nn.Linear(residual_dim, dim) if residual_dim and residual_dim != dim else None        self.ff = PreNorm(dim, FeedForward(dim))    def forward(self, x, context, mask=None, residual=None):        b = context.size(0)        latents = repeat(self.latents, 'n d -> b n d', b=b)        latents = self.cross_attn(latents, context=context, mask=mask) + latents        if residual is not None:            if self.residual_proj:                residual = self.residual_proj(residual)            latents = latents + residual        latents = self.self_attn(latents) + latents        latents = self.ff(latents) + latents        return latentsclass CascadedPerceiverIO(nn.Module):    def __init__(self, *, input_dim, queries_dim, logits_dim=None, latent_dims=(256, 256, 256), num_latents=(256, 256, 256), cross_heads=4, cross_dim_head=128, self_heads=8, self_dim_head=128, decoder_ff=False):        super().__init__()        assert len(latent_dims) == len(num_latents), 'latent_dims and num_latents must match'        self.input_proj = nn.Sequential(nn.Linear(input_dim, input_dim))        self.encoder_blocks = nn.ModuleList()        prev_dim = None        for dim, n_latents in zip(latent_dims, num_latents):            block = CascadedBlock(dim=dim, n_latents=n_latents, input_dim=input_dim, cross_heads=cross_heads, cross_dim_head=cross_dim_head, self_heads=self_heads, self_dim_head=self_dim_head, residual_dim=prev_dim)            self.encoder_blocks.append(block)            prev_dim = dim        final_latent_dim = latent_dims[-1]        self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, final_latent_dim, heads=cross_heads, dim_head=cross_dim_head), context_dim=final_latent_dim)        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None        self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity()        self.self_attn_blocks = nn.Sequential(*[            nn.Sequential(                PreNorm(latent_dims[-1], Attention(latent_dims[-1], heads=self_heads, dim_head=self_dim_head)),                PreNorm(latent_dims[-1], FeedForward(latent_dims[-1]))            )            for _ in range(4)        ])    def forward(self, data, mask=None, queries=None):        b = data.size(0)        residual = None        data = self.input_proj(data)        for block in self.encoder_blocks:            residual = block(x=residual, context=data, mask=mask, residual=residual)        for sa_block in self.self_attn_blocks:            residual = sa_block[0](residual) + residual            residual = sa_block[1](residual) + residual        if queries is None:            return residual        if queries.ndim == 2:            queries = repeat(queries, 'n d -> b n d', b=b)        x = self.decoder_cross_attn(queries, context=residual)        x = x + queries        if self.decoder_ff:            x = x + self.decoder_ff(x)        return self.to_logits(x)

In [None]:
# === SOLID/DDPM U-Net core ===class ResnetBlock(nn.Module):    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32):        super().__init__()        dim_out = dim if dim_out is None else dim_out        self.dim, self.dim_out = dim, dim_out        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=dim)        self.activation1 = nn.SiLU()        self.conv1 = nn.Conv2d(dim, dim_out, kernel_size=3, padding=1)        self.block1 = nn.Sequential(self.norm1, self.activation1, self.conv1)        self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if time_emb_dim is not None else None        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=dim_out)        self.activation2 = nn.SiLU()        self.dropout = nn.Dropout(dropout) if dropout is not None else nn.Identity()        self.conv2 = nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=1)        self.block2 = nn.Sequential(self.norm2, self.activation2, self.dropout, self.conv2)        self.residual_conv = nn.Conv2d(dim, dim_out, kernel_size=1) if dim != dim_out else nn.Identity()    def forward(self, x, time_emb=None):        hidden = self.block1(x)        if time_emb is not None:            hidden = hidden + self.mlp(time_emb)[..., None, None]        hidden = self.block2(hidden)        return hidden + self.residual_conv(x)class AttentionBlock(nn.Module):    def __init__(self, dim, groups=32):        super().__init__()        self.dim, self.dim_out = dim, dim        self.scale = dim ** (-0.5)        self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim)        self.to_qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)        self.to_out = nn.Conv2d(dim, dim, kernel_size=1)    def forward(self, x):        b, c, h, w = x.shape        qkv = self.to_qkv(self.norm(x)).chunk(3, dim=1)        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), qkv)        sim = torch.einsum('b i c, b j c -> b i j', q, k)        attn = (sim * self.scale).softmax(dim=-1)        out = torch.einsum('b i j, b j c -> b i c', attn, v)        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)        return self.to_out(out) + xclass ResnetAttentionBlock(nn.Module):    def __init__(self, dim, dim_out=None, time_emb_dim=None, dropout=None, groups=32):        super().__init__()        self.resnet = ResnetBlock(dim, dim_out, time_emb_dim, dropout, groups)        self.attention = AttentionBlock(dim_out if dim_out is not None else dim)    def forward(self, x, time_emb=None):        x = self.resnet(x, time_emb)        return self.attention(x)class downSample(nn.Module):    def __init__(self, dim_in):        super().__init__()        self.downsample = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=2, padding=1)    def forward(self, x):        return self.downsample(x)class upSample(nn.Module):    def __init__(self, dim_in):        super().__init__()        self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1))    def forward(self, x):        return self.upsample(x)class Unet(nn.Module):    def __init__(self, dim, image_size, dim_multiply=(1, 2, 4, 8), channel=3, num_res_blocks=2, attn_resolutions=(16,), dropout=0, device='cuda', groups=32):        super().__init__()        assert dim % groups == 0, 'dim must be divisible by groups'        self.dim = dim        self.channel = channel        self.time_emb_dim = 4 * self.dim        self.num_resolutions = len(dim_multiply)        self.device = device        self.resolution = [int(image_size / (2 ** i)) for i in range(self.num_resolutions)]        self.hidden_dims = [self.dim, *map(lambda x: x * self.dim, dim_multiply)]        self.num_res_blocks = num_res_blocks        positional_encoding = PositionalEncoding(self.dim)        self.time_mlp = nn.Sequential(positional_encoding, nn.Linear(self.dim, self.time_emb_dim), nn.SiLU(), nn.Linear(self.time_emb_dim, self.time_emb_dim))        self.down_path = nn.ModuleList([])        self.up_path = nn.ModuleList([])        concat_dim = list()        self.init_conv = nn.Conv2d(channel * 3, self.dim, kernel_size=3, padding=1)        concat_dim.append(self.dim)        for level in range(self.num_resolutions):            d_in, d_out = self.hidden_dims[level], self.hidden_dims[level + 1]            for block in range(num_res_blocks):                d_in_ = d_in if block == 0 else d_out                if self.resolution[level] in attn_resolutions:                    self.down_path.append(ResnetAttentionBlock(d_in_, d_out, self.time_emb_dim, dropout, groups))                else:                    self.down_path.append(ResnetBlock(d_in_, d_out, self.time_emb_dim, dropout, groups))                concat_dim.append(d_out)            if level != self.num_resolutions - 1:                self.down_path.append(downSample(d_out))                concat_dim.append(d_out)        mid_dim = self.hidden_dims[-1]        self.middle_resnet_attention = ResnetAttentionBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)        self.middle_resnet = ResnetBlock(mid_dim, mid_dim, self.time_emb_dim, dropout, groups)        for level in reversed(range(self.num_resolutions)):            d_out = self.hidden_dims[level + 1]            for block in range(num_res_blocks + 1):                d_in = self.hidden_dims[level + 2] if block == 0 and level != self.num_resolutions - 1 else d_out                d_in = d_in + concat_dim.pop()                if self.resolution[level] in attn_resolutions:                    self.up_path.append(ResnetAttentionBlock(d_in, d_out, self.time_emb_dim, dropout, groups))                else:                    self.up_path.append(ResnetBlock(d_in, d_out, self.time_emb_dim, dropout, groups))            if level != 0:                self.up_path.append(upSample(d_out))        assert not concat_dim, 'Concat mismatch'        final_ch = self.hidden_dims[1]        self.final_norm = nn.GroupNorm(groups, final_ch)        self.final_activation = nn.SiLU()        self.final_conv = nn.Conv2d(final_ch, channel, kernel_size=3, padding=1)    def forward(self, x, time, sparse_input=None, mask=None, x_coarse=None):        t = self.time_mlp(time)        concat = list()        x = self.init_conv(x)        concat.append(x)        for layer in self.down_path:            x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x)            concat.append(x)        x = self.middle_resnet_attention(x, t)        x = self.middle_resnet(x, t)        for layer in self.up_path:            if not isinstance(layer, upSample):                x = torch.cat((x, concat.pop()), dim=1)            x = layer(x, t) if not isinstance(layer, (upSample, downSample)) else layer(x)        x = self.final_activation(self.final_norm(x))        return self.final_conv(x)

In [None]:
# === Gaussian Diffusion & DDIM ===class GaussianDiffusion(nn.Module):    def __init__(self, model, image_size, time_step=DIFFUSION_STEPS, loss_type='l2'):        super().__init__()        self.unet = model        self.channel = self.unet.channel        self.device = next(self.unet.parameters()).device        self.image_size = image_size        self.time_step = time_step        self.loss_type = loss_type        beta = self.linear_beta_schedule()        alpha = 1. - beta        alpha_bar = torch.cumprod(alpha, dim=0)        alpha_bar_prev = nn.functional.pad(alpha_bar[:-1], pad=(1, 0), value=1.)        self.register_buffer('beta', beta)        self.register_buffer('alpha', alpha)        self.register_buffer('alpha_bar', alpha_bar)        self.register_buffer('alpha_bar_prev', alpha_bar_prev)        self.register_buffer('sqrt_alpha_bar', torch.sqrt(alpha_bar))        self.register_buffer('sqrt_one_minus_alpha_bar', torch.sqrt(1 - alpha_bar))        self.register_buffer('beta_tilde', beta * ((1. - alpha_bar_prev) / (1. - alpha_bar)))        self.register_buffer('mean_tilde_x0_coeff', beta * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar))        self.register_buffer('mean_tilde_xt_coeff', torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))        self.register_buffer('sqrt_recip_alpha_bar', torch.sqrt(1. / alpha_bar))        self.register_buffer('sqrt_recip_alpha_bar_min_1', torch.sqrt(1. / alpha_bar - 1))        self.register_buffer('sqrt_recip_alpha', torch.sqrt(1. / alpha))        self.register_buffer('beta_over_sqrt_one_minus_alpha_bar', beta / torch.sqrt(1. - alpha_bar))    def linear_beta_schedule(self):        scale = 1000 / self.time_step        beta_start = scale * 0.0001        beta_end = scale * 0.02        return torch.linspace(beta_start, beta_end, self.time_step, dtype=torch.float32)    def q_sample(self, x0, t, noise):        return self.sqrt_alpha_bar[t][:, None, None, None] * x0 + self.sqrt_one_minus_alpha_bar[t][:, None, None, None] * noise    def forward(self, img):        b, c, h, w = img.shape        assert h == self.image_size and w == self.image_size        t = torch.randint(0, self.time_step, (b,), device=img.device).long()        noise = torch.randn_like(img)        noised_image = self.q_sample(img, t, noise)        model_input = noised_image        predicted_noise = self.unet(model_input, t)        if self.loss_type == 'l1':            loss = nn.functional.l1_loss(noise, predicted_noise)        elif self.loss_type == 'l2':            loss = nn.functional.mse_loss(noise, predicted_noise)        elif self.loss_type == 'huber':            loss = nn.functional.smooth_l1_loss(noise, predicted_noise)        else:            raise NotImplementedError()        return loss    @torch.inference_mode()    def p_sample(self, xt, t, clip=True):        batched_time = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)        pred_noise = self.unet(xt, batched_time)        if clip:            x0 = self.sqrt_recip_alpha_bar[t] * xt - self.sqrt_recip_alpha_bar_min_1[t] * pred_noise            x0.clamp_(-1., 1.)            mean = self.mean_tilde_x0_coeff[t] * x0 + self.mean_tilde_xt_coeff[t] * xt        else:            mean = self.sqrt_recip_alpha[t] * (xt - self.beta_over_sqrt_one_minus_alpha_bar[t] * pred_noise)        variance = self.beta_tilde[t]        noise = torch.randn_like(xt) if t > 0 else 0.        return mean + torch.sqrt(variance) * noise    @torch.inference_mode()    def sample(self, batch_size=16, return_all_timestep=False, clip=True):        xT = torch.randn([batch_size, self.channel, self.image_size, self.image_size], device=self.device)        denoised = [xT]        xt = xT        for t in tqdm(reversed(range(0, self.time_step)), desc='DDPM Sampling', total=self.time_step, leave=False):            xt = self.p_sample(xt, t, clip=clip)            denoised.append(xt)        images = xt if not return_all_timestep else torch.stack(denoised, dim=1)        images.clamp_(min=-1.0, max=1.0)        images = (images + 1.0) * 0.5        return imagesclass DDIM_Sampler(nn.Module):    def __init__(self, diffusion_model, ddim_sampling_steps=DDIM_STEPS, eta=0, clip=True):        super().__init__()        self.ddim_steps = ddim_sampling_steps        self.eta = eta        self.clip = clip        self.model = diffusion_model        ddpm_steps = diffusion_model.time_step        assert self.ddim_steps <= ddpm_steps        alpha_bar = diffusion_model.alpha_bar        self.register_buffer('tau', torch.linspace(-1, ddpm_steps - 1, steps=self.ddim_steps + 1, dtype=torch.long)[1:])        alpha_tau_i = alpha_bar[self.tau]        alpha_tau_i_min_1 = nn.functional.pad(alpha_bar[self.tau[:-1]], pad=(1, 0), value=1.)        self.register_buffer('sigma', eta * (((1 - alpha_tau_i_min_1) / (1 - alpha_tau_i) * (1 - alpha_tau_i / alpha_tau_i_min_1)).sqrt()))        self.register_buffer('coeff', (1 - alpha_tau_i_min_1 - self.sigma ** 2).sqrt())        self.register_buffer('sqrt_alpha_i_min_1', alpha_tau_i_min_1.sqrt())        assert self.coeff[0] == 0.0 and self.sqrt_alpha_i_min_1[0] == 1.0    @torch.inference_mode()    def ddim_p_sample(self, xt, i, clip=True):        t = self.tau[i]        batched_time = torch.full((xt.shape[0],), t, device=self.model.device, dtype=torch.long)        pred_noise = self.model.unet(xt, batched_time)        x0 = self.model.sqrt_recip_alpha_bar[t] * xt - self.model.sqrt_recip_alpha_bar_min_1[t] * pred_noise        if clip:            x0.clamp_(-1., 1.)            pred_noise = (self.model.sqrt_recip_alpha_bar[t] * xt - x0) / self.model.sqrt_recip_alpha_bar_min_1[t]        mean = self.sqrt_alpha_i_min_1[i] * x0 + self.coeff[i] * pred_noise        noise = torch.randn_like(xt) if i > 0 else 0.        return mean + self.sigma[i] * noise    @torch.inference_mode()    def sample(self, batch_size, noise=None, return_all_timestep=False, clip=True):        xT = torch.randn([batch_size, self.model.channel, self.model.image_size, self.model.image_size], device=self.model.device) if noise is None else noise.to(self.model.device)        denoised = [xT]        xt = xT        for i in tqdm(reversed(range(0, self.ddim_steps)), desc='DDIM Sampling', total=self.ddim_steps, leave=False):            xt = self.ddim_p_sample(xt, i, clip=clip)            denoised.append(xt)        images = xt if not return_all_timestep else torch.stack(denoised, dim=1)        images.clamp_(min=-1.0, max=1.0)        images = (images + 1.0) * 0.5        return images

## Coordinate features for super-resolution

In [None]:
def build_coord_grid(h, w, device, add_r=True):    y, x = torch.meshgrid(        torch.linspace(-1, 1, steps=h, device=device),        torch.linspace(-1, 1, steps=w, device=device),        indexing='ij'    )    coords = torch.stack([x, y], dim=-1)    if add_r:        r = torch.sqrt(x ** 2 + y ** 2)        coords = torch.cat([coords, r[..., None]], dim=-1)    return coords  # (h, w, 2 or 3)# OmniField decoder wraps CascadedPerceiverIO to decode coordinates -> RGBclass OmniFieldDecoder(nn.Module):    def __init__(self, input_dim=5, queries_dim=128, logits_dim=3):        super().__init__()        self.perceiver = CascadedPerceiverIO(            input_dim=input_dim,            queries_dim=queries_dim,            logits_dim=logits_dim,            latent_dims=(256, 256, 256),            num_latents=(128, 128, 128),            cross_heads=4,            cross_dim_head=64,            self_heads=4,            self_dim_head=64,        )        self.query_proj = nn.Sequential(            nn.Linear(3, queries_dim),            nn.SiLU(),            nn.Linear(queries_dim, queries_dim)        )        self.context_proj = nn.Sequential(            nn.Linear(input_dim, input_dim),            nn.SiLU(),            nn.Linear(input_dim, input_dim)        )    def forward(self, coarse_img, target_h, target_w):        b, c, h, w = coarse_img.shape        coords_coarse = build_coord_grid(h, w, coarse_img.device, add_r=True)        coords_coarse = coords_coarse.view(1, h * w, -1).repeat(b, 1, 1)        rgb_tokens = rearrange(coarse_img, 'b c h w -> b (h w) c')        context = torch.cat([rgb_tokens, coords_coarse], dim=-1)        context = self.context_proj(context)        coords_target = build_coord_grid(target_h, target_w, coarse_img.device, add_r=True)        queries = self.query_proj(coords_target.view(1, target_h * target_w, -1))        queries = queries.repeat(b, 1, 1)        logits = self.perceiver(context, queries=queries)        sr = logits.view(b, target_h, target_w, 3).permute(0, 3, 1, 2)        sr = sr.tanh()  # keep outputs in [-1,1]        return sr

## Data loaders (CIFAR-10)

In [None]:
transform = transforms.Compose([    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),    transforms.ToTensor(),    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)len(train_dataset), len(test_dataset)

## Instantiate models

In [None]:
unet = Unet(dim=128, image_size=IMAGE_SIZE, channel=3).to(device)diffusion = GaussianDiffusion(unet, image_size=IMAGE_SIZE, time_step=DIFFUSION_STEPS).to(device)ddim_sampler = DDIM_Sampler(diffusion, ddim_sampling_steps=DDIM_STEPS, eta=0, clip=True).to(device)omnifield_decoder = OmniFieldDecoder(input_dim=6, queries_dim=128, logits_dim=3).to(device)opt = torch.optim.AdamW(diffusion.parameters(), lr=LR)opt_sr = torch.optim.AdamW(omnifield_decoder.parameters(), lr=LR)

## Training loop (diffusion on 32×32 grid)

In [None]:
diffusion.train()for epoch in range(EPOCHS):    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False)    for imgs, _ in pbar:        imgs = imgs.to(device)        loss = diffusion(imgs)        opt.zero_grad()        loss.backward()        nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)        opt.step()        pbar.set_postfix(loss=f"{loss.item():.4f}")    print(f'Epoch {epoch+1} diffusion loss: {loss.item():.4f}')

## Optional: train OmniField decoder for coordinate-conditioned super-resolution

In [None]:
# This stage teaches OmniField to map 32×32 coarse images to SUPERRES_SIZE outputs using only coordinates + coarse RGB.# It keeps the latent token count fixed while interpolating coordinates at the target resolution.omnifield_decoder.train()resize_to_sr = transforms.Resize((SUPERRES_SIZE, SUPERRES_SIZE))for epoch in range(1):  # increase if you want stronger super-res    pbar = tqdm(train_loader, desc=f'OmniField SR epoch {epoch+1}', leave=False)    for imgs, _ in pbar:        imgs = imgs.to(device)        with torch.no_grad():            highres_gt = resize_to_sr((imgs + 1) / 2) * 2 - 1  # keep in [-1,1]        sr_pred = omnifield_decoder(imgs, SUPERRES_SIZE, SUPERRES_SIZE)        loss_sr = nn.functional.l1_loss(sr_pred, highres_gt)        opt_sr.zero_grad()        loss_sr.backward()        nn.utils.clip_grad_norm_(omnifield_decoder.parameters(), 1.0)        opt_sr.step()        pbar.set_postfix(sr_l1=f"{loss_sr.item():.4f}")    print(f'OmniField SR epoch {epoch+1} loss: {loss_sr.item():.4f}')

## Sampling: base 32×32 and super-res decode

In [None]:
diffusion.eval(); omnifield_decoder.eval()with torch.inference_mode():    base_samples = ddim_sampler.sample(batch_size=8)    sr_samples = omnifield_decoder((base_samples * 2 - 1), target_h=SUPERRES_SIZE, target_w=SUPERRES_SIZE)    sr_samples = (sr_samples + 1) * 0.5import matplotlib.pyplot as pltdef show_grid(images, title):    images = images.cpu().clamp(0,1)    grid = torchvision.utils.make_grid(images, nrow=4)    plt.figure(figsize=(8,8))    plt.axis('off')    plt.title(title)    plt.imshow(grid.permute(1,2,0))    plt.show()show_grid(base_samples, 'DDIM samples @32x32')show_grid(sr_samples, f'OmniField super-res @{SUPERRES_SIZE}x{SUPERRES_SIZE}')

## Checkpoint saving

In [None]:
os.makedirs('checkpoints', exist_ok=True)torch.save(diffusion.state_dict(), 'checkpoints/solid_ddpm_cifar10.pt')torch.save(omnifield_decoder.state_dict(), 'checkpoints/omnifield_sr.pt')