# Style Blend Experiments
Mix expressive vs neutral styles while preserving lip-sync.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load neutral (torch-saved .npz)
target_style_neutral = torch.load("/mnt/fasttalk/demo/styles/style_2.npz", map_location=device)

# Load expressive (npz with exp/pose/jaw)
style = np.load("/mnt/Datasets/expressive_ft/synthetic_dataset/npz/video_015.npz") #/mnt/Datasets/expressive_ft/synthetic_dataset_v2/npz/video_000.npz
exp = style["exp"].reshape(-1, 50)
gpose = style["pose"].reshape(-1, 3)
jaw = style["jaw"].reshape(-1, 3)

# eyelids: all ones
eyelids = np.ones((exp.shape[0], 2), dtype=np.float32)

# concat to [T,58]
concat = np.concatenate([exp, gpose, jaw, eyelids], axis=1)

# pytorch tensor
target_style_tensor = torch.from_numpy(concat).float()
target_style_expressive = target_style_tensor.unsqueeze(0).to(device)  # [1,T,58]

def to_numpy(x):
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x)

neutral = to_numpy(target_style_neutral)
expressive = to_numpy(target_style_expressive)

# Expect [1, T, 58] -> drop batch dim
if neutral.ndim == 3 and neutral.shape[0] == 1:
    neutral = neutral[0]
if expressive.ndim == 3 and expressive.shape[0] == 1:
    expressive = expressive[0]

# Ensure [T, 58]
if neutral.ndim > 2:
    neutral = neutral.reshape(neutral.shape[0], -1)
if expressive.ndim > 2:
    expressive = expressive.reshape(expressive.shape[0], -1)

print('neutral:', neutral.shape, 'expressive:', expressive.shape)

In [None]:
# Auto-discover which expression components affect mouth vs brows
import pickle
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path('/mnt/fasttalk')
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from flame_model.FLAME import FLAMEModel

mask_path = "/mnt/fasttalk/flame_model/assets/FLAME_masks.pkl"
with open(mask_path, "rb") as f:
    masks = pickle.load(f, encoding="latin1") if hasattr(pickle, "load") else pickle.load(f)

def _find_mask_key(keywords):
    for k in masks.keys():
        lk = k.lower()
        if any(word in lk for word in keywords):
            return k
    return None

mouth_key = _find_mask_key(["lip", "mouth"] )
brow_key = _find_mask_key(["brow", "eyebrow"])
if brow_key is None:
    # fallback proxies for brow area
    brow_key = _find_mask_key(["forehead", "eye_region", "eye region"] )
print("mask keys:", list(masks.keys()))
print("mouth_key:", mouth_key, "brow_key:", brow_key)

def _to_index_list(mask):
    arr = np.asarray(mask)
    if arr.dtype == bool:
        return np.where(arr)[0]
    return arr.astype(int).ravel()

if mouth_key and brow_key:
    mouth_verts = _to_index_list(masks[mouth_key])
    brow_verts = _to_index_list(masks[brow_key])
    flame_tmp = FLAMEModel(n_shape=300, n_exp=50).to(device)
    exp_dirs = flame_tmp.shapedirs[:, :, -50:]  # (V, 3, 50)

    mouth_energy = []
    brow_energy = []
    for i in range(exp_dirs.shape[-1]):
        d = exp_dirs[:, :, i]
        mouth_energy.append(d[mouth_verts].norm(dim=1).mean().item())
        brow_energy.append(d[brow_verts].norm(dim=1).mean().item())
    mouth_energy = np.array(mouth_energy)
    brow_energy = np.array(brow_energy)
    ratio = mouth_energy / (brow_energy + 1e-8)
    # Pick mouth-dominant components
    mouth_exp_idx = np.argsort(ratio)[-10:]
    brow_exp_idx = np.argsort(ratio)[:10]
    print("mouth_exp_idx:", mouth_exp_idx)
    print("brow_exp_idx:", brow_exp_idx)
else:
    mouth_exp_idx = np.array([], dtype=int)
    brow_exp_idx = np.array([], dtype=int)

In [None]:
# Align lengths for neutral and expressive base styles
min_len = min(neutral.shape[0], expressive.shape[0])
neutral_aligned = neutral[:min_len]
expressive_aligned = expressive[:min_len]

