In [66]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [67]:
import random
import numpy as np

def set_random_seeds(seed_value=42):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    np.random.seed(seed_value)  # Numpy module.
    random.seed(seed_value)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seeds()

In [None]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domains, transform=None):
        self.root_dir = root_dir
        self.domains = domains
        self.transform = transform
        self.images = []
        self.labels = []
        self._load_images_labels()

    def _load_images_labels(self):
        for domain in self.domains:
            domain_dir = os.path.join(self.root_dir, domain)
            classes = sorted(
                [
                    d
                    for d in os.listdir(domain_dir)
                    if os.path.isdir(os.path.join(domain_dir, d))
                ]
            )

            for label, class_name in enumerate(classes):
                class_dir = os.path.join(domain_dir, class_name)
                for image_name in os.listdir(class_dir):
                    if image_name.endswith((".png", ".jpg", ".jpeg")):
                        self.images.append(os.path.join(class_dir, image_name))
                        self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


def get_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    ])

In [68]:
def get_dataloader(
    root_dir,
    domains,
    batch_size=32,
    train_size=0.7,
    val_size=0.15,
    test_size=0.15,
    random_seed=42,
):
    # Ensure the sizes sum to 1
    assert train_size + val_size + test_size == 1.0

    # Load all image paths and their corresponding labels from all domains
    image_paths = []
    labels = []
    for domain in domains:
        domain_dir = os.path.join(root_dir, domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

    # Split the combined dataset into train, val, and test
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        image_paths,
        labels,
        train_size=train_size,
        stratify=labels,
        random_state=random_seed,
    )
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths,
        temp_labels,
        train_size=val_size / (val_size + test_size),
        stratify=temp_labels,
        random_state=random_seed,
    )

    # Create datasets
    train_dataset = PACSDataset(root_dir, domains, transform=get_transform())
    train_dataset.images = train_paths
    train_dataset.labels = train_labels

    val_dataset = PACSDataset(root_dir, domains, transform=get_transform())
    val_dataset.images = val_paths
    val_dataset.labels = val_labels

    test_dataset = PACSDataset(root_dir, domains, transform=get_transform())
    test_dataset.images = test_paths
    test_dataset.labels = test_labels

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [69]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()

        # Sử dụng EfficientNet-B1
        self.efficientnet = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V1)

        # Freeze EfficientNet layers
        for param in self.efficientnet.parameters():
            param.requires_grad = False

        # Lấy số features từ lớp cuối cùng của EfficientNet-B1
        in_features = self.efficientnet.classifier[1].in_features

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(in_features, in_features // 16),
            nn.ReLU(),
            nn.Linear(in_features // 16, in_features),
            nn.Sigmoid(),
        )

        # Mean (mu) and log-variance (logvar) layers
        self.fc_mu = nn.Linear(in_features, latent_dim)
        self.fc_logvar = nn.Linear(in_features, latent_dim)

        self.dropout = nn.Dropout(0.5)  # Add dropout

    def forward(self, x):
        # Pass input through EfficientNet feature extractor
        features = self.efficientnet.features(x)
        x = self.efficientnet.avgpool(features)
        x = torch.flatten(x, 1)

        x = self.dropout(x)  # Apply dropout

        # Apply attention
        attention_weights = self.attention(x)
        x = x * attention_weights

        # Compute mu and logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = F.relu(out)
        return out


class Decoder(nn.Module):
    def __init__(self, latent_dim, num_domains):
        super(Decoder, self).__init__()

        self.domain_embedding = nn.Embedding(num_domains, latent_dim)

        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)

        self.decoder = nn.Sequential(
            ResidualBlock(512, 256),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            ResidualBlock(128, 128),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            ResidualBlock(64, 64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Thêm Tanh để đảm bảo đầu ra trong khoảng [-1, 1]
        )

        # Attention mechanism
        self.attention = nn.Sequential(nn.Conv2d(3, 1, kernel_size=1), nn.Sigmoid())

    def forward(self, z, domain_label):
        domain_embed = self.domain_embedding(domain_label)
        z = z + domain_embed

        x = self.fc(z)
        x = x.view(-1, 512, 7, 7)
        x = self.decoder(x)

        # Apply attention
        attention_map = self.attention(x)
        x = x * attention_map

        return x


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, z):
        z = self.dropout(z)
        return self.fc(z)

In [70]:
class LabelSmoothingLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction="mean", smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    def k_one_hot(self, targets: torch.Tensor, n_classes: int, smoothing=0.0):
        with torch.no_grad():
            targets = (
                torch.empty(size=(targets.size(0), n_classes), device=targets.device)
                .fill_(smoothing / (n_classes - 1))
                .scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing)
            )
        return targets

    def reduce_loss(self, loss):
        return (
            loss.mean()
            if self.reduction == "mean"
            else loss.sum() if self.reduction == "sum" else loss
        )

    def forward(self, inputs, targets):
        assert 0 <= self.smoothing < 1

        targets = self.k_one_hot(targets, inputs.size(-1), self.smoothing)
        log_preds = F.log_softmax(inputs, -1)

        if self.weight is not None:
            log_preds = log_preds * self.weight.unsqueeze(0)

        return self.reduce_loss(-(targets * log_preds).sum(dim=-1))

