In [None]:
import math, random, numpy as np
import torch, torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF
from PIL import Image

device = 'cpu'

# ---- 1) Model (GoogLeNet ~ Inception-family) ----
weights = models.GoogLeNet_Weights.IMAGENET1K_V1
net = models.googlenet(weights=weights, aux_logits=True).to(device).eval()
for p in net.parameters(): p.requires_grad_(False)

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=device)[:, None, None]
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=device)[:, None, None]
CATEGORIES = weights.meta["categories"]  # list of 1000 labels ordered by index

def class_idx_from_name(name, categories=CATEGORIES):
    name = name.lower()
    # simple fuzzy match; first exact contains, then fallback
    try:
        return categories.index(name)
    except ValueError:
        hits = [i for i, s in enumerate(categories) if name in s.lower()]
        if not hits:
            raise ValueError(f"'{name}' not found in ImageNet categories.")
        return hits[0]

# ---- 2) Fourier parameterization helpers ----
def radial_freq(h, w, device):
    fy = torch.fft.fftfreq(h, d=1.0).to(device).reshape(h, 1)
    fx = torch.fft.rfftfreq(w, d=1.0).to(device).reshape(1, w//2 + 1)
    return torch.sqrt(fx*fx + fy*fy).clamp(min=1e-6)

def make_fft_params(h=384, w=384):
    # Real/imag parts for rFFT spectrum
    spec = torch.randn(1, 3, h, w//2 + 1, 2, device=device, requires_grad=True)
    freqs = radial_freq(h, w, device)
    return spec, freqs

def spectrum_to_image(spec, freqs, decay_power=1.5):
    """Inverse FFT with 1/f^decay frequency decay -> natural image bias, map to [0,1]."""
    # shape: [1,3,H,W//2+1,2] -> complex
    complex_spec = torch.view_as_complex(spec)
    scaled = complex_spec * (1.0 / (freqs ** decay_power))  # low-freq emphasis
    img = torch.fft.irfft2(scaled, s=(spec.shape[2], (spec.shape[3]-1)*2), norm='ortho')
    img = img / (img.std(dim=(-2, -1), keepdim=True) + 1e-8)   # stabilize contrast
    img = torch.tanh(img) * 0.5 + 0.5                          # [0,1] with gradient
    return img

# ---- 3) Priors ----
def tv_loss(x):
    dx = x[..., 1:] - x[..., :-1]
    dy = x[..., :, 1:, :] - x[..., :, :-1, :]
    return (dx.pow(2).mean() + dy.pow(2).mean())

def blur3(x):
    # simple 3x3 box blur to suppress ringing
    return F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)

# ---- 4) Data augmentation for transform robustness ----
def random_view(x, out_hw=(224, 224), jitter=24, rot=15, scale_range=(0.92, 1.08), flip=True, noise_std=0.02):
    # noise_std= 0.006
    # x: [1,3,H,W]
    _, _, H, W = x.shape
    # jitter via integer roll
    ox, oy = np.random.randint(-jitter, jitter+1, 2)
    v = torch.roll(x, shifts=(ox, oy), dims=(-1, -2))
    # small affine
    angle = float(random.uniform(-rot, rot))
    scale = float(random.uniform(*scale_range))
    trans = (int(random.uniform(-0.06*W, 0.06*W)), int(random.uniform(-0.06*H, 0.06*H)))
    v = TF.affine(v, angle=angle, translate=trans, scale=scale, shear=[0.0, 0.0])
    if flip and random.random() < 0.5:
        v = TF.hflip(v)
    # resize to model’s resolution
    v = F.interpolate(v, size=out_hw, mode='bilinear', align_corners=False)
    # light noise
    if noise_std > 0:
        v = (v + noise_std * torch.randn_like(v)).clamp(0, 1)
    # normalize
    v = (v - IMAGENET_MEAN) / IMAGENET_STD
    return v

# ---- 5) Main optimization ----
@torch.no_grad()
def to_pil(x):
    x = x.squeeze(0).clamp(0,1).cpu()
    return TF.to_pil_image(x)

def synthesize_class(
    class_name="dumbbell",
    steps=700,
    n_views=8,
    size=384,
    lr=0.08,
    tv_w=1e-4,
    l2_w=1e-6,
    decay_power=1.5,
    blur_every=60
):
    target = class_idx_from_name(class_name)  # e.g., "dumbbell" -> correct ImageNet index
    # print(target)
    spec, freqs = make_fft_params(size, size)
    opt = torch.optim.Adam([spec], lr=lr)

    for t in range(steps):
        img = spectrum_to_image(spec, freqs, decay_power=decay_power)

        if blur_every and t > 0 and t % blur_every == 0:
            img = blur3(img)

        # build a batch of randomly transformed views
        batch = torch.cat([random_view(img) for _ in range(n_views)], dim=0)
        logits = net(batch)
        cls = logits[:, target].mean()

        # natural image priors
        tv = tv_loss(img)
        l2 = ((img - 0.5) ** 2).mean()

        # maximize class score with priors
        loss = cls - tv_w * tv - l2_w * l2

        opt.zero_grad(set_to_none=True)
        (-loss).backward()   # gradient ascent
        opt.step()

        # gentle spectrum damping prevents exploding amplitudes
        with torch.no_grad():
            spec.mul_(0.995)

    final = spectrum_to_image(spec, freqs, decay_power=decay_power)
    return to_pil(final)

# ---- Example:
# out = synthesize_class("dumbbell", steps=100, n_views=8, size=448, lr=0.01, decay_power=1.7)
out = synthesize_class("dumbbell", steps=1_000, n_views=8, size=448, lr=0.01, decay_power=0.5)
out.save("dumbbell_classviz.png")




In [None]:
import math, random, numpy as np
import torch, torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ===== Model (GoogLeNet / Inception-family) =====
weights = models.GoogLeNet_Weights.IMAGENET1K_V1
net = models.googlenet(weights=weights, aux_logits=True).to(device).eval()
for p in net.parameters(): p.requires_grad_(False)

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=device)[:, None, None]
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=device)[:, None, None]
CATEGORIES = weights.meta["categories"]

def class_idx_from_name(name, categories=CATEGORIES):
    name = name.lower()
    try:
        return categories.index(name)
    except ValueError:
        hits = [i for i, s in enumerate(categories) if name in s.lower()]
        if not hits: raise ValueError(f"'{name}' not found in ImageNet categories.")
        return hits[0]

