In [32]:
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.net import build_backbone, conv1x1, Classifier


class MultiLabelMEDAF(nn.Module):
    """
    Multi-Label version of MEDAF (Multi-Expert Diverse Attention Fusion)

    Key changes from original MEDAF:
    1. Support for multi-hot label targets
    2. BCEWithLogitsLoss instead of CrossEntropyLoss
    3. Multi-label attention diversity computation
    4. Per-sample CAM extraction for multiple positive classes
    """

    def __init__(self, args=None):
        super(MultiLabelMEDAF, self).__init__()
        backbone, feature_dim, self.cam_size = build_backbone(
            img_size=args["img_size"],
            backbone_name=args["backbone"],
            projection_dim=-1,
            inchan=3,
        )
        self.img_size = args["img_size"]
        self.gate_temp = args["gate_temp"]
        self.num_classes = args[
            "num_classes"
        ]  # Changed from num_known to num_classes for clarity
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # Shared layers (L1-L3)
        self.shared_l3 = nn.Sequential(*list(backbone.children())[:-6])

        # Expert branch 1
        self.branch1_l4 = nn.Sequential(*list(backbone.children())[-6:-3])
        self.branch1_l5 = nn.Sequential(*list(backbone.children())[-3])
        self.branch1_cls = conv1x1(feature_dim, self.num_classes)

        # Expert branch 2 (deep copy)
        self.branch2_l4 = copy.deepcopy(self.branch1_l4)
        self.branch2_l5 = copy.deepcopy(self.branch1_l5)
        self.branch2_cls = conv1x1(feature_dim, self.num_classes)

        # Expert branch 3 (deep copy)
        self.branch3_l4 = copy.deepcopy(self.branch1_l4)
        self.branch3_l5 = copy.deepcopy(self.branch1_l5)
        self.branch3_cls = conv1x1(feature_dim, self.num_classes)

        # Gating network
        self.gate_l3 = copy.deepcopy(self.shared_l3)
        self.gate_l4 = copy.deepcopy(self.branch1_l4)
        self.gate_l5 = copy.deepcopy(self.branch1_l5)
        self.gate_cls = nn.Sequential(
            Classifier(feature_dim, int(feature_dim / 4), bias=True),
            Classifier(int(feature_dim / 4), 3, bias=True),  # 3 experts
        )

    def forward(self, x, y=None, return_ft=False):
        """
        Forward pass for multi-label MEDAF

        Args:
            x: Input tensor [B, C, H, W]
            y: Multi-hot labels [B, num_classes] or None
            return_ft: Whether to return features

        Returns:
            Dictionary containing logits, gate predictions, and CAMs/features
        """
        b = x.size(0)
        ft_till_l3 = self.shared_l3(x)

        # Expert branch 1
        branch1_l4 = self.branch1_l4(ft_till_l3.clone())
        branch1_l5 = self.branch1_l5(branch1_l4)
        b1_ft_cams = self.branch1_cls(branch1_l5)  # [B, num_classes, H, W]
        b1_logits = self.avg_pool(b1_ft_cams).view(b, -1)

        # Expert branch 2
        branch2_l4 = self.branch2_l4(ft_till_l3.clone())
        branch2_l5 = self.branch2_l5(branch2_l4)
        b2_ft_cams = self.branch2_cls(branch2_l5)  # [B, num_classes, H, W]
        b2_logits = self.avg_pool(b2_ft_cams).view(b, -1)

        # Expert branch 3
        branch3_l4 = self.branch3_l4(ft_till_l3.clone())
        branch3_l5 = self.branch3_l5(branch3_l4)
        b3_ft_cams = self.branch3_cls(branch3_l5)  # [B, num_classes, H, W]
        b3_logits = self.avg_pool(b3_ft_cams).view(b, -1)

        # Store CAMs for diversity loss computation
        cams_list = [b1_ft_cams, b2_ft_cams, b3_ft_cams]

        # Multi-label CAM extraction for positive classes
        if y is not None:
            # Extract CAMs for all positive classes across all experts
            # This will be used for attention diversity computation
            multi_label_cams = self._extract_multilabel_cams(cams_list, y)
        else:
            multi_label_cams = None

        if return_ft:
            # Aggregate features from all experts
            fts = (
                b1_ft_cams.detach().clone()
                + b2_ft_cams.detach().clone()
                + b3_ft_cams.detach().clone()
            )

        # Gating network
        gate_l5 = self.gate_l5(self.gate_l4(self.gate_l3(x)))
        gate_pool = self.avg_pool(gate_l5).view(b, -1)
        gate_pred = F.softmax(self.gate_cls(gate_pool) / self.gate_temp, dim=1)

        # Adaptive fusion using gating weights
        gate_logits = torch.stack(
            [b1_logits.detach(), b2_logits.detach(), b3_logits.detach()], dim=-1
        )
        gate_logits = gate_logits * gate_pred.view(
            gate_pred.size(0), 1, gate_pred.size(1)
        )
        gate_logits = gate_logits.sum(-1)

        logits_list = [b1_logits, b2_logits, b3_logits, gate_logits]

        if return_ft and y is None:
            outputs = {
                "logits": logits_list,
                "gate_pred": gate_pred,
                "fts": fts,
                "cams_list": cams_list,
            }
        else:
            outputs = {
                "logits": logits_list,
                "gate_pred": gate_pred,
                "multi_label_cams": multi_label_cams,
                "cams_list": cams_list,
            }

        return outputs

    def _extract_multilabel_cams(self, cams_list, targets):
        """
        Extract CAMs for all positive classes in multi-label setting

        Args:
            cams_list: List of CAMs from 3 experts [B, num_classes, H, W]
            targets: Multi-hot labels [B, num_classes]

        Returns:
            extracted_cams: List of CAMs for positive classes per expert
        """
        batch_size = targets.size(0)
        extracted_cams = []

        for expert_idx, expert_cams in enumerate(cams_list):
            expert_extracted = []

            for batch_idx in range(batch_size):
                # Find positive class indices for this sample
                positive_classes = torch.where(targets[batch_idx] == 1)[0]

                if len(positive_classes) > 0:
                    # Extract CAMs for positive classes
                    sample_cams = expert_cams[
                        batch_idx, positive_classes
                    ]  # [num_positive, H, W]
                    expert_extracted.append(sample_cams)
                else:
                    # If no positive classes, create zero tensor
                    H, W = expert_cams.shape[-2:]
                    expert_extracted.append(
                        torch.zeros(1, H, W, device=expert_cams.device)
                    )

            extracted_cams.append(expert_extracted)

        return extracted_cams

    def get_params(self, prefix="extractor"):
        """Get model parameters for different learning rates"""
        extractor_params = (
            list(self.shared_l3.parameters())
            + list(self.branch1_l4.parameters())
            + list(self.branch1_l5.parameters())
            + list(self.branch2_l4.parameters())
            + list(self.branch2_l5.parameters())
            + list(self.branch3_l4.parameters())
            + list(self.branch3_l5.parameters())
            + list(self.gate_l3.parameters())
            + list(self.gate_l4.parameters())
            + list(self.gate_l5.parameters())
        )
        extractor_params_ids = list(map(id, extractor_params))
        classifier_params = filter(
            lambda p: id(p) not in extractor_params_ids, self.parameters()
        )

        if prefix in ["extractor", "extract"]:
            return extractor_params
        elif prefix in ["classifier"]:
            return classifier_params


