<a href="https://colab.research.google.com/github/mehravehj/Debiased_supernet_sampling/blob/main/Pooling_post_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Post process training

1.   Load Saved checkpoint (given test number, epoch)
2.   get output, feature maps, grads on validation mini-batch
3.   Save them in separate folders



# Definitions

In [1]:
from google.colab import drive

# Mount your Google Drive
drive.mount('/content/drive')#, force_remount=True)

Mounted at /content/drive


In [2]:
# Block 1: All Definitions (Complete with modified ResNet20)

import os
import math
import random # For shuffle in data_loader
from datetime import datetime
from itertools import combinations
from typing import Optional, Tuple, List, Dict, Union, Type # Added Type for block_type hint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
import torchvision.transforms as transforms
from torch.distributions.categorical import Categorical


# --- From utility_functions.py ---
def string_to_list(x: str, leng: int) -> List[int]:
    """
    Converts a string to a list of integers.
    If 'x' is a comma-separated string of numbers, it splits them.
    The resulting list must have a length equal to 'leng'.
    If 'x' is a single number, it creates a list of 'leng' repetitions of that number.
    """
    if not isinstance(x, str):
        raise TypeError(f"Input 'x' must be a string, got {type(x)}")
    if not isinstance(leng, int) or leng <= 0:
        raise ValueError(f"'leng' must be a positive integer, got {leng}")

    if ',' in x:
        parts = x.split(',')
        try:
            res = [int(p.strip()) for p in parts]
        except ValueError as e:
            raise ValueError(f"Invalid number in comma-separated string '{x}': {e}")
        if len(res) != leng:
            raise ValueError(f"Channel string '{x}' provides {len(res)} values, but network depth 'leng' is {leng}. They must match.")
    else:
        try:
            val = int(x.strip())
            res = [val for _ in range(leng)]
        except ValueError:
            raise ValueError(f"Channel string '{x}' is not a valid single integer or a comma-separated list of integers.")
    return res

# --- From search_space_design.py ---
def create_search_space(num_layers: int, num_scales: int) -> Tuple[tuple, int]:
    if num_layers < 0: # Allow num_layers = 0 for empty search space
        raise ValueError("num_layers cannot be negative.")
    if num_scales <= 0:
        raise ValueError("num_scales must be positive.")

    if num_layers == 0:
        print(f'No layers defined (num_layers=0), 0 paths created.')
        return tuple(), 0

    # num_scales > num_layers implies more pooling stages than layers (after the first conv) to put them after.
    if num_scales > num_layers :
         raise ValueError(f"num_scales ({num_scales}) cannot exceed num_layers ({num_layers}) for this search space design when num_layers > 0.")

    num_pooling = num_scales - 1
    # num_available_positions_for_pooling refers to the number of positions *after* the initial conv block
    # and after each subsequent ResBasicBlock where a pooling operation could be inserted.
    # If path has length num_layers, and path[0] is before first block (always 0 pool),
    # then there are num_layers-1 subsequent positions for pooling.
    num_available_positions_for_pooling = num_layers - 1

    if num_pooling < 0 : # Should not happen if num_scales is positive
        num_pooling = 0 # Corrects to no pooling if num_scales = 0 somehow (though arg check prevents)

    # This condition checks if we're trying to place more pooling layers than available slots
    if num_pooling > 0 and num_available_positions_for_pooling < num_pooling :
        raise ValueError(
            f"Cannot place {num_pooling} pooling layers in {num_available_positions_for_pooling} available slots "
            f"(num_layers={num_layers}, num_scales={num_scales})."
        )

    paths = []
    # combinations(range(N), k)
    # N = num_available_positions_for_pooling
    # k = num_pooling
    # if N=0 (num_layers=1), k=0 (num_scales=1) => combinations(range(0),0) gives one empty tuple. p=[]. path=[0]. Correct.
    # if N=0 (num_layers=1), k>0 (num_scales>1) => combinations(range(0),k) gives zero items. paths list remains empty. Correct.

    for positions in combinations(range(num_available_positions_for_pooling), num_pooling):
        p = [0] * num_available_positions_for_pooling
        for i in positions:
            p[i] = 1
        paths.append(tuple([0] + p)) # path[0] is always 0 (no pooling before the first block)

    paths_tuple = tuple(paths)
    number_paths = len(paths_tuple)

    # This specific check handles the case where num_layers=1, num_scales>1 leading to 0 paths
    if number_paths == 0 and num_layers > 0:
        if not (num_available_positions_for_pooling == 0 and num_pooling > 0):
             # If it's not the expected C(0, k>0) case, then it's an unexpected empty path list
             print(f"Warning: No paths generated for num_layers={num_layers}, num_scales={num_scales}, "
                   f"num_pooling={num_pooling}, num_available_positions={num_available_positions_for_pooling}. "
                   f"This might be an issue if paths were expected.")

    print(f'All {number_paths} paths created.')
    if 0 < number_paths < 20: # Print paths only if a small number for brevity
        print(paths_tuple)
    return paths_tuple, number_paths