print('neutral_aligned:', neutral_aligned.shape)
print('expressive_aligned:', expressive_aligned.shape)

In [None]:
# Compare first 250 frames, first 20 params (neutral vs expressive)
n_frames = min(250, neutral.shape[0], expressive.shape[0])
n_params = 20

neutral_250_10 = neutral[:n_frames, :n_params]
expressive_250_10 = expressive[:n_frames, :n_params]

fig, axes = plt.subplots(n_params, 1, figsize=(10, 2 * n_params), sharex=True)
if n_params == 1:
    axes = [axes]
for i in range(n_params):
    axes[i].plot(neutral_250_10[:, i], label='neutral', alpha=0.8)
    axes[i].plot(expressive_250_10[:, i], label='expressive', alpha=0.8)
    axes[i].set_ylabel(f'p{i}')
axes[-1].set_xlabel('frame')
axes[0].legend(loc='upper right')
plt.tight_layout()

In [None]:
# Summary stats for first 250 frames, first 20 params (neutral vs expressive)
def stats(x):
    return x.mean(axis=0), x.std(axis=0)

n_mean, n_std = stats(neutral_250_10)
e_mean, e_std = stats(expressive_250_10)

print('Neutral mean/std:', n_mean, n_std)
print('Expressive mean/std:', e_mean, e_std)

In [None]:
# Bar chart: mean ± std per parameter (neutral vs expressive)
params = np.arange(n_params)
x = np.arange(n_params)
width = 0.35

fig, ax = plt.subplots(figsize=(12, 4))
ax.bar(x - width / 2, n_mean, width, yerr=n_std, capsize=3, label="neutral")
ax.bar(x + width / 2, e_mean, width, yerr=e_std, capsize=3, label="expressive")
ax.set_xticks(x)
ax.set_xticklabels([f"p{i}" for i in params])
ax.set_ylabel("value")
ax.set_title("Mean ± Std per parameter (first 20)")
ax.legend()
plt.tight_layout()

In [None]:
# Setup renderer utilities (ensure render() is available)

from renderer.renderer import Renderer
from pytorch3d.transforms import matrix_to_euler_angles
import matplotlib.animation as animation
import os
from pathlib import Path
# change base folder
os.chdir('../')

if "flame" not in globals():
    flame = FLAMEModel(n_shape=300, n_exp=50).to(device)
if "renderer" not in globals():
    renderer = Renderer(render_full_head=True).to(device)

def get_vertices_from_blendshapes(expr, gpose, jaw, eyelids):
    expr_tensor = expr.to(device)
    gpose_tensor = gpose.to(device)
    jaw_tensor = jaw.to(device)
    eyelids_tensor = eyelids.to(device)

    target_shape_tensor = torch.zeros(expr_tensor.shape[0], 300).expand(expr_tensor.shape[0], -1).to(device)
    I = matrix_to_euler_angles(torch.cat([torch.eye(3)[None]], dim=0), "XYZ").to(device)
    eye_r = I.clone().to(device).squeeze()
    eye_l = I.clone().to(device).squeeze()
    eyes = torch.cat([eye_r, eye_l], dim=0).expand(expr_tensor.shape[0], -1).to(device)
    pose = torch.cat([gpose_tensor, jaw_tensor], dim=-1).to(device)

    flame_output_only_shape, _ = flame.forward(
        shape_params=target_shape_tensor,
        expression_params=expr_tensor,
        pose_params=pose,
        eye_pose_params=eyes
    )
    return flame_output_only_shape.detach()