# ===== Fourier parameterization =====
def radial_freq(h, w, device):
    fy = torch.fft.fftfreq(h, d=1.0).to(device).reshape(h, 1)
    fx = torch.fft.rfftfreq(w, d=1.0).to(device).reshape(1, w//2 + 1)
    return torch.sqrt(fx*fx + fy*fy).clamp(min=1e-6)

def make_fft_params(h=448, w=448):
    spec = torch.randn(1, 3, h, w//2 + 1, 2, device=device, requires_grad=True)
    freqs = radial_freq(h, w, device)
    return spec, freqs

def spectrum_to_image(spec, freqs, decay_power=1.8, contrast=1.8):
    complex_spec = torch.view_as_complex(spec)
    scaled = complex_spec * (1.0 / (freqs ** decay_power))
    img = torch.fft.irfft2(scaled, s=(spec.shape[2], (spec.shape[3]-1)*2), norm='ortho')
    # Center per-channel to prevent DC drift; map to [0,1] gently
    img = img - img.mean(dim=(-2, -1), keepdim=True)
    img = torch.sigmoid(img * contrast)
    return img

# ===== Core priors =====
def tv_l2(x):
    dx = x[..., 1:] - x[..., :-1]
    dy = x[..., :, 1:, :] - x[..., :, :-1, :]
    return (dx.pow(2).mean() + dy.pow(2).mean())

def gaussian_kernel1d(sigma):
    radius = max(1, int(round(3*sigma)))
    x = torch.arange(-radius, radius+1, device=device, dtype=torch.float32)
    k = torch.exp(-0.5 * (x/sigma)**2)
    k = (k / k.sum()).view(1, 1, -1)  # [1,1,K]
    return k, radius

def gaussian_blur(x, sigma):
    if sigma <= 0: return x
    k1d, r = gaussian_kernel1d(sigma)
    # depthwise separable blur (reflect pad)
    # Horizontal
    w_h = k1d.view(1,1,1,-1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(x, (r, r, 0, 0), mode='reflect'), w_h, groups=x.shape[1])
    # Vertical
    w_v = k1d.view(1,1,-1,1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(v, (0, 0, r, r), mode='reflect'), w_v, groups=x.shape[1])
    return v

def corr_loss(x, sigma=1.2):
    # Encourage x to equal its blurred version -> neighbor correlation
    xb = gaussian_blur(x, sigma)
    return ((x - xb)**2).mean()

def laplacian_energy(x):
    # Penalize second derivatives (gentler than TV-L1)
    lap = torch.tensor([[0,1,0],[1,-4,1],[0,1,0]], dtype=torch.float32, device=device)
    w = lap.view(1,1,3,3).repeat(x.shape[1],1,1,1)
    y = F.conv2d(F.pad(x, (1,1,1,1), mode='reflect'), w, groups=x.shape[1])
    return (y.pow(2).mean())

# ===== Transform-robust multi-view =====
def normalize_for_model(v):
    # GoogLeNet with weights will transform internally; skip external norm.
    return v if getattr(net, 'transform_input', False) else (v - IMAGENET_MEAN) / IMAGENET_STD

def random_view(x, out_hw=(224, 224), jitter=24, rot=15, scale_range=(0.92, 1.08),
                flip=True, noise_std=0.012):
    # noise_std= 0.006
    _, _, H, W = x.shape
    ox, oy = np.random.randint(-jitter, jitter+1, 2)
    v = torch.roll(x, shifts=(ox, oy), dims=(-1, -2))
    ang = float(random.uniform(-rot, rot))
    sc  = float(random.uniform(*scale_range))
    tr  = (int(random.uniform(-0.05*W, 0.05*W)), int(random.uniform(-0.05*H, 0.05*H)))
    v = TF.affine(v, angle=ang, translate=tr, scale=sc, shear=[0.0, 0.0])
    if flip and random.random() < 0.5:
        v = TF.hflip(v)
    v = F.interpolate(v, size=out_hw, mode='bilinear', align_corners=False)
    if noise_std > 0: v = (v + noise_std * torch.randn_like(v)).clamp(0,1)
    return normalize_for_model(v)

def cosine_anneal(start, end, t, T):
    # Smoothly decay from start -> end over steps
    return end + (start - end) * 0.5 * (1 + math.cos(math.pi * t / max(T,1)))

# ===== Main optimization with correlation priors =====
@torch.no_grad()
def to_pil(x): return TF.to_pil_image(x.squeeze(0).clamp(0,1).cpu())

def synthesize_class(
    class_name="dumbbell",
    steps=1200, n_views=10, size=512, lr=0.08,
    decay_power=2.0, contrast=1.6,
    # Prior strengths (start -> end)
    corr_w_start=0.03, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
    lap_w_start=0.004, lap_w_end=0.0,
    tv_w=5e-4, l2_w=1e-6,
    blur_every=100
):
    target = class_idx_from_name(class_name)
    spec, freqs = make_fft_params(size, size)
    opt = torch.optim.Adam([spec], lr=lr)

    for t in range(steps):
        img = spectrum_to_image(spec, freqs, decay_power=decay_power, contrast=contrast)
        if blur_every and t and t % blur_every == 0:
            img = gaussian_blur(img, sigma=0.8)

        # Build multi-view batch
        batch = torch.cat([random_view(img) for _ in range(n_views)], dim=0)
        logits = net(batch)
        cls = logits[:, target].mean()

        # Priors + schedules
        cw  = cosine_anneal(corr_w_start, corr_w_end, t, steps)
        lw  = cosine_anneal(lap_w_start,  lap_w_end,  t, steps)
        cs  = cosine_anneal(corr_sigma_start, corr_sigma_end, t, steps)

        loss_corr = corr_loss(img, sigma=cs)
        loss_lap  = laplacian_energy(img)
        loss_tv   = tv_l2(img)
        loss_l2   = ((img - 0.5)**2).mean()

        loss = cls - cw*loss_corr - lw*loss_lap - tv_w*loss_tv - l2_w*loss_l2

        opt.zero_grad(set_to_none=True)
        (-loss).backward()
        opt.step()

        # Stabilizers in spectrum space
        with torch.no_grad():
            spec[:, :, 0, 0, :] = 0      # no DC drift
            spec.mul_(0.9995)            # very gentle amplitude damping

    final = spectrum_to_image(spec, freqs, decay_power=decay_power, contrast=contrast)
    return to_pil(final)

# Example:
out = synthesize_class("dumbbell", steps = 800, lr = 0.05, decay_power=.01, corr_w_start=0.08, corr_sigma_start = 2, blur_every=100, contrast=0.5)
# out = synthesize_class("dumbbell")
out.save("dumbbell_last.png")


In [None]:
import math, random, numpy as np
import torch, torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ===== Model (GoogLeNet; avoids double-normalization) =====
weights = models.GoogLeNet_Weights.IMAGENET1K_V1
net = models.googlenet(weights=weights, aux_logits=True).to(device).eval()
for p in net.parameters(): p.requires_grad_(False)

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=device)[:, None, None]
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=device)[:, None, None]
CATEGORIES = weights.meta["categories"]