In [1]:
import ast
import json
from pathlib import Path
from typing import Optional, Union

import numpy as np
import pandas as pd
import torch

import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
from sklearn.datasets import make_multilabel_classification
from sklearn.preprocessing import StandardScaler

In [24]:
import os
CURRENT_DIR = '/home/s2320437/WORK/aidan-medaf/'
os.chdir(CURRENT_DIR)
print(f"Current working directory: {os.getcwd()}")


KNOWN_LABELS = [
    "Atelectasis",
    "Cardiomegaly",
    "Effusion",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pneumonia",
    "Pneumothorax",
]

DEFAULT_IMAGE_ROOT = Path(f"{CURRENT_DIR}/datasets/data/chestxray/NIH/images-224")
DEFAULT_KNOWN_CSV = Path(f"{CURRENT_DIR}/datasets/data/chestxray/NIH/chestxray_train_known.csv")

DEFAULT_CHECKPOINT_DIR = Path(f"{CURRENT_DIR}/checkpoints/medaf_phase1")



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}
train_loader = None
val_loader = None
test_loader = None
dataset_name = None
class_names = KNOWN_LABELS

print()

Current working directory: /home/s2320437/WORK/aidan-medaf



In [30]:
# Dataset: ChestXrayKnownDataset
class ChestXrayKnownDataset(data.Dataset):
    """Dataset that reads the known-label ChestX-ray14 split for Phase 1 training."""

    def __init__(
        self,
        csv_path: Path,
        image_root: Path,
        img_size: int = 224,
        max_samples: Optional[int] = None,
        transform=None,
    ) -> None:
        self.csv_path = Path(csv_path)
        self.image_root = Path(image_root)
        self.img_size = img_size
        self.class_names = KNOWN_LABELS
        self.num_classes = len(self.class_names)


        if not self.csv_path.exists():
            raise FileNotFoundError(f"ChestX-ray CSV not found: {self.csv_path}")
        if not self.image_root.exists():
            raise FileNotFoundError(
                f"ChestX-ray image directory not found: {self.image_root}"
            )

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                    ),
                ]
            )
        else:
            self.transform = transform

        df = pd.read_csv(self.csv_path)
        if "known_labels" not in df.columns:
            raise ValueError(
                "Expected 'known_labels' column in CSV. Run utils/create_chestxray_splits.py first."
            )

        if max_samples is not None and max_samples < len(df):
            df = df.sample(n=max_samples, random_state=42).reset_index(drop=True)
        else:
            df = df.reset_index(drop=True)

        self.records = df.to_dict("records")
        self.label_to_idx = {label: idx for idx, label in enumerate(self.class_names)}
        print(f"Label to index: {self.label_to_idx}")

    @staticmethod
    def _parse_label_list(raw_value):
        if isinstance(raw_value, list):
            return raw_value
        if pd.isna(raw_value):
            return []
        if isinstance(raw_value, str):
            raw_value = raw_value.strip()
            if not raw_value:
                return []
            try:
                parsed = ast.literal_eval(raw_value)
                if isinstance(parsed, (list, tuple)):
                    return list(parsed)
                if isinstance(parsed, str):
                    return [parsed]
            except (ValueError, SyntaxError):
                pass
            return [item.strip() for item in raw_value.split("|") if item.strip()]
        return []

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        image_path = self.image_root / record["Image Index"]
        if not image_path.exists():
            raise FileNotFoundError(f"Missing image: {image_path}")

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        labels = torch.zeros(self.num_classes, dtype=torch.float32)
        for label in self._parse_label_list(record.get("known_labels", [])):
            if label in self.label_to_idx:
                labels[self.label_to_idx[label]] = 1.0

        return image, labels

    # class_name
    @staticmethod
    def class_name():
        return KNOWN_LABELS


