In [2]:
from sklearn.metrics import roc_auc_score
import torch, numpy as np
from torch.utils.data import DataLoader


In [22]:
import torch, numpy as np
from types import SimpleNamespace

# ------------------------------------------------------------------
# 1.  Stub-out torch.distributed for single-process testing
# ------------------------------------------------------------------
class _FakeDist:
    def get_world_size(self): return 1
    def all_reduce(self, *_, **__): pass
    def all_gather(self, out_list, tensor): out_list[0].copy_(tensor)
    def broadcast_object_list(self, *_, **__): pass
import types, torch.distributed as dist
for fn in ["get_world_size", "all_reduce", "all_gather", "broadcast_object_list"]:
    setattr(dist, fn, getattr(_FakeDist(), fn))

# ------------------------------------------------------------------
# 2.  Minimal dummy model that outputs a 1×6 logit vector
# ------------------------------------------------------------------
class DummyModel(torch.nn.Module):
    def __init__(self, num_tasks=6):
        super().__init__()
        self.num_tasks = num_tasks
    def forward(self, x, spacing=None):
        B = 1          # your code slices to [1, tasks]
        return torch.randn(B, self.num_tasks, device=x.device)

# ------------------------------------------------------------------
# 3.  Build a fake val_loader producing 8 samples
# ------------------------------------------------------------------
def make_sample(target_vec):
    """
    target_vec: list/array of length 6 containing {0,1,-1,2}
    returns (chunks, labels, mask, spacing)
    """
    chunks  = torch.randn(3, 3, 448, 448)          # [D=3, 3, 448, 448]
    labels  = np.array(target_vec, dtype=np.float32)
    mask    = labels != -1                         # False where label == -1
    labels[labels == -1] = 0                       # placeholder, never used when mask False
    spacing = 1.0
    return chunks, labels, mask, spacing

val_samples = [
    make_sample([0, 0, 1, 0, -1, -1]),
    make_sample([0, 0, 0, 1, 0,  -1]),
    make_sample([0, 1, 0, 0, -1, -1]),
    make_sample([0, 0, 1, 0, 0,   1]),
    make_sample([0, -1,-1,-1,-1, -1]),
    make_sample([-1,-1,-1,-1,-1, -1]),    # all missing
    make_sample([1, 0, 1, 0, 1,  0]),     # <- has a spurious "2" for task2
]

val_loader = DataLoader(val_samples, batch_size=1, shuffle=False, collate_fn=lambda b: b[0])

