In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F

In [12]:
import random
import numpy as np

def set_random_seeds(seed_value=42):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    np.random.seed(seed_value)  # Numpy module.
    random.seed(seed_value)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seeds()

In [13]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = root_dir
        self.domain = domain
        self.transform = transform
        self.images, self.labels = self._load_images_labels()

    def _load_images_labels(self):
        image_paths = []
        labels = []
        domain_dir = os.path.join(self.root_dir, self.domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

        return image_paths, labels

    def __len__(self):
        return len(self.images)  # Return the number of images

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


# Function to get DataLoader
def get_dataloader(root_dir, domain, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    dataset = PACSDataset(root_dir, domain, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [14]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.fc_mu = nn.Linear(resnet.fc.in_features, latent_dim)
        self.fc_logvar = nn.Linear(resnet.fc.in_features, latent_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim, num_domains):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)

        # Domain embedding
        self.domain_embedding = nn.Embedding(num_domains, latent_dim)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z, domain_label):
        # Incorporate domain information
        domain_embed = self.domain_embedding(domain_label)
        z = z + domain_embed  # Combine latent vector with domain embedding
        z = self.fc(z)
        z = z.view(-1, 512, 7, 7)
        return self.decoder(z)


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)

    def forward(self, z):
        return self.fc(z)

In [15]:
# Reparameterization trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


def compute_loss(
    reconstructed_imgs_list,
    original_imgs,
    mu,
    logvar,
    predicted_labels,
    true_labels,
    clf_loss_fn,
    alpha=1.0,
    beta=1.0,
    gamma=1.0,
):
    # Reconstruction Loss
    recon_loss = sum(
        F.mse_loss(recon, original_imgs, reduction="sum")
        for recon in reconstructed_imgs_list
    )
    
    # KL Divergence Loss
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Classification Loss
    clf_loss = clf_loss_fn(predicted_labels, true_labels)
    
    # Total Loss with weights
    total_loss = alpha * recon_loss + beta * clf_loss + gamma * kld_loss
    return total_loss, recon_loss.item(), clf_loss.item(), kld_loss.item()

In [16]:
def train_model(
    encoder,
    decoders,
    classifier,
    source_domain,
    target_domains,
    dataloader,
    optimizer,
    num_epochs=10,
    device="cuda",
):
    clf_loss_fn = nn.CrossEntropyLoss()
    domain_to_idx = {domain: idx for idx, domain in enumerate(target_domains)}

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        for inputs, labels in tqdm(dataloader, desc=f"Training on {source_domain}"):
            inputs, labels = inputs.to(device), labels.to(device)

            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)

            reconstructed_imgs_list = []
            for domain in target_domains:
                domain_label = torch.tensor([domain_to_idx[domain]] * inputs.size(0), device=device)
                reconstructed_imgs = decoders[domain](z, domain_label)
                reconstructed_imgs_list.append(reconstructed_imgs)

            predicted_labels = classifier(z)

            loss, recon_loss, clf_loss, kld_loss = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                clf_loss_fn,
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

In [17]:
def evaluate_model(encoder, classifier, decoder, dataloader, device, domain_label):
    encoder.eval()
    classifier.eval()
    decoder.eval()
    total_clf_loss = 0.0
    total_recon_loss = 0.0
    correct = 0
    total = 0
    clf_loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            batch_size = inputs.size(0)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            # Chuyển domain_label thành tensor và lặp lại cho mỗi mẫu trong batch
            domain_labels = torch.full((batch_size,), domain_label, device=device)
            reconstructed_imgs = decoder(z, domain_labels)
            
            # Classification accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Losses
            clf_loss = clf_loss_fn(outputs, labels)
            recon_loss = F.mse_loss(reconstructed_imgs, inputs, reduction="sum")
            total_clf_loss += clf_loss.item()
            total_recon_loss += recon_loss.item()

    accuracy = correct / total
    avg_clf_loss = total_clf_loss / len(dataloader.dataset)
    avg_recon_loss = total_recon_loss / len(dataloader.dataset)
    return accuracy, avg_clf_loss, avg_recon_loss

