In [None]:
# =========================================
# 神经风格迁移 - PyTorch 实现（修复版）
# =========================================
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
import matplotlib.pyplot as plt
import os

# =========================================
# 配置
# =========================================
class CONFIG:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    IMG_SIZE = 512
    OUTPUT_DIR = "output"
    os.makedirs(OUTPUT_DIR, exist_ok=True)

# =========================================
# 工具函数 (等价于 nst_utils)
# =========================================
def image_loader(path, imsize=CONFIG.IMG_SIZE):
    loader = transforms.Compose([
        transforms.Resize(imsize),
        transforms.CenterCrop(imsize),
        transforms.ToTensor()
    ])
    img = Image.open(path).convert("RGB")
    img = loader(img).unsqueeze(0)  # [1,3,H,W]
    return img.to(CONFIG.DEVICE, torch.float)

def imshow(tensor, title=None):
    img = tensor.cpu().clone().squeeze(0)
    img = transforms.ToPILImage()(img)
    plt.imshow(img)
    if title:
        plt.title(title)
    plt.axis("off")
    plt.show()

def save_image(tensor, path):
    img = tensor.clone().cpu().squeeze(0)
    vutils.save_image(img, path)

# =========================================
# 模型准备
# =========================================
cnn = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(CONFIG.DEVICE).eval()

# 禁用 inplace ReLU
for module in cnn.modules():
    if isinstance(module, nn.ReLU):
        module.inplace = False

# =========================================
# 特征提取 & Gram 矩阵
# =========================================
def get_features(x, model, layers):
    features = {}
    cur = x
    for name, layer in model._modules.items():
        cur = layer(cur)
        if name in layers:
            features[name] = cur
    return features

def gram_of_feature(feat):
    b, C, H, W = feat.shape
    F = feat.view(C, H*W)
    G = torch.mm(F, F.t())
    return G, C, H, W

# =========================================
# 内容 & 风格层定义
# =========================================
STYLE_LAYER_IDS = ['0','5','10','19','28']
CONTENT_LAYER_ID = '21'

# =========================================
# 损失函数
# =========================================
def compute_content_cost(a_C, a_G):
    return torch.mean((a_C - a_G)**2)

def style_loss_from_gram(G_target, G_current, C, H, W):
    denom = 4.0 * (C**2) * (H*W)**2
    return torch.sum((G_target - G_current)**2) / denom

def total_cost(J_content, J_style, alpha=10, beta=40):
    return alpha * J_content + beta * J_style

# =========================================
# 预处理 target
# =========================================
def prepare_targets(cnn, content_img, style_img):
    needed = STYLE_LAYER_IDS + [CONTENT_LAYER_ID]
    content_feats = get_features(content_img, cnn, needed)
    style_feats = get_features(style_img, cnn, needed)

    content_target = content_feats[CONTENT_LAYER_ID].detach().clone()

    style_targets = {}
    for lid in STYLE_LAYER_IDS:
        G, C, H, W = gram_of_feature(style_feats[lid])
        style_targets[lid] = {'G': G.detach().clone(), 'C': C, 'H': H, 'W': W}

    return content_target, style_targets

# =========================================
# 风格迁移主函数
# =========================================
def run_style_transfer(cnn, content_img, style_img, input_img,
                       num_steps=200, style_weight=40, content_weight=10):
    for p in cnn.parameters():
        p.requires_grad = False
    cnn.eval()

    content_target, style_targets = prepare_targets(cnn, content_img, style_img)

    generated = input_img.clone().to(CONFIG.DEVICE)
    generated.requires_grad_(True)

    optimizer = optim.LBFGS([generated], max_iter=20)
    run = [0]

    while run[0] < num_steps:
        def closure():
            optimizer.zero_grad()
            gen_feats = get_features(generated, cnn, STYLE_LAYER_IDS + [CONTENT_LAYER_ID])

            a_C = content_target
            a_G = gen_feats[CONTENT_LAYER_ID]
            J_content = compute_content_cost(a_C, a_G)

            J_style = 0.0
            for lid, coeff in zip(STYLE_LAYER_IDS, [0.2]*len(STYLE_LAYER_IDS)):
                G_current, C, H, W = gram_of_feature(gen_feats[lid])
                target = style_targets[lid]['G']
                J_style += coeff * style_loss_from_gram(target, G_current, C, H, W)

            J = total_cost(J_content, J_style, alpha=content_weight, beta=style_weight)
            J.backward()

            if run[0] % 20 == 0:
                print(f"Iteration {run[0]}:")
                print(f"  Total loss: {J.item():.4f}, Content: {J_content.item():.4f}, Style: {J_style.item():.4f}")
                save_path = os.path.join(CONFIG.OUTPUT_DIR, f"step_{run[0]}.png")
                save_image(generated.detach(), save_path)
                print(f"  Saved intermediate result to {save_path}")

            run[0] += 1
            return J

        optimizer.step(closure)

    return generated.detach()

# =========================================
# 主流程
# =========================================
content_img = image_loader("images/louvre.jpg")
style_img = image_loader("images/monet_800600.jpg")
generated_img = content_img.clone()

output = run_style_transfer(cnn, content_img, style_img, generated_img, num_steps=200)

imshow(output, title="Final Generated Image")
save_image(output, os.path.join(CONFIG.OUTPUT_DIR, "final.png"))
print("最终结果已保存到 output/final.png")
