<a href="https://colab.research.google.com/github/cyteena/U-net/blob/main/AdaIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, encoder, decoder):
        super(Net, self).__init__()
        enc_layers = list(encoder.children())
        self.enc_1 = nn.Sequential(*enc_layers[:4])
        self.enc_2 = nn.Sequential(*enc_layers[4:11])
        self.enc_3 = nn.Sequential(*enc_layers[11:18])
        self.enc_4 = nn.Sequential(*enc_layers[18:31])
        self.decoder = decoder
        self.mse_loss = nn.MSELoss()

        #fix the encoder
        for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
    def encode_with_intermediate(self, input):
        results =[input]
        for i in range(1, 5):
            func = getattr(self, "enc_{:d}".format(i))
            results.append(func(results[-1]))
        return results[1: ]

    # extract relu4_1 from input image
    def encode(self, input):
        for i in range(4):
            input = getattr(self, "enc_{:d}".format(i+1))(input)
        return input

    def calc_content_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        return self.mse_loss(input, target)

    def calc_style_loss(self, input, target):
        assert (input.size() == target.size())
        assert (target.requires_grad is False)
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)

    def forward(self, content, style, alpha = 1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)

        t = adaptive_instance_normalization(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])

        return loss_c, loss_s


decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

vgg = nn.Sequential(
    nn.Conv2d(3, 3, (1, 1)),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(3, 64, (3, 3)),
    nn.ReLU(),  # relu1-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),  # relu1-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 128, (3, 3)),
    nn.ReLU(),  # relu2-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),  # relu2-2
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 256, (3, 3)),
    nn.ReLU(),  # relu3-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),  # relu3-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 512, (3, 3)),
    nn.ReLU(),  # relu4-1, this is the last layer used
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu4-4
    nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-1
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-2
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU(),  # relu5-3
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 512, (3, 3)),
    nn.ReLU()  # relu5-4
)


In [None]:
import torch

def calc_mean_std(feat, eps = 1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim = 2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim = 2).view(N, C, 1, 1)
    return feat_mean, feat_std #(N, C, 1, 1)

def adaptive_instance_normalization(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

In [None]:
!pip install datasets
!pip install tensorboardX

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
network = Net(vgg, decoder)

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.utils.data as data
import torch.optim as optim
from PIL import Image
from pathlib import Path
import torchvision.models as models
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torchvision.datasets as datasets
from datasets import load_dataset
import io

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])



# Load the pretrained VGG19 model
vgg = models.vgg19(pretrained=True).features

# Remove the classification layers (if any) to get the feature extractor
vgg = nn.Sequential(*list(vgg.children())[:31])

network = Net(vgg, decoder)
network.train()
network.to(device)

content_tf = transform
style_tf = transform

# 使用cartoon_face作为content dataset
content_dataset = load_dataset("detection-datasets/coco", split='train', streaming=True)

# 使用WikiArt作为style dataset
style_dataset = load_dataset("Artificio/WikiArt", split='train', streaming=True)

def my_collate(batch):
    images = []
    for item in batch:
        try:
            img =item['image'].convert('RGB')
            img = transform(img)
            images.append(img)
        except (IOError, OSError, UnidentifiedImageError):
            continue

    if images:
        images = torch.stack(images)
    return images

# 创建DataLoader
content_loader = data.DataLoader(content_dataset, batch_size=4, collate_fn=my_collate)
style_loader = data.DataLoader(style_dataset, batch_size=4, collate_fn=my_collate)


##
optimizer = optim.Adam(network.decoder.parameters(), lr = 3e-4)

num_epochs = 2

log_dir=Path('/content')
log_dir.mkdir(exist_ok=True, parents=True)
writer = SummaryWriter(log_dir='/content/logs')

num_iterations = 1000  # 设置一个固定的迭代次数

for epoch in range(num_epochs):
    content_iter = iter(content_loader)
    style_iter = iter(style_loader)
    for i in range(num_iterations):
      content_images = next(content_iter).to(device)
      style_images = next(style_iter).to(device)

      if content_images.nelement() == 0 or style_images.nelement() == 0:
        print("Skipping iteration due to empty batch")
        continue
      network.zero_grad()
      loss_c, loss_s = network(content_images, style_images)
      loss = loss_c + loss_s
      loss.backward()
      optimizer.step()

      writer.add_scalar('loss_content', loss_c.item(), epoch * num_iterations + i)
      writer.add_scalar('loss_style', loss_s.item(), epoch * num_iterations + i)
      if (epoch + 1) % num_epochs == 0:
        state_dict = network.decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict, "/content/models/model.pth")

writer.close()


Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
# 使用cartoon_face作为content dataset
content_dataset = load_dataset("huggan/cartoon-faces", split='train', streaming=True)

# 使用WikiArt作为style dataset
style_dataset = load_dataset("Artificio/WikiArt", split='train', streaming=True)

dataset_infos.json:   0%|          | 0.00/648 [00:00<?, ?B/s]

In [None]:
content_iter = iter(content_loader)
style_iter = iter(style_loader)

In [None]:
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)

In [None]:
style_images.shape

torch.Size([4, 3, 256, 256])

In [None]:
content_images.shape

torch.Size([4, 3, 256, 256])

    def forward(self, content, style, alpha = 1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        
        t = adaptive_instance_normalization(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat
        
        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)
        
        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        
        return loss_c, loss_s

In [None]:
network.encode_with_intermediate(style_images)[-1].size()

torch.Size([4, 512, 32, 32])

In [None]:
len(network.encode_with_intermediate(style_images))

4

In [None]:
network.encode(content_images).shape

torch.Size([4, 512, 16, 16])

In [None]:
style_feats = network.encode_with_intermediate(style_images)
content_feat = network.encode(content_images)

In [None]:
adaptive_instance_normalization(content_feat, style_feats[-1]).shape

torch.Size([4, 512, 16, 16])

In [None]:
t = adaptive_instance_normalization(content_feat, style_feats[-1])
t.size()

torch.Size([4, 512, 16, 16])

In [None]:
g_t = network.decoder(t)
g_t.size()

torch.Size([4, 3, 256, 256])

In [None]:
g_t_feats = network.encode_with_intermediate(g_t)
g_t_feats[-1].size()

torch.Size([4, 512, 32, 32])

In [None]:
loss_c = network.calc_content_loss(g_t_feats[-1], t)