In [20]:
def evaluate_on_all_domains(encoder, classifier, decoders, domains, data_path, device):
    print("\nFinal Evaluation on All Domains\n")
    for domain in domains:
        print(f'\nEvaluting on domain {domain}')
        eval_dataloader = get_dataloader(data_path, domain)
        domain_label = domains.index(domain)
        accuracy, avg_clf_loss, avg_recon_loss = evaluate_model(
            encoder,
            classifier,
            decoders[domain],
            eval_dataloader,
            device,
            domain_label,
        )
        print(f"Domain: {domain}")
        print(f"  Accuracy: {accuracy * 100:.2f}%")
        print(f"  Avg Classification Loss: {avg_clf_loss:.4f}")
        print(f"  Avg Reconstruction Loss: {avg_recon_loss:.4f}\n")

In [19]:
# Main training and evaluation script
DATA_PATH = "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Initialize models outside the loop
encoder = Encoder(latent_dim).to(device)
decoders = {domain: Decoder(latent_dim, len(domains)).to(device) for domain in domains}
classifier = Classifier(latent_dim, num_classes).to(device)

# Optimizer
params = list(encoder.parameters()) + list(classifier.parameters())
for decoder in decoders.values():
    params += list(decoder.parameters())
optimizer = optim.Adam(params, lr=1e-4)

for source_domain in domains:
    print(f"\nTraining on source domain: {source_domain}\n")

    target_domains = [d for d in domains if d != source_domain]

    # Dataloader for source domain
    dataloader = get_dataloader(DATA_PATH, source_domain)

    # Train model
    train_model(
        encoder,
        decoders,
        classifier,
        source_domain,
        target_domains,
        dataloader,
        optimizer,
        num_epochs=10,
        device=device,
    )

    # Evaluate model
    for eval_domain in target_domains:
        eval_dataloader = get_dataloader(DATA_PATH, eval_domain)
        domain_label = domains.index(eval_domain)
        accuracy, avg_clf_loss, avg_recon_loss = evaluate_model(
            encoder,
            classifier,
            decoders[eval_domain],
            eval_dataloader,
            device,
            domain_label,
        )
        print(f"Accuracy on target domain '{eval_domain}': {accuracy * 100:.2f}%")
        print(f"Avg Classification Loss: {avg_clf_loss:.4f}")
        print(f"Avg Reconstruction Loss: {avg_recon_loss:.4f}")

# Final evaluation on all domains
evaluate_on_all_domains(encoder, classifier, decoders, domains, DATA_PATH, device)

Using device: cuda

Training on source domain: art_painting

Epoch 1/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.21it/s]


Epoch 1, Loss: 23728388.8438
Epoch 2/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.27it/s]


Epoch 2, Loss: 21367374.5938
Epoch 3/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.25it/s]


Epoch 3, Loss: 19837006.9844
Epoch 4/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.25it/s]


Epoch 4, Loss: 19033064.9844
Epoch 5/10


Training on art_painting: 100%|██████████| 64/64 [00:27<00:00,  2.29it/s]


Epoch 5, Loss: 18444755.2969
Epoch 6/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.29it/s]


Epoch 6, Loss: 17795944.3906
Epoch 7/10


Training on art_painting: 100%|██████████| 64/64 [00:27<00:00,  2.29it/s]


Epoch 7, Loss: 17513212.9375
Epoch 8/10


Training on art_painting: 100%|██████████| 64/64 [00:28<00:00,  2.22it/s]


Epoch 8, Loss: 17236755.0625
Epoch 9/10


Training on art_painting: 100%|██████████| 64/64 [00:27<00:00,  2.29it/s]


Epoch 9, Loss: 17038124.6250
Epoch 10/10


Training on art_painting: 100%|██████████| 64/64 [00:27<00:00,  2.30it/s]


Epoch 10, Loss: 16966236.1250