def render(blendshapes_pred, name="mixed_style"):
    expr_pr = blendshapes_pred[:, :50]
    gpose_pr = blendshapes_pred[:, 50:53]
    jaw_pr = blendshapes_pred[:, 53:56]
    eyelids_pr = blendshapes_pred[:, 56:]

    verts_pr = get_vertices_from_blendshapes(expr_pr, gpose_pr, jaw_pr, eyelids_pr)
    cam = torch.tensor([5, 0, 0], dtype=torch.float32).unsqueeze(0).to(verts_pr.device)
    cam = cam.expand(verts_pr.shape[0], -1)
    frames_pr = renderer.forward(verts_pr, cam)["rendered_img"]

    import os
    output_dir = "/mnt/fasttalk/demo/style_previews"
    os.makedirs(output_dir, exist_ok=True)
    video_file = f"{output_dir}/{name}.mp4"

    def update(frame_idx, pr_seq, axes):
        pr = pr_seq[frame_idx].detach().cpu().numpy().transpose(1, 2, 0)
        scaled = (pr * 255).astype(np.uint8)
        axes.clear()
        axes.imshow(scaled)
        axes.axis("off")

    fig, ax = plt.subplots(figsize=(5, 5))
    ani = animation.FuncAnimation(
        fig,
        update,
        frames=frames_pr.shape[0],
        fargs=(frames_pr, ax),
        interval=100
    )
    ani.save(video_file, writer='ffmpeg', fps=25)
    plt.close(fig)
    print(f"Saved render to {video_file}")

In [None]:
# Lip-safe microexpression transfer with automatic scoring + alpha fitting
if "neutral_aligned" not in globals() or "expressive_aligned" not in globals():
    min_len = min(neutral.shape[0], expressive.shape[0])
    neutral_aligned = neutral[:min_len]
    expressive_aligned = expressive[:min_len]

assert neutral_aligned.shape == expressive_aligned.shape
T, D = neutral_aligned.shape

# Expression dims: [0..49], global pose: [50..52], jaw: [53..55], eyelids: [56..57]
all_idx = np.arange(D)
pose_jaw_idx = np.arange(50, min(56, D))
exp_idx = np.arange(min(50, D))

if "mouth_exp_idx" in globals() and len(mouth_exp_idx) > 0:
    mouth_exp_idx_safe = np.array([i for i in mouth_exp_idx if 0 <= i < min(50, D)], dtype=int)
else:
    mouth_exp_idx_safe = np.array([0, 4, 11], dtype=int)
    mouth_exp_idx_safe = mouth_exp_idx_safe[mouth_exp_idx_safe < min(50, D)]

# Brows/non-mouth expression channels are primary injection targets
if "brow_exp_idx" in globals() and len(brow_exp_idx) > 0:
    brow_exp_idx_safe = np.array([i for i in brow_exp_idx if 0 <= i < min(50, D)], dtype=int)
else:
    brow_exp_idx_safe = np.setdiff1d(exp_idx, mouth_exp_idx_safe)[:10]

protected_idx = np.unique(np.concatenate([mouth_exp_idx_safe, pose_jaw_idx]))
inject_idx = np.setdiff1d(all_idx, protected_idx)

print("Protected dims:", protected_idx.tolist())
print("Inject dims count:", len(inject_idx))

def moving_average_1d(x, win=25):
    if win <= 1:
        return x.copy()
    if win % 2 == 0:
        win += 1
    pad = win // 2
    xpad = np.pad(x, ((pad, pad),), mode="edge")
    kernel = np.ones(win, dtype=np.float32) / float(win)
    return np.convolve(xpad, kernel, mode="valid")

def moving_average_2d(arr, win=25):
    out = np.zeros_like(arr)
    for j in range(arr.shape[1]):
        out[:, j] = moving_average_1d(arr[:, j], win=win)
    return out

# High-frequency residual from expressive = micro-dynamics carrier
expr_low = moving_average_2d(expressive_aligned, win=31)
expr_res = expressive_aligned - expr_low
expr_res = moving_average_2d(expr_res, win=5)  # denoise tiny spikes

# Per-channel normalization
n_std = neutral_aligned.std(axis=0) + 1e-6
e_res_std = expr_res.std(axis=0) + 1e-6
norm_gain = np.clip(n_std / e_res_std, 0.2, 3.0)

# Extra boost only on brow/non-mouth expression channels
pop_gain = np.ones(D, dtype=np.float32)
pop_gain[brow_exp_idx_safe] = 2.2

global_cap = 3.5

def build_mixed(alpha_micro, alpha_low):
    mixed_tmp = neutral_aligned.copy()

    delta = alpha_micro * expr_res * norm_gain[None, :] * pop_gain[None, :]
    delta[:, protected_idx] = 0.0
    delta = np.clip(delta, -global_cap * n_std[None, :], global_cap * n_std[None, :])
    mixed_tmp += delta

    expr_nonmouth_low = moving_average_2d(expressive_aligned[:, inject_idx], win=41)
    neut_nonmouth_low = moving_average_2d(neutral_aligned[:, inject_idx], win=41)
    mixed_tmp[:, inject_idx] += alpha_low * (expr_nonmouth_low - neut_nonmouth_low)
    return mixed_tmp