def class_idx_from_name(name, categories=CATEGORIES):
    name = name.lower()
    try: return categories.index(name)
    except ValueError:
        hits = [i for i,s in enumerate(categories) if name in s.lower()]
        if not hits: raise ValueError(f"'{name}' not in ImageNet categories.")
        return hits[0]

# ===== FFT parameterization with progressive low-pass =====
def radial_freq(h, w, device):
    fy = torch.fft.fftfreq(h, d=1.0).to(device).reshape(h, 1)
    fx = torch.fft.rfftfreq(w, d=1.0).to(device).reshape(1, w//2 + 1)
    return torch.sqrt(fx*fx + fy*fy).clamp(min=1e-6)

def make_fft_params(h=512, w=512):
    spec = torch.randn(1, 3, h, w//2 + 1, 2, device=device, requires_grad=True)
    freqs = radial_freq(h, w, device)
    return spec, freqs

def radial_lowpass_mask(freqs, cutoff, rolloff=8.0):
    # cutoff in [0,1] relative to Nyquist; rolloff controls steepness
    r = freqs / freqs.max()
    return 1.0 / (1.0 + (r / max(cutoff, 1e-6))**rolloff)

def spectrum_to_image(spec, freqs, decay_power=2.2, contrast=1.4,
                      lp_cutoff=0.35, lp_rolloff=8.0):
    complex_spec = torch.view_as_complex(spec)                               # [1,3,H,W/2+1]
    mask = radial_lowpass_mask(freqs, cutoff=lp_cutoff, rolloff=lp_rolloff)  # [H,W/2+1]
    scaled = complex_spec * (1.0 / (freqs ** decay_power)) * mask
    img = torch.fft.irfft2(scaled, s=(spec.shape[2], (spec.shape[3]-1)*2), norm='ortho')
    # Center per-channel to prevent DC drift; avoid hard saturation with gentle sigmoid
    img = img - img.mean(dim=(-2, -1), keepdim=True)
    img = torch.sigmoid(img * contrast)                                      # [0,1]
    return img

# ===== Priors =====
def tv_l2(x):
    dx = x[..., 1:] - x[..., :-1]
    dy = x[..., :, 1:, :] - x[..., :, :-1, :]
    return (dx.pow(2).mean() + dy.pow(2).mean())

def gaussian_kernel1d(sigma):
    radius = max(1, int(round(3*sigma)))
    x = torch.arange(-radius, radius+1, device=device, dtype=torch.float32)
    k = torch.exp(-0.5 * (x/sigma)**2); k = (k / k.sum()).view(1,1,-1)
    return k, radius

def gaussian_blur(x, sigma):
    if sigma <= 0: return x
    k1d, r = gaussian_kernel1d(sigma)
    w_h = k1d.view(1,1,1,-1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(x, (r,r,0,0), mode='reflect'), w_h, groups=x.shape[1])
    w_v = k1d.view(1,1,-1,1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(v, (0,0,r,r), mode='reflect'), w_v, groups=x.shape[1])
    return v

def corr_loss(x, sigma=1.2):   # neighbor correlation: x ≈ blurred(x)
    xb = gaussian_blur(x, sigma)
    return ((x - xb)**2).mean()

def laplacian_energy(x):       # softer than TV-L1
    lap = torch.tensor([[0,1,0],[1,-4,1],[0,1,0]], dtype=torch.float32, device=device)
    w = lap.view(1,1,3,3).repeat(x.shape[1],1,1,1)
    y = F.conv2d(F.pad(x, (1,1,1,1), mode='reflect'), w, groups=x.shape[1])
    return (y.pow(2).mean())

# --- YUV utilities for chroma regularization ---
def rgb_to_yuv(x):
    r,g,b = x[:,0:1], x[:,1:2], x[:,2:3]
    y = 0.299*r + 0.587*g + 0.114*b
    u = b - y
    v = r - y
    return y,u,v

def chroma_saturation_loss(x, thresh=0.30):
    # penalize only excess saturation; keep normal colors intact
    _,u,v = rgb_to_yuv(x)
    return (F.relu(u.abs()-thresh).pow(2).mean() + F.relu(v.abs()-thresh).pow(2).mean())

def chroma_tv_loss(x):
    _,u,v = rgb_to_yuv(x)
    def tv(a):
        dx = a[..., 1:] - a[..., :-1]; dy = a[..., :, 1:, :] - a[..., :, :-1, :]
        return (dx.pow(2).mean() + dy.pow(2).mean())
    return tv(u) + tv(v)

# ===== Transform‑robust multi‑view (with mild defocus) =====
def normalize_for_model(v):
    return v if getattr(net, 'transform_input', False) else (v - IMAGENET_MEAN) / IMAGENET_STD

def random_view(x, out_hw=(224,224), jitter=16, rot=10, scale_range=(0.94,1.06),
                flip=True, noise_std=0.0, blur_sigma=(0.0,1.2), grayscale_p=0.15):
    _, _, H, W = x.shape
    ox, oy = np.random.randint(-jitter, jitter+1, 2)
    v = torch.roll(x, shifts=(ox, oy), dims=(-1, -2))
    ang = float(random.uniform(-rot, rot))
    sc  = float(random.uniform(*scale_range))
    tr  = (int(random.uniform(-0.04*W, 0.04*W)), int(random.uniform(-0.04*H, 0.04*H)))
    v = TF.affine(v, angle=ang, translate=tr, scale=sc, shear=[0.0, 0.0])
    if flip and random.random() < 0.5: v = TF.hflip(v)
    # mild, random defocus per view
    if blur_sigma and blur_sigma[1] > 0:
        s = random.uniform(blur_sigma[0], blur_sigma[1])
        if s > 1e-3: v = gaussian_blur(v, s)
    v = F.interpolate(v, size=out_hw, mode='bilinear', align_corners=False)
    if grayscale_p>0 and random.random()<grayscale_p:
        y,_,_ = rgb_to_yuv(v); v = torch.cat([y,y,y], dim=1)
    if noise_std>0: v = (v + noise_std * torch.randn_like(v)).clamp(0,1)
    return normalize_for_model(v)

def cosine_anneal(start, end, t, T):
    return end + (start - end) * 0.5 * (1 + math.cos(math.pi * t / max(T,1)))

# ===== Main optimization =====
@torch.no_grad()
def to_pil(x): return TF.to_pil_image(x.squeeze(0).clamp(0,1).cpu())

def synthesize_class_natural(
    class_name="dumbbell",
    steps=1400, n_views=12, size=512, lr=0.05,
    # Fourier & low-pass schedule
    decay_power=2.4, contrast_start=1.1, contrast_end=1.6,
    lp_cutoff_start=0.20, lp_cutoff_end=0.95, lp_rolloff=10.0,
    # Priors (start -> end)
    corr_w_start=0.035, corr_w_end=0.008, corr_sigma_start=2.5, corr_sigma_end=0.7,
    lap_w_start=0.005,  lap_w_end=0.0,
    tv_w=5e-4, l2_w=1e-6,
    chroma_sat_w_start=0.030, chroma_sat_w_end=0.006, chroma_tv_w=0.004,
    blur_every=120
):
    target = class_idx_from_name(class_name)
    spec, freqs = make_fft_params(size, size)
    opt = torch.optim.Adam([spec], lr=lr)

    for t in range(steps):
        # progressive low-pass + contrast schedule
        cutoff_t   = cosine_anneal(lp_cutoff_start, lp_cutoff_end, t, steps)
        contrast_t = cosine_anneal(contrast_start, contrast_end, t, steps)
        img = spectrum_to_image(spec, freqs, decay_power=decay_power,
                                contrast=contrast_t, lp_cutoff=cutoff_t, lp_rolloff=lp_rolloff)

        if blur_every and t and t % blur_every == 0:
            img = gaussian_blur(img, sigma=0.8)

        # multi-view objective
        batch = torch.cat([random_view(img) for _ in range(n_views)], dim=0)
        logits = net(batch)
        cls = logits[:, target].mean()

        # schedules
        cw   = cosine_anneal(corr_w_start,      corr_w_end,      t, steps)
        lw   = cosine_anneal(lap_w_start,       lap_w_end,       t, steps)
        csig = cosine_anneal(corr_sigma_start,  corr_sigma_end,  t, steps)
        csw  = cosine_anneal(chroma_sat_w_start, chroma_sat_w_end, t, steps)

        # losses
        loss_corr  = corr_loss(img, sigma=csig)
        loss_lap   = laplacian_energy(img)
        loss_tv    = tv_l2(img)
        loss_l2    = ((img - 0.5)**2).mean()
        loss_csat  = chroma_saturation_loss(img)   # penalize neon colors
        loss_ctv   = chroma_tv_loss(img)           # keep chroma smooth

        loss = (cls
                - cw*loss_corr - lw*loss_lap - tv_w*loss_tv - l2_w*loss_l2
                - csw*loss_csat - chroma_tv_w*loss_ctv)

        opt.zero_grad(set_to_none=True)
        (-loss).backward()
        opt.step()

        # stabilize spectrum
        with torch.no_grad():
            spec[:, :, 0, 0, :] = 0    # remove DC drift
            spec.mul_(0.9995)          # gentle damping

    final = spectrum_to_image(spec, freqs, decay_power=decay_power,
                              contrast=contrast_end, lp_cutoff=lp_cutoff_end, lp_rolloff=lp_rolloff)
    return to_pil(final)

# Example:
out = synthesize_class("dumbbell", steps = 1200, lr = 0.01, decay_power=.25, corr_w_start=0.08, corr_sigma_start = 2, blur_every=60, contrast=0.5)
# out = synthesize_class("dumbbell", steps = 800, lr = 0.01, decay_power=.25, corr_w_start=0.02, corr_sigma_start = 1, blur_every=100, contrast=0.5)
# out = synthesize_class("dumbbell")
# out = synthesize_class_natural(
#     class_name="dumbbell",
#     steps=1300, n_views=12, size=512, lr=0.01,
#     decay_power=0.25,                     # <- key change
#     contrast_start=1.2, contrast_end=1.6,
#     lp_cutoff_start=0.35, lp_cutoff_end=0.95, lp_rolloff=10.0,
#     # NEW: add to your function signature + schedule inside:
#     # hp_cutoff_start=0.12, hp_cutoff_end=0.00,
#     corr_w_start=0.02, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
#     lap_w_start=0.003, lap_w_end=0.0,
#     tv_w=2e-4, l2_w=1e-6,
#     chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
#     blur_every=140
# )

out.save("dumbbell_natural.png")


In [None]:
cv_presets = {
  # 1) Balanced, natural-looking dumbbell (good baseline)
  "balanced_natural": dict(
    steps=1300, n_views=12, size=512, lr=0.05,
    decay_power=0.8, contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.35, lp_cutoff_end=0.95, lp_rolloff=10.0,
    corr_w_start=0.020, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
    blur_every=140
  ),

  # 2) “Arms amplifier” (pushes the dataset bias harder; more anatomy-like forms)
  "arms_amplifier": dict(
    steps=1600, n_views=14, size=512, lr=0.06,
    decay_power=0.6, contrast_start=1.1, contrast_end=1.5,
    lp_cutoff_start=0.30, lp_cutoff_end=0.95, lp_rolloff=8.0,
    corr_w_start=0.016, corr_w_end=0.004, corr_sigma_start=1.8, corr_sigma_end=0.6,
    lap_w_start=0.002, lap_w_end=0.0,
    tv_w=1.5e-4, l2_w=1e-6,
    chroma_sat_w_start=0.010, chroma_sat_w_end=0.003, chroma_tv_w=0.002,
    blur_every=160
  ),

  # 3) Soft pastel “dream” (hazy, gentle color; fewer neon artifacts)
  "pastel_soft_dream": dict(
    steps=1200, n_views=10, size=512, lr=0.045,
    decay_power=1.0, contrast_start=1.0, contrast_end=1.35,
    lp_cutoff_start=0.42, lp_cutoff_end=0.90, lp_rolloff=12.0,
    corr_w_start=0.040, corr_w_end=0.010, corr_sigma_start=2.8, corr_sigma_end=0.8,
    lap_w_start=0.004, lap_w_end=0.001,
    tv_w=3e-4, l2_w=1e-6,
    chroma_sat_w_start=0.040, chroma_sat_w_end=0.010, chroma_tv_w=0.006,
    blur_every=80
  ),

  # 4) Shape‑first, detail‑later (big coherent forms; add detail late)
  "shape_then_detail": dict(
    steps=1500, n_views=12, size=512, lr=0.05,
    decay_power=0.9, contrast_start=1.0, contrast_end=1.5,
    lp_cutoff_start=0.28, lp_cutoff_end=0.95, lp_rolloff=12.0,
    corr_w_start=0.030, corr_w_end=0.008, corr_sigma_start=2.4, corr_sigma_end=0.7,
    lap_w_start=0.004, lap_w_end=0.001,
    tv_w=2.5e-4, l2_w=1e-6,
    chroma_sat_w_start=0.020, chroma_sat_w_end=0.005, chroma_tv_w=0.004,
    blur_every=120
  ),

  # 5) Crisp(er) structure, less smoothing (expect sharper metal plates)
  "crisper_structure": dict(
    steps=1100, n_views=10, size=512, lr=0.055,
    decay_power=0.6, contrast_start=1.2, contrast_end=1.7,
    lp_cutoff_start=0.38, lp_cutoff_end=0.95, lp_rolloff=9.0,
    corr_w_start=0.012, corr_w_end=0.004, corr_sigma_start=1.6, corr_sigma_end=0.6,
    lap_w_start=0.001, lap_w_end=0.0,
    tv_w=1.2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.010, chroma_sat_w_end=0.003, chroma_tv_w=0.002,
    blur_every=160
  ),

  # 6) Low‑color (monochrome‑ish metal look). Requires raising grayscale in your aug.
  # If your function doesn’t expose it, set grayscale_p=0.35 inside random_view.
  "low_chroma_metal": dict(
    steps=1400, n_views=12, size=512, lr=0.05,
    decay_power=0.9, contrast_start=1.1, contrast_end=1.4,
    lp_cutoff_start=0.36, lp_cutoff_end=0.92, lp_rolloff=11.0,
    corr_w_start=0.030, corr_w_end=0.010, corr_sigma_start=2.2, corr_sigma_end=0.8,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.060, chroma_sat_w_end=0.020, chroma_tv_w=0.008,
    blur_every=100
  ),

  # 7) High‑resolution poster (bigger canvas; smoother dynamics)
  "highres_poster_768": dict(
    steps=1800, n_views=10, size=768, lr=0.045,
    decay_power=0.8, contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.30, lp_cutoff_end=0.95, lp_rolloff=10.0,
    corr_w_start=0.026, corr_w_end=0.008, corr_sigma_start=2.2, corr_sigma_end=0.7,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2.5e-4, l2_w=1e-6,
    chroma_sat_w_start=0.014, chroma_sat_w_end=0.005, chroma_tv_w=0.003,
    blur_every=180
  ),

  # 8) Fine detail early (let micro‑structure appear; more risk of noise)
  "detail_forward": dict(
    steps=1200, n_views=10, size=512, lr=0.055,
    decay_power=0.5, contrast_start=1.3, contrast_end=1.7,
    lp_cutoff_start=0.45, lp_cutoff_end=0.95, lp_rolloff=8.0,
    corr_w_start=0.010, corr_w_end=0.003, corr_sigma_start=1.4, corr_sigma_end=0.6,
    lap_w_start=0.000, lap_w_end=0.0,
    tv_w=1e-4, l2_w=1e-6,
    chroma_sat_w_start=0.008, chroma_sat_w_end=0.003, chroma_tv_w=0.002,
    blur_every=160
  ),

  # 9) Hard color clamp (for teaching how chroma priors tame neon)
  "color_clamp_strict": dict(
    steps=1300, n_views=12, size=512, lr=0.05,
    decay_power=0.9, contrast_start=1.1, contrast_end=1.5,
    lp_cutoff_start=0.34, lp_cutoff_end=0.92, lp_rolloff=12.0,
    corr_w_start=0.028, corr_w_end=0.010, corr_sigma_start=2.4, corr_sigma_end=0.7,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2.2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.080, chroma_sat_w_end=0.025, chroma_tv_w=0.010,
    blur_every=120
  ),

  # 10) Super‑soft focus (max dreaminess; very gentle detail)
  "super_soft_focus": dict(
    steps=1400, n_views=12, size=512, lr=0.045,
    decay_power=1.2, contrast_start=1.0, contrast_end=1.3,
    lp_cutoff_start=0.26, lp_cutoff_end=0.90, lp_rolloff=14.0,
    corr_w_start=0.050, corr_w_end=0.015, corr_sigma_start=3.0, corr_sigma_end=1.0,
    lap_w_start=0.005, lap_w_end=0.001,
    tv_w=3e-4, l2_w=1e-6,
    chroma_sat_w_start=0.030, chroma_sat_w_end=0.010, chroma_tv_w=0.007,
    blur_every=80
  ),

  # 11) Low‑noise “steel” (subtle colors, mid‑high structure)
  "steel_low_noise": dict(
    steps=1200, n_views=12, size=512, lr=0.05,
    decay_power=0.7, contrast_start=1.15, contrast_end=1.55,
    lp_cutoff_start=0.38, lp_cutoff_end=0.95, lp_rolloff=10.0,
    corr_w_start=0.024, corr_w_end=0.007, corr_sigma_start=2.0, corr_sigma_end=0.7,
    lap_w_start=0.002, lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.020, chroma_sat_w_end=0.006, chroma_tv_w=0.004,
    blur_every=120
  ),

  # 12) Light‑run (good for quick comparisons of settings)
  "light_run": dict(
    steps=700, n_views=8, size=448, lr=0.055,
    decay_power=0.8, contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.40, lp_cutoff_end=0.95, lp_rolloff=10.0,
    corr_w_start=0.016, corr_w_end=0.006, corr_sigma_start=1.8, corr_sigma_end=0.7,
    lap_w_start=0.002, lap_w_end=0.0,
    tv_w=1.8e-4, l2_w=1e-6,
    chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
    blur_every=120
  )
}

for name, p in cv_presets.items():
    out = synthesize_class_natural("dumbbell", **p)
    out.save(f"cv_{name}.png")


In [None]:
# --- Class-from-noise with BAND-PASS schedule (GoogLeNet + FFT) ---
# Requirements: torch, torchvision, PIL

import math, random, numpy as np
import torch, torch.nn.functional as F
from torchvision import models
import torchvision.transforms.functional as TF
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# =========================
# 1) Model (GoogLeNet)
# =========================
# With pretrained weights, GoogLeNet normalizes INSIDE the model (transform_input=True),
# so we must NOT pre-normalize externally.
weights = models.GoogLeNet_Weights.IMAGENET1K_V1
net = models.googlenet(weights=weights, aux_logits=True).to(device).eval()
for p in net.parameters(): p.requires_grad_(False)
CATEGORIES = weights.meta["categories"]

def class_idx_from_name(name, categories=CATEGORIES):
    name = name.lower()
    try:
        return categories.index(name)
    except ValueError:
        hits = [i for i, s in enumerate(categories) if name in s.lower()]
        if not hits: raise ValueError(f"'{name}' not in ImageNet categories")
        return hits[0]

# =========================
# 2) Fourier parameterization + band-pass masks
# =========================
def radial_freq(h, w, device):
    fy = torch.fft.fftfreq(h, d=1.0).to(device).reshape(h, 1)
    fx = torch.fft.rfftfreq(w, d=1.0).to(device).reshape(1, w//2 + 1)
    return torch.sqrt(fx*fx + fy*fy).clamp(min=1e-6)  # avoid divide-by-zero

def make_fft_params(h=512, w=512):
    # Real/imag spectrum with grads: [1, 3, H, W//2+1, 2]
    spec = torch.randn(1, 3, h, w//2 + 1, 2, device=device, requires_grad=True)
    freqs = radial_freq(h, w, device)
    return spec, freqs

def radial_lowpass_mask(freqs, cutoff, rolloff=8.0):
    # Butterworth-style LP; cutoff is relative to Nyquist in [0,1]
    r = (freqs / freqs.max()).clamp(min=1e-6)
    return 1.0 / (1.0 + (r / max(cutoff, 1e-6))**rolloff)

def radial_highpass_mask(freqs, cutoff, rolloff=8.0):
    # Butterworth-style HP; cutoff is relative to Nyquist in [0,1]
    r = (freqs / freqs.max()).clamp(min=1e-6)
    return 1.0 / (1.0 + (cutoff / r)**rolloff)

def spectrum_to_image_bp(
    spec, freqs,
    decay_power=0.8,         # 0.5–1.0 works well with band-pass
    contrast=1.4,
    lp_cutoff=0.45,          # low-pass cutoff (open highs gradually)
    hp_cutoff=0.10,          # high-pass cutoff (let lows in later)
    rolloff=10.0
):
    complex_spec = torch.view_as_complex(spec)                              # [1,3,H,W/2+1]
    lp = radial_lowpass_mask(freqs, cutoff=lp_cutoff, rolloff=rolloff)      # [H,W/2+1]
    hp = radial_highpass_mask(freqs, cutoff=hp_cutoff, rolloff=rolloff)
    bp = lp * hp
    scaled = complex_spec * (1.0 / (freqs ** decay_power)) * bp
    img = torch.fft.irfft2(scaled, s=(spec.shape[2], (spec.shape[3]-1)*2), norm='ortho')
    # Remove per-channel DC & map to [0,1] gently (avoid tanh saturation)
    img = img - img.mean(dim=(-2, -1), keepdim=True)
    img = torch.sigmoid(img * contrast)
    return img  # [1,3,H,W] in [0,1]

# =========================
# 3) Priors (correlation, Laplacian, chroma)
# =========================
def gaussian_kernel1d(sigma, device):
    r = max(1, int(round(3*sigma)))
    x = torch.arange(-r, r+1, device=device, dtype=torch.float32)
    k = torch.exp(-0.5*(x/sigma)**2); k = (k/k.sum()).view(1,1,-1)
    return k, r

def gaussian_blur(x, sigma):
    if sigma <= 0: return x
    k1d, r = gaussian_kernel1d(sigma, x.device)
    w_h = k1d.view(1,1,1,-1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(x, (r,r,0,0), mode='reflect'), w_h, groups=x.shape[1])
    w_v = k1d.view(1,1,-1,1).repeat(x.shape[1],1,1,1)
    v = F.conv2d(F.pad(v, (0,0,r,r), mode='reflect'), w_v, groups=x.shape[1])
    return v

def corr_loss(x, sigma=1.2):
    # Encourage x to equal its blurred self (neighbor correlation)
    return ((x - gaussian_blur(x, sigma))**2).mean()

def laplacian_energy(x):
    lap = torch.tensor([[0,1,0],[1,-4,1],[0,1,0]], dtype=torch.float32, device=x.device)
    w = lap.view(1,1,3,3).repeat(x.shape[1],1,1,1)
    y = F.conv2d(F.pad(x, (1,1,1,1), mode='reflect'), w, groups=x.shape[1])
    return (y.pow(2).mean())

def tv_l2(x):
    dx = x[..., 1:] - x[..., :-1]
    dy = x[..., :, 1:, :] - x[..., :, :-1, :]
    return (dx.pow(2).mean() + dy.pow(2).mean())

def rgb_to_yuv(x):
    r, g, b = x[:,0:1], x[:,1:2], x[:,2:3]
    y = 0.299*r + 0.587*g + 0.114*b
    u = b - y
    v = r - y
    return y, u, v

def chroma_saturation_loss(x, thresh=0.30):
    _, u, v = rgb_to_yuv(x)
    return (F.relu(u.abs()-thresh).pow(2).mean() +
            F.relu(v.abs()-thresh).pow(2).mean())

def chroma_tv_loss(x):
    _, u, v = rgb_to_yuv(x)
    def tv(a):
        dx = a[..., 1:] - a[..., :-1]
        dy = a[..., :, 1:, :] - a[..., :, :-1, :]
        return (dx.pow(2).mean() + dy.pow(2).mean())
    return tv(u) + tv(v)

# =========================
# 4) Multi-view augmentation
# =========================
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], device=device)[:, None, None]
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225], device=device)[:, None, None]