# ------------------------------------------------------------------
# 4.  Minimal trainer holding only the pieces evaluate() needs
# ------------------------------------------------------------------
class MiniTrainer:
    def __init__(self):
        self.args = SimpleNamespace(max_chunks=4, rank=0)
        self.device = torch.device('cpu')
        self.model  = DummyModel().to(self.device)
        self.crit   = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.label_cols = [f"task{i}" for i in range(6)]
    # --- paste your evaluate() here (no DDP ops will explode) ---
    def evaluate(self, val_loader):
        """Evaluate model on validation set using distributed evaluation"""
        self.model.eval()
        local_loss_sum = 0.0
        local_samples = 0
        local_preds = []
        local_targets = []
        local_masks = []
        
        # Process validation samples assigned to this rank
        for i, (chunks, labels, mask, spacing) in enumerate(val_loader):
            chunks = chunks.squeeze(1).to(self.device)
            
            # Apply max_chunks constraint
            if chunks.size(0) > self.args.max_chunks:
                mid_idx = chunks.size(0) // 2
                start_idx = max(0, mid_idx - self.args.max_chunks // 2)
                end_idx = min(chunks.size(0), start_idx + self.args.max_chunks)
                chunks = chunks[start_idx:end_idx]
            
            # Forward pass
            logits = self.model(chunks, spacing)
            
            # Calculate loss
            target = torch.tensor(labels, dtype=torch.float32, device=self.device)
            mask_tensor = torch.tensor(mask, dtype=torch.bool, device=self.device)
            #print(chunks.shape, logits.shape, target.shape, mask_tensor.shape)#torch.Size([max_chunks, 3, 448, 448]) torch.Size([1, 6]) torch.Size([6]) torch.Size([6])
            print("logit" ,logits )#logit tensor([[ 1.0563,  0.4275,  0.4167,  0.1464, -0.3098, -1.9752]],device='cuda:3')
            print("target" ,target)#target tensor([ 0., -1., -1., -1., -1., -1.], device='cuda:3')
            print("mask_tensor" ,mask_tensor)#mask_tensor tensor([ True, False, False, False, False, False], device='cuda:3')
            loss = 0.0
            for j in range(logits.size(1)):
                if mask_tensor[j]:
                    task_loss = self.crit(logits[0, j:j+1], target[j:j+1])
                    # Mori:the weight in loss is probably not properly applied! 
                    # Maybe later I can change logit where mask is False to 0.0. Actually in training and not here. 
                    loss += task_loss
            
            # Normalize loss
            active_tasks = mask_tensor.sum().item()
            if active_tasks > 0:
                loss = loss / active_tasks
                local_loss_sum += loss.item()
            
            # Store predictions, targets, and masks
            local_preds.append(logits.cpu())
            local_targets.append(torch.tensor(labels, dtype=torch.float32))
            local_masks.append(torch.tensor(mask, dtype=torch.bool))
            
            local_samples += 1
            print(local_preds)
            print(local_targets)
            print(local_masks)
        
        # Convert to tensors for gathering
        if local_preds:
            # Use stack for all tensors to ensure consistent [S, T] shape
            local_preds_tensor = torch.stack(local_preds, dim=0)
            local_targets_tensor = torch.stack(local_targets, dim=0)  
            local_masks_tensor = torch.stack(local_masks, dim=0)
        else:
            # Handle edge case where a rank might not get any validation samples
            num_tasks = len(self.label_cols) 
            local_preds_tensor = torch.zeros((0, num_tasks), dtype=torch.float32)
            local_targets_tensor = torch.zeros((0, num_tasks), dtype=torch.float32)
            local_masks_tensor = torch.zeros((0, num_tasks), dtype=torch.bool)

        
        print(local_preds_tensor.shape, local_targets_tensor.shape, local_masks_tensor.shape)
        #torch.Size([7, 1, 6]) torch.Size([7, 6]) torch.Size([7, 6])
        # Move tensors to device for all_gather
        local_preds_tensor = local_preds_tensor.to(self.device)
        local_targets_tensor = local_targets_tensor.to(self.device)
        local_masks_tensor = local_masks_tensor.to(self.device)
        
        # Compute total loss across all ranks
        total_samples = torch.tensor([local_samples], dtype=torch.float32, device=self.device)
        total_loss_sum = torch.tensor([local_loss_sum], dtype=torch.float32, device=self.device)
        print(local_samples, local_loss_sum)#7 4.781846523284912
        
        dist.all_reduce(total_samples, op=dist.ReduceOp.SUM)
        dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM)
        
        avg_loss = total_loss_sum.item() / max(1, total_samples.item())
        total_samples = int(total_samples.item())
        
        # Initialize metrics
        metrics = {
            'samples_evaluated': total_samples,
            'avg_loss': avg_loss,
        }
        
        # Get world size for gathering
        world_size = dist.get_world_size()
        
        # Padding tensors to the same size for all_gather
        # First find the maximum size across all ranks
        local_size = torch.tensor([local_preds_tensor.shape[0]], dtype=torch.long, device=self.device)
        all_sizes = [torch.ones(1, dtype=torch.long, device=self.device) for _ in range(world_size)]
        dist.all_gather(all_sizes, local_size)
        max_size = max([size.item() for size in all_sizes])
        
        
        # Pad local tensors to max_size
        num_tasks = len(self.label_cols)
        if local_preds_tensor.shape[0] < max_size:
            padding_size = max_size - local_preds_tensor.shape[0]
            # Pad with zeros
            preds_padding = torch.zeros((padding_size, num_tasks), dtype=torch.float32, device=self.device)
            targets_padding = torch.zeros((padding_size, num_tasks), dtype=torch.float32, device=self.device)
            masks_padding = torch.zeros((padding_size, num_tasks), dtype=torch.bool, device=self.device)
            
            # Concatenate original tensor with padding
            local_preds_tensor = torch.cat([local_preds_tensor, preds_padding], dim=0)
            local_targets_tensor = torch.cat([local_targets_tensor, targets_padding], dim=0)
            local_masks_tensor = torch.cat([local_masks_tensor, masks_padding], dim=0)
        
        # Use all_gather for efficient collection on GPU
        gathered_preds = [torch.empty_like(local_preds_tensor) for _ in range(world_size)]
        gathered_targets = [torch.empty_like(local_targets_tensor) for _ in range(world_size)]
        gathered_masks = [torch.empty_like(local_masks_tensor) for _ in range(world_size)]
        
        # Gather data from all ranks
        dist.all_gather(gathered_preds, local_preds_tensor)
        dist.all_gather(gathered_targets, local_targets_tensor)
        dist.all_gather(gathered_masks, local_masks_tensor)
        
        print("## 33 ##", gathered_preds[0].shape, gathered_targets[0].shape, gathered_masks[0].shape)
        # torch.Size([7, 1, 6]) torch.Size([7, 6]) torch.Size([7, 6])
        print("## 33 ##", len(gathered_preds), len(gathered_targets), len(gathered_masks))
        # 1 1 1
        print("#########")
        print(all_sizes, local_size) # [tensor([7])] tensor([7])
        # Remove padding and convert to numpy for metric calculation
        valid_preds = []
        valid_targets = []
        valid_masks = []
        
        for i, size in enumerate([size.item() for size in all_sizes]):
            if size > 0:
                valid_preds.append(gathered_preds[i][:size])
                valid_targets.append(gathered_targets[i][:size])
                valid_masks.append(gathered_masks[i][:size])
        print("valid_preds", valid_preds)
        print("valid_targets", valid_targets)
        print("valid_masks", valid_masks)
        # Combine all gathered data
        all_preds = torch.cat(valid_preds, dim=0).cpu().numpy()
        all_targets = torch.cat(valid_targets, dim=0).cpu().numpy()
        all_masks = torch.cat(valid_masks, dim=0).cpu().numpy()

        print("all_preds", all_preds) #all_preds (7, 1, 6)
        print("all_targets", all_targets)#all_targets (7, 6)
        print("all_masks", all_masks)#all_masks (7, 6)
        all_preds = torch.cat(valid_preds, dim=0).squeeze(1).cpu().numpy()
        #print(all_preds.shape) #(7, 1, 6)
        for i in range(all_preds.shape[1]):
            valid_idx = all_masks[:, i] #gather all of the ith pred indices that not masked
            print(valid_idx)
            if valid_idx.sum() == 0:
                continue

            task_preds   = all_preds[valid_idx, i]
            task_targets = all_targets[valid_idx, i]

            # 1️⃣ keep only entries where target is 0 or 1
            keep = (task_targets == 0) | (task_targets == 1)
            task_preds, task_targets = task_preds[keep], task_targets[keep]

            # 2️⃣ debug: if anything was dropped, print the unique values once
            if self.args.rank == 0 and np.unique(task_targets).size < np.unique(all_targets[valid_idx, i]).size:
                bad_vals = set(all_targets[valid_idx, i]) - {0, 1}
                print(f"[WARN] Task {i}: dropped labels {bad_vals}")
            print("task_targets", task_targets)
            print("task_targets__", np.unique(task_targets).size)
            # 3️⃣ need both classes left to compute AUC
            if np.unique(task_targets).size < 2:
                metrics[f'auc_task{i}'] = np.nan
                metrics[f'acc_task{i}'] = np.nan
                continue

            # --- DEBUG BLOCK ----------------------------------------------------
            bad_vals = np.setdiff1d(task_targets, [0, 1])
            if bad_vals.size > 0:
                # Print once per task per validation
                print(f"[Rank {self.args.rank}] Task {i} – bad labels detected: {bad_vals}")
                # Optional: locate the offending rows (expensive, so enable only if needed)
                bad_rows = np.where(~np.isin(task_targets, [0, 1]))[0]
                print(f"Bad rows (local indices): {bad_rows[:20]} ...")
                # You could also log IDs if your dataset returns them.

                # Drop the invalid entries so metrics still compute
                keep = np.isin(task_targets, [0, 1])
                task_preds, task_targets = task_preds[keep], task_targets[keep]
            # --------------------------------------------------------------------
            # 4️⃣ metrics
            auc = roc_auc_score(task_targets, task_preds)
            metrics[f'auc_task{i}'] = auc

            prob = 1 / (1 + np.exp(-task_preds))
            acc  = (prob > 0.5).astype(int).mean()
            metrics[f'acc_task{i}'] = acc
        print(metrics)