In [28]:
 # Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Demo configuration
config = {
    "data_source": "chestxray",
    "known_csv": str(DEFAULT_KNOWN_CSV),
    "image_root": str(DEFAULT_IMAGE_ROOT),
    "batch_size": 32,
    "num_epochs": 5,
    "learning_rate": 1e-4,
    "val_ratio": 0.1,
    "num_workers": 2,
    # "max_samples": None,  # Set to an int for quicker experiments
    "max_samples": 1000,
    "phase1_checkpoint": "medaf_phase1_chestxray.pt",
    "checkpoint_dir": str(DEFAULT_CHECKPOINT_DIR),
    "run_phase2": False,
}

In [31]:
data_source = config.get("data_source", "chestxray").lower()

if data_source == "chestxray":
    csv_path = Path(config.get("known_csv", DEFAULT_KNOWN_CSV))
    image_root = Path(config.get("image_root", DEFAULT_IMAGE_ROOT))
    max_samples = config.get("max_samples")
    if isinstance(max_samples, str):
        max_samples = int(max_samples)
    print(f"Loading ChestX-ray14 known-label split from {csv_path}")
    dataset = ChestXrayKnownDataset(
        csv_path=csv_path,
        image_root=image_root,
        img_size=config.get("img_size", 224),
        max_samples=max_samples,
    )
    dataset_name = "ChestX-ray14 (known labels)"
    class_names = dataset.class_names
    config["num_classes"] = dataset.num_classes
    config["img_size"] = dataset.img_size