def normalize_for_model(v):
    # GoogLeNet with weights does its own normalization; skip external norm.
    return v if getattr(net, 'transform_input', False) else (v - IMAGENET_MEAN) / IMAGENET_STD

def random_view(
    x, out_hw=(224,224),
    jitter=16, rot=10, scale_range=(0.94,1.06),
    flip=True, noise_std=0.0, blur_sigma=(0.0,1.2), grayscale_p=0.15
):
    _, _, H, W = x.shape
    ox, oy = np.random.randint(-jitter, jitter+1, 2)
    v = torch.roll(x, shifts=(ox, oy), dims=(-1, -2))
    ang = float(random.uniform(-rot, rot))
    sc  = float(random.uniform(*scale_range))
    tr  = (int(random.uniform(-0.04*W, 0.04*W)), int(random.uniform(-0.04*H, 0.04*H)))
    v = TF.affine(v, angle=ang, translate=tr, scale=sc, shear=[0.0, 0.0])
    if flip and random.random() < 0.5: v = TF.hflip(v)
    # mild, random defocus per view
    if blur_sigma and blur_sigma[1] > 0:
        s = random.uniform(blur_sigma[0], blur_sigma[1])
        if s > 1e-3: v = gaussian_blur(v, s)
    v = F.interpolate(v, size=out_hw, mode='bilinear', align_corners=False)
    if grayscale_p>0 and random.random()<grayscale_p:
        y,_,_ = rgb_to_yuv(v); v = torch.cat([y,y,y], dim=1)
    if noise_std>0: v = (v + noise_std*torch.randn_like(v)).clamp(0,1)
    return normalize_for_model(v)