def init_path_logit(num_paths: int, initial_logits: float = 1.0) -> torch.Tensor:
    if num_paths < 0:
        raise ValueError("Number of paths cannot be negative.")
    if num_paths == 0:
        return torch.FloatTensor([])
    initial_path_weights = torch.FloatTensor([initial_logits for _ in range(num_paths)])
    return initial_path_weights

def sample_uniform(sample_weights: torch.Tensor, paths: tuple) -> Tuple[int, tuple]:
    if not paths:
        raise ValueError("Cannot sample from empty paths tuple.")
    if sample_weights.nelement() == 0:
        raise ValueError("Cannot sample path: sample_weights tensor is empty (likely num_paths was 0).")

    if sample_weights.dim() > 1:
        sample_weights = sample_weights.squeeze()
        if sample_weights.dim() > 1:
            raise ValueError("sample_weights must be a 1D tensor of logits.")
    if sample_weights.nelement() != len(paths):
        raise ValueError(f"Number of sample_weights ({sample_weights.nelement()}) must match number of paths ({len(paths)}).")

    try:
        prob_distribution = Categorical(logits=sample_weights)
    except RuntimeError as e:
        if not torch.isfinite(sample_weights).all():
            raise ValueError("Logits in sample_weights must be finite.") from e
        raise
    path_index = int(prob_distribution.sample().item())
    if not (0 <= path_index < len(paths)):
        raise IndexError(f"Sampled path_index {path_index} is out of bounds for paths list of length {len(paths)}.")
    return path_index, paths[path_index]


