## Data Preparation

In [1]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.transforms.functional as F
from tqdm import tqdm
from PIL import Image
import shutil
import random
import os
import glob

def style_sampling(base_path, dest_path, n_samples=100, split=(0.8, 0.1, 0.1)):
    """
    L·∫•y m·∫´u ·∫£nh t·ª´ m·ªói style v√† chia ƒë·ªÅu train/valid/test.
    ƒê·∫∑t t√™n file d·∫°ng style_001.jpg, style_002.jpg...
    
    Args:
        base_path (str): th∆∞ m·ª•c ch·ª©a c√°c folder style
        dest_path (str): th∆∞ m·ª•c l∆∞u k·∫øt qu·∫£
        n_samples (int): s·ªë ·∫£nh l·∫•y m·∫´u cho m·ªói style
        split (tuple): t·ªâ l·ªá chia train/valid/test
    """
    random.seed(2025)
    
    subsets = ["train", "valid", "test"]
    for subset in subsets:
        os.makedirs(os.path.join(dest_path, subset), exist_ok=True)
    
    styles = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    total_copied = 0

    for style in tqdm(styles, desc="Sampling styles"):
        style_path = os.path.join(base_path, style)
        # L·ªçc file ·∫£nh
        image_files = [f for f in os.listdir(style_path) 
                       if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]
        if not image_files:
            continue

        # L·∫•y m·∫´u ng·∫´u nhi√™n
        sample_imgs = random.sample(image_files, min(len(image_files), n_samples))
        n = len(sample_imgs)

        # S·ªë l∆∞·ª£ng cho train/valid/test (l√†m tr√≤n xu·ªëng)
        n_train = int(n * split[0])
        n_valid = int(n * split[1])
        n_test = n - n_train - n_valid

        # Chia ·∫£nh
        split_dict = {
            "train": sample_imgs[:n_train],
            "valid": sample_imgs[n_train:n_train+n_valid],
            "test": sample_imgs[n_train+n_valid:]
        }

        # Copy ·∫£nh sang folder t∆∞∆°ng ·ª©ng v·ªõi t√™n style_index.jpg
        for subset, imgs in split_dict.items():
            dest_folder = os.path.join(dest_path, subset)
            for idx, img_name in enumerate(imgs, start=1):
                src = os.path.join(style_path, img_name)
                dst_name = f"{style}_{idx:03d}.jpg"  # style_001.jpg, style_002.jpg...
                dst = os.path.join(dest_folder, dst_name)
                shutil.copy(src, dst)
                total_copied += 1

    print(f"ƒê√£ sao ch√©p {total_copied} ·∫£nh t·ª´ {len(styles)} style v√†o {dest_path}/train, valid, test")


# ------------------- Transform -------------------
class TransformImageNet:
    def __init__(self, target_long=512, min_short=256, crop_size=None, gray_ratio=0.0):
        """
        target_long: c·∫°nh l·ªõn c·ªßa ·∫£nh sau resize
        min_short: n·∫øu c·∫°nh nh·ªè < min_short, s·∫Ω padding
        crop_size: n·∫øu mu·ªën crop, None = kh√¥ng crop
        gray_ratio: x√°c su·∫•t chuy·ªÉn ·∫£nh sang grayscale
        """
        self.target_long = target_long
        self.min_short = min_short
        self.crop_size = crop_size
        self.gray_ratio = gray_ratio
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.485,0.456,0.406],
                                     std=[0.229,0.224,0.225])

    def resize_and_pad(self, img):
        w, h = img.size
        # scale t·ªâ l·ªá theo c·∫°nh l·ªõn
        if w > h:
            new_w = self.target_long
            new_h = int(h * self.target_long / w)
        else:
            new_h = self.target_long
            new_w = int(w * self.target_long / h)
        img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)

        # padding n·∫øu c·∫°nh nh·ªè < min_short
        pad_w = max(0, self.min_short - new_w)
        pad_h = max(0, self.min_short - new_h)
        if pad_w > 0 or pad_h > 0:
            img = F.pad(img, (0,0,pad_w,pad_h), fill=0)
        return img

    def __call__(self, img):
        # Grayscale augmentation
        if random.random() < self.gray_ratio:
            img = img.convert("L").convert("RGB")

        img = self.resize_and_pad(img)

        # RandomCrop n·∫øu mu·ªën
        if self.crop_size:
            img = T.RandomCrop(self.crop_size)(img)

        img = self.to_tensor(img)
        img = self.normalize(img)
        return img