class DynamicWeightBalancer:
    def __init__(self, init_alpha=0.5, init_beta=2.0, init_gamma=0.1, patience=5, scaling_factor=0.8):
        self.alpha = init_alpha  # Reconstruction loss weight
        self.beta = init_beta    # Classification loss weight
        self.gamma = init_gamma  # KL divergence weight
        self.patience = patience
        self.scaling_factor = scaling_factor
        self.best_loss = float('inf')
        self.counter = 0

    def update(self, current_loss, recon_loss, clf_loss, kl_loss):
        if current_loss < self.best_loss:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            self.counter = 0
            # Increase classification weight and decrease others
            self.beta /= self.scaling_factor
            self.alpha *= self.scaling_factor
            self.gamma *= self.scaling_factor

        # Ensure classification loss weight is always significantly larger
        total_weight = self.alpha + self.beta + self.gamma
        self.alpha = max(0.1, min(0.3, self.alpha / total_weight))
        self.beta = max(0.6, min(0.8, self.beta / total_weight))
        self.gamma = 1 - self.alpha - self.beta

        return self.alpha, self.beta, self.gamma

In [None]:
def reparameterize(mu, logvar, dropout_rate=0.5):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    z = F.dropout(z, p=dropout_rate, training=True)  # Apply dropout
    return z

def compute_loss(reconstructed_imgs_list, original_imgs, mu, logvar, predicted_labels, true_labels, clf_loss_fn, epoch, total_epochs, balancer):
    recon_loss = sum(
        F.mse_loss(recon, original_imgs, reduction="mean")
        for recon in reconstructed_imgs_list
    ) / len(reconstructed_imgs_list)

    kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    clf_loss = clf_loss_fn(predicted_labels, true_labels)

    alpha, beta, gamma = balancer.update(recon_loss + clf_loss + kld_loss, recon_loss, clf_loss, kld_loss)

    total_loss = alpha * recon_loss + beta * clf_loss + gamma * kld_loss
    return total_loss, recon_loss.item(), clf_loss.item(), kld_loss.item(), alpha, beta, gamma

In [71]:
def mixup_data(x, y, alpha=1.0, device="cuda"):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
def train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    dataloader,
    val_loaders,
    optimizer,
    scheduler,
    num_epochs=100,
    device="cuda",
    patience=10,
):  
    print("Training model with progressive domain adaptation")
    print(f"Number of epochs: {num_epochs}")
    print(f"Patience: {patience}")
    print(f"Domains: {domains}")
    print(f"Device: {device}")
    print(f'Number of training samples: {len(dataloader.dataset)}')
    print(f'Number of validation samples: {len(val_loaders[domains[0]].dataset)}')

    clf_loss_fn = LabelSmoothingLoss(smoothing=0.1)
    domain_to_idx = {domain: idx for idx, domain in enumerate(domains)}

    best_loss = float("inf")
    patience_counter = 0

    balancer = DynamicWeightBalancer()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        running_recon_loss = 0.0
        running_clf_loss = 0.0
        running_kl_loss = 0.0

        for inputs, labels in tqdm(dataloader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            inputs, labels_a, labels_b, lam = mixup_data(
                inputs, labels, alpha=0.2, device=device
            )

            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)

            reconstructed_imgs_list = []
            for domain in domains:
                domain_label = torch.tensor(
                    [domain_to_idx[domain]] * inputs.size(0), device=device
                )
                reconstructed_imgs = decoders[domain](z, domain_label)
                reconstructed_imgs_list.append(reconstructed_imgs)

            predicted_labels = classifier(z)

            loss, recon_loss, clf_loss, kl_loss, alpha, beta, gamma = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                lambda pred, target: mixup_criterion(
                    clf_loss_fn, pred, labels_a, labels_b, lam
                ),
                epoch,
                num_epochs,
                balancer,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_recon_loss += recon_loss
            running_clf_loss += clf_loss
            running_kl_loss += kl_loss

        avg_loss = running_loss / len(dataloader)
        avg_recon_loss = running_recon_loss / len(dataloader)
        avg_clf_loss = running_clf_loss / len(dataloader)
        avg_kl_loss = running_kl_loss / len(dataloader)

        print(
            f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}, Recon: {avg_recon_loss:.4f}, Clf: {avg_clf_loss:.4f}, KL: {avg_kl_loss:.4f}"
        )
        print(f"Weights - Alpha: {alpha:.4f}, Beta: {beta:.4f}, Gamma: {gamma:.4f}")

        if (epoch + 1) % 10 == 0:
            # Evaluate on all domains
            encoder.eval()
            classifier.eval()
            for domain in domains:
                val_loader = val_loaders[domain]
                accuracy, _, _ = evaluate_model(
                    encoder,
                    classifier,
                    decoders[domain],
                    val_loader,
                    device,
                    domain_to_idx[domain],
                )
                print(f"Validation Accuracy on {domain}: {accuracy * 100:.2f}%")

        scheduler.step(avg_loss)

        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break

