In [10]:
import os

import albumentations as A
import rich.progress as rprogress
import torch
import torchmetrics.functional
from rich.traceback import install

from types import SimpleNamespace

from src.callbacks import ModelCheckpoint
from src.datasets import MultiDomainDataset, DomainRole, PreprocessingPipeline
from src.models import MLDG, BaseLearner, Encoder, Classifier, compute_metrics, ERM
from rich import print


install(show_locals=False)

args = {
    "num_classes": 7,
    "batch_size": 8,
    "nonlinear_classifier": False,
    "dropout": 0.,
    "lr": 1e-5,
    "weight_decay": 0.,
    "beta": 1.0,
}
args = SimpleNamespace(**args)


# Backend: albumentations
transform = A.Compose([
    A.Resize(height=224, width=224),
    A.Normalize(
        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
    ),
    A.ToTensorV2(),
])

augment_transform = A.Compose([
    A.RandomResizedCrop(size=(224, 224), scale=(0.7, 1.0), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3, p=0.5),
    A.ToGray(num_output_channels=3, p=0.10),
    A.Normalize(
        mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
    ),
    A.ToTensorV2(),
])


def wrapper_albumentations_transform(transform: A.Compose):
    def transform_(img):
        img_ = img.permute(1, 2, 0).numpy()
        try:
            return transform(image=img_).get("image")
        except Exception as error:
            print(f"Failed to apply transform at image {img_.shape}, {img_.dtype}")
            raise error

    return transform_

pipeline = PreprocessingPipeline(
    source_transform=wrapper_albumentations_transform(augment_transform),
    target_transform=wrapper_albumentations_transform(transform),
)

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device selected: {device}")

In [3]:
domains = ["art_painting", "cartoon", "photo", "sketch"]
roles = [DomainRole.TARGET, DomainRole.TARGET, DomainRole.TARGET, DomainRole.TARGET]


In [11]:
def load_model(checkpoint, args):

    if args.baseline:
        model = ERM(
        network=BaseLearner(
            encoder=Encoder(hparams={}),
            classifier=Classifier(hparams={"num_feats": 2048, "num_classes": args.num_classes, "dropout": args.dropout})
        ),
        device=device,
        num_classes=args.num_classes,
        num_domains=len([role for role in roles if role == DomainRole.TARGET]),
        hparams={
            "lr": args.lr,
            "weight_decay": args.weight_decay,
        }
    )

    else:
        model = MLDG(
        network=BaseLearner(
            encoder=Encoder(hparams={}),
            classifier=Classifier(hparams={"num_feats": 2048, "num_classes": args.num_classes, "dropout": args.dropout})
        ),
        device=device,
        num_classes=args.num_classes,
        num_domains=len([role for role in roles if role == DomainRole.TARGET]),
        hparams={
            "num_meta_test": 1,
            "lr": args.lr,
            "weight_decay": args.weight_decay,
            "beta": args.beta,
            "lr_clone": args.lr,  # TODO: for now, using same value for both
            "weight_decay_clone": args.weight_decay,
        }
    )

    # Test
    model.load_state_dict(
        state_dict=torch.load(
            checkpoint
        )
    )
    model.eval()
    return model


In [5]:
dataset = MultiDomainDataset(
    root="../data/PACS/",
    domains=domains,
    roles=roles,
    transforms=[pipeline] * len(domains),
    seed=0,
    split_ratio=0.80,
)


progress_bar = rprogress.Progress(
    rprogress.SpinnerColumn(),
    *rprogress.Progress.get_default_columns(),
    rprogress.TimeElapsedColumn(),
    rprogress.TextColumn("{task.fields[metrics]}", justify="right"),
)

loaders = dataset.build_target_dataloaders(
    use_splits=False, batch_size=args.batch_size, num_workers=8,
    shuffle=False, persistent_workers=True
)

In [19]:
RESULTS = {
    "CHECKPOINTS": [
        "../dump/dutiful-leaf-69_val.ckpt", # photo
        "../dump/polar-mountain-68_val.ckpt",  # art
        "../dump/hearty-microwave-67_val.ckpt",  # cartoon
        "../dump/vague-armadillo-66_val.ckpt",  # sketch
    ],
    "TEST_ACC": [],
    "TEST_MACRO": [],
    "TEST_ACC_PER_DOMAIN": [],
    "TEST_ACC_PER_CLASS": [],
}

In [20]:
for checkpoint in RESULTS["CHECKPOINTS"]:
    if not os.path.exists(checkpoint):
        print(f"Failed to locate checkpoint {checkpoint}")

In [21]:
args

namespace(num_classes=7,
          batch_size=8,
          nonlinear_classifier=False,
          dropout=0.0,
          lr=1e-05,
          weight_decay=0.0,
          beta=1.0)

In [22]:
import torch.nn.functional as F