# ------------------- Dataset -------------------
class CustomImageDataset(Dataset):
    def __init__(self, content_folder, style_folder, subset,
                 transform=None, gray_ratio=0.2,
                 valid_ext=('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
        self.content_folder = os.path.join(content_folder, subset)
        self.style_folder = os.path.join(style_folder, subset)

        self.content_files = []
        self.style_files = []

        for ext in valid_ext:
            self.content_files.extend(glob.glob(os.path.join(self.content_folder, f"*{ext}")))
            self.style_files.extend(glob.glob(os.path.join(self.style_folder, f"*{ext}")))

        self.content_files = sorted(self.content_files)
        self.style_files = sorted(self.style_files)

        if len(self.content_files) == 0:
            raise RuntimeError(f"No content images found in {self.content_folder}")
        if len(self.style_files) == 0:
            raise RuntimeError(f"No style images found in {self.style_folder}")

        self.transform = transform
        self.gray_ratio = gray_ratio

    def __len__(self):
        return len(self.content_files)

    def __getitem__(self, idx):
        # Content image
        content_path = self.content_files[idx]
        content_img = Image.open(content_path).convert("RGB")
        
        # Style image (random)
        style_path = random.choice(self.style_files)
        style_img = Image.open(style_path).convert("RGB")
        
        if self.transform:
            content_img = self.transform(content_img)
            style_img = self.transform(style_img)

        return content_img, style_img

# ------------------- DataLoader factory -------------------
def get_dataloaders(content_folder, style_folder,
                    batch_size=8, num_workers=4, 
                    persistent_workers=False, gray_ratio=0.2,
                    target_long=512, min_short=256, crop_size=256):
    """
    T·∫°o DataLoader cho train/valid/test.
    Gi·∫£ s·ª≠ content_folder v√† style_folder ƒë√£ c√≥ subfolder 'train', 'valid', 'test'.
    C√≥ th·ªÉ b·∫≠t tqdm ƒë·ªÉ quan s√°t ti·∫øn tr√¨nh load d·ªØ li·ªáu.
    """
    transform = TransformImageNet(
        target_long=target_long,
        min_short=min_short,
        crop_size=crop_size,
        gray_ratio=gray_ratio
    )

    loaders = {}
    for subset in ["train", "valid", "test"]:
        dataset = CustomImageDataset(
            content_folder,
            style_folder,
            subset=subset,
            transform=transform,
            gray_ratio=gray_ratio,
        )
        shuffle = (subset == "train")

        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=persistent_workers
        )

        loaders[subset] = loader

    return loaders

## AdaIn

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

class VGGEncoder(nn.Module):
    def __init__(self, path_vgg_weights=None, device='cpu'):
        super(VGGEncoder, self).__init__()
        self.device = device  # l∆∞u device ƒë·ªÉ d√πng khi forward
        if path_vgg_weights is None:
            vgg19 = models.vgg19(pretrained=True)
        else:
            vgg19 = models.vgg19()
            vgg19.load_state_dict(torch.load(path_vgg_weights, map_location=device))

        # Ch·ªâ l·∫•y feature t·ªõi conv4_3
        self.encoder_layers = nn.Sequential(*list(vgg19.features.children())[:21])

        # Freeze weights
        for param in self.encoder_layers.parameters():
            param.requires_grad = False

        # Chuy·ªÉn encoder sang device
        self.encoder_layers.to(device)  

    def forward(self, x):
        x = x.to(self.device)
        x = self.encoder_layers(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3)
        )

    def forward(self, x):
        return x + self.block(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, out_channels, 3),
            nn.ReLU(inplace=True),
            ResidualBlock(out_channels)
        )

    def forward(self, x):
        return self.block(x)
    
class Decoder(nn.Module):
    def __init__(self, out_channels=3):
        super().__init__()
        self.decoder_layers = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(512, 256, 3),
            nn.ReLU(inplace=True),

            ConvBlock(256, 256),
            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 128, 3),
            nn.ReLU(inplace=True),

            ConvBlock(128, 128),
            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 64, 3),
            nn.ReLU(inplace=True),

            ConvBlock(64, out_channels)
        )

    def forward(self, x):
        return self.decoder_layers(x)

    
