In [None]:
"""
=============================================================================
NEURAL STYLE TRANSFER - GOOGLE COLAB
Projeto de Vis√£o Computacional - Parte 1 (Quest√£o 2)
=============================================================================

INSTRU√á√ïES:
1. Execute a c√©lula de instala√ß√£o
2. Fa√ßa upload das imagens (content.jpg e style.jpg)
3. Execute a c√©lula principal
4. Aguarde os resultados!

Tempo estimado: 15-20 minutos (com GPU)
"""

# =============================================================================
# C√âLULA 1: INSTALA√á√ÉO E CONFIGURA√á√ÉO
# =============================================================================
print("üîß Instalando bibliotecas necess√°rias...")
!pip install scikit-image -q

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim
import time
import os
from google.colab import files

# Verifica GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n‚úÖ Configura√ß√£o completa!")
print(f"üñ•Ô∏è  Dispositivo: {device}")
if device.type == 'cuda':
    print(f"üöÄ GPU detectada: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è  GPU n√£o detectada. Recomendo ativar em: Runtime > Change runtime type > GPU")

# =============================================================================
# C√âLULA 2: UPLOAD DAS IMAGENS
# =============================================================================
print("\n" + "="*70)
print("üì§ UPLOAD DAS IMAGENS")
print("="*70)

def upload_images():
    """Fun√ß√£o para fazer upload das imagens"""

    print("\n1Ô∏è‚É£  Fa√ßa upload da imagem de CONTE√öDO (sua foto/retrato):")
    print("   Pode ser: selfie, foto de algu√©m, paisagem, etc.")
    uploaded = files.upload()

    if uploaded:
        filename = list(uploaded.keys())[0]
        os.rename(filename, 'content.jpg')
        print(f"   ‚úÖ Conte√∫do salvo: content.jpg")

        # Mostra preview
        img = Image.open('content.jpg')
        plt.figure(figsize=(5, 5))
        plt.imshow(img)
        plt.title('Imagem de Conte√∫do')
        plt.axis('off')
        plt.show()

    print("\n2Ô∏è‚É£  Fa√ßa upload da imagem de ESTILO (obra de arte):")
    print("   Exemplos: Noite Estrelada (Van Gogh), Guernica (Picasso), etc.")
    uploaded = files.upload()

    if uploaded:
        filename = list(uploaded.keys())[0]
        os.rename(filename, 'style.jpg')
        print(f"   ‚úÖ Estilo salvo: style.jpg")

        # Mostra preview
        img = Image.open('style.jpg')
        plt.figure(figsize=(5, 5))
        plt.imshow(img)
        plt.title('Imagem de Estilo')
        plt.axis('off')
        plt.show()

    print("\nüéâ Upload completo! Pronto para processar.\n")

# Executa upload
upload_images()

