In [48]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss

In [49]:
import random
import numpy as np

def set_random_seeds(seed_value=42):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    np.random.seed(seed_value)  # Numpy module.
    random.seed(seed_value)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seeds()

In [50]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = root_dir
        self.domain = domain
        self.transform = transform
        self.images, self.labels = self._load_images_labels()

    def _load_images_labels(self):
        image_paths = []
        labels = []
        domain_dir = os.path.join(self.root_dir, self.domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

        return image_paths, labels

    def __len__(self):
        return len(self.images)  # Return the number of images

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


def get_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    ])


def get_dataloaders(root_dir, domain, batch_size=32):
    train_paths, train_labels, val_paths, val_labels, test_paths, test_labels = (
        split_dataset(root_dir, domain)
    )

    train_dataset = PACSDatasetFromPaths(
        train_paths, train_labels, transform=get_transform()
    )
    val_dataset = PACSDatasetFromPaths(val_paths, val_labels, transform=get_transform())
    test_dataset = PACSDatasetFromPaths(
        test_paths, test_labels, transform=get_transform()
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader


def get_mixed_dataloader(root_dir, domains, batch_size=32):
    datasets = [
        PACSDataset(root_dir, domain, transform=get_transform()) for domain in domains
    ]
    combined_dataset = torch.utils.data.ConcatDataset(datasets)
    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

In [51]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        # Sử dụng EfficientNet-B1
        self.efficientnet = models.efficientnet_b1(pretrained=True)

        # Freeze EfficientNet layers
        for param in self.efficientnet.parameters():
            param.requires_grad = False

        # Lấy số features từ lớp cuối cùng của EfficientNet-B1
        in_features = self.efficientnet.classifier[1].in_features

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(in_features, in_features // 16),
            nn.ReLU(),
            nn.Linear(in_features // 16, in_features),
            nn.Sigmoid(),
        )

        # Mean (mu) and log-variance (logvar) layers
        self.fc_mu = nn.Linear(in_features, latent_dim)
        self.fc_logvar = nn.Linear(in_features, latent_dim)

    def forward(self, x):
        # Pass input through EfficientNet feature extractor
        features = self.efficientnet.features(x)
        x = self.efficientnet.avgpool(features)
        x = torch.flatten(x, 1)

        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights

        # Compute mu and logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = F.relu(out)
        return out


class Decoder(nn.Module):
    def __init__(self, latent_dim, num_domains):
        super(Decoder, self).__init__()

        self.domain_embedding = nn.Embedding(num_domains, latent_dim)

        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)

        self.decoder = nn.Sequential(
            ResidualBlock(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64, 64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Thêm Tanh để đảm bảo đầu ra trong khoảng [-1, 1]
        )

        # Attention mechanism
        self.attention = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1), nn.Sigmoid())

    def forward(self, z, domain_label):
        domain_embed = self.domain_embedding(domain_label)
        z = z + domain_embed

        x = self.fc(z)
        x = x.view(-1, 512, 7, 7)
        x = self.decoder(x)

        # Apply attention
        attention_map = self.attention(x)
        x = x * attention_map

        return x


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, z):
        z = self.dropout(z)
        return self.fc(z)

In [52]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


class LabelSmoothingLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction="mean", smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    def k_one_hot(self, targets: torch.Tensor, n_classes: int, smoothing=0.0):
        with torch.no_grad():
            targets = (
                torch.empty(size=(targets.size(0), n_classes), device=targets.device)
                .fill_(smoothing / (n_classes - 1))
                .scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing)
            )
        return targets

    def reduce_loss(self, loss):
        return (
            loss.mean()
            if self.reduction == "mean"
            else loss.sum() if self.reduction == "sum" else loss
        )

    def forward(self, inputs, targets):
        assert 0 <= self.smoothing < 1

        targets = self.k_one_hot(targets, inputs.size(-1), self.smoothing)
        log_preds = F.log_softmax(inputs, -1)

        if self.weight is not None:
            log_preds = log_preds * self.weight.unsqueeze(0)

        return self.reduce_loss(-(targets * log_preds).sum(dim=-1))

class DynamicWeightBalancer:
    def __init__(self, init_alpha=0.5, init_beta=2.0, init_gamma=0.1, patience=5, scaling_factor=0.8):
        self.alpha = init_alpha  # Reconstruction loss weight
        self.beta = init_beta    # Classification loss weight
        self.gamma = init_gamma  # KL divergence weight
        self.patience = patience
        self.scaling_factor = scaling_factor
        self.best_loss = float('inf')
        self.counter = 0

    def update(self, current_loss, recon_loss, clf_loss, kl_loss):
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            self.counter = 0
            # Increase classification weight and decrease others
            self.beta /= self.scaling_factor
            self.alpha *= self.scaling_factor
            self.gamma *= self.scaling_factor

        # Ensure classification loss weight is always significantly larger
        total_weight = self.alpha + self.beta + self.gamma
        self.alpha = max(0.1, min(0.3, self.alpha / total_weight))
        self.beta = max(0.6, min(0.8, self.beta / total_weight))
        self.gamma = 1 - self.alpha - self.beta

        return self.alpha, self.beta, self.gamma

