In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from PIL import Image
import os
import torch.nn.functional as F

In [24]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir = root_dir
        self.domain = domain
        self.transform = transform
        self.images, self.labels = self._load_images_labels()

    def _load_images_labels(self):
        image_paths = []
        labels = []
        domain_dir = os.path.join(self.root_dir, self.domain)
        classes = sorted(
            [
                d
                for d in os.listdir(domain_dir)
                if os.path.isdir(os.path.join(domain_dir, d))
            ]
        )

        for label, class_name in enumerate(classes):
            class_dir = os.path.join(domain_dir, class_name)
            for image_name in os.listdir(class_dir):
                if image_name.endswith((".png", ".jpg", ".jpeg")):
                    image_paths.append(os.path.join(class_dir, image_name))
                    labels.append(label)

        return image_paths, labels

    def __len__(self):
        return len(self.images)  # Return the number of images

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


# Function to get DataLoader
def get_dataloader(root_dir, domain, batch_size=32):
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    dataset = PACSDataset(root_dir, domain, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [25]:
# Define Encoder, Decoder, Classifier
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-1])
        self.fc_mu = nn.Linear(resnet.fc.in_features, latent_dim)
        self.fc_logvar = nn.Linear(resnet.fc.in_features, latent_dim)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        z = self.fc(z)
        z = z.view(-1, 512, 7, 7)
        return self.decoder(z)


class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(latent_dim, num_classes)

    def forward(self, z):
        return self.fc(z)

In [26]:
# Reparameterization trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# VAE loss function
def vae_loss(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD


def compute_loss(
    reconstructed_imgs_list,
    original_imgs,
    mu,
    logvar,
    predicted_labels,
    true_labels,
    clf_loss_fn,
):
    vae_loss_total = 0.0
    for reconstructed_imgs in reconstructed_imgs_list:
        vae_loss_val = vae_loss(reconstructed_imgs, original_imgs, mu, logvar)
        vae_loss_total += vae_loss_val
    clf_loss = clf_loss_fn(predicted_labels, true_labels)
    total_loss = vae_loss_total + clf_loss
    return total_loss

In [27]:
def train_model(
    encoder,
    decoders,
    classifier,
    source_domain,
    target_domains,
    dataloader,
    optimizer,
    num_epochs=10,
    device="cuda",
):
    clf_loss_fn = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        encoder.train()
        classifier.train()
        for decoder in decoders.values():
            decoder.train()

        running_loss = 0.0
        for inputs, labels in tqdm(dataloader, desc=f"Training on {source_domain}"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass through encoder to get mu and logvar
            mu, logvar = encoder(inputs)

            # Reparameterization trick
            z = reparameterize(mu, logvar)

            # Forward through decoders for each target domain
            reconstructed_imgs_list = []
            for domain in target_domains:
                reconstructed_imgs = decoders[domain](z)
                reconstructed_imgs_list.append(reconstructed_imgs)

            # Forward through classifier to predict class
            predicted_labels = classifier(z)

            # Compute loss
            loss = compute_loss(
                reconstructed_imgs_list,
                inputs,
                mu,
                logvar,
                predicted_labels,
                labels,
                clf_loss_fn,
            )

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

In [28]:
def evaluate_model(encoder, classifier, dataloader, device):
    encoder.eval()
    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            mu, logvar = encoder(inputs)
            z = reparameterize(mu, logvar)
            outputs = classifier(z)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [29]:
# Main training and evaluation script
DATA_PATH = "/kaggle/input/pacs-dataset/kfold"  # Update this path to your dataset location
latent_dim = 256
num_classes = 7  # Update this according to your PACS dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Domains in PACS dataset
domains = ["art_painting", "cartoon", "photo", "sketch"]

# Perform leave-one-out cross-validation
for leave_out_domain in domains:
    print(f"\nLeaving out domain: {leave_out_domain}\n")

    source_domain = leave_out_domain
    target_domains = [d for d in domains if d != source_domain]

    # Initialize models
    encoder = Encoder(latent_dim).to(device)
    decoders = {domain: Decoder(latent_dim).to(device) for domain in target_domains}
    classifier = Classifier(latent_dim, num_classes).to(device)

    # Optimizer
    optimizer = optim.Adam(
        list(encoder.parameters())
        + list(classifier.parameters())
        + [param for decoder in decoders.values() for param in decoder.parameters()],
        lr=1e-4,
    )

    # Dataloader for source domain
    dataloader = get_dataloader(DATA_PATH, source_domain)

    # Train model
    train_model(
        encoder,
        decoders,
        classifier,
        source_domain,
        target_domains,
        dataloader,
        optimizer,
        num_epochs=10,
        device=device,
    )

    # Evaluate on target domains
    for eval_domain in target_domains:
        eval_dataloader = get_dataloader(DATA_PATH, eval_domain)
        accuracy = evaluate_model(encoder, classifier, eval_dataloader, device)
        print(f"Accuracy on target domain '{eval_domain}': {accuracy * 100:.2f}%")

# Optionally, you can save the models after training
# torch.save(encoder.state_dict(), 'encoder.pth')
# torch.save(classifier.state_dict(), 'classifier.pth')
# for domain, decoder in decoders.items():
#     torch.save(decoder.state_dict(), f'decoder_{domain}.pth')

Using device: cuda

Leaving out domain: art_painting

Epoch 1/10


Training on art_painting: 100%|██████████| 64/64 [00:15<00:00,  4.17it/s]


Epoch 1, Loss: 21511798.4062
Epoch 2/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.42it/s]


Epoch 2, Loss: 18961116.8750
Epoch 3/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.35it/s]


Epoch 3, Loss: 17580366.0156
Epoch 4/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.42it/s]


Epoch 4, Loss: 16753058.9219
Epoch 5/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.44it/s]