Evaluating: 100%|██████████| 74/74 [00:23<00:00,  3.13it/s]


Accuracy on target domain 'cartoon': 20.09%
Avg Classification Loss: 0.0611
Avg Reconstruction Loss: 250982.6343


Evaluating: 100%|██████████| 53/53 [00:19<00:00,  2.68it/s]


Accuracy on target domain 'photo': 31.62%
Avg Classification Loss: 0.0585
Avg Reconstruction Loss: 184420.6982


Evaluating: 100%|██████████| 123/123 [00:35<00:00,  3.51it/s]


Accuracy on target domain 'sketch': 19.22%
Avg Classification Loss: 0.0608
Avg Reconstruction Loss: 274486.5206

Training on source domain: cartoon

Epoch 1/10


Training on cartoon: 100%|██████████| 74/74 [00:30<00:00,  2.43it/s]


Epoch 1, Loss: 27017963.3243
Epoch 2/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.53it/s]


Epoch 2, Loss: 24826419.6351
Epoch 3/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.48it/s]


Epoch 3, Loss: 23949692.7500
Epoch 4/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.50it/s]


Epoch 4, Loss: 23280479.7500
Epoch 5/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.51it/s]


Epoch 5, Loss: 22916207.2635
Epoch 6/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.51it/s]


Epoch 6, Loss: 22655107.7162
Epoch 7/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.50it/s]


Epoch 7, Loss: 22527150.2905
Epoch 8/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.55it/s]


Epoch 8, Loss: 21974861.1351
Epoch 9/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.49it/s]


Epoch 9, Loss: 22252388.3716
Epoch 10/10


Training on cartoon: 100%|██████████| 74/74 [00:29<00:00,  2.50it/s]


Epoch 10, Loss: 21927625.9054


Evaluating: 100%|██████████| 64/64 [00:21<00:00,  2.94it/s]


Accuracy on target domain 'art_painting': 27.64%
Avg Classification Loss: 0.0607
Avg Reconstruction Loss: 169889.6614


Evaluating: 100%|██████████| 53/53 [00:17<00:00,  3.05it/s]


Accuracy on target domain 'photo': 30.84%
Avg Classification Loss: 0.0638
Avg Reconstruction Loss: 201114.0524


Evaluating: 100%|██████████| 123/123 [00:29<00:00,  4.12it/s]


Accuracy on target domain 'sketch': 21.33%
Avg Classification Loss: 0.0594
Avg Reconstruction Loss: 281876.7296

Training on source domain: photo

Epoch 1/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.27it/s]


Epoch 1, Loss: 17302811.0425
Epoch 2/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.23it/s]


Epoch 2, Loss: 16289796.3443
Epoch 3/10


Training on photo: 100%|██████████| 53/53 [00:22<00:00,  2.31it/s]


Epoch 3, Loss: 16085445.5142
Epoch 4/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.28it/s]


Epoch 4, Loss: 16026506.6887
Epoch 5/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.28it/s]


Epoch 5, Loss: 15820948.0660
Epoch 6/10


Training on photo: 100%|██████████| 53/53 [00:24<00:00,  2.19it/s]


Epoch 6, Loss: 15869980.8538
Epoch 7/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.30it/s]


Epoch 7, Loss: 15571356.0142
Epoch 8/10


Training on photo: 100%|██████████| 53/53 [00:22<00:00,  2.31it/s]


Epoch 8, Loss: 15664235.4009
Epoch 9/10


Training on photo: 100%|██████████| 53/53 [00:23<00:00,  2.29it/s]


Epoch 9, Loss: 15735767.0660
Epoch 10/10


Training on photo: 100%|██████████| 53/53 [00:22<00:00,  2.37it/s]


Epoch 10, Loss: 15876578.2830


Evaluating: 100%|██████████| 64/64 [00:21<00:00,  3.04it/s]


Accuracy on target domain 'art_painting': 21.04%
Avg Classification Loss: 0.0619
Avg Reconstruction Loss: 167008.2710