def cosine_anneal(start, end, t, T):
    return end + (start - end) * 0.5 * (1 + math.cos(math.pi * t / max(T,1)))

@torch.no_grad()
def to_pil(x): return TF.to_pil_image(x.squeeze(0).clamp(0,1).cpu())

# =========================
# 5) Main function (band-pass schedules)
# =========================
def synthesize_class_natural_bp(
    class_name="dumbbell",
    *,
    steps=1300, n_views=12, size=512, lr=0.05, seed=None,
    # FFT & band-pass schedule
    decay_power=0.8,                 # <= key; 0.5–1.0 recommended with band-pass
    contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.35, lp_cutoff_end=0.95,
    hp_cutoff_start=0.12, hp_cutoff_end=0.00,
    rolloff=10.0,
    # Priors (start -> end)
    corr_w_start=0.020, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
    lap_w_start=0.003,  lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
    blur_every=140,
    # Aug params exposed for convenience
    view_jitter=16, view_rot=10, view_scale=(0.94,1.06), grayscale_p=0.15
):
    if seed is not None:
        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)

    target = class_idx_from_name(class_name)
    spec, freqs = make_fft_params(size, size)
    opt = torch.optim.Adam([spec], lr=lr)

    for t in range(steps):
        # Band-pass + contrast schedules
        lp_t = cosine_anneal(lp_cutoff_start, lp_cutoff_end, t, steps)
        hp_t = cosine_anneal(hp_cutoff_start, hp_cutoff_end, t, steps)
        contrast_t = cosine_anneal(contrast_start, contrast_end, t, steps)

        img = spectrum_to_image_bp(
            spec, freqs,
            decay_power=decay_power, contrast=contrast_t,
            lp_cutoff=lp_t, hp_cutoff=hp_t, rolloff=rolloff
        )

        if blur_every and t and t % blur_every == 0:
            img = gaussian_blur(img, sigma=0.8)

        # Transform-robust objective: average class logit over random views
        batch = torch.cat([
            random_view(
                img, out_hw=(224,224),
                jitter=view_jitter, rot=view_rot, scale_range=view_scale,
                blur_sigma=(0.0,1.2), grayscale_p=grayscale_p
            )
            for _ in range(n_views)
        ], dim=0)
        logits = net(batch)
        cls = logits[:, target].mean()

        # Priors + schedules
        cw   = cosine_anneal(corr_w_start, corr_w_end, t, steps)
        lw   = cosine_anneal(lap_w_start,  lap_w_end,  t, steps)
        csig = cosine_anneal(corr_sigma_start, corr_sigma_end, t, steps)
        csw  = cosine_anneal(chroma_sat_w_start, chroma_sat_w_end, t, steps)

        loss_corr = corr_loss(img, sigma=csig)
        loss_lap  = laplacian_energy(img)
        loss_tv   = tv_l2(img)
        loss_l2   = ((img - 0.5)**2).mean()
        loss_csat = chroma_saturation_loss(img)  # clamp neon
        loss_ctv  = chroma_tv_loss(img)          # smooth chroma

        loss = (cls
                - cw*loss_corr - lw*loss_lap - tv_w*loss_tv - l2_w*loss_l2
                - csw*loss_csat - chroma_tv_w*loss_ctv)

        opt.zero_grad(set_to_none=True)
        (-loss).backward()
        opt.step()

        # Stabilize spectrum
        with torch.no_grad():
            spec[:, :, 0, 0, :] = 0    # no DC drift
            spec.mul_(0.9995)          # gentle amplitude damping

    # Final image at relaxed band-pass
    final = spectrum_to_image_bp(
        spec, freqs,
        decay_power=decay_power, contrast=contrast_end,
        lp_cutoff=lp_cutoff_end, hp_cutoff=hp_cutoff_end, rolloff=rolloff
    )
    return to_pil(final)