Epoch 5, Loss: 16165075.5625
Epoch 6/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.50it/s]


Epoch 6, Loss: 15725671.0156
Epoch 7/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.39it/s]


Epoch 7, Loss: 15431988.7656
Epoch 8/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.38it/s]


Epoch 8, Loss: 15218165.7500
Epoch 9/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.35it/s]


Epoch 9, Loss: 15055520.8906
Epoch 10/10


Training on art_painting: 100%|██████████| 64/64 [00:14<00:00,  4.28it/s]


Epoch 10, Loss: 14883226.8281


Evaluating: 100%|██████████| 74/74 [00:15<00:00,  4.89it/s]


Accuracy on target domain 'cartoon': 18.94%


Evaluating: 100%|██████████| 53/53 [00:07<00:00,  6.69it/s]


Accuracy on target domain 'photo': 23.35%


Evaluating: 100%|██████████| 123/123 [00:16<00:00,  7.26it/s]


Accuracy on target domain 'sketch': 19.06%

Leaving out domain: cartoon

Epoch 1/10


Training on cartoon: 100%|██████████| 74/74 [00:17<00:00,  4.30it/s]


Epoch 1, Loss: 35071793.3784
Epoch 2/10


Training on cartoon: 100%|██████████| 74/74 [00:17<00:00,  4.31it/s]


Epoch 2, Loss: 31609434.5405
Epoch 3/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.36it/s]


Epoch 3, Loss: 29379053.1149
Epoch 4/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.46it/s]


Epoch 4, Loss: 27866682.8986
Epoch 5/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.41it/s]


Epoch 5, Loss: 26862060.9527
Epoch 6/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.41it/s]


Epoch 6, Loss: 26127331.0811
Epoch 7/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.40it/s]


Epoch 7, Loss: 25587596.8986
Epoch 8/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.45it/s]


Epoch 8, Loss: 25190427.0405
Epoch 9/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.40it/s]


Epoch 9, Loss: 24892841.5608
Epoch 10/10


Training on cartoon: 100%|██████████| 74/74 [00:16<00:00,  4.44it/s]


Epoch 10, Loss: 24655867.5135


Evaluating: 100%|██████████| 64/64 [00:08<00:00,  7.79it/s]


Accuracy on target domain 'art_painting': 22.17%


Evaluating: 100%|██████████| 53/53 [00:06<00:00,  8.27it/s]


Accuracy on target domain 'photo': 23.11%


Evaluating: 100%|██████████| 123/123 [00:16<00:00,  7.53it/s]


Accuracy on target domain 'sketch': 17.46%

Leaving out domain: photo

Epoch 1/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.42it/s]


Epoch 1, Loss: 22369355.5094
Epoch 2/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.45it/s]


Epoch 2, Loss: 20292857.7689
Epoch 3/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.44it/s]


Epoch 3, Loss: 18679739.3396
Epoch 4/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.55it/s]


Epoch 4, Loss: 17426323.5236
Epoch 5/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.56it/s]


Epoch 5, Loss: 16583701.4481
Epoch 6/10


Training on photo: 100%|██████████| 53/53 [00:12<00:00,  4.36it/s]


Epoch 6, Loss: 15948028.8019
Epoch 7/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.42it/s]


Epoch 7, Loss: 15508020.0991
Epoch 8/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.46it/s]


Epoch 8, Loss: 15153117.2500
Epoch 9/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.47it/s]


Epoch 9, Loss: 14873650.5401
Epoch 10/10


Training on photo: 100%|██████████| 53/53 [00:11<00:00,  4.53it/s]


Epoch 10, Loss: 14644519.4858


Evaluating: 100%|██████████| 64/64 [00:07<00:00,  8.08it/s]


Accuracy on target domain 'art_painting': 20.07%


Evaluating: 100%|██████████| 74/74 [00:09<00:00,  7.94it/s]


Accuracy on target domain 'cartoon': 15.78%


Evaluating: 100%|██████████| 123/123 [00:16<00:00,  7.47it/s]


Accuracy on target domain 'sketch': 12.42%

Leaving out domain: sketch

Epoch 1/10


Training on sketch: 100%|██████████| 123/123 [00:28<00:00,  4.28it/s]


Epoch 1, Loss: 46912727.8699
Epoch 2/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.19it/s]


Epoch 2, Loss: 40609708.6829
Epoch 3/10


Training on sketch: 100%|██████████| 123/123 [00:28<00:00,  4.25it/s]


Epoch 3, Loss: 36681193.5772
Epoch 4/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.15it/s]


Epoch 4, Loss: 34556304.5041
Epoch 5/10


Training on sketch: 100%|██████████| 123/123 [00:28<00:00,  4.25it/s]


Epoch 5, Loss: 33447207.1707
Epoch 6/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.13it/s]


Epoch 6, Loss: 32818574.1138
Epoch 7/10


Training on sketch: 100%|██████████| 123/123 [00:28<00:00,  4.27it/s]


Epoch 7, Loss: 32428526.1789
Epoch 8/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.18it/s]


Epoch 8, Loss: 32169752.6992
Epoch 9/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.19it/s]


Epoch 9, Loss: 31992522.7967
Epoch 10/10


Training on sketch: 100%|██████████| 123/123 [00:29<00:00,  4.13it/s]


Epoch 10, Loss: 31863278.5691


Evaluating: 100%|██████████| 64/64 [00:08<00:00,  7.67it/s]


Accuracy on target domain 'art_painting': 13.72%


Evaluating: 100%|██████████| 74/74 [00:09<00:00,  8.18it/s]


Accuracy on target domain 'cartoon': 19.11%


Evaluating: 100%|██████████| 53/53 [00:06<00:00,  7.94it/s]

Accuracy on target domain 'photo': 9.70%



