# 02. Model Sanity Checks

In [1]:
import torch
import os
import sys
import logging

In [2]:
try:
    project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
    if project_root not in sys.path:
        sys.path.insert(0, project_root)
    from src.utils import load_config, setup_logging, get_device
    from src.models import get_models # The main function we'll be testing
    from src.data_loader import ColoredMNISTDataset # To understand target shapes
except ImportError as e:
    print(f"ImportError: {e}")
    print("Make sure your notebook is in the 'notebooks' directory of the project,")
    print("and that 'src' is a Python package (contains __init__.py) and accessible.")
    print(f"Current sys.path: {sys.path}")
    print(f"Attempted project_root: {project_root}")
    raise

#### Configuration

In [3]:
setup_logging(log_to_console=True, log_file=None)

CONFIG_FILE_PATH = os.path.join(project_root, "configs", "colored_mnist_default.yaml")

if not os.path.exists(CONFIG_FILE_PATH):
    print(f"ERROR: Configuration file not found at {CONFIG_FILE_PATH}")
    raise FileNotFoundError(f"Config file missing: {CONFIG_FILE_PATH}")

config = load_config(CONFIG_FILE_PATH)

# Use CPU for these sanity checks by default, unless GPU is explicitly desired for testing
# This avoids issues if a GPU isn't available or configured for the notebook environment.
device_name = config['training'].get('device', 'cpu')
if device_name == 'auto': # If auto, default to CPU for notebook stability
    device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(f"Using device for model checks: {device}")

2025-05-17 19:39:54 - root - INFO - Logging configured.
2025-05-17 19:39:54 - root - INFO - Successfully loaded configuration from: /home/studio-lab-user/learning-not-to-learn/configs/colored_mnist_default.yaml


Using device for model checks: cpu


#### 1. Instantiate Models using get_models

In [4]:
print("\n--- 1. Instantiating Models ---")
try:
    models_dict = get_models(config, device)
    feature_extractor_f = models_dict['feature_extractor_f']
    task_classifier_g = models_dict['task_classifier_g']
    bias_predictor_h = models_dict['bias_predictor_h']
    grl_layer = models_dict['grl'] # Gradient Reversal Layer

    print("Models instantiated successfully:")
    print(f"  Feature Extractor (f): {type(feature_extractor_f).__name__}")
    print(f"  Task Classifier (g): {type(task_classifier_g).__name__}")
    print(f"  Bias Predictor (h): {type(bias_predictor_h).__name__}")
    print(f"  Gradient Reversal Layer (grl): {type(grl_layer).__name__} with lambda={grl_layer.lambda_val}")

except Exception as e:
    print(f"ERROR: Failed to instantiate models using get_models: {e}")
    raise

2025-05-17 19:40:36 - root - INFO - LeNet_F initialized: Input Channels=3, Feature Dim=128
2025-05-17 19:40:36 - root - INFO - MLP_Classifier_G initialized: Input Dim=128, Hidden Dims=[64], Output Dim=10
2025-05-17 19:40:36 - root - INFO - ConvBiasPredictorH initialized: InputFeatureDim=128, NumBiasChannels=3, NumBiasQuantizationBins=8, Output Spatial=(14,14), IntermediateChannels=64, BaseSpatialDimForUpsample=7
2025-05-17 19:40:36 - root - INFO - All models created and moved to device.



--- 1. Instantiating Models ---
Models instantiated successfully:
  Feature Extractor (f): LeNet_F
  Task Classifier (g): MLP_Classifier_G
  Bias Predictor (h): ConvBiasPredictorH
  Gradient Reversal Layer (grl): GradientReversalLayer with lambda=1.0


#### 2. Prepare Dummy Inputs

In [5]:
print("\n--- 2. Preparing Dummy Inputs ---")
# Get parameters from config for shapes
data_cfg = config['data']
model_cfg = config['model']

batch_size = 4 # A small batch for testing
img_channels = data_cfg.get('img_channels', 3)
img_h = data_cfg.get('img_size', 28)
img_w = data_cfg.get('img_size', 28)

feature_dim = model_cfg['feature_extractor_f']['params']['feature_dim']

# Dummy image batch (like from DataLoader)
dummy_images = torch.randn(batch_size, img_channels, img_h, img_w).to(device)
print(f"Dummy images shape: {dummy_images.shape} on device: {dummy_images.device}")

# Dummy feature vector (output of f, input to g and h)
dummy_features = torch.randn(batch_size, feature_dim).to(device)
print(f"Dummy features shape: {dummy_features.shape} on device: {dummy_features.device}")