def compute_loss(reconstructed_imgs_list, original_imgs, mu, logvar, predicted_labels, true_labels, clf_loss_fn, epoch, total_epochs, balancer):
    recon_loss = sum(
        F.mse_loss(recon, original_imgs, reduction="mean")
        for recon in reconstructed_imgs_list
    ) / len(reconstructed_imgs_list)

    kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    clf_loss = clf_loss_fn(predicted_labels, true_labels)

    alpha, beta, gamma = balancer.update(recon_loss + clf_loss + kld_loss, recon_loss, clf_loss, kld_loss)

    total_loss = alpha * recon_loss + beta * clf_loss + gamma * kld_loss
    return total_loss, recon_loss.item(), clf_loss.item(), kld_loss.item(), alpha, beta, gamma

In [53]:
def mixup_data(x, y, alpha=1.0, device="cuda"):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

class LossWeightScheduler:
    def __init__(self, init_alpha=0.1, init_beta=1.0, init_gamma=0.1, decay_factor=0.9, decay_epochs=10):
        self.alpha = init_alpha
        self.beta = init_beta
        self.gamma = init_gamma
        self.decay_factor = decay_factor
        self.decay_epochs = decay_epochs

    def step(self, epoch):
        if (epoch + 1) % self.decay_epochs == 0:
            self.alpha *= self.decay_factor
            self.gamma *= self.decay_factor
        return self.alpha, self.beta, self.gamma


def train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    dataloader,
    val_loaders,
    optimizer,
    scheduler,
    num_epochs=100,
    device="cuda",
    patience=10,
):
    clf_loss_fn = LabelSmoothingLoss(smoothing=0.1)
    domain_to_idx = {domain: idx for idx, domain in enumerate(domains)}

    best_loss = float("inf")
    patience_counter = 0

    balancer = DynamicWeightBalancer()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        running_recon_loss = 0.0
        running_clf_loss = 0.0
        running_kl_loss = 0.0

        for inputs, labels in tqdm(dataloader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            inputs, labels_a, labels_b, lam = mixup_data(
                inputs, labels, alpha=0.2, device=device
            )

            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)

            reconstructed_imgs_list = []
            for domain in domains:
                domain_label = torch.tensor(
                    [domain_to_idx[domain]] * inputs.size(0), device=device
                )
                reconstructed_imgs = decoders[domain](z, domain_label)
                reconstructed_imgs_list.append(reconstructed_imgs)

            predicted_labels = classifier(z)

            loss, recon_loss, clf_loss, kl_loss, alpha, beta, gamma = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                lambda pred, target: mixup_criterion(
                    clf_loss_fn, pred, labels_a, labels_b, lam
                ),
                epoch,
                num_epochs,
                balancer,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_recon_loss += recon_loss
            running_clf_loss += clf_loss
            running_kl_loss += kl_loss

        avg_loss = running_loss / len(dataloader)
        avg_recon_loss = running_recon_loss / len(dataloader)
        avg_clf_loss = running_clf_loss / len(dataloader)
        avg_kl_loss = running_kl_loss / len(dataloader)

        print(
            f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}, Recon: {avg_recon_loss:.4f}, Clf: {avg_clf_loss:.4f}, KL: {avg_kl_loss:.4f}"
        )
        print(f"Weights - Alpha: {alpha:.4f}, Beta: {beta:.4f}, Gamma: {gamma:.4f}")

        # Evaluate on all domains
        encoder.eval()
        classifier.eval()
        for domain in domains:
            val_loader = val_loaders[domain]
            accuracy, _, _ = evaluate_model(
                encoder,
                classifier,
                decoders[domain],
                val_loader,
                device,
                domain_to_idx[domain],
            )
            print(f"Validation Accuracy on {domain}: {accuracy * 100:.2f}%")

        scheduler.step(avg_loss)

        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break

In [54]:
def evaluate_model(encoder, classifier, decoder, dataloader, device, domain_label):
    encoder.eval()
    classifier.eval()
    decoder.eval()
    total_clf_loss = 0.0
    total_recon_loss = 0.0
    correct = 0
    total = 0
    clf_loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            
            # Chuyển domain_label thành tensor và lặp lại cho mỗi mẫu trong batch
            domain_labels = torch.full((batch_size,), domain_label, device=device)
            reconstructed_imgs = decoder(z, domain_labels)

            # Classification accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Losses
            clf_loss = clf_loss_fn(outputs, labels)
            recon_loss = F.mse_loss(reconstructed_imgs, inputs, reduction="sum")
            total_clf_loss += clf_loss.item()
            total_recon_loss += recon_loss.item()

    accuracy = correct / total
    avg_clf_loss = total_clf_loss / len(dataloader.dataset)
    avg_recon_loss = total_recon_loss / len(dataloader.dataset)
    return accuracy, avg_clf_loss, avg_recon_loss

