In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt

# Phát hiện lỗi gradient
torch.autograd.set_detect_anomaly(True)

# Định nghĩa hàm để load và xử lý ảnh
def load_image(image_path, max_size=400):
    image = Image.open(image_path).convert("RGB")
    size = max(max_size, min(image.size))
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    image = transform(image).unsqueeze(0)
    return image.to(device)

# Hiển thị hình ảnh
def im_convert(tensor):
    image = tensor.cpu().clone().detach().squeeze(0)
    image = transforms.ToPILImage()(image)
    return image

# Gram matrix
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    gram = torch.mm(features, features.t())
    return gram / (c * h * w)

# Lớp tính toán loss
class StyleContentLoss(nn.Module):
    def __init__(self, vgg, style_layers, content_layers):
        super(StyleContentLoss, self).__init__()
        self.vgg = vgg
        self.style_layers = style_layers
        self.content_layers = content_layers

    def forward(self, x, style_features, content_features):
        style_loss = torch.tensor(0.0, device=x.device)
        content_loss = torch.tensor(0.0, device=x.device)
        for name, layer in self.vgg._modules.items():
            x = layer(x.clone())  # Clone tránh in-place modification
            if name in self.style_layers:
                gram_x = gram_matrix(x)
                style_loss += torch.mean((gram_x - style_features[name]) ** 2)
            if name in self.content_layers:
                content_loss += torch.mean((x - content_features[name]) ** 2)
        return style_loss, content_loss

# Cài đặt thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load mô hình VGG đã được huấn luyện trước
vgg = models.vgg19(pretrained=True).features.to(device).eval()

# Xác định các lớp sử dụng cho NST
style_layers = {
    '0': 'conv1_1',
    '5': 'conv2_1',
    '10': 'conv3_1',
    '19': 'conv4_1',
    '28': 'conv5_1',
}
content_layers = {'21': 'conv4_2'}

# Load ảnh style và content
content_image = load_image('./images/tiger3.jpg')
style_image = load_image('./images/flowers.jpg')
target_image = content_image.clone().requires_grad_(True).to(device)

# Trích xuất đặc trưng ảnh style và content
def get_features(image, model, layers):
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

style_features = get_features(style_image, vgg, style_layers)
content_features = get_features(content_image, vgg, content_layers)

# Cấu hình tối ưu hóa và trọng số
optimizer = optim.Adam([target_image], lr=0.003)
style_weight = 1e6
content_weight = 1

# Huấn luyện NST
for step in range(1000):
    optimizer.zero_grad()
    style_loss, content_loss = StyleContentLoss(vgg, style_layers, content_layers)(
        target_image.clone(), style_features, content_features
    )
    total_loss = style_weight * style_loss + content_weight * content_loss
    total_loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f'Step [{step}/1000], Style Loss: {style_loss.item():.4f}, Content Loss: {content_loss.item():.4f}, Total Loss: {total_loss.item():.4f}')

# Hiển thị kết quả
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Content Image")
plt.imshow(im_convert(content_image))

plt.subplot(1, 3, 2)
plt.title("Style Image")
plt.imshow(im_convert(style_image))

plt.subplot(1, 3, 3)
plt.title("Result Image")
plt.imshow(im_convert(target_image))

plt.show()


KeyError: '0'