class AdaINet(nn.Module):
    def __init__(self, path_vgg_weights=None, out_channels=3, device='cpu'):
        super(AdaINet, self).__init__()
        self.encoder = VGGEncoder(path_vgg_weights=path_vgg_weights, device=device)
        self.decoder = Decoder(out_channels=out_channels)

    def forward(self, content, style, alpha=1.0):
        # Encode
        content_feat = self.encoder(content)
        style_feat = self.encoder(style)
        # AdaIN
        t = self.adain(content_feat, style_feat)
        t = alpha * t + (1 - alpha) * content_feat
        # Decode
        generated = self.decoder(t)
        return generated, t

    def adain(self, content_feat, style_feat, eps=1e-5):
        c_mean = torch.mean(content_feat, dim=[2, 3], keepdim=True)
        c_std = torch.std(content_feat, dim=[2, 3], keepdim=True) + eps
        s_mean = torch.mean(style_feat, dim=[2, 3], keepdim=True)
        s_std = torch.std(style_feat, dim=[2, 3], keepdim=True) + eps

        normalized_feat = (content_feat - c_mean) / c_std
        stylized_feat = normalized_feat * s_std + s_mean
        return stylized_feat

    
class VGGEncoderMultiLayer(nn.Module):
    def __init__(self, path_vgg_weights=None, device='cpu'):
        super(VGGEncoderMultiLayer, self).__init__()
        self.device = device
        if path_vgg_weights is None:
            vgg19 = models.vgg19(pretrained=True)
        else:
            vgg19 = models.vgg19()
            vgg19.load_state_dict(torch.load(path_vgg_weights, map_location=device))
        
        self.slice1 = nn.Sequential(*list(vgg19.features.children())[:2])   # relu1_1
        self.slice2 = nn.Sequential(*list(vgg19.features.children())[2:7])  # relu2_1
        self.slice3 = nn.Sequential(*list(vgg19.features.children())[7:14]) # relu3_1
        self.slice4 = nn.Sequential(*list(vgg19.features.children())[14:21]) # relu4_1
        # Freeze weights
        for param in self.parameters():
            param.requires_grad = False

        self.to(device)
        
    def forward(self, x):
        x = x.to(self.device)
        relu1_1 = self.slice1(x)
        relu2_1 = self.slice2(relu1_1)
        relu3_1 = self.slice3(relu2_1)
        relu4_1 = self.slice4(relu3_1)
        return {
            'relu1_1': relu1_1,
            'relu2_1': relu2_1,
            'relu3_1': relu3_1,
            'relu4_1': relu4_1
        }
class AdaINLossMultiLayer(nn.Module):
    def __init__(self, encoder, alpha=1.0, beta=0.5, eps=1e-5):
        super().__init__()
        self.encoder = encoder
        self.alpha = alpha
        self.beta = beta
        self.eps = eps
        self.style_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']

    def calc_mean_std(self, feat):
        B, C = feat.size(0), feat.size(1)
        feat_reshaped = feat.view(B, C, -1)
        mean = feat_reshaped.mean(dim=2).view(B, C, 1, 1)
        std = feat_reshaped.std(dim=2).view(B, C, 1, 1) + self.eps
        return mean, std

    def content_loss(self, gen_feat, t):
        return nn.functional.mse_loss(gen_feat, t)

    def style_loss(self, gen_feat, style_feat):
        loss = 0.0
        for layer in self.style_layers:
            g_mean, g_std = self.calc_mean_std(gen_feat[layer])
            s_mean, s_std = self.calc_mean_std(style_feat[layer])
            loss += nn.functional.mse_loss(g_mean, s_mean) + nn.functional.mse_loss(g_std, s_std)
        return loss

    def forward(self, generated, t, style):
        gen_feat = self.encoder(generated)
        style_feat = self.encoder(style)
        # resize t cho kh·ªõp v·ªõi gen_feat
        t_resized = nn.functional.interpolate(t, size=gen_feat['relu4_1'].shape[2:], mode='nearest')
        c_loss = self.content_loss(gen_feat['relu4_1'], t_resized)
        s_loss = self.style_loss(gen_feat, style_feat)

        total = self.alpha * c_loss + self.beta * s_loss
        return total, c_loss, s_loss