# =========================
# 6) Example usage
# =========================
out = synthesize_class_natural_bp(
    "dumbbell",
    steps=1200, n_views=12, size=512, lr=0.01, seed=0,
    decay_power=0.1,
    contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.35, lp_cutoff_end=0.95,
    hp_cutoff_start=0.12, hp_cutoff_end=0.00,
    rolloff=10.0,
    corr_w_start=0.020, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
    lap_w_start=0.003,  lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
    blur_every=100,
)
out.save("dumbbell_bp.png")


In [None]:
cv_bp_presets = {
  # Keeps mid-frequencies early, opens both ends gradually (great for structure)
  "bp_midband_balanced": dict(
    steps=1300, n_views=12, size=512, lr=0.05,
    decay_power=0.8, contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.35, lp_cutoff_end=0.95, hp_cutoff_start=0.12, hp_cutoff_end=0.00, lp_rolloff=10.0,
    corr_w_start=0.020, corr_w_end=0.006, corr_sigma_start=2.0, corr_sigma_end=0.6,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.012, chroma_sat_w_end=0.004, chroma_tv_w=0.003,
    blur_every=140
  ),

  # Strong shape bias (blocks very low & very high freqs initially)
  "bp_shape_first": dict(
    steps=1500, n_views=12, size=512, lr=0.05,
    decay_power=0.9, contrast_start=1.1, contrast_end=1.5,
    lp_cutoff_start=0.32, lp_cutoff_end=0.95, hp_cutoff_start=0.18, hp_cutoff_end=0.02, lp_rolloff=12.0,
    corr_w_start=0.030, corr_w_end=0.008, corr_sigma_start=2.4, corr_sigma_end=0.7,
    lap_w_start=0.004, lap_w_end=0.001,
    tv_w=2.2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.016, chroma_sat_w_end=0.005, chroma_tv_w=0.004,
    blur_every=120
  ),

  # Detail early, but still controlled (less neon than pure low-pass)
  "bp_detail_forward": dict(
    steps=1200, n_views=10, size=512, lr=0.055,
    decay_power=0.6, contrast_start=1.3, contrast_end=1.7,
    lp_cutoff_start=0.45, lp_cutoff_end=0.95, hp_cutoff_start=0.08, hp_cutoff_end=0.00, lp_rolloff=8.0,
    corr_w_start=0.012, corr_w_end=0.004, corr_sigma_start=1.6, corr_sigma_end=0.6,
    lap_w_start=0.001, lap_w_end=0.0,
    tv_w=1.2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.010, chroma_sat_w_end=0.003, chroma_tv_w=0.002,
    blur_every=160
  ),

  # Pastel dream with band-pass
  "bp_pastel": dict(
    steps=1200, n_views=10, size=512, lr=0.045,
    decay_power=1.0, contrast_start=1.0, contrast_end=1.35,
    lp_cutoff_start=0.40, lp_cutoff_end=0.90, hp_cutoff_start=0.15, hp_cutoff_end=0.02, lp_rolloff=12.0,
    corr_w_start=0.040, corr_w_end=0.010, corr_sigma_start=2.8, corr_sigma_end=0.8,
    lap_w_start=0.004, lap_w_end=0.001,
    tv_w=3e-4, l2_w=1e-6,
    chroma_sat_w_start=0.040, chroma_sat_w_end=0.010, chroma_tv_w=0.006,
    blur_every=100
  ),

  # Punchy color but still controlled (looser chroma clamp)
  "bp_vivid": dict(
    steps=1100, n_views=10, size=512, lr=0.055,
    decay_power=0.7, contrast_start=1.25, contrast_end=1.75,
    lp_cutoff_start=0.42, lp_cutoff_end=0.95, hp_cutoff_start=0.10, hp_cutoff_end=0.00, lp_rolloff=9.0,
    corr_w_start=0.010, corr_w_end=0.003, corr_sigma_start=1.6, corr_sigma_end=0.6,
    lap_w_start=0.001, lap_w_end=0.0,
    tv_w=1.2e-4, l2_w=1e-6,
    chroma_sat_w_start=0.008, chroma_sat_w_end=0.003, chroma_tv_w=0.002,
    blur_every=160
  ),

  # Big poster canvas with band-pass
  "bp_highres_768": dict(
    steps=1800, n_views=10, size=768, lr=0.045,
    decay_power=0.8, contrast_start=1.2, contrast_end=1.6,
    lp_cutoff_start=0.32, lp_cutoff_end=0.95, hp_cutoff_start=0.14, hp_cutoff_end=0.00, lp_rolloff=10.0,
    corr_w_start=0.026, corr_w_end=0.008, corr_sigma_start=2.2, corr_sigma_end=0.7,
    lap_w_start=0.003, lap_w_end=0.0,
    tv_w=2.5e-4, l2_w=1e-6,
    chroma_sat_w_start=0.014, chroma_sat_w_end=0.005, chroma_tv_w=0.003,
    blur_every=180
  )
}

