In [171]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F

In [172]:
import random
import numpy as np

def set_random_seeds(seed_value=42):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    np.random.seed(seed_value)  # Numpy module.
    random.seed(seed_value)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seeds()

In [173]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = root_dir
        self.domain = domain
        self.transform = transform
        self.images, self.labels = self._load_images_labels()

    def _load_images_labels(self):
        image_paths = []
        labels = []
        domain_dir = os.path.join(self.root_dir, self.domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

        return image_paths, labels

    def __len__(self):
        return len(self.images)  # Return the number of images

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


def get_transform():
    return transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
            ),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )


def get_combined_dataloader(root_dir, domains, batch_size=32):
    datasets = [
        PACSDataset(root_dir, domain, transform=get_transform()) for domain in domains
    ]
    combined_dataset = torch.utils.data.ConcatDataset(datasets)
    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)


def get_dataloader(root_dir, domain, batch_size=32):
    dataset = PACSDataset(root_dir, domain, transform=get_transform())
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [174]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim, model_name='efficientnet_b0'):
        super(Encoder, self).__init__()

        # Load pretrained EfficientNet
        self.efficientnet = models.efficientnet_b0(pretrained=True)

        # Freeze EfficientNet layers if you do not want to fine-tune
        for param in self.efficientnet.parameters():
            param.requires_grad = False

        # Replace the final classifier with linear layers for mu and logvar
        in_features = self.efficientnet.classifier[1].in_features  # Get the in_features from the classifier

        # Mean (mu) and log-variance (logvar) layers
        self.fc_mu = nn.Linear(in_features, latent_dim)
        self.fc_logvar = nn.Linear(in_features, latent_dim)

    def forward(self, x):
        # Pass input through EfficientNet feature extractor
        x = self.efficientnet.features(x)  # Use EfficientNet features
        x = nn.AdaptiveAvgPool2d(1)(x)  # Adapt to a fixed output size (1x1)
        x = torch.flatten(x, 1)  # Flatten the output from (batch, channels, 1, 1) to (batch, channels)

        # Compute mu and logvar
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim, num_domains, model_name="efficientnet_b0"):
        super(Decoder, self).__init__()

        # Pretrained EfficientNet model
        self.efficientnet = models.efficientnet_b0(pretrained=True)

        # Freeze EfficientNet layers (if you want to fine-tune only the final part)
        for param in self.efficientnet.parameters():
            param.requires_grad = False

        # Domain embedding
        self.domain_embedding = nn.Embedding(num_domains, latent_dim)

        # Replace the first convolutional layer to accept latent_dim channels
        self.efficientnet.features[0][0] = nn.Conv2d(
            latent_dim, 32, kernel_size=3, stride=2, padding=1, bias=False
        )

        # Adjust the classifier to output the correct shape
        num_ftrs = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 512 * 7 * 7), nn.ReLU()
        )

        # Decoder part (deconvolutional layers to generate images)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, z, domain_label):
        # Incorporate domain information
        domain_embed = self.domain_embedding(domain_label)
        z = z + domain_embed  # Combine latent vector with domain embedding

        # Reshape z to match EfficientNet input shape
        z = z.view(-1, z.size(1), 1, 1)
        z = z.expand(-1, -1, 7, 7)  # Expand to spatial dimensions

        # Pass through EfficientNet
        x = self.efficientnet.features(z)
        x = self.efficientnet.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.efficientnet.classifier(x)

        # Reshape and pass through the decoder
        x = x.view(-1, 512, 7, 7)
        return self.decoder(x)


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)

    def forward(self, z):
        return self.fc(z)

In [175]:
# Reparameterization trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


def compute_loss(
    reconstructed_imgs_list,
    original_imgs,
    mu,
    logvar,
    predicted_labels,
    true_labels,
    clf_loss_fn,
    alpha=1.0,
    beta=1.0,
    gamma=1.0,
):
    # Reconstruction Loss (MSE loss without normalization)
    recon_loss = sum(
        F.mse_loss(recon, original_imgs, reduction="sum")
        for recon in reconstructed_imgs_list
    )

    # KL Divergence Loss
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Classification Loss
    clf_loss = clf_loss_fn(predicted_labels, true_labels)

    # Total Loss with weights
    total_loss = alpha * recon_loss + beta * clf_loss + gamma * kld_loss
    return total_loss, recon_loss.item(), clf_loss.item(), kld_loss.item()

In [176]:
def train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    dataloader,
    optimizer,
    num_epochs=10,
    device="cuda",
):
    clf_loss_fn = nn.CrossEntropyLoss()
    domain_to_idx = {domain: idx for idx, domain in enumerate(domains)}

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        for inputs, labels in tqdm(dataloader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)

            reconstructed_imgs_list = []
            for domain in domains:
                domain_label = torch.tensor(
                    [domain_to_idx[domain]] * inputs.size(0), device=device
                )
                reconstructed_imgs = decoders[domain](z, domain_label)
                reconstructed_imgs_list.append(reconstructed_imgs)

            predicted_labels = classifier(z)

            loss, recon_loss, clf_loss, kld_loss = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                clf_loss_fn,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

