<a href="https://colab.research.google.com/github/fay23-dam/cloud_computing_tugas/blob/main/testing_style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from datasets import load_dataset
import torch_xla.core.xla_model as xm
import matplotlib.pyplot as plt

# Load datasets
realistic_ds = load_dataset("jlbaker361/flickr_humans_5k")["train"]
vangogh_ds = load_dataset("jlbaker361/flickr_humans_5k_vangogh")["train"]

# Transformasi data
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Dataset Class
class ImageDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]["image"]  # Gambar langsung dari kolom 'image'
        if self.transform:
            image = self.transform(image)
        return image

# Load dataset ke DataLoader
realistic_loader = DataLoader(
    ImageDataset(realistic_ds, transform=transform), batch_size=1, shuffle=True
)
vangogh_loader = DataLoader(
    ImageDataset(vangogh_ds, transform=transform), batch_size=1, shuffle=True
)

# Display function
def im_convert(tensor):
    image = tensor.clone().detach().cpu().squeeze()
    image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    image = image.clamp(0, 1).permute(1, 2, 0).numpy()
    return image

# Load model VGG-19
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad = False
vgg = vgg.to(xm.xla_device())

# Feature extraction function
def get_features(image, model, layers=None):
    if layers is None:
        layers = {
            '0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',  # Content layer
            '28': 'conv5_1'
        }
    features = {}
    x = image
    for name, layer in model._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x
    return features

# Gram Matrix
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w)
    gram = torch.mm(tensor, tensor.t())
    return gram

# Load content and style images
content_image = next(iter(realistic_loader)).to(xm.xla_device())
style_image = next(iter(vangogh_loader)).to(xm.xla_device())
target = content_image.clone().requires_grad_(True).to(xm.xla_device())

# Extract features
content_features = get_features(content_image, vgg)
style_features = get_features(style_image, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

# Loss weights
content_weight = 1e5
style_weight = 1e10

# Optimizer
optimizer = optim.Adam([target], lr=0.003)

# Training loop
epochs = 300
for epoch in range(epochs):
    target_features = get_features(target, vgg)

    # Content loss
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)

    # Style loss
    style_loss = 0
    for layer in style_grams:
        target_gram = gram_matrix(target_features[layer])
        style_gram = style_grams[layer]
        layer_style_loss = torch.mean((target_gram - style_gram) ** 2)
        style_loss += layer_style_loss / (target_features[layer].shape[1] ** 2)

    # Total loss
    total_loss = content_weight * content_loss + style_weight * style_loss

    # Backpropagation
    optimizer.zero_grad()
    total_loss.backward()
    xm.optimizer_step(optimizer)  # TPU-compatible optimizer step

    if epoch % 50 == 0:
        print(f"Epoch {epoch}, Total Loss: {total_loss.item()}")

# Display results
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Content Image")
plt.imshow(im_convert(content_image))
plt.subplot(1, 3, 2)
plt.title("Style Image")
plt.imshow(im_convert(style_image))
plt.subplot(1, 3, 3)
plt.title("Stylized Output")
plt.imshow(im_convert(target))
plt.show()


ModuleNotFoundError: No module named 'datasets'