def rms(x):
    return np.sqrt(np.mean(np.square(x), axis=0))

# Metrics:
# - keep mouth close to neutral (lower is better)
# - increase brow/non-mouth dynamics over neutral (higher is better)
def score_candidate(mixed_tmp):
    mouth_motion = rms(np.diff(mixed_tmp[:, mouth_exp_idx_safe], axis=0)) if len(mouth_exp_idx_safe) else np.array([0.0])
    mouth_neut = rms(np.diff(neutral_aligned[:, mouth_exp_idx_safe], axis=0)) if len(mouth_exp_idx_safe) else np.array([1.0])
    lip_deviation_ratio = float(np.mean(mouth_motion / (mouth_neut + 1e-6)))

    brow_motion = rms(np.diff(mixed_tmp[:, brow_exp_idx_safe], axis=0)) if len(brow_exp_idx_safe) else np.array([0.0])
    brow_neut = rms(np.diff(neutral_aligned[:, brow_exp_idx_safe], axis=0)) if len(brow_exp_idx_safe) else np.array([1.0])
    express_gain_ratio = float(np.mean(brow_motion / (brow_neut + 1e-6)))

    # Favor high express_gain_ratio and lip ratio close to 1.0
    objective = express_gain_ratio - 0.8 * abs(lip_deviation_ratio - 1.0)
    return objective, lip_deviation_ratio, express_gain_ratio

alpha_micro_grid = np.array([0.35, 0.50, 0.70, 0.90, 1.10, 1.30], dtype=np.float32)
alpha_low_grid = np.array([0.05, 0.08, 0.12, 0.16], dtype=np.float32)

best = None
for a_m in alpha_micro_grid:
    for a_l in alpha_low_grid:
        cand = build_mixed(alpha_micro=float(a_m), alpha_low=float(a_l))
        obj, lip_ratio, expr_ratio = score_candidate(cand)
        row = {
            "alpha_micro": float(a_m),
            "alpha_low": float(a_l),
            "objective": float(obj),
            "lip_ratio": float(lip_ratio),
            "expr_ratio": float(expr_ratio),
            "mixed": cand,
        }
        if best is None or row["objective"] > best["objective"]:
            best = row

alpha_micro = best["alpha_micro"]
alpha_low = best["alpha_low"]
mixed_micro = best["mixed"]

print("Best alpha_micro:", alpha_micro, "alpha_low:", alpha_low)
print("lip_ratio (target~1.0):", round(best["lip_ratio"], 3), "expr_ratio (>1 better):", round(best["expr_ratio"], 3))
print("mixed_micro:", mixed_micro.shape)

# Optional hard push if still not expressive enough
force_pop = True
if force_pop:
    hard_boost = 1.35
    mixed_micro[:, brow_exp_idx_safe] = neutral_aligned[:, brow_exp_idx_safe] + hard_boost * (mixed_micro[:, brow_exp_idx_safe] - neutral_aligned[:, brow_exp_idx_safe])
    print("Applied hard brow boost:", hard_boost)

## Alternative: lip-safe microexpression transfer
Keep neutral lip-sync channels fixed, and inject only high-frequency expressive residuals into non-mouth channels.

In [None]:
# Render neutral, expressive, and micro-enhanced styles
neutral_tensor = torch.from_numpy(neutral_aligned).float().to(device)
expressive_tensor = torch.from_numpy(expressive_aligned).float().to(device)
mixed_micro_tensor = torch.from_numpy(mixed_micro).float().to(device)

render(neutral_tensor, name="neutral_style")
render(expressive_tensor, name="expressive_style")
render(mixed_micro_tensor, name="micro_style")

In [None]:
# Save micro-enhanced style
import os
os.makedirs('/mnt/fasttalk/demo/new_styles', exist_ok=True)

np.save('/mnt/fasttalk/demo/new_styles/micro_style.npy', mixed_micro)

print('Saved: /mnt/fasttalk/demo/new_styles/micro_style.npy')