In [178]:
def evaluate_model(encoder, classifier, decoder, dataloader, device, domain_label):
    encoder.eval()
    classifier.eval()
    decoder.eval()
    total_clf_loss = 0.0
    total_recon_loss = 0.0
    correct = 0
    total = 0
    clf_loss_fn = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            
            # Chuyển domain_label thành tensor và lặp lại cho mỗi mẫu trong batch
            domain_labels = torch.full((batch_size,), domain_label, device=device)
            reconstructed_imgs = decoder(z, domain_labels)

            # Classification accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Losses
            clf_loss = clf_loss_fn(outputs, labels)
            recon_loss = F.mse_loss(reconstructed_imgs, inputs, reduction="sum")
            total_clf_loss += clf_loss.item()
            total_recon_loss += recon_loss.item()

    accuracy = correct / total
    avg_clf_loss = total_clf_loss / len(dataloader.dataset)
    avg_recon_loss = total_recon_loss / len(dataloader.dataset)
    return accuracy, avg_clf_loss, avg_recon_loss

In [179]:
def evaluate_on_all_domains(encoder, classifier, decoders, domains, data_path, device):
    print("\nFinal Evaluation on All Domains\n")
    for domain in domains:
        eval_dataloader = get_dataloader(data_path, domain)
        domain_label = domains.index(domain)
        accuracy, avg_clf_loss, avg_recon_loss = evaluate_model(
            encoder,
            classifier,
            decoders[domain],
            eval_dataloader,
            device,
            domain_label,
        )
        print(f"Domain: {domain}")
        print(f"  Accuracy: {accuracy * 100:.2f}%")
        print(f"  Avg Classification Loss: {avg_clf_loss:.4f}")
        print(f"  Avg Reconstruction Loss: {avg_recon_loss:.4f}\n")

In [180]:
# Main training and evaluation script
DATA_PATH = (
    "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
)
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Initialize models outside the loop
encoder = Encoder(latent_dim).to(device)
decoders = {domain: Decoder(latent_dim, len(domains)).to(device) for domain in domains}
classifier = Classifier(latent_dim, num_classes).to(device)

# Optimizer
params = list(encoder.parameters()) + list(classifier.parameters())
for decoder in decoders.values():
    params += list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-4)

# Create a combined DataLoader for all domains
combined_dataloader = get_combined_dataloader(DATA_PATH, domains)

# Train model using progressive domain training
train_model_progressive(
    encoder,
    decoders,
    classifier,
    domains,
    combined_dataloader,
    optimizer,
    num_epochs=10,
    device=device,
)

# Final evaluation on all domains
evaluate_on_all_domains(encoder, classifier, decoders, domains, DATA_PATH, device)

Using device: cuda
Epoch 1/10


Training: 100%|██████████| 313/313 [02:41<00:00,  1.94it/s]


Epoch 1, Loss: 46980892.0543
Epoch 2/10


Training: 100%|██████████| 313/313 [02:40<00:00,  1.96it/s]


Epoch 2, Loss: 38304866.3195
Epoch 3/10


Training: 100%|██████████| 313/313 [02:39<00:00,  1.96it/s]


Epoch 3, Loss: 35017076.2109
Epoch 4/10


Training: 100%|██████████| 313/313 [02:36<00:00,  1.99it/s]


Epoch 4, Loss: 33432980.3866
Epoch 5/10


Training: 100%|██████████| 313/313 [02:45<00:00,  1.89it/s]


Epoch 5, Loss: 32169800.0000
Epoch 6/10


Training: 100%|██████████| 313/313 [02:38<00:00,  1.98it/s]


Epoch 6, Loss: 31087122.3003
Epoch 7/10


Training: 100%|██████████| 313/313 [03:13<00:00,  1.62it/s]


Epoch 7, Loss: 30591644.7764
Epoch 8/10


Training: 100%|██████████| 313/313 [02:47<00:00,  1.87it/s]


Epoch 8, Loss: 30288445.4393
Epoch 9/10


Training: 100%|██████████| 313/313 [02:42<00:00,  1.93it/s]


Epoch 9, Loss: 29747116.5272
Epoch 10/10


Training: 100%|██████████| 313/313 [02:44<00:00,  1.90it/s]


Epoch 10, Loss: 29391209.7668

Final Evaluation on All Domains



Evaluating: 100%|██████████| 64/64 [00:21<00:00,  2.95it/s]


Domain: art_painting
  Accuracy: 26.27%
  Avg Classification Loss: 0.0583
  Avg Reconstruction Loss: 245778.3926



Evaluating: 100%|██████████| 74/74 [00:25<00:00,  2.92it/s]


Domain: cartoon
  Accuracy: 28.54%
  Avg Classification Loss: 0.0556
  Avg Reconstruction Loss: 276124.8969



Evaluating: 100%|██████████| 53/53 [00:17<00:00,  2.97it/s]


Domain: photo
  Accuracy: 43.47%
  Avg Classification Loss: 0.0513
  Avg Reconstruction Loss: 272931.5403



Evaluating: 100%|██████████| 123/123 [00:31<00:00,  3.84it/s]

Domain: sketch
  Accuracy: 37.64%
  Avg Classification Loss: 0.0489
  Avg Reconstruction Loss: 168129.9263