In [56]:
# Main training and evaluation script
DATA_PATH = (
    "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
)
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Initialize models
encoder = Encoder(latent_dim).to(device)
decoders = {domain: Decoder(latent_dim, len(domains)).to(device) for domain in domains}
classifier = Classifier(latent_dim, num_classes).to(device)

# Optimizer and Scheduler
params = list(encoder.parameters()) + list(classifier.parameters())
for decoder in decoders.values():
    params += list(decoder.parameters())
optimizer = optim.AdamW(params, lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=5, verbose=True
)

# Create a mixed DataLoader for all domains
mixed_dataloader = get_mixed_dataloader(DATA_PATH, domains)

# Create validation loaders for each domain
val_loaders = {domain: get_dataloader(DATA_PATH, domain)[1] for domain in domains}

# Train model using progressive domain training
train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    mixed_dataloader,
    val_loaders,
    optimizer,
    scheduler,
    num_epochs=100,
    device=device,
    patience=10,
)

# Final evaluation on the test set for each domain
for domain in domains:
    print(f"Evaluating on test set for domain: {domain}")
    evaluate_model(
        encoder,
        classifier,
        decoders[domain],
        test_loaders[domain],
        device,
        domains.index(domain),
    )

Using device: cuda
Epoch 1/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 1, Loss: 1.7842, Recon: 2.4097, Clf: 1.9171, KL: 0.0414
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 2/100


Training: 100%|██████████| 313/313 [02:07<00:00,  2.46it/s]


Epoch 2, Loss: 1.4642, Recon: 2.0374, Clf: 1.5518, KL: 0.1902
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 3/100


Training: 100%|██████████| 313/313 [02:04<00:00,  2.51it/s]


Epoch 3, Loss: 1.3852, Recon: 1.9024, Clf: 1.4568, KL: 0.2952
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 4/100


Training: 100%|██████████| 313/313 [02:03<00:00,  2.54it/s]


Epoch 4, Loss: 1.3501, Recon: 1.8440, Clf: 1.4134, KL: 0.3504
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 5/100


Training: 100%|██████████| 313/313 [02:04<00:00,  2.52it/s]


Epoch 5, Loss: 1.3314, Recon: 1.7929, Clf: 1.3936, KL: 0.3724
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 6/100


Training: 100%|██████████| 313/313 [02:04<00:00,  2.52it/s]


Epoch 6, Loss: 1.3074, Recon: 1.7632, Clf: 1.3661, KL: 0.3812
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 7/100


Training: 100%|██████████| 313/313 [02:03<00:00,  2.53it/s]


Epoch 7, Loss: 1.3169, Recon: 1.7112, Clf: 1.3834, KL: 0.3909
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 8/100


Training: 100%|██████████| 313/313 [02:05<00:00,  2.49it/s]


Epoch 8, Loss: 1.2926, Recon: 1.7082, Clf: 1.3536, KL: 0.3894
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 9/100


Training: 100%|██████████| 313/313 [02:04<00:00,  2.51it/s]


Epoch 9, Loss: 1.2750, Recon: 1.6903, Clf: 1.3344, KL: 0.3847
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 10/100


Training: 100%|██████████| 313/313 [02:04<00:00,  2.52it/s]


Epoch 10, Loss: 1.2685, Recon: 1.6641, Clf: 1.3299, KL: 0.3815
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000
Epoch 11/100


Training:   4%|▍         | 13/313 [00:06<02:26,  2.05it/s]


KeyboardInterrupt: 

In [57]:
evaluate_on_all_domains(encoder, classifier, decoders, domains, DATA_PATH, device)


Final Evaluation on All Domains



Evaluating: 100%|██████████| 64/64 [00:22<00:00,  2.85it/s]


Domain: art_painting
  Accuracy: 59.57%
  Avg Classification Loss: 0.0377
  Avg Reconstruction Loss: 208500.9526



Evaluating: 100%|██████████| 74/74 [00:23<00:00,  3.10it/s]


Domain: cartoon
  Accuracy: 71.20%
  Avg Classification Loss: 0.0305
  Avg Reconstruction Loss: 284906.5195



Evaluating: 100%|██████████| 53/53 [00:18<00:00,  2.88it/s]


Domain: photo
  Accuracy: 87.54%
  Avg Classification Loss: 0.0194
  Avg Reconstruction Loss: 215283.1936



Evaluating: 100%|██████████| 123/123 [00:34<00:00,  3.57it/s]

Domain: sketch
  Accuracy: 67.40%
  Avg Classification Loss: 0.0307
  Avg Reconstruction Loss: 365410.1858




