In [96]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F

In [97]:
import random
import numpy as np

def set_random_seeds(seed_value=42):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    np.random.seed(seed_value)  # Numpy module.
    random.seed(seed_value)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seeds()

In [98]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = root_dir
        self.domain = domain
        self.transform = transform
        self.images, self.labels = self._load_images_labels()

    def _load_images_labels(self):
        image_paths = []
        labels = []
        domain_dir = os.path.join(self.root_dir, self.domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

        return image_paths, labels

    def __len__(self):
        return len(self.images)  # Return the number of images

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


def get_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    ])


def get_combined_dataloader(root_dir, domains, batch_size=32):
    datasets = [
        PACSDataset(root_dir, domain, transform=get_transform()) for domain in domains
    ]
    combined_dataset = torch.utils.data.ConcatDataset(datasets)
    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)


def get_dataloader(root_dir, domain, batch_size=32):
    dataset = PACSDataset(root_dir, domain, transform=get_transform())
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [99]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        # Sử dụng EfficientNet-B1
        self.efficientnet = models.efficientnet_b1(pretrained=True)

        # Freeze EfficientNet layers
        for param in self.efficientnet.parameters():
            param.requires_grad = False

        # Lấy số features từ lớp cuối cùng của EfficientNet-B1
        in_features = self.efficientnet.classifier[1].in_features

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(in_features, in_features // 16),
            nn.ReLU(),
            nn.Linear(in_features // 16, in_features),
            nn.Sigmoid(),
        )

        # Mean (mu) and log-variance (logvar) layers
        self.fc_mu = nn.Linear(in_features, latent_dim)
        self.fc_logvar = nn.Linear(in_features, latent_dim)

    def forward(self, x):
        # Pass input through EfficientNet feature extractor
        features = self.efficientnet.features(x)
        x = self.efficientnet.avgpool(features)
        x = torch.flatten(x, 1)

        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights

        # Compute mu and logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = F.relu(out)
        return out


class Decoder(nn.Module):
    def __init__(self, latent_dim, num_domains):
        super(Decoder, self).__init__()

        self.domain_embedding = nn.Embedding(num_domains, latent_dim)

        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)

        self.decoder = nn.Sequential(
            ResidualBlock(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64, 64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Thêm Tanh để đảm bảo đầu ra trong khoảng [-1, 1]
        )

        # Attention mechanism
        self.attention = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1), nn.Sigmoid())

    def forward(self, z, domain_label):
        domain_embed = self.domain_embedding(domain_label)
        z = z + domain_embed

        x = self.fc(z)
        x = x.view(-1, 512, 7, 7)
        x = self.decoder(x)

        # Apply attention
        attention_map = self.attention(x)
        x = x * attention_map

        return x


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, z):
        z = self.dropout(z)
        return self.fc(z)

In [100]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


class LabelSmoothingLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction="mean", smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    def k_one_hot(self, targets: torch.Tensor, n_classes: int, smoothing=0.0):
        with torch.no_grad():
            targets = (
                torch.empty(size=(targets.size(0), n_classes), device=targets.device)
                .fill_(smoothing / (n_classes - 1))
                .scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing)
            )
        return targets

    def reduce_loss(self, loss):
        return (
            loss.mean()
            if self.reduction == "mean"
            else loss.sum() if self.reduction == "sum" else loss
        )

    def forward(self, inputs, targets):
        assert 0 <= self.smoothing < 1

        targets = self.k_one_hot(targets, inputs.size(-1), self.smoothing)
        log_preds = F.log_softmax(inputs, -1)

        if self.weight is not None:
            log_preds = log_preds * self.weight.unsqueeze(0)

        return self.reduce_loss(-(targets * log_preds).sum(dim=-1))


def compute_loss(
    reconstructed_imgs_list,
    original_imgs,
    mu,
    logvar,
    predicted_labels,
    true_labels,
    clf_loss_fn,
    epoch,
    total_epochs,
    alpha=1.0,
    beta=1.0,
    gamma=1.0,
):
    recon_loss = sum(
        F.mse_loss(recon, original_imgs, reduction="sum")
        for recon in reconstructed_imgs_list
    )

    # KL Divergence annealing
    annealing_factor = min(1.0, epoch / (total_epochs * 0.2))
    kld_loss = (
        -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) * annealing_factor
    )

    clf_loss = clf_loss_fn(predicted_labels, true_labels)

    total_loss = alpha * recon_loss + beta * clf_loss + gamma * kld_loss
    return total_loss, recon_loss.item(), clf_loss.item(), kld_loss.item()

