In [None]:
import torch
from torch import nn 
from torch import optim
from PIL import Image

from torchvision import transforms
from torchvision import models
from torchvision.utils import save_image

model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features

In [None]:
print(model)

# We want the conv layers, therefore 0, 5, 10, 19 and 28

In [None]:
class NST_VGG(nn.Module):
    def __init__(self):
        super().__init__()

        self.chosen_features = ['0', '5', '10', '19', '28']
        self.model = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features[:29]

    def forward(self, x):
        features = []

        for layer_num, layer in enumerate(self.model):
            x = layer(x)

            if str(layer_num) in self.chosen_features:
                features.append(x)

        return features


In [None]:
img_size = 256

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loader = transforms.Compose(
    [
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[], std=[])
    ]
)


def load_image(img_name):
    image = Image.open(img_name).convert('RGB')
    image = loader(image).unsqueeze(0)
    return image.to(device)

In [None]:
original_img = load_image('/home/heitor/datasets/filter_network/train/original_images/bn48k_44039_11-00_r.png')
style_img = load_image('/home/heitor/datasets/filter_network/train/output_cartoon/sd258_069_11-00_latent_bad.png')

generated = original_img.clone().requires_grad_(True)


In [None]:
model = NST_VGG().eval()
model.to(device)

total_steps = 6000
lr = 1e-3
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=lr)

In [None]:
for step in range(total_steps):
    generated_features = model(generated)
    original_img_features = model(original_img)
    style_features = model(style_img)

    style_loss = original_loss = 0

    for gen_feature, orig_feature, style_feature in zip(
        generated_features, original_img_features, style_features
    ):
        batch_size, channel, height, width = gen_feature.shape
        original_loss += torch.mean((gen_feature - orig_feature)**2)

        G = gen_feature.view(channel, height*width).mm(
            gen_feature.view(channel, height*width).t()
        )

        A = style_feature.view(channel, height*width).mm(
            style_feature.view(channel, height*width).t()
        )

        style_loss += torch.mean((G - A)**2)
    
    total_loss = alpha*original_loss + beta*style_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(total_loss.item())
        save_image(generated, f'generated_{step}.png')

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def visualize_features(feature_list, input_tensor, layer_names=None, topk=8):
    H, W = input_tensor.shape[2:]
    if layer_names is None:
        layer_names = [f"layer_{i}" for i in range(len(feature_list))]

    for idx, feat in enumerate(feature_list):
        layer_name = layer_names[idx]
        fmap = feat[0]

        # Rank channels by mean absolute activation
        scores = fmap.abs().mean(dim=(1,2))
        top_idx = torch.topk(scores, min(topk, fmap.shape[0])).indices

        # Upsample to input size for visualization
        fmap_up = F.interpolate(fmap[top_idx].unsqueeze(1), size=(H, W), mode='bilinear', align_corners=False)

        # Normalize for display
        fmap_up = (fmap_up - fmap_up.min()) / (fmap_up.max() - fmap_up.min() + 1e-8)

        grid = make_grid(fmap_up, nrow=4, normalize=False, pad_value=1.0)
        plt.figure(figsize=(8, 6))
        plt.imshow(grid.squeeze().permute(1,2,0).cpu().numpy(), interpolation='nearest')
        plt.axis('off')
        plt.title(f"Top {len(top_idx)} channels from {layer_name}")
        plt.show()


with torch.no_grad():
    features = model(original_img)

# Visualize
visualize_features(
    feature_list=features, 
    input_tensor=original_img,
    layer_names=model.chosen_features,
    topk=8  # number of feature maps per layer to plot
)