## Training

In [5]:
import os
import torch
from tqdm import tqdm
from torchvision.utils import save_image

def train_model(
    train_loader, test_loader,
    model, criterion, optimizer,
    device, num_epochs,
    save_dir="./checkpoints"
):
    os.makedirs(save_dir, exist_ok=True)

    # ====================== LOAD CHECKPOINT ======================
    checkpoint_path = os.path.join(save_dir, "latest_checkpoint.pth")
    start_epoch = 0
    best_val_loss = float("inf")

    if os.path.exists(checkpoint_path):
        print(f"üîÑ Loading checkpoint from {checkpoint_path}")
        ckpt = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optimizer"])
        start_epoch = ckpt["epoch"] + 1
        best_val_loss = ckpt["best_val_loss"]
        print(f"‚û° Continue from epoch {start_epoch}, best_val_loss={best_val_loss:.4f}")

    # ====================== TRAIN LOOP ======================
    for epoch in range(start_epoch + 1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch} Training")

        for content, style in pbar:
            content, style = content.to(device), style.to(device)

            optimizer.zero_grad()
            generated, t = model(content, style)
            loss, c_loss, s_loss = criterion(generated, t, style)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.set_postfix({
                "Loss": f"{loss.item():.4f}",
                "Content": f"{c_loss.item():.4f}",
                "Style": f"{s_loss.item():.4f}",
            })

        avg_train_loss = running_loss / len(train_loader)

        # ====================== VALIDATION ======================
        model.eval()
        val_loss = 0.0
        sample_saved = False

        with torch.no_grad():
            for batch_idx, (content, style) in enumerate(test_loader):
                content, style = content.to(device), style.to(device)
                generated, target_feat = model(content, style)
                loss, _, _ = criterion(generated, target_feat, style)
                val_loss += loss.item()

                if not sample_saved:
                    img_path = os.path.join(save_dir, f"epoch_{epoch}_sample.png")
                    save_image(generated.clamp(0,1), img_path)
                    sample_saved = True

        avg_val_loss = val_loss / len(test_loader)
        print(f"‚úÖ Epoch {epoch} | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")

        # ====================== SAVE BEST MODEL ======================
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_path = os.path.join(save_dir, "best_model.pth")
            torch.save(model.state_dict(), best_path)
            print(f"üèÜ Best model updated! Saved: {best_path}")

        # ====================== SAVE CHECKPOINT (Resume) ======================
        torch.save({
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "best_val_loss": best_val_loss,
        }, checkpoint_path)

    print("üéØ Training Completed!")
    return best_val_loss

In [6]:
import torchvision.transforms.functional as F
loader = get_dataloaders(content_folder="../data/coco2017",
                         style_folder="../data/wikiart_sampled", num_workers=64, batch_size=16, persistent_workers=False)
path_vgg_weights = "../models/vgg19.pth"
device = ("cuda:7" if torch.cuda.is_available() else "cpu")
criterion = AdaINLossMultiLayer(encoder=VGGEncoderMultiLayer(path_vgg_weights=path_vgg_weights, device=device), alpha=1.0, beta=10)
model = AdaINet(out_channels=3, path_vgg_weights=path_vgg_weights, device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4*0.5)
best_val_loss = train_model(
    train_loader=loader["train"],
    test_loader=loader["valid"],
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=50,
    save_dir="checkpoints_adaln"
)


Epoch 1 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:36<00:00,  2.76it/s, Loss=27.0191, Content=12.3609, Style=1.4658]  


‚úÖ Epoch 1 | Train: 42.3589 | Val: 27.9994
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 2 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:55<00:00,  2.81it/s, Loss=21.6986, Content=9.8964, Style=1.1802] 


‚úÖ Epoch 2 | Train: 26.9479 | Val: 23.7101
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 3 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:30<00:00,  2.77it/s, Loss=20.3479, Content=10.5261, Style=0.9822]


‚úÖ Epoch 3 | Train: 23.7721 | Val: 22.0815
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 4 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:41<00:00,  2.82it/s, Loss=24.7104, Content=12.1015, Style=1.2609]


