In [1]:
import time, numpy as np, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---- load scripted model (no class/arch needed) ----
model = torch.jit.load("unet_tf_parity_scripted.pt", map_location=device)
model.eval().to(device)

# ---- pick your loss (same as train) ----
# If you trained with Focal-Tversky on probs:
class FocalTverskyLoss(torch.nn.Module):
    def __init__(self, alpha=0.95, gamma=3.1, eps=1.0):
        super().__init__(); self.alpha, self.gamma, self.eps = alpha, gamma, eps; self.beta = 1 - alpha
    def forward(self, p, t):
        p = p.view(-1); t = t.view(-1)
        TP = (p*t).sum(); FP = (p*(1-t)).sum(); FN = ((1-p)*t).sum()
        tv = (TP + self.eps)/(TP + self.alpha*FN + self.beta*FP + self.eps)
        return torch.pow(1 - tv, self.gamma)

criterion = FocalTverskyLoss(alpha=0.95, gamma=3.1)

# ---- mask resize helper (your TF-parity step) ----
def resize_masks_to(pred, masks):
    H, W = pred.shape[-2:]
    m = F.interpolate(masks.float(), size=(H, W), mode='bilinear', align_corners=False)
    return torch.ceil(m).clamp_(0, 1)

# ---- evaluation loop (AUC computed once at end) ----
@torch.no_grad()
def evaluate(loader, tag="Test", print_every=10):
    start = time.time()
    model.eval()
    total = len(loader.dataset)
    seen = 0
    running_loss = 0.0
    tp = fp = fn = 0.0
    preds_all, targs_all = [], []

    for b, (xb, yb) in enumerate(loader, start=1):
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)                 # sigmoid probs
        yb_r = resize_masks_to(out, yb)
        loss = criterion(out, yb_r)

        bs = xb.size(0)
        seen += bs
        running_loss += float(loss.item()) * bs

        p = out.detach().cpu().view(-1)
        t = yb_r.detach().cpu().view(-1)
        preds_all.append(p.numpy())
        targs_all.append(t.numpy())

        # stream PRF1
        p_bin = (p >= 0.5).float()
        tp += float((p_bin * t).sum())
        fp += float((p_bin * (1 - t)).sum())
        fn += float(((1 - p_bin) * t).sum())

        if (b % print_every == 0) or (seen == total):
            prec = tp / (tp + fp + 1e-8)
            rec  = tp / (tp + fn + 1e-8)
            f1   = 2 * prec * rec / (prec + rec + 1e-8)
            avg_loss = running_loss / seen
            print(f"\r[{tag}] {seen}/{total} ex | loss={avg_loss:.4f} | F1 {f1:.4f} | P {prec:.4f} | R {rec:.4f} | {time.time()-start:.1f}s",
                  end='', flush=True)

    print()  # newline
    # Final AUC (fixes your “AUC None” — now computed once per split)
    try:
        P = np.concatenate(preds_all, 0)
        T = np.concatenate(targs_all, 0)
        auc = roc_auc_score(T, P)
    except Exception:
        auc = float('nan')

    prec = tp / (tp + fp + 1e-8)
    rec  = tp / (tp + fn + 1e-8)
    f1   = 2 * prec * rec / (prec + rec + 1e-8)
    loss = running_loss / total
    print(f"[{tag}] Loss {loss:.4f} | AUC {auc:.4f} | F1 {f1:.4f} | P {prec:.4f} | R {rec:.4f}")
    return loss, auc, prec, rec, f1


ValueError: The provided filename unet_tf_parity_scripted.pt does not exist

In [2]:
model = torch.load("./model1.pt", weights_only=False)

AttributeError: Can't get attribute 'UNetTFParity' on <module '__main__'>