trainer = MiniTrainer()
trainer.evaluate(val_loader)      # <-- runs end-to-end


logit tensor([[ 0.6554,  1.6408,  0.5262, -0.8042,  1.4421, -0.1529]])
target tensor([0., 0., 1., 0., 0., 0.])
mask_tensor tensor([ True,  True,  True,  True, False, False])
[tensor([[ 0.6554,  1.6408,  0.5262, -0.8042,  1.4421, -0.1529]])]
[tensor([0., 0., 1., 0., 0., 0.])]
[tensor([ True,  True,  True,  True, False, False])]
logit tensor([[ 0.7631,  0.0692,  0.1290, -0.9055,  0.9705, -1.3907]])
target tensor([0., 0., 0., 1., 0., 0.])
mask_tensor tensor([ True,  True,  True,  True,  True, False])
[tensor([[ 0.6554,  1.6408,  0.5262, -0.8042,  1.4421, -0.1529]]), tensor([[ 0.7631,  0.0692,  0.1290, -0.9055,  0.9705, -1.3907]])]
[tensor([0., 0., 1., 0., 0., 0.]), tensor([0., 0., 0., 1., 0., 0.])]
[tensor([ True,  True,  True,  True, False, False]), tensor([ True,  True,  True,  True,  True, False])]
logit tensor([[-0.4765,  1.3066,  0.6974, -0.1251,  1.5011, -0.7810]])
target tensor([0., 1., 0., 0., 0., 0.])
mask_tensor tensor([ True,  True,  True,  True, False, False])
[tensor([[ 0.655

In [23]:
all_sizes = [torch.tensor(7)] #[tensor([7])]
print(all_sizes)
for i, size in enumerate([size.item() for size in all_sizes]):
    print(i, size)

[tensor(7)]
0 7