‚úÖ Epoch 4 | Train: 22.0859 | Val: 20.5901
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 5 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:30<00:00,  2.77it/s, Loss=23.9760, Content=11.9395, Style=1.2036]


‚úÖ Epoch 5 | Train: 20.9490 | Val: 19.6499
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 6 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:57<00:00,  2.94it/s, Loss=16.4754, Content=8.1747, Style=0.8301] 


‚úÖ Epoch 6 | Train: 20.1510 | Val: 19.2875
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 7 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:07<00:00,  2.93it/s, Loss=25.0993, Content=12.3067, Style=1.2793]


‚úÖ Epoch 7 | Train: 19.5458 | Val: 19.0018
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 8 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:35<00:00,  2.89it/s, Loss=19.4890, Content=9.5534, Style=0.9936] 


‚úÖ Epoch 8 | Train: 18.9838 | Val: 18.2381
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 9 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:43<00:00,  2.95it/s, Loss=17.0811, Content=8.9079, Style=0.8173] 


‚úÖ Epoch 9 | Train: 18.6621 | Val: 18.2244
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 10 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:47<00:00,  2.95it/s, Loss=20.6248, Content=10.2450, Style=1.0380]


‚úÖ Epoch 10 | Train: 18.3973 | Val: 17.8170
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 11 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:52<00:00,  2.94it/s, Loss=17.3373, Content=8.8371, Style=0.8500] 


‚úÖ Epoch 11 | Train: 18.0825 | Val: 17.6929
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 12 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:14<00:00,  2.92it/s, Loss=22.4544, Content=11.4177, Style=1.1037]


‚úÖ Epoch 12 | Train: 17.8827 | Val: 17.3491
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 13 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:20<00:00,  2.98it/s, Loss=14.9834, Content=7.8763, Style=0.7107] 


‚úÖ Epoch 13 | Train: 17.6824 | Val: 17.3186
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 14 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:32<00:00,  2.97it/s, Loss=14.9529, Content=7.8193, Style=0.7134] 


‚úÖ Epoch 14 | Train: 17.5745 | Val: 17.7008


Epoch 15 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:14<00:00,  2.92it/s, Loss=16.9495, Content=8.7697, Style=0.8180] 


‚úÖ Epoch 15 | Train: 17.3898 | Val: 17.1051
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 16 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:58<00:00,  2.80it/s, Loss=23.9362, Content=11.7325, Style=1.2204]


‚úÖ Epoch 16 | Train: 17.2330 | Val: 16.8057
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 17 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:26<00:00,  2.84it/s, Loss=19.5505, Content=10.1435, Style=0.9407]


‚úÖ Epoch 17 | Train: 17.0967 | Val: 16.9358


Epoch 18 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:21<00:00,  2.84it/s, Loss=12.9779, Content=6.7530, Style=0.6225] 


‚úÖ Epoch 18 | Train: 17.0519 | Val: 16.6971
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 19 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:07<00:00,  2.79it/s, Loss=14.4173, Content=6.5595, Style=0.7858] 


‚úÖ Epoch 19 | Train: 16.8844 | Val: 16.8819


Epoch 20 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:45<00:00,  2.88it/s, Loss=16.9671, Content=8.4809, Style=0.8486] 


‚úÖ Epoch 20 | Train: 16.8295 | Val: 16.7640


Epoch 21 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:03<00:00,  2.93it/s, Loss=14.8463, Content=7.9576, Style=0.6889] 


‚úÖ Epoch 21 | Train: 16.6649 | Val: 16.2724
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 22 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:53<00:00,  2.87it/s, Loss=14.6827, Content=8.0460, Style=0.6637] 


‚úÖ Epoch 22 | Train: 16.5924 | Val: 16.2699
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 23 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:51<00:00,  2.87it/s, Loss=10.5540, Content=5.4513, Style=0.5103] 


‚úÖ Epoch 23 | Train: 16.6026 | Val: 16.3367


Epoch 24 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:21<00:00,  2.98it/s, Loss=17.0318, Content=9.0178, Style=0.8014]  


‚úÖ Epoch 24 | Train: 16.4983 | Val: 16.2744