# =============================================================================
# C√âLULA 3: CLASSE NEURAL STYLE TRANSFER
# =============================================================================
class NeuralStyleTransfer:
    """Classe principal para Neural Style Transfer"""

    def __init__(self, content_img_path, style_img_path, img_size=512):
        self.img_size = img_size
        self.device = device

        self.content_img = self.load_image(content_img_path)
        self.style_img = self.load_image(style_img_path)
        self.model = self.load_vgg19()

        self.content_layers = ['conv4_2']
        self.style_layers = ['conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1']

    def load_image(self, img_path):
        transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        image = Image.open(img_path).convert('RGB')
        image = transform(image).unsqueeze(0)
        return image.to(self.device)

    def load_vgg19(self):
        vgg = models.vgg19(pretrained=True).features.to(self.device).eval()
        for param in vgg.parameters():
            param.requires_grad_(False)
        return vgg

    def get_features(self, image, model, layers):
        features = {}
        x = image

        layer_names = {
            '0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1',
            '19': 'conv4_1', '21': 'conv4_2', '28': 'conv5_1'
        }

        for name, layer in model._modules.items():
            x = layer(x)
            if name in layer_names:
                if layer_names[name] in layers:
                    features[layer_names[name]] = x
        return features

    def gram_matrix(self, tensor):
        batch_size, channels, height, width = tensor.size()
        tensor = tensor.view(channels, height * width)
        gram = torch.mm(tensor, tensor.t())
        return gram / (channels * height * width)

    def content_loss(self, target_features, content_features):
        loss = 0
        for layer in self.content_layers:
            loss += torch.mean((target_features[layer] - content_features[layer]) ** 2)
        return loss

    def style_loss(self, target_features, style_grams):
        loss = 0
        for layer in self.style_layers:
            target_gram = self.gram_matrix(target_features[layer])
            loss += torch.mean((target_gram - style_grams[layer]) ** 2)
        return loss

    def total_variation_loss(self, img):
        batch_size, channels, height, width = img.size()
        tv_h = torch.pow(img[:, :, 1:, :] - img[:, :, :-1, :], 2).sum()
        tv_w = torch.pow(img[:, :, :, 1:] - img[:, :, :, :-1], 2).sum()
        return (tv_h + tv_w) / (batch_size * channels * height * width)

    def denormalize(self, tensor):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(self.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(self.device)
        tensor = tensor * std + mean
        return torch.clamp(tensor, 0, 1)

    def tensor_to_image(self, tensor):
        image = tensor.cpu().clone().detach()
        image = image.squeeze(0)
        image = self.denormalize(image)
        image = image.numpy().transpose(1, 2, 0)
        return image

    def calculate_ssim(self, img1, img2):
        if img1.shape[-1] == 3:
            img1_gray = np.mean(img1, axis=2)
            img2_gray = np.mean(img2, axis=2)
        else:
            img1_gray = img1
            img2_gray = img2
        return ssim(img1_gray, img2_gray, data_range=1.0)

    def run_style_transfer(self, num_steps=300, alpha=1, beta=1e6, gamma=1e-4,
                          init_type='content', lr=0.01, show_every=50):

        print(f"\n{'='*70}")
        print(f"üé® INICIANDO TRANSFER√äNCIA DE ESTILO")
        print(f"{'='*70}")
        print(f"‚öôÔ∏è  Configura√ß√µes:")
        print(f"   ‚Ä¢ Itera√ß√µes: {num_steps}")
        print(f"   ‚Ä¢ Alpha (conte√∫do): {alpha}")
        print(f"   ‚Ä¢ Beta (estilo): {beta}")
        print(f"   ‚Ä¢ Gamma (suaviza√ß√£o): {gamma}")
        print(f"   ‚Ä¢ Inicializa√ß√£o: {init_type}")
        print(f"   ‚Ä¢ Learning rate: {lr}")
        print(f"{'='*70}\n")

        content_features = self.get_features(self.content_img, self.model,
                                            self.content_layers + self.style_layers)
        style_features = self.get_features(self.style_img, self.model,
                                          self.style_layers)

        style_grams = {layer: self.gram_matrix(style_features[layer])
                       for layer in self.style_layers}

        if init_type == 'content':
            target = self.content_img.clone().requires_grad_(True)
        elif init_type == 'style':
            target = self.style_img.clone().requires_grad_(True)
        else:
            target = torch.randn_like(self.content_img).requires_grad_(True)

        optimizer = optim.Adam([target], lr=lr)

        history = {
            'total_loss': [],
            'content_loss': [],
            'style_loss': [],
            'tv_loss': [],
            'images': []
        }

        start_time = time.time()

        for step in range(num_steps):

            def closure():
                optimizer.zero_grad()
                target.data.clamp_(0, 1)

                target_features = self.get_features(target, self.model,
                                                   self.content_layers + self.style_layers)

                c_loss = self.content_loss(target_features, content_features)
                s_loss = self.style_loss(target_features, style_grams)
                tv_loss = self.total_variation_loss(target)

                loss = alpha * c_loss + beta * s_loss + gamma * tv_loss
                loss.backward()

                if step % show_every == 0:
                    history['total_loss'].append(loss.item())
                    history['content_loss'].append(c_loss.item())
                    history['style_loss'].append(s_loss.item())
                    history['tv_loss'].append(tv_loss.item())

                return loss

            optimizer.step(closure)

            if step % show_every == 0:
                elapsed = time.time() - start_time
                progress = (step / num_steps) * 100
                print(f"üìä Itera√ß√£o {step}/{num_steps} ({progress:.1f}%) - "
                      f"Tempo: {elapsed:.1f}s - "
                      f"Loss: {history['total_loss'][-1]:.2f}")

                history['images'].append(self.tensor_to_image(target))

        target.data.clamp_(0, 1)

        total_time = time.time() - start_time
        print(f"\n{'='*70}")
        print(f"‚úÖ TRANSFER√äNCIA CONCLU√çDA!")
        print(f"‚è±Ô∏è  Tempo total: {total_time:.1f}s ({total_time/60:.1f} minutos)")
        print(f"{'='*70}\n")

        return target, history

    def visualize_results(self, target, history, save_path='results.png'):
        fig = plt.figure(figsize=(18, 12))

        plt.subplot(2, 3, 1)
        plt.imshow(self.tensor_to_image(self.content_img))
        plt.title('üì∑ Imagem de Conte√∫do', fontsize=14, fontweight='bold')
        plt.axis('off')

        plt.subplot(2, 3, 2)
        plt.imshow(self.tensor_to_image(self.style_img))
        plt.title('üé® Imagem de Estilo', fontsize=14, fontweight='bold')
        plt.axis('off')

        plt.subplot(2, 3, 3)
        plt.imshow(self.tensor_to_image(target))
        plt.title('‚ú® Resultado Final', fontsize=14, fontweight='bold')
        plt.axis('off')

        steps = np.arange(0, len(history['total_loss'])) * 50

        plt.subplot(2, 3, 4)
        plt.plot(steps, history['content_loss'], 'b-', linewidth=2)
        plt.xlabel('Itera√ß√£o')
        plt.ylabel('Perda')
        plt.title('üìâ Perda de Conte√∫do')
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 3, 5)
        plt.plot(steps, history['style_loss'], 'r-', linewidth=2)
        plt.xlabel('Itera√ß√£o')
        plt.ylabel('Perda')
        plt.title('üìâ Perda de Estilo')
        plt.grid(True, alpha=0.3)

        plt.subplot(2, 3, 6)
        plt.plot(steps, history['total_loss'], 'g-', linewidth=2)
        plt.xlabel('Itera√ß√£o')
        plt.ylabel('Perda')
        plt.title('üìâ Perda Total')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"üíæ Resultados salvos: {save_path}")
        plt.show()

    def visualize_progression(self, history, save_path='progression.png'):
        num_images = len(history['images'])

        if num_images == 0:
            return

        fig, axes = plt.subplots(1, num_images, figsize=(4*num_images, 4))

        if num_images == 1:
            axes = [axes]

        for idx, (ax, img) in enumerate(zip(axes, history['images'])):
            ax.imshow(img)
            ax.set_title(f'Itera√ß√£o {idx * 50}', fontweight='bold')
            ax.axis('off')

        plt.tight_layout()
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"üíæ Progress√£o salva: {save_path}")
        plt.show()

    def calculate_metrics(self, target):
        content_np = self.tensor_to_image(self.content_img)
        target_np = self.tensor_to_image(target)

        ssim_value = self.calculate_ssim(content_np, target_np)

        target_features = self.get_features(target, self.model, self.content_layers)
        content_features = self.get_features(self.content_img, self.model, self.content_layers)

        perceptual_loss = self.content_loss(target_features, content_features).item()

        return {
            'SSIM': ssim_value,
            'Perceptual Loss': perceptual_loss
        }

