# Betti Matching Loss 训练方式（项目内实现）

在本项目里，训练逻辑在 `train.py` 中：

- 当 `config.LOSS.USE_LOSS == 'BettiMatching'` 时，会创建 `BettiMatchingLoss(relative=..., filtration=...)`。
- 训练循环中，用 `loss, metrics = loss_function(outputs, labels)` 得到损失与指标字典，然后 `loss.backward()` 反向传播。
- `BettiMatchingLoss` 内部会对 `outputs` 做 `sigmoid`，并对每个样本调用 `BettiMatching(...).loss()` 计算拓扑损失。

下面是一个最小示例：构造一批 logits 和标签，计算 Betti matching loss，并完成一次反向传播。

In [1]:
import torch
from loss_functions import BettiMatchingLoss

# 假设输出是 UNet 的 logits（未经过 sigmoid）
batch_size, height, width = 2, 128, 128
logits = torch.randn(batch_size, 1, height, width, requires_grad=True)
labels = (torch.rand(batch_size, 1, height, width) > 0.5).float()

# 对应 train.py 中的用法：loss, metrics = loss_function(outputs, labels)
loss_fn = BettiMatchingLoss(relative=False, filtration="superlevel")
loss, metrics = loss_fn(logits, labels)
loss.backward()

print("loss:", float(loss))
print("metrics:", {k: float(v) for k, v in metrics.items()})

  pkg = __import__(module)  # top level module


loss: 1107.451416015625
metrics: {'dice': 0.49937933683395386, 'Betti matching': 1107.451416015625}


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Basic MNIST training script (PyTorch)

def build_loaders(batch_size=128, data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_ds = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
    test_ds = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader


class SimpleMNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.net(x)


def train_one_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    return correct / total


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = build_loaders(batch_size=128)
model = SimpleMNISTCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    test_acc = evaluate(model, test_loader, device)
    print(f"epoch {epoch} | train loss {train_loss:.4f} | test acc {test_acc:.4f}")