Epoch 25 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:31<00:00,  2.90it/s, Loss=16.7971, Content=8.7661, Style=0.8031] 


‚úÖ Epoch 25 | Train: 16.3984 | Val: 16.1331
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 26 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:29<00:00,  2.90it/s, Loss=22.0356, Content=9.5426, Style=1.2493] 


‚úÖ Epoch 26 | Train: 16.3901 | Val: 15.9972
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 27 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:52<00:00,  2.87it/s, Loss=14.8152, Content=7.2599, Style=0.7555] 


‚úÖ Epoch 27 | Train: 16.3241 | Val: 16.0570


Epoch 28 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:32<00:00,  2.83it/s, Loss=15.1826, Content=8.0524, Style=0.7130] 


‚úÖ Epoch 28 | Train: 16.2939 | Val: 16.0063


Epoch 29 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:47<00:00,  2.88it/s, Loss=16.3475, Content=8.6288, Style=0.7719] 


‚úÖ Epoch 29 | Train: 16.1992 | Val: 15.6524
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 30 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:08<00:00,  2.92it/s, Loss=16.3666, Content=8.2768, Style=0.8090] 


‚úÖ Epoch 30 | Train: 16.1050 | Val: 15.9236


Epoch 31 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:31<00:00,  2.90it/s, Loss=16.6719, Content=8.1018, Style=0.8570] 


‚úÖ Epoch 31 | Train: 16.0638 | Val: 15.6007
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 32 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:06<00:00,  2.93it/s, Loss=13.1379, Content=6.2882, Style=0.6850] 


‚úÖ Epoch 32 | Train: 16.0770 | Val: 15.7848


Epoch 33 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:28<00:00,  2.90it/s, Loss=13.8813, Content=7.4583, Style=0.6423] 


‚úÖ Epoch 33 | Train: 15.9968 | Val: 15.6326


Epoch 34 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:19<00:00,  2.91it/s, Loss=14.7660, Content=7.6256, Style=0.7140] 


‚úÖ Epoch 34 | Train: 15.9689 | Val: 15.9310


Epoch 35 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:52<00:00,  2.81it/s, Loss=13.4401, Content=6.9956, Style=0.6444] 


‚úÖ Epoch 35 | Train: 15.9501 | Val: 15.5572
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 36 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:34<00:00,  2.76it/s, Loss=12.8020, Content=6.4856, Style=0.6316] 


‚úÖ Epoch 36 | Train: 15.8927 | Val: 16.0707


Epoch 37 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:34<00:00,  2.76it/s, Loss=11.8183, Content=6.1942, Style=0.5624] 


‚úÖ Epoch 37 | Train: 15.8000 | Val: 15.8556


Epoch 38 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [44:01<00:00,  2.80it/s, Loss=14.2297, Content=7.1689, Style=0.7061] 


‚úÖ Epoch 38 | Train: 15.7980 | Val: 15.3429
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 39 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:31<00:00,  2.90it/s, Loss=16.1509, Content=8.5662, Style=0.7585] 


‚úÖ Epoch 39 | Train: 15.7719 | Val: 15.5238


Epoch 40 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:17<00:00,  2.98it/s, Loss=18.9347, Content=10.0059, Style=0.8929]


‚úÖ Epoch 40 | Train: 15.7473 | Val: 15.4320


Epoch 41 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:48<00:00,  2.95it/s, Loss=13.1887, Content=6.8703, Style=0.6318] 


‚úÖ Epoch 41 | Train: 15.7579 | Val: 15.0980
üèÜ Best model updated! Saved: checkpoints_adaln/best_model.pth


Epoch 42 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:22<00:00,  2.91it/s, Loss=26.2230, Content=9.6620, Style=1.6561] 


‚úÖ Epoch 42 | Train: 15.6658 | Val: 15.5563


Epoch 43 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:08<00:00,  2.92it/s, Loss=11.7257, Content=5.6413, Style=0.6084] 


‚úÖ Epoch 43 | Train: 15.6601 | Val: 15.6823


Epoch 44 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:20<00:00,  2.91it/s, Loss=17.4502, Content=9.5318, Style=0.7918] 


‚úÖ Epoch 44 | Train: 15.6549 | Val: 15.5584


Epoch 45 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:41<00:00,  2.89it/s, Loss=17.3954, Content=9.0403, Style=0.8355] 