--- 2. Preparing Dummy Inputs ---
Dummy images shape: torch.Size([4, 3, 28, 28]) on device: cpu
Dummy features shape: torch.Size([4, 128]) on device: cpu


#### 3. Test Forward Pass of Feature Extractor ( $f$ )

In [6]:
print("\n--- 3. Testing Feature Extractor (f) ---")
try:
    feature_extractor_f.eval() # Set to eval mode for sanity check
    with torch.no_grad(): # No need to compute gradients
        output_features_f = feature_extractor_f(dummy_images)
    print(f"Output features shape from f: {output_features_f.shape}")
    assert output_features_f.shape == (batch_size, feature_dim), \
        f"Expected feature_dim {feature_dim}, got {output_features_f.shape[1]}"
    print("Feature Extractor (f) forward pass successful.")
except Exception as e:
    print(f"ERROR during Feature Extractor (f) forward pass: {e}")
    raise


--- 3. Testing Feature Extractor (f) ---
Output features shape from f: torch.Size([4, 128])
Feature Extractor (f) forward pass successful.


#### 4. Test Forward Pass of Task Classifier ( $g$ )

In [7]:
print("\n--- 4. Testing Task Classifier (g) ---")
num_main_classes = data_cfg['num_main_classes']
try:
    task_classifier_g.eval()
    with torch.no_grad():
        # Use features from f's output for a more integrated test, or dummy_features for isolation
        output_task_g = task_classifier_g(output_features_f) 
        # output_task_g = task_classifier_g(dummy_features) # Alternative for isolated test
    print(f"Output logits shape from g: {output_task_g.shape}")
    assert output_task_g.shape == (batch_size, num_main_classes), \
        f"Expected num_main_classes {num_main_classes}, got {output_task_g.shape[1]}"
    print("Task Classifier (g) forward pass successful.")
except Exception as e:
    print(f"ERROR during Task Classifier (g) forward pass: {e}")
    raise


--- 4. Testing Task Classifier (g) ---
Output logits shape from g: torch.Size([4, 10])
Task Classifier (g) forward pass successful.


#### 5. Test Forward Pass of Bias Predictor ( $h$ )

In [8]:
print("\n--- 5. Testing Bias Predictor (h) ---")
# Expected output shape for ConvBiasPredictorH:
# (Batch, NumBiasBins, NumBiasChannels, H_out, W_out)
h_params = model_cfg['bias_predictor_h']['params']
num_bias_quant_bins = h_params['num_bias_quantization_bins']
# num_bias_channels is effectively img_channels for our setup
h_output_h = h_params['output_h']
h_output_w = h_params['output_w']
expected_h_shape = (batch_size, num_bias_quant_bins, img_channels, h_output_h, h_output_w)

try:
    bias_predictor_h.eval()
    with torch.no_grad():
        # Use features from f's output
        output_bias_h = bias_predictor_h(output_features_f)
        # output_bias_h = bias_predictor_h(dummy_features) # Alternative for isolated test
    print(f"Output bias prediction shape from h: {output_bias_h.shape}")
    assert output_bias_h.shape == expected_h_shape, \
        f"Expected shape {expected_h_shape}, got {output_bias_h.shape}"
    print("Bias Predictor (h) forward pass successful.")
except Exception as e:
    print(f"ERROR during Bias Predictor (h) forward pass: {e}")
    raise


--- 5. Testing Bias Predictor (h) ---
Output bias prediction shape from h: torch.Size([4, 8, 3, 14, 14])
Bias Predictor (h) forward pass successful.


#### 6. Test Gradient Reversal Layer (GRL)
GRL's main effect is on the backward pass, but we can check its forward pass (identity) and that it can be applied.