Evaluating: 100%|██████████| 74/74 [00:21<00:00,  3.39it/s]


Accuracy on target domain 'cartoon': 25.17%
Avg Classification Loss: 0.0608
Avg Reconstruction Loss: 230779.5275


Evaluating: 100%|██████████| 123/123 [00:29<00:00,  4.12it/s]


Accuracy on target domain 'sketch': 14.35%
Avg Classification Loss: 0.0671
Avg Reconstruction Loss: 268746.9290

Training on source domain: sketch

Epoch 1/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.84it/s]


Epoch 1, Loss: 28974552.4878
Epoch 2/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.82it/s]


Epoch 2, Loss: 26806441.5772
Epoch 3/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.85it/s]


Epoch 3, Loss: 26073576.7805
Epoch 4/10


Training on sketch: 100%|██████████| 123/123 [00:42<00:00,  2.87it/s]


Epoch 4, Loss: 26117777.0894
Epoch 5/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.86it/s]


Epoch 5, Loss: 25471309.1545
Epoch 6/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.82it/s]


Epoch 6, Loss: 25485403.3171
Epoch 7/10


Training on sketch: 100%|██████████| 123/123 [00:42<00:00,  2.88it/s]


Epoch 7, Loss: 25741038.9919
Epoch 8/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.80it/s]


Epoch 8, Loss: 25534879.5285
Epoch 9/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.86it/s]


Epoch 9, Loss: 25268298.9919
Epoch 10/10


Training on sketch: 100%|██████████| 123/123 [00:43<00:00,  2.86it/s]


Epoch 10, Loss: 25224703.5610


Evaluating: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s]


Accuracy on target domain 'art_painting': 10.40%
Avg Classification Loss: 0.2640
Avg Reconstruction Loss: 230264.0049


Evaluating: 100%|██████████| 74/74 [00:21<00:00,  3.46it/s]


Accuracy on target domain 'cartoon': 20.52%
Avg Classification Loss: 0.1155
Avg Reconstruction Loss: 256191.2451


Evaluating: 100%|██████████| 53/53 [00:17<00:00,  3.06it/s]


Accuracy on target domain 'photo': 12.46%
Avg Classification Loss: 0.2692
Avg Reconstruction Loss: 302364.4510

Final Evaluation on All Domains



Evaluating: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s]


Domain: art_painting
  Accuracy: 10.30%
  Avg Classification Loss: 0.2660
  Avg Reconstruction Loss: 231631.0410



Evaluating: 100%|██████████| 123/123 [00:29<00:00,  4.11it/s]


Domain: sketch
  Accuracy: 42.10%
  Avg Classification Loss: 0.0456
  Avg Reconstruction Loss: 360168.9527



In [21]:
# Final evaluation on all domains
evaluate_on_all_domains(encoder, classifier, decoders, domains, DATA_PATH, device)


Final Evaluation on All Domains


Evaluting on domain art_painting


Evaluating: 100%|██████████| 64/64 [00:22<00:00,  2.90it/s]


Domain: art_painting
  Accuracy: 10.55%
  Avg Classification Loss: 0.2622
  Avg Reconstruction Loss: 230154.9268


Evaluting on domain cartoon


Evaluating: 100%|██████████| 74/74 [00:21<00:00,  3.49it/s]


Domain: cartoon
  Accuracy: 20.86%
  Avg Classification Loss: 0.1150
  Avg Reconstruction Loss: 252531.8223


Evaluting on domain photo


Evaluating: 100%|██████████| 53/53 [00:17<00:00,  3.11it/s]


Domain: photo
  Accuracy: 12.34%
  Avg Classification Loss: 0.2696
  Avg Reconstruction Loss: 300694.8666


Evaluting on domain sketch


Evaluating: 100%|██████████| 123/123 [00:30<00:00,  4.07it/s]

Domain: sketch
  Accuracy: 42.86%
  Avg Classification Loss: 0.0457
  Avg Reconstruction Loss: 362928.5531