def compute_metrics_no_agg(model, batch):
    acc = torch.zeros(model.num_classes).to(model.device)
    count = 0
    assert len(batch) == 1
    for (_, x, y) in batch:
        x, y = x.to(model.device), y.to(model.device)
        logits = model.network(x)
        probs = F.softmax(logits, dim=1)
        labels = torch.argmax(probs, dim=1)
        acc_ = torchmetrics.functional.accuracy(labels, y, task="multiclass", num_classes=model.num_classes, average="none")
        acc += acc_
        count += 1
    acc /= count
    return acc


def compute_metrics_macro(model, batch):
    acc = 0
    count = 0
    assert len(batch) == 1
    for (_, x, y) in batch:
        x, y = x.to(model.device), y.to(model.device)
        logits = model.network(x)
        probs = F.softmax(logits, dim=1)
        labels = torch.argmax(probs, dim=1)
        acc_ = torchmetrics.functional.accuracy(labels, y, task="multiclass", num_classes=model.num_classes, average="macro")
        acc += acc_
        count += 1
    acc /= count
    return acc

In [None]:
with progress_bar:
    for i, CHECKPOINT in enumerate(RESULTS["CHECKPOINTS"]):
        model = load_model(CHECKPOINT, args)


        tasks = [
            progress_bar.add_task(
                description=f"Testing {dataset.domains[i]}",
                total=len(loader),
                metrics=f"acc: --",
            )
            for i, loader in enumerate(loaders)
        ]


        ACC = []
        ACC_MACRO = []
        ACC_PER_DOMAIN = []
        ACC_PER_CLASS = []
        for task, loader in zip(tasks, loaders):
            acc = 0
            acc_macro = 0
            acc_per_class = torch.zeros(model.num_classes)
            count = 0
            for batch in loader:
                acc_ = model.log_metrics((batch,))
                acc_macro_ = compute_metrics_macro(model, (batch, ))
                acc_per_class_ = compute_metrics_no_agg(model, (batch, ))

                # print(f"acc_ {acc_.shape}")
                # print(f"acc_per_class_ {acc_per_class_.shape}")

                count += 1
                acc += acc_
                acc_macro += acc_macro_
                acc_per_class += acc_per_class_.cpu().numpy()

                progress_bar.update(
                    task,
                    advance=1,
                    metrics=f"acc: {acc_:.2%}",
                )

            acc /= count
            acc_macro /= count
            acc_per_class /= count

            progress_bar.update(
                task,
                metrics=f"acc: {acc:.2%}",
            )
            ACC.append(acc)
            ACC_MACRO.append(acc_macro)
            ACC_PER_CLASS.append(acc_per_class)

        RESULTS["TEST_ACC"].append(ACC[i])
        RESULTS["TEST_MACRO"].append(ACC_MACRO[i])
        RESULTS["TEST_ACC_PER_DOMAIN"].append(ACC)
        RESULTS["TEST_ACC_PER_CLASS"].append(ACC_PER_CLASS)

In [24]:
dataset.domains

['art_painting', 'cartoon', 'photo', 'sketch']

In [14]:
acc_per_class

tensor([0.1156, 0.1105, 0.1646, 0.1458, 0.1044, 0.0122, 0.0077])

In [16]:
print(RESULTS)

In [19]:
dlist = dataset.domains

for domain, acc in zip(dlist, RESULTS["TEST_ACC"]):
    print(f"[bold green]{domain}[/bold green]: {acc:.2%}")
print(f"[bold green]Average[/bold green]: {torch.mean(torch.tensor(RESULTS["TEST_ACC"])).item():.2%}")

for domain, acc in zip(dlist, RESULTS["TEST_MACRO"]):
    print(f"[bold green]{domain}[/bold green]: {acc:.2%}")

for domain, acc in zip(dlist, RESULTS["TEST_ACC_PER_DOMAIN"]):
    print(f"[bold green]{domain}[/bold green]: {[f'{a:.2%}' for a in acc]}")

cls_to_idx = dataset.datasets[0].class_to_idx.copy()
idx_to_cls = {v: k for k, v in cls_to_idx.items()}
names = list(map(idx_to_cls.get, range(7)))
print(names)
for j, acc in enumerate(RESULTS["TEST_ACC_PER_CLASS"]):
    vals = {k: f"{v.item():.2%}" for i, (k, v) in enumerate(zip(names, acc[j]))}
    print(f"[bold green]{domains[j]}[/bold green]: {vals}")


In [12]:
import torch
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall
)

# Parameters
num_classes = 7
domains = dataset.domains
print(f"Domains: {domains}")

# Helper
domain_to_idx = {domain.lower(): i for i, domain in enumerate(domains)}