for name, p in cv_bp_presets.items():
    out = synthesize_class_natural_bp("dumbbell", **p)
    out.save(f"cv_{name}.png")

TypeError: synthesize_class_natural_bp() got an unexpected keyword argument 'lp_rolloff'

In [None]:
# deepdream_dumbbells.py
import math, torch, torch.nn.functional as F
from torchvision import models, transforms
from torchvision.transforms.functional import resize
from PIL import Image
import numpy as np
torch.set_grad_enabled(True)

# ---- Config ----
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
MODEL        = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1).to(DEVICE).eval()
TARGET_CLASS = 543  # "dumbbell" in ImageNet-1k mapping
STEPS        = 600   # per octave
STEP_SIZE    = 0.01
TV_WEIGHT    = 1e-4
NUM_OCTAVES  = 5
OCTAVE_SCALE = 1.35
JITTER       = 16
SEED         = 3

# InceptionV3 expects 299x299 and specific normalization
preproc = models.Inception_V3_Weights.IMAGENET1K_V1.transforms()
mean = torch.tensor([0.485, 0.456, 0.406], device=DEVICE).view(1,3,1,1)
std  = torch.tensor([0.229, 0.224, 0.225], device=DEVICE).view(1,3,1,1)

g = torch.Generator(device=DEVICE).manual_seed(SEED)
# start from noise near “natural” stats
h, w = 299, 299
img = torch.clamp(torch.randn(1,3,h,w, device=DEVICE, generator=g)*0.2 + 0.5, 0, 1).requires_grad_(True)