# =============================================================================
# C√âLULA 4: EXECU√á√ÉO DOS EXPERIMENTOS
# =============================================================================
print("\n" + "="*70)
print("üöÄ NEURAL STYLE TRANSFER - EXECU√á√ÉO")
print("="*70 + "\n")

# Inicializa NST
nst = NeuralStyleTransfer('content.jpg', 'style.jpg', img_size=512)

# EXPERIMENTO 1: Configura√ß√£o Balanceada
print("\n" + "="*70)
print("üß™ EXPERIMENTO 1: Configura√ß√£o Balanceada")
print("="*70)

target1, history1 = nst.run_style_transfer(
    num_steps=100,  # 100 itera√ß√µes - equil√≠brio perfeito!
    alpha=1,
    beta=1e6,
    gamma=1e-4,
    init_type='content',
    lr=0.03,
    show_every=20  # Mostra progresso a cada 20 itera√ß√µes
)

nst.visualize_results(target1, history1, 'resultado_experimento1.png')
nst.visualize_progression(history1, 'progressao_experimento1.png')

metrics1 = nst.calculate_metrics(target1)
print("\nüìä M√©tricas Experimento 1:")
for key, value in metrics1.items():
    print(f"   ‚Ä¢ {key}: {value:.4f}")

# EXPERIMENTOS 2 e 3 DESATIVADOS para ser mais r√°pido
# Descomente se quiser rodar mais experimentos depois

# COMPARA√á√ÉO FINAL
print("\n" + "="*70)
print("üìä M√âTRICAS FINAIS")
print("="*70)

print("\nüß™ Experimento 1 (Configura√ß√£o Balanceada):")
print(f"   Configura√ß√£o: Alpha=1, Beta=1e6")
for key, value in metrics1.items():
    print(f"   ‚Ä¢ {key}: {value:.4f}")

print("\n" + "="*70)
print("‚úÖ PROJETO CONCLU√çDO COM SUCESSO!")
print("="*70)

# DOWNLOAD DOS RESULTADOS
print("\nüì• Baixando resultados...")
files.download('resultado_experimento1.png')
files.download('progressao_experimento1.png')

print("\nüéâ Todos os arquivos foram baixados!")
print("‚úÖ Projeto Neural Style Transfer conclu√≠do!")

üîß Instalando bibliotecas necess√°rias...

‚úÖ Configura√ß√£o completa!
üñ•Ô∏è  Dispositivo: cpu
‚ö†Ô∏è  GPU n√£o detectada. Recomendo ativar em: Runtime > Change runtime type > GPU

üì§ UPLOAD DAS IMAGENS

1Ô∏è‚É£  Fa√ßa upload da imagem de CONTE√öDO (sua foto/retrato):
   Pode ser: selfie, foto de algu√©m, paisagem, etc.


KeyboardInterrupt: 