val_ratio = float(config.get("val_ratio", 0.1))
val_size = max(1, int(len(dataset) * val_ratio))
train_size = len(dataset) - val_size

train_dataset, val_dataset = data.random_split(
    dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42),
)

batch_size = config.get("batch_size", 16)
num_workers = config.get("num_workers", 4)
pin_memory = torch.cuda.is_available()

train_loader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

val_loader = data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

test_loader = val_loader

print(
    f"Dataset prepared: {train_size} train / {val_size} val samples ({dataset_name})"
)


Loading ChestX-ray14 known-label split from /home/s2320437/WORK/aidan-medaf/datasets/data/chestxray/NIH/chestxray_train_known.csv
Label to index: {'Atelectasis': 0, 'Cardiomegaly': 1, 'Effusion': 2, 'Infiltration': 3, 'Mass': 4, 'Nodule': 5, 'Pneumonia': 6, 'Pneumothorax': 7}
Dataset prepared: 900 train / 100 val samples (ChestX-ray14 (known labels))


In [25]:
def save_model(model, args, loss_history):
    ckpt_dir = Path(config["checkpoint_dir"])
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = ckpt_dir / config["phase1_checkpoint"]

    payload = {
        "state_dict": model.state_dict(),
        "args": args,
        "class_names": class_names,
        "dataset": dataset_name,
    }
    torch.save(payload, checkpoint_path)

    metadata = {
        "dataset": dataset_name,
        "class_names": class_names,
        "num_epochs": config["num_epochs"],
        "batch_size": config.get("batch_size"),
        "learning_rate": config.get("learning_rate"),
        "loss_history": [float(loss) for loss in loss_history],
        "device": str(device),
        "checkpoint": str(checkpoint_path),
        "config": {
            k: v
            for k, v in config.items()
            if isinstance(v, (int, float, str, bool))
        },
    }
    metadata_path = checkpoint_path.with_suffix(".json")
    with metadata_path.open("w", encoding="utf-8") as fp:
        json.dump(metadata, fp, indent=2)

    return checkpoint_path

In [33]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from misc.util import *


def multiLabelAttnDiv(cams_list, targets, eps=1e-6):
    """
    Multi-label attention diversity loss

    Encourages different experts to focus on different spatial regions
    for all positive classes in multi-label setting.

    Args:
        cams_list: List of CAMs from 3 experts [B, num_classes, H, W]
        targets: Multi-hot labels [B, num_classes]
        eps: Small value for numerical stability

    Returns:
        diversity_loss: Scalar tensor representing attention diversity loss
    """
    if targets is None or targets.sum() == 0:
        return torch.tensor(0.0, device=cams_list[0].device)

    cos = nn.CosineSimilarity(dim=1, eps=eps)
    diversity_loss = 0.0
    total_pairs = 0
    batch_size = targets.size(0)

    for batch_idx in range(batch_size):
        # Get positive class indices for this sample
        positive_classes = torch.where(targets[batch_idx] == 1)[0]

        if len(positive_classes) == 0:
            continue

        # Process each positive class
        for class_idx in positive_classes:
            # Extract CAMs for this class from all experts
            expert_cams = torch.stack(
                [
                    cams_list[0][batch_idx, class_idx],  # Expert 1: [H, W]
                    cams_list[1][batch_idx, class_idx],  # Expert 2: [H, W]
                    cams_list[2][batch_idx, class_idx],  # Expert 3: [H, W]
                ]
            )  # [3, H, W]

            # Flatten spatial dimensions and normalize
            expert_cams = expert_cams.view(3, -1)  # [3, H*W]
            expert_cams = F.normalize(expert_cams, p=2, dim=-1)

            # Remove mean activation to focus on relative attention patterns
            mean = expert_cams.mean(dim=-1, keepdim=True)  # [3, 1]
            expert_cams = F.relu(expert_cams - mean)

            # Compute pairwise cosine similarity (encourage orthogonality)
            for i in range(3):
                for j in range(i + 1, 3):
                    similarity = cos(
                        expert_cams[i : i + 1], expert_cams[j : j + 1]
                    ).mean()
                    diversity_loss += similarity
                    total_pairs += 1

    # Average over all pairs
    if total_pairs > 0:
        return diversity_loss / total_pairs
    else:
        return torch.tensor(0.0, device=cams_list[0].device)