‚úÖ Epoch 45 | Train: 15.5627 | Val: 15.6663


Epoch 46 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:14<00:00,  2.99it/s, Loss=15.4773, Content=7.7898, Style=0.7687] 


‚úÖ Epoch 46 | Train: 15.6199 | Val: 15.3248


Epoch 47 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:30<00:00,  2.97it/s, Loss=14.7440, Content=7.3743, Style=0.7370] 


‚úÖ Epoch 47 | Train: 15.5511 | Val: 15.4204


Epoch 48 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [41:00<00:00,  3.00it/s, Loss=14.0261, Content=6.9587, Style=0.7067] 


‚úÖ Epoch 48 | Train: 15.5232 | Val: 15.2205


Epoch 49 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [42:27<00:00,  2.90it/s, Loss=15.5791, Content=7.7439, Style=0.7835] 


‚úÖ Epoch 49 | Train: 15.4413 | Val: 15.2572


Epoch 50 Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7393/7393 [43:14<00:00,  2.85it/s, Loss=11.4392, Content=6.2939, Style=0.5145] 


‚úÖ Epoch 50 | Train: 15.5040 | Val: 15.1806
üéØ Training Completed!


## Testing

In [7]:
import os
from torchvision.utils import save_image
from tqdm import tqdm
import cv2
import numpy as np
import os, glob, random
import torch
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

def load_image(path, device):
    """Load 1 ·∫£nh, resize v√† chuy·ªÉn th√†nh tensor [1,3,H,W]"""
    img = Image.open(path).convert("RGB")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    img = transform(img).unsqueeze(0).to(device)
    return img

def unnormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406], device=tensor.device).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=tensor.device).view(1, 3, 1, 1)
    return tensor * std + mean

def test_model(model, content_dir, style_dir, device, checkpoint_path, save_dir="./test_results"):
    os.makedirs(save_dir, exist_ok=True)
    content_paths = sorted(glob.glob(os.path.join(content_dir, "*.jpg")))
    style_paths = sorted(glob.glob(os.path.join(style_dir, "*.jpg")))
    content_paths = content_paths[:20]

    # Load checkpoint
    print(f"üîÑ Loading model weights from {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt)
    model.to(device)
    model.eval()

    with torch.no_grad():
        for ci, content_path in enumerate(tqdm(content_paths, desc="Testing content")):
            # üëâ ch·ªçn ng·∫´u nhi√™n 1 style cho m·ªói content
            style_path = random.choice(style_paths)

            # Load ·∫£nh
            content = load_image(content_path, device)
            style = load_image(style_path, device)

            # Save k·∫øt qu·∫£
            pair_dir = os.path.join(save_dir, f"pair_{ci + 1}")
            os.makedirs(pair_dir, exist_ok=True)

            save_image(unnormalize(content).clamp(0, 1),
                       os.path.join(pair_dir, "content.jpg"))
            save_image(unnormalize(style).clamp(0, 1),
                       os.path.join(pair_dir, "style.jpg"))
            for alpha in [0.8, 1.0, 1.2]:
                generated, _ = model(content, style, alpha=alpha)
                
                _, _, h, w = content.shape
                # resized = F.interpolate(generated, size=(h, w), mode='bilinear', align_corners=False)
                out = unnormalize(generated)
                # Save
                out_path = os.path.join(pair_dir, f"result_alpha_{alpha:.1f}.jpg")
                save_image(out.clamp(0,1), out_path)

    print(f"‚úÖ Testing completed! Results saved in {save_dir}")


In [8]:
path_vgg_weights = "../models/vgg19.pth"
device = ("cuda:2" if torch.cuda.is_available() else "cpu")
model = AdaINet(out_channels=3, path_vgg_weights=path_vgg_weights, device=device).to(device)
test_model(
    model=model,
    content_dir="../data/coco2017/test",
    style_dir="../data/wikiart_sampled/test",
    device=device,
    checkpoint_path="./checkpoints_adaln/best_model.pth",
    save_dir="../results/adain"
)

üîÑ Loading model weights from ./checkpoints_adaln/best_model.pth


Testing content: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:07<00:00,  2.57it/s]

‚úÖ Testing completed! Results saved in ../results/adain