In [101]:
def mixup_data(x, y, alpha=1.0, device="cuda"):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    dataloader,
    optimizer,
    scheduler,
    num_epochs=100,
    device="cuda",
    patience=10,
):
    clf_loss_fn = LabelSmoothingLoss(smoothing=0.1)
    domain_to_idx = {domain: idx for idx, domain in enumerate(domains)}

    best_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        for inputs, labels in tqdm(dataloader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Apply mixup
            inputs, labels_a, labels_b, lam = mixup_data(
                inputs, labels, alpha=0.2, device=device
            )

            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)

            reconstructed_imgs_list = []
            for domain in domains:
                domain_label = torch.tensor(
                    [domain_to_idx[domain]] * inputs.size(0), device=device
                )
                reconstructed_imgs = decoders[domain](z, domain_label)
                reconstructed_imgs_list.append(reconstructed_imgs)

            predicted_labels = classifier(z)

            loss, recon_loss, clf_loss, kld_loss = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                lambda pred, target: mixup_criterion(
                    clf_loss_fn, pred, labels_a, labels_b, lam
                ),
                epoch,
                num_epochs,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

        scheduler.step(avg_loss)

        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break

In [102]:
def evaluate_model(encoder, classifier, decoder, dataloader, device, domain_label):
    encoder.eval()
    classifier.eval()
    decoder.eval()
    total_clf_loss = 0.0
    total_recon_loss = 0.0
    correct = 0
    total = 0
    clf_loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            
            # Chuyển domain_label thành tensor và lặp lại cho mỗi mẫu trong batch
            domain_labels = torch.full((batch_size,), domain_label, device=device)
            reconstructed_imgs = decoder(z, domain_labels)

            # Classification accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Losses
            clf_loss = clf_loss_fn(outputs, labels)
            recon_loss = F.mse_loss(reconstructed_imgs, inputs, reduction="sum")
            total_clf_loss += clf_loss.item()
            total_recon_loss += recon_loss.item()

    accuracy = correct / total
    avg_clf_loss = total_clf_loss / len(dataloader.dataset)
    avg_recon_loss = total_recon_loss / len(dataloader.dataset)
    return accuracy, avg_clf_loss, avg_recon_loss

In [103]:
def evaluate_on_all_domains(encoder, classifier, decoders, domains, data_path, device):
    print("\nFinal Evaluation on All Domains\n")
    for domain in domains:
        eval_dataloader = get_dataloader(data_path, domain)
        domain_label = domains.index(domain)
        accuracy, avg_clf_loss, avg_recon_loss = evaluate_model(
            encoder,
            classifier,
            decoders[domain],
            eval_dataloader,
            device,
            domain_label,
        )
        print(f"Domain: {domain}")
        print(f"  Accuracy: {accuracy * 100:.2f}%")
        print(f"  Avg Classification Loss: {avg_clf_loss:.4f}")
        print(f"  Avg Reconstruction Loss: {avg_recon_loss:.4f}\n")

In [104]:
# Main training and evaluation script
DATA_PATH = (
    "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
)
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Initialize models
encoder = Encoder(latent_dim).to(device)
decoders = {domain: Decoder(latent_dim, len(domains)).to(device) for domain in domains}
classifier = Classifier(latent_dim, num_classes).to(device)

# Optimizer and Scheduler
params = list(encoder.parameters()) + list(classifier.parameters())
for decoder in decoders.values():
    params += list(decoder.parameters())
optimizer = optim.AdamW(params, lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Create a combined DataLoader for all domains
combined_dataloader = get_combined_dataloader(DATA_PATH, domains)

# Train model using progressive domain training
train_model_progressive(
  encoder,
  decoders,
  classifier,
  domains,
  combined_dataloader,
  optimizer,
  scheduler,
  num_epochs=100,
  device=device,
  patience=10
)

# Final evaluation on all domains
evaluate_on_all_domains(encoder, classifier, decoders, domains, DATA_PATH, device)

Using device: cuda
Epoch 1/100


Training: 100%|██████████| 313/313 [02:52<00:00,  1.82it/s]


Epoch 1, Loss: 44819868.0096
Epoch 2/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 2, Loss: 37957506.8179
Epoch 3/100


Training: 100%|██████████| 313/313 [02:12<00:00,  2.35it/s]


Epoch 3, Loss: 35689299.5048
Epoch 4/100


Training: 100%|██████████| 313/313 [02:14<00:00,  2.32it/s]


Epoch 4, Loss: 34525172.2556
Epoch 5/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.40it/s]


Epoch 5, Loss: 33574778.1997
Epoch 6/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.43it/s]


Epoch 6, Loss: 32987901.3642
Epoch 7/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 7, Loss: 31990916.3882
Epoch 8/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 8, Loss: 31945094.9481
Epoch 9/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.39it/s]


Epoch 9, Loss: 31503885.2748
Epoch 10/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.41it/s]