In [72]:
def evaluate_model(encoder, classifier, decoder, dataloader, device, domain_label):
    encoder.eval()
    classifier.eval()
    decoder.eval()
    total_clf_loss = 0.0
    total_recon_loss = 0.0
    correct = 0
    total = 0
    clf_loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            
            # Chuyển domain_label thành tensor và lặp lại cho mỗi mẫu trong batch
            domain_labels = torch.full((batch_size,), domain_label, device=device)
            reconstructed_imgs = decoder(z, domain_labels)

            # Classification accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Losses
            clf_loss = clf_loss_fn(outputs, labels)
            recon_loss = F.mse_loss(reconstructed_imgs, inputs, reduction="sum")
            total_clf_loss += clf_loss.item()
            total_recon_loss += recon_loss.item()

    accuracy = correct / total
    avg_clf_loss = total_clf_loss / len(dataloader.dataset)
    avg_recon_loss = total_recon_loss / len(dataloader.dataset)
    return accuracy, avg_clf_loss, avg_recon_loss

In [73]:
# Main training and evaluation script
DATA_PATH = "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Initialize models
encoder = Encoder(latent_dim).to(device)
decoders = {domain: Decoder(latent_dim, len(domains)).to(device) for domain in domains}
classifier = Classifier(latent_dim, num_classes).to(device)

# Optimizer and Scheduler
params = list(encoder.parameters()) + list(classifier.parameters())
for decoder in decoders.values():
    params += list(decoder.parameters())
optimizer = optim.AdamW(params, lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
   optimizer, mode="min", factor=0.1, patience=5
)

# Create a single DataLoader for all domains
train_loader, val_loader, test_loader = get_dataloader(DATA_PATH, domains)

# Create a dictionary for validation loaders if needed
val_loaders = {domain: val_loader for domain in domains}

# Train model using the combined DataLoader
train_model_progressive(
  encoder,
  decoders,
  classifier,
  domains,
  train_loader,
  val_loaders,  # Pass the dictionary of validation loaders
  optimizer,
  scheduler,
  num_epochs=100,
  device=device,
  patience=10,
)

Using device: cuda
Epoch 1/100


Training: 100%|██████████| 219/219 [01:27<00:00,  2.51it/s]


Epoch 1, Loss: 1.8766, Recon: 2.4858, Clf: 2.0176, KL: 0.0253
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000


Evaluating: 100%|██████████| 47/47 [00:21<00:00,  2.23it/s]


Validation Accuracy on art_painting: 36.49%


Evaluating: 100%|██████████| 47/47 [00:14<00:00,  3.29it/s]


Validation Accuracy on cartoon: 36.56%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.43it/s]


Validation Accuracy on photo: 37.22%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.54it/s]


Validation Accuracy on sketch: 37.56%
Epoch 2/100


Training: 100%|██████████| 219/219 [01:28<00:00,  2.47it/s]


Epoch 2, Loss: 1.5592, Recon: 2.1286, Clf: 1.6680, KL: 0.1189
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000


Evaluating: 100%|██████████| 47/47 [00:15<00:00,  3.12it/s]


Validation Accuracy on art_painting: 51.57%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.46it/s]


Validation Accuracy on cartoon: 54.70%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.48it/s]


Validation Accuracy on photo: 53.17%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.51it/s]


Validation Accuracy on sketch: 53.77%
Epoch 3/100


Training: 100%|██████████| 219/219 [01:28<00:00,  2.47it/s]


Epoch 3, Loss: 1.4371, Recon: 1.9769, Clf: 1.5209, KL: 0.2272
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000


Evaluating: 100%|██████████| 47/47 [00:14<00:00,  3.34it/s]


Validation Accuracy on art_painting: 57.77%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.36it/s]


Validation Accuracy on cartoon: 59.64%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.51it/s]


Validation Accuracy on photo: 59.51%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.47it/s]


Validation Accuracy on sketch: 58.24%
Epoch 4/100


Training: 100%|██████████| 219/219 [01:28<00:00,  2.49it/s]


Epoch 4, Loss: 1.3827, Recon: 1.8973, Clf: 1.4546, KL: 0.2930
Weights - Alpha: 0.1000, Beta: 0.8000, Gamma: 0.1000


Evaluating: 100%|██████████| 47/47 [00:14<00:00,  3.24it/s]


Validation Accuracy on art_painting: 60.24%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.51it/s]


Validation Accuracy on cartoon: 61.57%


Evaluating: 100%|██████████| 47/47 [00:14<00:00,  3.24it/s]


Validation Accuracy on photo: 61.71%


Evaluating: 100%|██████████| 47/47 [00:13<00:00,  3.36it/s]


Validation Accuracy on sketch: 62.04%
Epoch 5/100


Training:   3%|▎         | 6/219 [00:02<01:39,  2.15it/s]


KeyboardInterrupt: 