# --- From lr_scheduler.py ---
class CosineAnnealingWarmupRestarts(_LRScheduler):
    def __init__(self,
                 optimizer : torch.optim.Optimizer,
                 first_cycle_steps : int,
                 cycle_mult : float = 1.,
                 max_lr : float = 0.1,
                 min_lr : float = 0.001,
                 warmup_steps : int = 0,
                 gamma : float = 1.,
                 last_epoch : int = -1
        ):
        if not first_cycle_steps > 0:
            raise ValueError("first_cycle_steps must be positive.")
        if warmup_steps < 0:
            raise ValueError("warmup_steps cannot be negative.")
        if warmup_steps >= first_cycle_steps:
            raise ValueError("warmup_steps must be less than first_cycle_steps.")

        self.first_cycle_steps = first_cycle_steps
        self.cycle_mult = cycle_mult
        self.base_max_lr = max_lr
        self.max_lr = max_lr
        self.min_lr = min_lr
        self.warmup_steps = warmup_steps
        self.gamma = gamma
        self.cur_cycle_steps = first_cycle_steps
        self.cycle = 0
        self.step_in_cycle = last_epoch
        super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
        self.init_lr()

    def init_lr(self):
        self.base_lrs = []
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.min_lr
            self.base_lrs.append(self.min_lr)

    def get_lr(self):
        if self.step_in_cycle == -1:
            return self.base_lrs
        elif self.step_in_cycle < self.warmup_steps:
            if self.warmup_steps == 0: # Avoid division by zero if no warmup
                 return [self.max_lr for _ in self.base_lrs]
            return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
        else:
            denominator = (self.cur_cycle_steps - self.warmup_steps)
            if denominator <= 0: # Should be > 0 if cur_cycle_steps > warmup_steps
                 return [self.min_lr for _ in self.base_lrs] # Or max_lr if cycle just ended? min_lr seems safer for cosine end.
            return [base_lr + (self.max_lr - base_lr) *
                    (1 + math.cos(math.pi * (self.step_in_cycle - self.warmup_steps) / denominator)) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch: Optional[int] = None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.step_in_cycle = self.step_in_cycle + 1
            if self.step_in_cycle >= self.cur_cycle_steps:
                self.cycle += 1
                self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
                self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
        else:
            if epoch >= self.first_cycle_steps:
                if self.cycle_mult == 1.:
                    self.step_in_cycle = epoch % self.first_cycle_steps
                    self.cycle = epoch // self.first_cycle_steps
                else:
                    if self.cycle_mult <= 1.0: # Non-increasing cycle length
                        print(f"Warning: CosineAnnealingWarmupRestarts encountered cycle_mult={self.cycle_mult} <= 1.0. Using simpler cycle calculation.")
                        self.cycle = epoch // self.first_cycle_steps if self.first_cycle_steps > 0 else 0
                        self.step_in_cycle = epoch % self.first_cycle_steps if self.first_cycle_steps > 0 else epoch
                        self.cur_cycle_steps = self.first_cycle_steps
                    else: # Increasing cycle length
                        # Formula for sum of geometric series to find which cycle 'epoch' falls into
                        log_arg = (epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1) if self.first_cycle_steps > 0 else 1.0
                        if log_arg <= 0: # Should not happen with cycle_mult > 1 and epoch >= 0
                             print(f"Warning: CosineAnnealingWarmupRestarts encountered non-positive log argument: {log_arg}. Using simpler cycle calculation.")
                             self.cycle = epoch // self.first_cycle_steps if self.first_cycle_steps > 0 else 0
                             self.step_in_cycle = epoch % self.first_cycle_steps if self.first_cycle_steps > 0 else epoch
                             self.cur_cycle_steps = self.first_cycle_steps
                        else:
                            n = int(math.log(log_arg, self.cycle_mult))
                            self.cycle = n
                            # Sum of first n terms of geometric series: T_0 * (mult^n - 1) / (mult - 1)
                            denominator_cycle_sum = (self.cycle_mult - 1)
                            # if denominator_cycle_sum == 0: # Caught by cycle_mult == 1.
                            # sum_geometric_progression = self.first_cycle_steps * (self.cycle + 1)
                            # else:
                            sum_geometric_progression = self.first_cycle_steps * (self.cycle_mult**n - 1) / denominator_cycle_sum
                            self.step_in_cycle = epoch - int(sum_geometric_progression)
                            self.cur_cycle_steps = int(self.first_cycle_steps * self.cycle_mult**(n))
            else: # Before the first cycle completes
                self.cur_cycle_steps = self.first_cycle_steps
                self.step_in_cycle = epoch

        self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr_val in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr_val


# --- From create_model.py ---
def pooling_func(x: torch.Tensor) -> torch.Tensor:
    out = F.max_pool2d(x, kernel_size=2)
    return out

class ResBasicBlock(nn.Module):
    def __init__(self, in_planes: int, out_planes: int, stride: int = 1, downsample: Optional[nn.Module] = None):
        super(ResBasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes, affine=False)#, track_running_stats=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes, affine=False)#, track_running_stats=False)
        self.downsample = downsample

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class BasicConvBlock(nn.Module):
    def __init__(self, in_planes: int, out_planes: int, kernel_size: int = 3):
        super(BasicConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride=1, padding=kernel_size//2, bias=False)
        self.bn = nn.BatchNorm2d(out_planes, affine=False)#, track_running_stats=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

class ResNet20(nn.Module): # MODIFIED CLASS
    def __init__(self, block_type: Type[ResBasicBlock], num_layers: int, channels: List[int], num_classes: int = 10):
        super(ResNet20, self).__init__()
        if num_layers <= 0:
            raise ValueError("num_layers must be positive.")
        if len(channels) != num_layers:
            raise ValueError(f"Length of 'channels' list ({len(channels)}) must match 'num_layers' ({num_layers}).")

        self.num_layers = num_layers
        layers_list = [BasicConvBlock(3, channels[0])] # Input channels for image is 3
        self.current_inplanes = channels[0] # Output of the first block

        for i in range(1, num_layers):
            layers_list.append(self._make_layer(block_type, channels[i]))
            # self.current_inplanes is updated to channels[i] inside _make_layer

        self.res_blocks = nn.ModuleList(layers_list)
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.current_inplanes, num_classes, bias=False) # Uses output channels of the last block

        self.path_pooling_config: Optional[Tuple[int, ...]] = None

    def _make_layer(self, block_type: Type[ResBasicBlock], planes: int, stride_for_block: int = 1) -> nn.Module:
        downsample_shortcut = None
        if self.current_inplanes != planes:
            downsample_shortcut = nn.Sequential(
                nn.Conv2d(self.current_inplanes, planes, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes, affine=False)#, track_running_stats=False),
            )
        layer = block_type(self.current_inplanes, planes, stride=stride_for_block, downsample=downsample_shortcut)
        self.current_inplanes = planes # Update for the next layer
        return layer

    def set_path(self, path: Tuple[int, ...]):
        if len(path) != self.num_layers:
            raise ValueError(f"Path length ({len(path)}) must match number of layers ({self.num_layers}).")
        self.path_pooling_config = path

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.path_pooling_config is None:
            raise RuntimeError("Path not set. Call set_path(path) before forward pass.")

        out = x
        for c in range(self.num_layers):
            if self.path_pooling_config[c]:
                out = pooling_func(out)
            out = self.res_blocks[c](out)

        out = self.adaptive_avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

    def feature_extractor(self, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Extracts feature maps from the network.
        Args:
            x (torch.Tensor): The input tensor.
        Returns:
            List[torch.Tensor]: A list of feature maps.
                                - Feature maps after each block in self.res_blocks.
                                - Feature map after self.adaptive_avg_pool.
        """
        if self.path_pooling_config is None:
            raise RuntimeError("Path not set. Call set_path(path) before calling feature_extractor.")

        feature_maps: List[torch.Tensor] = []
        out = x
        for c in range(self.num_layers):
            if self.path_pooling_config[c]:
                out = pooling_func(out)
            out = self.res_blocks[c](out)
            feature_maps.append(out.detach().clone())

        out = self.adaptive_avg_pool(out)
        feature_maps.append(out.detach().clone())
        return feature_maps


# --- From data_loader.py ---
def _data_transforms_cifar10() -> Tuple[transforms.Compose, transforms.Compose]:
    cifar_mean = [0.49139968, 0.48215827, 0.44653124]
    cifar_std = [0.24703233, 0.24348505, 0.26158768]
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    return train_transform, test_transform

def _validation_set_indices(total_train_samples: int, valid_percent: float) -> Tuple[List[int], List[int]]:
    if not (0.0 <= valid_percent < 1.0):
        raise ValueError("valid_percent must be between 0.0 and almost 1.0.")
    num_validation_samples = int(valid_percent * total_train_samples)
    num_train_samples = total_train_samples - num_validation_samples
    print(f'Total original training samples: {total_train_samples}')
    print(f'New training size: {num_train_samples}, Validation size: {num_validation_samples}')
    indexes = list(range(total_train_samples))
    random.shuffle(indexes)
    train_indices = indexes[:num_train_samples]
    val_indices = indexes[num_train_samples:]
    return train_indices, val_indices

def get_data_loaders(
    dataset_name: str,
    valid_percent: float,
    batch_size: int,
    dataset_dir: str = '~/data/', # Consider using os.path.expanduser for ~
    workers: int = 2
) -> Tuple[DataLoader, Optional[DataLoader], Tuple[List[int], List[int]]]:
    dataset_dir = os.path.expanduser(dataset_dir) # Added expanduser
    if dataset_name == 'CIFAR10':
        train_transform_CIFAR, test_transform_CIFAR = _data_transforms_cifar10()
        full_trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=train_transform_CIFAR)
        # For validation, use the test transform (no augmentation) but on the training data split
        valset_for_loader = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=test_transform_CIFAR)
    else:
        raise Exception(f'Dataset {dataset_name} not supported!')

    total_train_samples = len(full_trainset)
    train_indices, val_indices = _validation_set_indices(total_train_samples, valid_percent)
    data_split_indices = (train_indices, val_indices)

    train_loader = DataLoader(full_trainset, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices), num_workers=workers, pin_memory=True, drop_last=True)
    validation_loader = None
    if valid_percent > 0.0 and val_indices: # Ensure val_indices is not empty
        validation_loader = DataLoader(valset_for_loader, batch_size=batch_size, sampler=SubsetRandomSampler(val_indices), num_workers=workers, pin_memory=True, drop_last=False)
    return train_loader, validation_loader, data_split_indices


# --- From NAS_trainer.py (now ModelTrainer) ---
def create_single_model(layers: int, channels: List[int], num_classes: int = 10) -> nn.Module: # Added num_classes
    model = ResNet20(ResBasicBlock, layers, channels, num_classes=num_classes) # Pass num_classes
    return model

def create_optimizer_and_scheduler(
    sched_type: str,
    model: nn.Module,
    lr: float,
    momentum: float,
    weight_decay: float,
    epochs: int, # Total epochs for CosineAnnealingLR
    min_lr: float, # For schedulers
    first_cycle_steps: int, # For CosineAnnealingWarmupRestarts
    cycle_mult: float, # For CosineAnnealingWarmupRestarts
    warmup_steps: int, # For CosineAnnealingWarmupRestarts
    gamma: float # For CosineAnnealingWarmupRestarts
) -> Tuple[optim.Optimizer, Optional[_LRScheduler]]:
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = None
    if sched_type == 'cosine_anneal':
        # epochs + 1 might be if T_max is total iterations, or if epochs is 0-indexed.
        # Usually T_max is the total number of epochs.
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=min_lr)
    elif sched_type == 'cosine_anneal_wr':
        scheduler = CosineAnnealingWarmupRestarts(optimizer, first_cycle_steps, cycle_mult, lr, min_lr, warmup_steps, gamma)
    elif sched_type == 'none' or sched_type is None:
        print("No learning rate scheduler will be used.")
    else:
        print(f"Warning: Unsupported scheduler type: '{sched_type}'. No scheduler will be used.")
    return optimizer, scheduler

def train_epoch_random_paths(
    model: nn.Module, # Should be ResNet20 instance
    train_queue: DataLoader,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    all_paths: tuple, # Tuple of path tuples
    uniform_path_weights: torch.Tensor, # Logits for sampling paths
    device: torch.device
) -> Tuple[float, float]: # Returns (average_loss, accuracy)
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    if len(train_queue) == 0:
        print("Warning: train_queue is empty. Skipping training for this epoch.")
        return 0.0, 0.0

    # Move weights to device once if they aren't already
    if uniform_path_weights.device != device:
        uniform_path_weights = uniform_path_weights.to(device)

    for inputs, targets in train_queue:
        inputs, targets = inputs.to(device), targets.to(device)

        _, current_path = sample_uniform(uniform_path_weights, all_paths) # sample_uniform expects weights, paths

        if not hasattr(model, 'set_path'):
            raise TypeError("Model does not have a 'set_path' method. Ensure it's the ResNet20 class.")
        model.set_path(current_path) # Type assertion for model if needed

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted_classes = outputs.max(1)
        total_samples += targets.size(0)
        correct_predictions += predicted_classes.eq(targets).sum().item()

    avg_loss = total_loss / len(train_queue) if len(train_queue) > 0 else 0.0
    accuracy = (100. * correct_predictions / total_samples) if total_samples > 0 else 0.0
    # Consider printing epoch number if available
    print(f'Training Epoch: Avg Loss: {avg_loss:.4f}, Accuracy: {correct_predictions}/{total_samples} ({accuracy:.2f}%)')
    return avg_loss, accuracy


# Extractor Functions
Output, feature_mpars, Gradient extractor

In [3]:
def output_extractor(model: nn.Module, val_inputs: torch.Tensor, device: torch.device) -> torch.Tensor:
    model.train()
    val_inputs = val_inputs.to(device)
    with torch.no_grad():
        outputs = model(val_inputs)
    return outputs.detach().cpu()

def feature_maps_extractor(model: nn.Module, val_inputs: torch.Tensor, device: torch.device) -> List[torch.Tensor]:
    model.train()
    val_inputs = val_inputs.to(device)
    with torch.no_grad():
        feature_maps = model.feature_extractor(val_inputs)
    return feature_maps

def gradients_extractor(model: nn.Module, val_inputs: torch.Tensor, val_targets: torch.Tensor, device: torch.device) -> Dict[str, Optional[torch.Tensor]]:
    model.train()
    model.zero_grad()
    val_inputs = val_inputs.to(device).detach().clone()
    val_targets = val_targets.to(device).detach().clone() # Corrected to use val_targets
    outputs = model(val_inputs)
    criterion = nn.CrossEntropyLoss().to(device)
    loss = criterion(outputs, val_targets)
    loss.backward()
    current_path_gradients: Dict[str, Optional[torch.Tensor]] = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            current_path_gradients[name] = param.grad.clone().detach().cpu()
        else:
            # This case can happen if a parameter was not used in the forward pass for this specific path
            current_path_gradients[name] = None
    return current_path_gradients

Extraction

In [4]:
def validate_model_all_paths_on_batch(
    model: nn.Module, # Should be ResNet20 instance
    all_paths: tuple, # Tuple of path tuples
    val_inputs: torch.Tensor,
    val_targets: torch.Tensor,
    criterion: nn.Module,
    device: torch.device
) -> Tuple[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]], Dict[str, Dict[str, Optional[torch.Tensor]]]]: # Updated return type hint
    # model.eval()
    #### We need to haave batchnorm in training mode
    model.train()
    output_vectors_all_paths: Dict[str, torch.Tensor] = {}
    feature_maps_all_paths: Dict[str, List[torch.Tensor]] = {} # Initialized feature_maps_all_paths
    gradients_all_paths: Dict[str, Dict[str, Optional[torch.Tensor]]] = {}

    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)

    print(f"Starting validation for {len(all_paths)} paths on a single mini-batch...")
    if not all_paths:
        print("Warning: No paths provided for validation.")
        return output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths # Include feature_maps_all_paths in empty return

    if not hasattr(model, 'set_path'):
        raise TypeError("Model does not have a 'set_path' method. Ensure it's the ResNet20 class.")

    for i, path_tuple in enumerate(all_paths):
        path_key = str(path_tuple)
        # print(f"  Validating path {i+1}/{len(all_paths)}: {path_tuple}") # Can be very verbose

        # Ensure model is zero_grad for each path's gradient calculation
        model.zero_grad()
        model.set_path(path_tuple)

        # Get output logits without gradients for storage
        with torch.no_grad():
            outputs_no_grad = output_extractor(model, val_inputs, device)

        # Get feature maps without gradients for storage
        # with torch.no_grad():
        #     feature_maps = feature_maps_extractor(model, val_inputs, device)
        output_vectors_all_paths[path_key] = outputs_no_grad
        # feature_maps_all_paths[path_key] = feature_maps

        ### Gradients

        # Re-enable grad for loss calculation and backward pass for this specific path
        # Detach inputs if they came from a previous computation graph part not relevant here
        # outputs_for_loss = model(val_inputs.detach().clone()) # Use detached clone for clean grad calculation
        # loss = criterion(outputs_for_loss, val_targets)
        # loss.backward() # Accumulates gradients in model.parameters() for the current path

        # current_path_gradients: Dict[str, Optional[torch.Tensor]] = {}
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         current_path_gradients[name] = param.grad.clone().detach().cpu()
        #     else:
        #         # This case can happen if a parameter was not used in the forward pass for this specific path
        #         current_path_gradients[name] = None

        # Corrected argument order for gradients_extractor
        # gradients_all_paths[path_key] = gradients_extractor(model, val_inputs, val_targets, device)

        # It's crucial to zero_grad() again if optimizer.step() is not called,
        # or before the next path's backward(). model.zero_grad() at start of loop handles this.

    print(f"Finished validation for all {len(all_paths)} paths.")
    return output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths # Include feature_maps_all_paths in return

# Load saved file
At a given epoch for a given run

In [5]:
test_number = 3001
analysis_epoch = 1

LOad and save output etc

In [6]:
import gc
import os
# analysis_epochs = [1, 6, 11, 16, 26, 46, 71, 101, 126, 151, 156, 171, 176, 196]
analysis_epochs = [1]

In [9]:
for analysis_epoch in analysis_epochs:
  print("epoch:", analysis_epoch)
  print('------------------------')
  gc.collect()
  load_save_for_epoch(test_number, analysis_epoch)

epoch: 1
------------------------
Successfully loaded checkpoint from /content/drive/MyDrive/paper4/pooling/test_3001/model_checkpoint_epoch_1_begin.pt
All 36 paths created.
Starting validation for 36 paths on a single mini-batch...
Finished validation for all 36 paths.
Outputs saved to /content/drive/MyDrive/paper4/pooling/test_3001/outputs/outputs_epoch_1.pt


In [8]:
def load_save_for_epoch(test_number, analysis_epoch):
  '''Load data'''
  # prompt: torch.load from the  /content/drive/MyDrive/paper4/pooling/test_"test_number"/model_checkpoint_epoch_"analysis_epoch"_begin.pt
  # Ensure the path exists and is correct
  try:
    del checkpoint, model, output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths
  except: pass

  gc.collect()
  checkpoint_path = f"/content/drive/MyDrive/paper4/pooling/test_{test_number}/model_checkpoint_epoch_{analysis_epoch}_begin.pt"

  # Load the model checkpoint
  checkpoint = torch.load(checkpoint_path)
  print(f"Successfully loaded checkpoint from {checkpoint_path}")

  channels = string_to_list("16,16,16,16,32,32,32,64,64,64", 10)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  model = create_single_model(10, channels, num_classes=10).to(device)
  model.load_state_dict(checkpoint['model_state_dict'])
  # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  '''Validate'''
  gc.collect()
  # get val minibatch
  val_data = "/content/drive/MyDrive/paper4/pooling/test_" + str(test_number) + "/extracted_val_batch.pt"
  saved_batch_data = torch.load(val_data, map_location=device)
  inputs = saved_batch_data['inputs']
  targets = saved_batch_data['targets']
  all_paths, num_paths = create_search_space(10, 3)
  criterion = nn.CrossEntropyLoss().to(device)

  output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths = validate_model_all_paths_on_batch(model, all_paths, inputs, targets, criterion, device)

  '''Save'''
  # Define the base directory for saving
  base_save_dir = f"/content/drive/MyDrive/paper4/pooling/test_{test_number}/"

  # Create the main test directory if it doesn't exist
  os.makedirs(base_save_dir, exist_ok=True)

  # Define subdirectories
  outputs_dir = os.path.join(base_save_dir, "outputs")
  feature_maps_dir = os.path.join(base_save_dir, "feature_maps")
  gradients_dir = os.path.join(base_save_dir, "gradients")

  # Create subdirectories if they don't exist
  os.makedirs(outputs_dir, exist_ok=True)
  os.makedirs(feature_maps_dir, exist_ok=True)
  os.makedirs(gradients_dir, exist_ok=True)

  # Save outputs
  output_save_path = os.path.join(outputs_dir, f"outputs_epoch_{analysis_epoch}.pt")
  torch.save(output_vectors_all_paths, output_save_path)
  print(f"Outputs saved to {output_save_path}")

  # # Save feature maps
  # feature_maps_save_path = os.path.join(feature_maps_dir, f"feature_maps_epoch_{analysis_epoch}.pt")
  # torch.save(feature_maps_all_paths, feature_maps_save_path)
  # print(f"Feature maps saved to {feature_maps_save_path}")

  # Save gradients
  # gradients_save_path = os.path.join(gradients_dir, f"gradients_epoch_{analysis_epoch}.pt")
  # torch.save(gradients_all_paths, gradients_save_path)
  # print(f"Gradients saved to {gradients_save_path}")



Load model

In [7]:
import gc
# prompt: torch.load from the  /content/drive/MyDrive/paper4/pooling/test_"test_number"/model_checkpoint_epoch_"analysis_epoch"_begin.pt
# Ensure the path exists and is correct
try:
  del checkpoint, model, output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths
except: pass

gc.collect()
checkpoint_path = f"/content/drive/MyDrive/paper4/pooling/test_{test_number}/model_checkpoint_epoch_{analysis_epoch}_begin.pt"

# Load the model checkpoint
checkpoint = torch.load(checkpoint_path)
print(f"Successfully loaded checkpoint from {checkpoint_path}")

channels = string_to_list("16,16,16,16,32,32,32,64,64,64", 10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = create_single_model(10, channels, num_classes=10).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Successfully loaded checkpoint from /content/drive/MyDrive/paper4/pooling/test_3001/model_checkpoint_epoch_1_begin.pt


<All keys matched successfully>

Validate and save

In [8]:
import gc

gc.collect()
# get val minibatch
saved_batch_data = torch.load("/content/drive/MyDrive/paper4/pooling/test_3000/extracted_val_batch.pt", map_location=device)
inputs = saved_batch_data['inputs']
targets = saved_batch_data['targets']
all_paths, num_paths = create_search_space(10, 3)
criterion = nn.CrossEntropyLoss().to(device)

output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths = validate_model_all_paths_on_batch(model, all_paths, inputs, targets, criterion, device)



All 36 paths created.
Starting validation for 36 paths on a single mini-batch...


OutOfMemoryError: CUDA out of memory. Tried to allocate 250.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 179.38 MiB is free. Process 233044 has 21.98 GiB memory in use. Of the allocated memory 21.18 GiB is allocated by PyTorch, and 578.80 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [109]:
# prompt: save outputs, feature maps and gradietns in separate folders in the directory   /content/drive/MyDrive/paper4/pooling/test_"test_number"/

import os

# Define the base directory for saving
base_save_dir = f"/content/drive/MyDrive/paper4/pooling/test_{test_number}/"

# Create the main test directory if it doesn't exist
os.makedirs(base_save_dir, exist_ok=True)

# Define subdirectories
outputs_dir = os.path.join(base_save_dir, "outputs")
feature_maps_dir = os.path.join(base_save_dir, "feature_maps")
gradients_dir = os.path.join(base_save_dir, "gradients")

# Create subdirectories if they don't exist
os.makedirs(outputs_dir, exist_ok=True)
os.makedirs(feature_maps_dir, exist_ok=True)
os.makedirs(gradients_dir, exist_ok=True)

# Save outputs
output_save_path = os.path.join(outputs_dir, f"outputs_epoch_{analysis_epoch}.pt")
torch.save(output_vectors_all_paths, output_save_path)
print(f"Outputs saved to {output_save_path}")

# Save feature maps
feature_maps_save_path = os.path.join(feature_maps_dir, f"feature_maps_epoch_{analysis_epoch}.pt")
torch.save(feature_maps_all_paths, feature_maps_save_path)
print(f"Feature maps saved to {feature_maps_save_path}")

# Save gradients
# gradients_save_path = os.path.join(gradients_dir, f"gradients_epoch_{analysis_epoch}.pt")
# torch.save(gradients_all_paths, gradients_save_path)
# print(f"Gradients saved to {gradients_save_path}")


Feature maps saved to /content/drive/MyDrive/paper4/pooling/test_3002/feature_maps/feature_maps_epoch_196.pt
Gradients saved to /content/drive/MyDrive/paper4/pooling/test_3002/gradients/gradients_epoch_196.pt


# ♒ Extract Val Batch

In [22]:
data_split_indices = checkpoint.get('data_split_indices')
_, val_indices = data_split_indices

output_batch_filename = f"extracted_val_batch.pt"
output_batch_filepath = os.path.join(os.path.dirname(f"/content/drive/MyDrive/paper4/pooling/test_{test_number}/"), output_batch_filename)

def _data_transforms_cifar10_validation_only() -> transforms.Compose:
    cifar_mean = [0.49139968, 0.48215827, 0.44653124]
    cifar_std = [0.24703233, 0.24348505, 0.26158768]
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    return test_transform

if 1:
    val_transform = _data_transforms_cifar10_validation_only()
    full_original_train_dataset = torchvision.datasets.CIFAR10(
        root='./', train=True, download=True, transform=val_transform
    )

val_sampler = SubsetRandomSampler(val_indices)
validation_batch_loader = DataLoader(
    full_original_train_dataset, batch_size=2000, sampler=val_sampler,
    num_workers=4, pin_memory=False
)

inputs, targets = next(iter(validation_batch_loader))

torch.save({'inputs': inputs, 'targets': targets}, output_batch_filepath)
print(f"Saved validation batch: {output_batch_filepath}")

Saved validation batch: /content/drive/MyDrive/paper4/pooling/test_3002/extracted_val_batch.pt


Inspection

In [None]:
# prompt: inspect output_vectors_all_paths, feature_maps_all_paths, gradients_all_paths

print("\n--- Inspecting output_vectors_all_paths ---")
print(f"Type: {type(output_vectors_all_paths)}")
print(f"Number of paths in dictionary: {len(output_vectors_all_paths)}")
if output_vectors_all_paths:
    # Get the first key (path tuple string)
    first_path_key = list(output_vectors_all_paths.keys())[0]
    print(f"Example key (path): {first_path_key}")
    first_output_tensor = output_vectors_all_paths[first_path_key]
    print(f"Value type (for path '{first_path_key}'): {type(first_output_tensor)}")
    if isinstance(first_output_tensor, torch.Tensor):
        print(f"Tensor shape: {first_output_tensor.shape}")
        print(f"Tensor device: {first_output_tensor.device}")
        print(f"Sample tensor content (first 5 values): {first_output_tensor.flatten()[:5].tolist()}")


print("\n--- Inspecting feature_maps_all_paths ---")
print(f"Type: {type(feature_maps_all_paths)}")
print(f"Number of paths in dictionary: {len(feature_maps_all_paths)}")
if feature_maps_all_paths:
    # Get the first key (path tuple string)
    first_path_key = list(feature_maps_all_paths.keys())[0]
    print(f"Example key (path): {first_path_key}")
    first_feature_list = feature_maps_all_paths[first_path_key]
    print(f"Value type (for path '{first_path_key}'): {type(first_feature_list)}")
    if isinstance(first_feature_list, list):
        print(f"Number of feature maps in the list: {len(first_feature_list)}")
        if first_feature_list:
            first_feature_tensor = first_feature_list[0]
            print(f"Type of first feature map: {type(first_feature_tensor)}")
            if isinstance(first_feature_tensor, torch.Tensor):
                print(f"Shape of first feature map: {first_feature_tensor.shape}")
                print(f"Device of first feature map: {first_feature_tensor.device}")
                print(f"Sample tensor content (first 5 values) of first feature map: {first_feature_tensor.flatten()[:5].tolist()}")
            last_feature_tensor = first_feature_list[-1] # Adaptive AvgPool output
            print(f"Type of last feature map: {type(last_feature_tensor)}")
            if isinstance(last_feature_tensor, torch.Tensor):
                 print(f"Shape of last feature map: {last_feature_tensor.shape}")
                 print(f"Device of last feature map: {last_feature_tensor.device}")
                 print(f"Sample tensor content (first 5 values) of last feature map: {last_feature_tensor.flatten()[:5].tolist()}")


print("\n--- Inspecting gradients_all_paths ---")
print(f"Type: {type(gradients_all_paths)}")
print(f"Number of paths in dictionary: {len(gradients_all_paths)}")
if gradients_all_paths:
    # Get the first key (path tuple string)
    first_path_key = list(gradients_all_paths.keys())[0]
    print(f"Example key (path): {first_path_key}")
    first_gradients_dict = gradients_all_paths[first_path_key]
    print(f"Value type (for path '{first_path_key}'): {type(first_gradients_dict)}")
    if isinstance(first_gradients_dict, dict):
        print(f"Number of gradient tensors/None in the dictionary: {len(first_gradients_dict)}")
        # Iterate through some items (e.g., first 5)
        print("Sample gradient entries:")
        for i, (grad_name, grad_tensor) in enumerate(first_gradients_dict.items()):
            if i >= 5: break
            print(f"  Parameter name: '{grad_name}'")
            print(f"    Gradient type: {type(grad_tensor)}")
            if isinstance(grad_tensor, torch.Tensor):
                print(f"    Gradient shape: {grad_tensor.shape}")
                print(f"    Gradient device: {grad_tensor.device}") # Should be cpu based on implementation
                # print(f"    Sample tensor content (first 5 values): {grad_tensor.flatten()[:5].tolist()}") # Might be large
            elif grad_tensor is None:
                 print("    Gradient is None (parameter not used in this path)")




--- Inspecting output_vectors_all_paths ---
Type: <class 'dict'>
Number of paths in dictionary: 36
Example key (path): (0, 1, 1, 0, 0, 0, 0, 0, 0, 0)
Value type (for path '(0, 1, 1, 0, 0, 0, 0, 0, 0, 0)'): <class 'builtin_function_or_method'>

--- Inspecting feature_maps_all_paths ---
Type: <class 'dict'>
Number of paths in dictionary: 36
Example key (path): (0, 1, 1, 0, 0, 0, 0, 0, 0, 0)
Value type (for path '(0, 1, 1, 0, 0, 0, 0, 0, 0, 0)'): <class 'list'>
Number of feature maps in the list: 11
Type of first feature map: <class 'torch.Tensor'>
Shape of first feature map: torch.Size([512, 16, 32, 32])
Device of first feature map: cuda:0
Sample tensor content (first 5 values) of first feature map: [0.5782656073570251, 0.17400991916656494, 0.0, 0.0, 0.0]
Type of last feature map: <class 'torch.Tensor'>
Shape of last feature map: torch.Size([512, 64, 1, 1])
Device of last feature map: cuda:0
Sample tensor content (first 5 values) of last feature map: [0.8793745040893555, 1.1690838336944