Epoch 10, Loss: 31011889.5559
Epoch 11/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 11, Loss: 30379192.5911
Epoch 12/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 12, Loss: 30133699.9185
Epoch 13/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 13, Loss: 29827970.6326
Epoch 14/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 14, Loss: 29001233.0703
Epoch 15/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 15, Loss: 28628659.2500
Epoch 16/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 16, Loss: 28149144.6597
Epoch 17/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 17, Loss: 28263901.5575
Epoch 18/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.45it/s]


Epoch 18, Loss: 28417963.1246
Epoch 19/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 19, Loss: 28203282.2604
Epoch 20/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 20, Loss: 27136287.1917
Epoch 21/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 21, Loss: 27500083.1118
Epoch 22/100


Training: 100%|██████████| 313/313 [02:06<00:00,  2.48it/s]


Epoch 22, Loss: 27069740.1118
Epoch 23/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 23, Loss: 26895216.2684
Epoch 24/100


Training: 100%|██████████| 313/313 [02:05<00:00,  2.49it/s]


Epoch 24, Loss: 26856977.1526
Epoch 25/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 25, Loss: 27038827.5623
Epoch 26/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.40it/s]


Epoch 26, Loss: 26348817.5367
Epoch 27/100


Training: 100%|██████████| 313/313 [02:06<00:00,  2.47it/s]


Epoch 27, Loss: 26180900.3411
Epoch 28/100


Training: 100%|██████████| 313/313 [02:11<00:00,  2.39it/s]


Epoch 28, Loss: 26306832.1510
Epoch 29/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.40it/s]


Epoch 29, Loss: 26217685.0527
Epoch 30/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 30, Loss: 26117391.4345
Epoch 31/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 31, Loss: 26160169.5631
Epoch 32/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 32, Loss: 26028985.7029
Epoch 33/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 33, Loss: 26249249.4010
Epoch 34/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 34, Loss: 25911353.5911
Epoch 35/100


Training: 100%|██████████| 313/313 [02:11<00:00,  2.38it/s]


Epoch 35, Loss: 25704476.5958
Epoch 36/100


Training: 100%|██████████| 313/313 [02:12<00:00,  2.36it/s]


Epoch 36, Loss: 26234883.5823
Epoch 37/100


Training: 100%|██████████| 313/313 [02:12<00:00,  2.36it/s]


Epoch 37, Loss: 25705016.1422
Epoch 38/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.43it/s]


Epoch 38, Loss: 25820578.7812
Epoch 39/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 39, Loss: 25141059.5224
Epoch 40/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 40, Loss: 25394478.7348
Epoch 41/100


Training: 100%|██████████| 313/313 [02:11<00:00,  2.38it/s]


Epoch 41, Loss: 25835922.5990
Epoch 42/100


Training: 100%|██████████| 313/313 [02:11<00:00,  2.38it/s]


Epoch 42, Loss: 25795979.2636
Epoch 43/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.40it/s]


Epoch 43, Loss: 25757505.9856
Epoch 44/100


Training: 100%|██████████| 313/313 [02:10<00:00,  2.40it/s]


Epoch 44, Loss: 25369754.3946
Epoch 45/100


Training: 100%|██████████| 313/313 [02:08<00:00,  2.44it/s]


Epoch 45, Loss: 25523892.6310
Epoch 46/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 46, Loss: 25437369.9665
Epoch 47/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 47, Loss: 25240194.3706
Epoch 48/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.42it/s]


Epoch 48, Loss: 25605277.6046
Epoch 49/100


Training: 100%|██████████| 313/313 [02:09<00:00,  2.41it/s]


Epoch 49, Loss: 25951982.9441
Early stopping triggered after 49 epochs

Final Evaluation on All Domains



Evaluating: 100%|██████████| 64/64 [00:23<00:00,  2.68it/s]


Domain: art_painting
  Accuracy: 41.60%
  Avg Classification Loss: 0.0512
  Avg Reconstruction Loss: 182955.4641



Evaluating: 100%|██████████| 74/74 [00:23<00:00,  3.17it/s]


Domain: cartoon
  Accuracy: 45.35%
  Avg Classification Loss: 0.0502
  Avg Reconstruction Loss: 245128.7253



Evaluating: 100%|██████████| 53/53 [00:18<00:00,  2.85it/s]


Domain: photo
  Accuracy: 66.41%
  Avg Classification Loss: 0.0368
  Avg Reconstruction Loss: 177599.3864



Evaluating: 100%|██████████| 123/123 [00:34<00:00,  3.56it/s]

Domain: sketch
  Accuracy: 39.96%
  Avg Classification Loss: 0.0483
  Avg Reconstruction Loss: 258051.7423