In [9]:
print("\n--- 6. Testing Gradient Reversal Layer (GRL) ---")
try:
    # GRL doesn't have eval/train mode
    # It's a torch.autograd.Function wrapper
    features_through_grl = grl_layer(output_features_f.clone().requires_grad_(True)) # Need requires_grad for backward demo
    print(f"Shape after GRL forward pass: {features_through_grl.shape}")
    assert features_through_grl.shape == output_features_f.shape, "GRL forward pass should be identity."
    print("GRL forward pass behaves as identity (shape check).")

    # Conceptual check of backward pass (requires a subsequent layer and loss)
    if features_through_grl.requires_grad:
        # Simulate a dummy loss and backward pass on GRL's output
        dummy_downstream_loss = features_through_grl.mean() 
        dummy_downstream_loss.backward()
        
        # Check if the original features (before GRL) have gradients
        # The gradient should be reversed and scaled by GRL's lambda
        if output_features_f.grad is not None:
            print(f"Gradient on input to GRL (sample): {output_features_f.grad[0, :5]}")
            print("GRL backward pass conceptually checked (gradient exists).")
        else:
            # This might happen if output_features_f itself didn't require grad or was detached.
            # Let's re-run f with requires_grad on input for a clearer GRL backward demo
            print("Re-running f for clearer GRL backward demo...")
            dummy_images_grad = torch.randn(batch_size, img_channels, img_h, img_w, device=device, requires_grad=True)
            feature_extractor_f.train() # Ensure params have grads
            output_features_f_grad = feature_extractor_f(dummy_images_grad)
            
            # Detach for GRL input to isolate GRL's effect on its input's grad, not f's params
            output_features_f_grad_detached = output_features_f_grad.detach().requires_grad_(True) 
            
            features_through_grl_2 = grl_layer(output_features_f_grad_detached)
            dummy_downstream_loss_2 = features_through_grl_2.sum() # Use sum for non-zero gradients
            dummy_downstream_loss_2.backward()
            
            if output_features_f_grad_detached.grad is not None:
                print(f"Gradient on input to GRL (sample from re-run): {output_features_f_grad_detached.grad[0, :5]}")
                print(f"Note: This gradient should be negative of what it would be without GRL (scaled by lambda={grl_layer.lambda_val}).")
                print("GRL backward pass conceptually checked.")
            else:
                print("WARN: Still no gradient on GRL input after re-run. Check requires_grad flags.")
    else:
        print("WARN: Output of GRL does not require grad. Cannot test backward pass effect directly here.")
        
except Exception as e:
    print(f"ERROR during GRL test: {e}")
    raise


--- 6. Testing Gradient Reversal Layer (GRL) ---
Shape after GRL forward pass: torch.Size([4, 128])
GRL forward pass behaves as identity (shape check).
Re-running f for clearer GRL backward demo...
Gradient on input to GRL (sample from re-run): tensor([-1., -1., -1., -1., -1.])
Note: This gradient should be negative of what it would be without GRL (scaled by lambda=1.0).
GRL backward pass conceptually checked.


#### 7. Conceptual Loss Application (Shape Check)

In [10]:
print("\n--- 7. Conceptual Loss Application (Shape Check) ---")
# Main Task Loss (g)
num_main_classes = data_cfg['num_main_classes']
dummy_main_labels = torch.randint(0, num_main_classes, (batch_size,), device=device, dtype=torch.long)
criterion_main = torch.nn.CrossEntropyLoss()
try:
    loss_g = criterion_main(output_task_g, dummy_main_labels)
    print(f"Conceptual main task loss (g) calculated: {loss_g.item():.4f}")
except Exception as e:
    print(f"ERROR calculating conceptual main task loss: {e}")

# Bias Prediction Loss (h)
# Target shape for bias from data_loader: (B, C, H_out, W_out) with class indices [0, NumBiasBins-1]
# Bias predictor h output shape: (B, NumBiasBins, C, H_out, W_out)
dummy_bias_targets = torch.randint(0, num_bias_quant_bins, 
                                   (batch_size, img_channels, h_output_h, h_output_w), 
                                   device=device, dtype=torch.long)
criterion_bias = torch.nn.CrossEntropyLoss(ignore_index=255) # As used in Trainer
try:
    loss_h = criterion_bias(output_bias_h, dummy_bias_targets)
    print(f"Conceptual bias prediction loss (h) calculated: {loss_h.item():.4f}")
except Exception as e:
    print(f"ERROR calculating conceptual bias prediction loss: {e}")
    print("  Check shapes: output_bias_h (preds) vs dummy_bias_targets (targets)")
    print(f"  output_bias_h shape: {output_bias_h.shape}")
    print(f"  dummy_bias_targets shape: {dummy_bias_targets.shape}")


print("\nModel sanity checks notebook finished.")
print("If all assertions passed and no errors occurred, your model architectures and forward passes are likely correct.")


--- 7. Conceptual Loss Application (Shape Check) ---
Conceptual main task loss (g) calculated: 2.2464
Conceptual bias prediction loss (h) calculated: 2.0811

Model sanity checks notebook finished.
If all assertions passed and no errors occurred, your model architectures and forward passes are likely correct.
