In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
import cv2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 640
CLOTH_IMAGE_PATH = '/content/drive/MyDrive/00070_00.jpg'
STYLE_IMAGE_PATH = '/content/drive/MyDrive/art6.jpg'
VGG_WEIGHTS_PATH = '/content/drive/MyDrive/vgg19-dcbb9e9d.pth'
RESULT_SAVE_PATH = '/content/drive/MyDrive/styled_cloth_result3.jpg'

def load_and_resize(path, size=256):
    img = Image.open(path).convert('RGB').resize((size, size), Image.LANCZOS)
    return img

def pil_to_tensor(img):
    return transforms.ToTensor()(img).unsqueeze(0).to(DEVICE)

def tensor_to_pil(tensor):
    t = tensor.clone().detach().cpu().squeeze(0)
    t = (t * 255).clamp(0,255).byte()
    return transforms.ToPILImage()(t)

def get_features(x, model, layers):
    feats = []
    for idx, layer in enumerate(model):
        x = layer(x)
        if str(idx) in layers:
            feats.append(x)
    return feats

def gram_matrix(t):
    b, c, h, w = t.size()
    f = t.view(b * c, h * w)
    return (f @ f.t()) / f.numel()

def resize_mask(mask, target_tensor):
    return F.interpolate(mask, size=target_tensor.shape[2:], mode='nearest')

def total_variation_loss(x):
    return torch.mean(torch.abs(x[:,:,:,1:] - x[:,:,:,:-1])) + torch.mean(torch.abs(x[:,:,1:,:] - x[:,:,:-1,:]))

def blur_mask(mask_tensor, sigma=2.5):
    if mask_tensor.ndim == 4 and mask_tensor.shape[0] == 1 and mask_tensor.shape[1] == 1:
        arr = mask_tensor.squeeze().cpu().numpy()
        arr_blur = cv2.GaussianBlur(arr, (0, 0), sigma)
        blur = torch.from_numpy(arr_blur).float().clamp(0, 1).to(mask_tensor.device)
        return blur.unsqueeze(0).unsqueeze(0)
    else:
        return mask_tensor

cloth_img = load_and_resize(CLOTH_IMAGE_PATH, IMG_SIZE)
style_img = load_and_resize(STYLE_IMAGE_PATH, IMG_SIZE)
content_tensor = pil_to_tensor(cloth_img)
style_tensor = pil_to_tensor(style_img)

yolo = YOLO('yolov8l-seg.pt')
results = yolo(cloth_img, imgsz=IMG_SIZE, conf=0.3, iou=0.5)

mask = None
if results[0].masks is not None and len(results[0].masks.data) > 0:
    mask_np = results[0].masks.data.cpu().numpy()[0]
    mask = torch.from_numpy(mask_np).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
    mask = mask / mask.max()
    mask = blur_mask(mask, sigma=7.0)

if mask is None:
    print("Warning: No object detected by YOLOv8. Applying style transfer to the entire image.")
    mask = torch.ones_like(content_tensor[:, :1, :, :])

vgg = models.vgg19(weights=None)
vgg.load_state_dict(torch.load(VGG_WEIGHTS_PATH, map_location=DEVICE))
vgg_features = vgg.features.to(DEVICE).eval()
content_layers = ["21"]
style_layers = ["5", "10", "19", "28"]

with torch.no_grad():
    content_feat = get_features(content_tensor, vgg_features, content_layers)[0]
    style_feats = get_features(style_tensor, vgg_features, style_layers)
    style_grams = [gram_matrix(f) for f in style_feats]

input_img = content_tensor.clone().requires_grad_(True)
optimizer = optim.Adam([input_img], lr=0.06)
num_steps = 600
alpha, beta, gamma = 1, 3e6, 2e-6

for step in range(num_steps):
    optimizer.zero_grad()
    gen_style_feats = get_features(input_img, vgg_features, style_layers)
    gen_content_feat = get_features(input_img, vgg_features, content_layers)[0]
    c_mask = resize_mask(mask, gen_content_feat)
    content_loss = alpha * F.mse_loss(gen_content_feat * c_mask, content_feat * c_mask)
    style_loss = 0
    for f, g in zip(gen_style_feats, style_grams):
        s_mask = resize_mask(mask, f)
        style_loss += F.mse_loss(gram_matrix(f * s_mask), g)
    style_loss = beta * style_loss
    tv_loss = gamma * total_variation_loss(input_img)
    total_loss = content_loss + style_loss + tv_loss
    total_loss.backward()
    optimizer.step()
    with torch.no_grad():
        input_img.clamp_(0, 1)
    if step % 20 == 0 or step == num_steps-1:
        print(f'Step {step:3d}: Content {content_loss.item():.2f} | Style {style_loss.item():.2f} | TV {tv_loss.item():.5f}')

with torch.no_grad():
    stylized = input_img.detach()
    mask_resz = resize_mask(mask, stylized).expand(-1, 3, -1, -1)
    blended = stylized * mask_resz + content_tensor * (1 - mask_resz)

result_pil = tensor_to_pil(blended)
result_pil.save(RESULT_SAVE_PATH)

plt.figure(figsize=(15, 5))
images = [(cloth_img, "Content Cloth"), (style_img, "Style"), (result_pil, "Stylized Cloth (Enhanced)")]
for idx, (img, title) in enumerate(images):
    plt.subplot(1,3,idx+1)
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')
plt.tight_layout()
plt.show()
print(f"✅ Enhanced stylized cloth saved to {RESULT_SAVE_PATH}")