# MLDG MODEL
# RESULTS = {
#     "CHECKPOINTS": [
#         "../dump/wobbly-disco-103_val.ckpt",  # sketch
#         "../dump/lively-violet-104_val.ckpt",  # photo
#         "../dump/breezy-pond-105_val.ckpt",  # cartoon
#         "../dump/deep-wave-107_val.ckpt",  # art
#     ],
#     "TEST_DOMAIN": [  # to make sure we dont mix up things
#         "Sketch",
#         "Photo",
#         "Cartoon",
#         "Art_Painting"
#     ],
#     "METRIC_SUMMARY": [],
# }


# BASELINE
RESULTS = {
    "CHECKPOINTS": [
        "../dump/devout-rain-112_val.ckpt",  # sketch
        "../dump/wobbly-shape-113_val.ckpt",  # photo
        "../dump/clean-dust-114_val.ckpt",  # cartoon
        "../dump/vital-bird-115_val.ckpt",  # art
    ],
    "TEST_DOMAIN": [  # to make sure we dont mix up things
        "Sketch",
        "Photo",
        "Cartoon",
        "Art_Painting"
    ],
    "METRIC_SUMMARY": [],
}

args.baseline = True

args.nonlinear_classifier = True

print(f"Checkpoints to load: {
    [(check, test) for check, test in zip(RESULTS['CHECKPOINTS'], RESULTS['TEST_DOMAIN'])]
}")
print(args)

In [13]:

with progress_bar:
    for i, (CHECKPOINT, DOMAIN) in enumerate(zip(RESULTS["CHECKPOINTS"], RESULTS["TEST_DOMAIN"])):
        model = load_model(CHECKPOINT, args)
        loader = loaders[domain_to_idx[RESULTS["TEST_DOMAIN"][i].lower()]]
        task = progress_bar.add_task(
            description=f"Testing {dataset.domains[i]}",
            total=len(loader),
            metrics="--"
        )


        # Define metric objects (stateful across batches)
        metric_acc = MulticlassAccuracy(num_classes=num_classes, average='micro')
        metric_f1_macro = MulticlassF1Score(num_classes=num_classes, average='macro')
        metric_f1_weighted = MulticlassF1Score(num_classes=num_classes, average='weighted')
        metric_precision = MulticlassPrecision(num_classes=num_classes, average=None)
        metric_recall = MulticlassRecall(num_classes=num_classes, average=None)
        metric_f1_per_class = MulticlassF1Score(num_classes=num_classes, average=None)

        metric_acc = metric_acc.to(device)
        metric_f1_macro = metric_f1_macro.to(device)
        metric_f1_weighted = metric_f1_weighted.to(device)
        metric_precision = metric_precision.to(device)
        metric_recall = metric_recall.to(device)
        metric_f1_per_class = metric_f1_per_class.to(device)


        for batch in loader:
            with torch.inference_mode():
                # Compute stuff
                (_, x, y_true) = batch
                x, y_true = x.to(model.device), y_true.to(model.device)
                logits = model.network(x)
                probs = torch.nn.functional.softmax(logits, dim=1)
                y_pred = torch.argmax(probs, dim=1)

                 # Update metrics
                metric_acc.update(y_pred, y_true)
                metric_f1_macro.update(y_pred, y_true)
                metric_f1_weighted.update(y_pred, y_true)
                metric_precision.update(y_pred, y_true)
                metric_recall.update(y_pred, y_true)
                metric_f1_per_class.update(y_pred, y_true)

            progress_bar.update(
                task,
                advance=1,
                metrics=f"acc: {metric_acc.compute():.4f}",
            )

        # Compute final results
        acc = metric_acc.compute()
        f1_macro = metric_f1_macro.compute()
        f1_weighted = metric_f1_weighted.compute()
        precision_per_class = metric_precision.compute()
        recall_per_class = metric_recall.compute()
        f1_per_class = metric_f1_per_class.compute()

        RESULTS["METRIC_SUMMARY"].append({
            "acc": acc,
            "f1_macro": f1_macro,
            "f1_weighted": f1_weighted,
            "precision_per_class": precision_per_class,
            "recall_per_class": recall_per_class,
            "f1_per_class": f1_per_class,
        })


        # Display results
        print(f"=== {DOMAIN.lower()} ===")
        print(f"Micro Accuracy:     {acc:.4f}")
        print(f"Macro F1 Score:     {f1_macro:.4f}")
        print(f"Weighted F1 Score:  {f1_weighted:.4f}\n")

        for i in range(num_classes):
            print(f"Class {i}: "
                  f"Precision={precision_per_class[i]:.4f}  "
                  f"Recall={recall_per_class[i]:.4f} "
                  f"F1={f1_per_class[i]:.4f}")


In [8]:
print(RESULTS)


In [9]:
import json
with open("SummaryResults.json", "w") as f:
    json.dump(RESULTS, f)