In [34]:

def multiLabelAccuracy(predictions, targets, threshold=0.5):
    """
    Compute multi-label accuracy metrics

    Args:
        predictions: Model predictions [B, num_classes]
        targets: Multi-hot ground truth [B, num_classes]
        threshold: Threshold for binary predictions

    Returns:
        subset_acc: Exact match accuracy (all labels correct)
        hamming_acc: Label-wise accuracy
        precision: Precision score
        recall: Recall score
        f1: F1 score
    """
    with torch.no_grad():
        # Convert logits to probabilities and then to binary predictions
        probs = torch.sigmoid(predictions)
        pred_binary = (probs > threshold).float()

        # Subset accuracy (exact match)
        subset_acc = (pred_binary == targets).all(dim=1).float().mean()

        # Hamming accuracy (label-wise accuracy)
        hamming_acc = (pred_binary == targets).float().mean()

        # Precision, Recall, F1
        tp = (pred_binary * targets).sum(dim=0)
        fp = (pred_binary * (1 - targets)).sum(dim=0)
        fn = ((1 - pred_binary) * targets).sum(dim=0)

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

        # Average across classes
        precision = precision.mean()
        recall = recall.mean()
        f1 = f1.mean()

    return subset_acc, hamming_acc, precision, recall, f1



In [35]:

def train_multilabel(train_loader, model, criterion, optimizer, args, device=None):
    """
    Training loop for multi-label MEDAF

    Args:
        train_loader: DataLoader with multi-label data
        model: MultiLabelMEDAF model
        criterion: Dictionary containing loss functions
        optimizer: Optimizer
        args: Training arguments
        device: Device to run on

    Returns:
        Average training loss
    """
    model.train()

    loss_keys = args["loss_keys"]  # ["b1", "b2", "b3", "gate", "divAttn", "total"]
    acc_keys = args["acc_keys"]  # ["acc1", "acc2", "acc3", "accGate"]

    loss_meter = {p: AverageMeter() for p in loss_keys}
    acc_meter = {p: AverageMeter() for p in acc_keys}
    time_start = time.time()

    for i, data in enumerate(train_loader):
        inputs = data[0].to(device)
        targets = data[1].to(device)  # Multi-hot labels [B, num_classes]

        # Forward pass
        output_dict = model(inputs, targets)
        logits = output_dict["logits"]  # List of logits from 4 heads
        cams_list = output_dict["cams_list"]  # CAMs from 3 experts

        # Multi-label classification losses for expert branches
        bce_losses = [
            criterion["bce"](logit.float(), targets.float())
            for logit in logits[:3]  # Expert branches only
        ]

        # Gating loss (on fused predictions)
        gate_loss = criterion["bce"](logits[3].float(), targets.float())

        # Multi-label attention diversity loss
        diversity_loss = multiLabelAttnDiv(cams_list, targets)

        # Combine losses according to weights
        loss_values = bce_losses + [gate_loss, diversity_loss]
        total_loss = (
            args["loss_wgts"][0] * sum(bce_losses)  # Expert loss weight
            + args["loss_wgts"][1] * gate_loss  # Gating loss weight
            + args["loss_wgts"][2] * diversity_loss  # Diversity loss weight
        )
        loss_values.append(total_loss)

        # Compute multi-label accuracies
        acc_values = []
        for logit in logits:
            subset_acc, hamming_acc, _, _, _ = multiLabelAccuracy(logit, targets)
            acc_values.append(subset_acc * 100)  # Convert to percentage

        # Update meters
        multi_loss = {loss_keys[k]: loss_values[k] for k in range(len(loss_keys))}
        train_accs = {acc_keys[k]: acc_values[k] for k in range(len(acc_keys))}

        update_meter(loss_meter, multi_loss, inputs.size(0))
        update_meter(acc_meter, train_accs, inputs.size(0))

        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Print progress
        if i % 50 == 0:  # Print every 50 batches
            tmp_str = f"Batch [{i}/{len(train_loader)}] "
            tmp_str += "< Training Loss >\n"
            for k, v in loss_meter.items():
                tmp_str += f"{k}:{v.value:.4f} "
            tmp_str += "\n< Training Accuracy >\n"
            for k, v in acc_meter.items():
                tmp_str += f"{k}:{v.value:.1f} "
            print(tmp_str)

    time_elapsed = time.time() - time_start
    print(f"\nEpoch completed in {time_elapsed:.1f}s")

    # Final epoch summary
    tmp_str = "< Final Training Loss >\n"
    for k, v in loss_meter.items():
        tmp_str += f"{k}:{v.value:.4f} "
    tmp_str += "\n< Final Training Accuracy >\n"
    for k, v in acc_meter.items():
        tmp_str += f"{k}:{v.value:.1f} "
    print(tmp_str)

    return loss_meter[loss_keys[-1]].value