def tv_loss(x):
    dx = x[:,:,1:,:] - x[:,:,:-1,:]
    dy = x[:,:,:,1:] - x[:,:,:,:-1]
    return (dx.pow(2).mean() + dy.pow(2).mean())

@torch.no_grad()
def to_pil(x):
    x = torch.clamp(x, 0, 1).mul(255).byte().squeeze(0).permute(1,2,0).cpu().numpy()
    return Image.fromarray(x)

def ascend(img, steps):
    for _ in range(steps):
        # jitter
        ox = int(torch.randint(-JITTER, JITTER+1, (1,), generator=g, device=DEVICE))
        oy = int(torch.randint(-JITTER, JITTER+1, (1,), generator=g, device=DEVICE))
        img.data = torch.roll(torch.roll(img.data, shifts=ox, dims=2), shifts=oy, dims=3)

        # forward
        x = (img - mean)/std
        logits = MODEL(x)
        if isinstance(logits, tuple):  # inception_v3 returns (logits, aux) sometimes
            logits = logits[0]
        score = logits[:, TARGET_CLASS].mean()
        loss = -score + TV_WEIGHT*tv_loss(img)

        # backward
        loss.backward()
        with torch.no_grad():
            gnorm = img.grad.norm().clamp(min=1e-8)
            img += STEP_SIZE * img.grad / gnorm
            img.grad.zero_()

        # undo jitter
        img.data = torch.roll(torch.roll(img.data, shifts=-ox, dims=2), shifts=-oy, dims=3)
        img.data.clamp_(0,1)
    return img

# Multi-octave (coarse→fine)
for o in range(NUM_OCTAVES):
    img = ascend(img, STEPS)
    # upscale for next octave
    if o < NUM_OCTAVES-1:
        nh, nw = int(h*(OCTAVE_SCALE**(o+1))), int(w*(OCTAVE_SCALE**(o+1)))
        with torch.no_grad():
            img = resize(img, [nh, nw], antialias=True).requires_grad_(True)

# Save multiple crops (to mimic the grid in the blog)
grid = []
base = img.detach()
for k in range(9):
    # Small random crop back to 299x299 and short refinement
    with torch.no_grad():
        H, W = base.shape[-2:]
        top  = int(torch.randint(0, max(1, H-299), (1,), generator=g, device=DEVICE))
        left = int(torch.randint(0, max(1, W-299), (1,), generator=g, device=DEVICE))
        crop = base[..., top:top+299, left:left+299].clone().requires_grad_(True)
    crop = ascend(crop, steps=20)
    grid.append(to_pil(crop))

# Make a simple 3x3 grid
def make_grid(images, rows=3, cols=3, pad=4):
    w, h = images[0].size
    canvas = Image.new("RGB", (cols*w + (cols-1)*pad, rows*h + (rows-1)*pad), (255,255,255))
    i = 0
    for r in range(rows):
        for c in range(cols):
            canvas.paste(images[i], (c*(w+pad), r*(h+pad)))
            i += 1
    return canvas

grid_img = make_grid(grid, 3, 3, 6)
grid_img.save("deepdream_dumbbells.png")
print("Wrote deepdream_dumbbells.png")


Wrote deepdream_dumbbells.png