In [36]:
"""Demonstrate Phase 1: Basic Multi-Label MEDAF"""
print("\n" + "=" * 60)
print("PHASE 1: Basic Multi-Label MEDAF")
print("=" * 60)

# Configuration for Phase 1
args = {
    "img_size": config["img_size"],
    "backbone": "resnet18",
    "num_classes": config["num_classes"],
    "gate_temp": 100,
    "loss_keys": ["b1", "b2", "b3", "gate", "divAttn", "total"],
    "acc_keys": ["acc1", "acc2", "acc3", "accGate"],
    "loss_wgts": [0.7, 1.0, 0.01],
}

# Create Phase 1 model
model = MultiLabelMEDAF(args)
model.to(device)

print(
    f"Phase 1 Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
)

# Training setup
criterion = {"bce": nn.BCEWithLogitsLoss()}
optimizer = torch.optim.Adam(
    model.parameters(), lr=config["learning_rate"]
)

# Training
phase1_metrics = []
for epoch in range(config["num_epochs"]):
    metrics = train_multilabel(
        train_loader, model, criterion, optimizer, args, device
    )
    phase1_metrics.append(metrics)

    if epoch % 2 == 0:
        print(f"Epoch {epoch}: Loss={metrics:.4f}")

final_loss = phase1_metrics[-1] if phase1_metrics else float("nan")
checkpoint_path = save_model(model, args, phase1_metrics)

results["phase1"] = {
    "model": model,
    "final_loss": final_loss,
    "metrics_history": phase1_metrics,
    "checkpoint": str(checkpoint_path),
}

if phase1_metrics:
    print(f"Phase 1 Final Loss: {final_loss:.4f}")
else:
    print("Phase 1 completed with zero epochs (no training performed)")
print(f"Phase 1 checkpoint saved to: {checkpoint_path}")
print(
    "Use load_phase1_checkpoint(CheckpointPath) to reload this model for evaluation."
)


PHASE 1: Basic Multi-Label MEDAF
Making resnet layer with channel 64 block 2 stride 1
Making resnet layer with channel 128 block 2 stride 2
Making resnet layer with channel 256 block 2 stride 2
Making resnet layer with channel 512 block 2 stride 2
Phase 1 Model Parameters: 44,749,955
Batch [0/29] < Training Loss >
b1:0.8801 b2:0.7273 b3:0.6110 gate:0.7187 divAttn:0.3237 total:2.2748 
< Training Accuracy >
acc1:0.0 acc2:0.0 acc3:0.0 accGate:0.0 

Epoch completed in 37.8s
< Final Training Loss >
b1:0.3716 b2:0.3237 b3:0.2963 gate:0.3232 divAttn:0.4948 total:1.0223 
< Final Training Accuracy >
acc1:51.7 acc2:52.6 acc3:55.6 accGate:54.9 
Epoch 0: Loss=1.0223
Batch [0/29] < Training Loss >
b1:0.2130 b2:0.1943 b3:0.1901 gate:0.1968 divAttn:0.3644 total:0.6186 
< Training Accuracy >
acc1:68.8 acc2:68.8 acc3:68.8 accGate:68.8 

Epoch completed in 33.6s
< Final Training Loss >
b1:0.2495 b2:0.2437 b3:0.2428 gate:0.2440 divAttn:0.2883 total:0.7621 
< Final Training Accuracy >
acc1:58.3 acc2:58.6