# Network Dissection Lite - Google Colab

This notebook runs NetDissect-Lite (Python 3.11 + RTX 4090 compatible version) to analyze interpretability of deep visual representations.

**Note**: Make sure to enable GPU in Runtime > Change runtime type > GPU

**Workflow**:
1. Clone original repo from GitHub
2. Manually copy-paste your upgraded files to replace the original ones
3. Download dataset and run analysis


## Step 1: Install Dependencies


In [1]:
# Install required packages
%pip install torch torchvision imageio scikit-image scipy numpy pandas -q

# Verify GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: Tesla T4
CUDA version: 12.6


## Step 2: Clone Original Repository from GitHub


In [2]:
import os

# Clone the original NetDissect-Lite repository
WORK_DIR = "/content/NetDissect-Lite"
if os.path.exists(WORK_DIR):
    !rm -rf {WORK_DIR}

print("Cloning NetDissect-Lite from GitHub...")
!git clone https://github.com/CSAILVision/NetDissect-Lite.git {WORK_DIR}

os.chdir(WORK_DIR)
print(f"\nWorking directory: {os.getcwd()}")
print("Repository cloned successfully!")


Cloning NetDissect-Lite from GitHub...
Cloning into '/content/NetDissect-Lite'...
remote: Enumerating objects: 147, done.[K
remote: Counting objects: 100% (62/62), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 147 (delta 50), reused 49 (delta 49), pack-reused 85 (from 1)[K
Receiving objects: 100% (147/147), 62.56 KiB | 4.17 MiB/s, done.
Resolving deltas: 100% (80/80), done.

Working directory: /content/NetDissect-Lite
Repository cloned successfully!


## Step 3: Replace Files with Upgraded Versions

Copy and paste your upgraded files below. Each cell will replace the corresponding file in the cloned repository.


## Step 4: Download Broden Dataset


In [3]:
import os
import zipfile

# Download broden1_224 dataset (~1GB)
dataset_dir = "dataset/broden1_224"
dataset_zip = "dataset/broden1_224.zip"
dataset_url = "http://netdissect.csail.mit.edu/data/broden1_224.zip"

if not os.path.exists(os.path.join(dataset_dir, "index.csv")):
    print("Downloading broden1_224 dataset (this may take a while ~1GB)...")
    os.makedirs("dataset", exist_ok=True)

    # Download using wget
    !wget --progress=bar {dataset_url} -O {dataset_zip}

    print("Extracting dataset...")
    with zipfile.ZipFile(dataset_zip, 'r') as zip_ref:
        zip_ref.extractall("dataset")

    # Clean up zip file
    os.remove(dataset_zip)
    print("Dataset downloaded and extracted successfully!")
else:
    print("Dataset already exists.")


Downloading broden1_224 dataset (this may take a while ~1GB)...
--2026-01-29 18:30:58--  http://netdissect.csail.mit.edu/data/broden1_224.zip
Resolving netdissect.csail.mit.edu (netdissect.csail.mit.edu)... 128.52.131.63
Connecting to netdissect.csail.mit.edu (netdissect.csail.mit.edu)|128.52.131.63|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 937582103 (894M) [application/zip]
Saving to: ‚Äòdataset/broden1_224.zip‚Äô


2026-01-29 18:31:16 (50.9 MB/s) - ‚Äòdataset/broden1_224.zip‚Äô saved [937582103/937582103]

Extracting dataset...
Dataset downloaded and extracted successfully!


## Step 6: Run NetDissect


In [5]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [6]:
import settings
import os
import torch

# ============================================
# CHECKPOINT CONFIGURATION
# ============================================
epoch = 1  # Change this to the epoch you want to analyze

checkpoint_path = (
    f"/content/drive/MyDrive/semantic_mortality_checkpoints/"
    f"checkpoint_epoch_{epoch}.pth"
)

print("=" * 60)
print("CHECKPOINT VERIFICATION")
print("=" * 60)

# Verify checkpoint exists
if os.path.exists(checkpoint_path):
    print(f"‚úÖ Checkpoint found: {checkpoint_path}")
    file_size = os.path.getsize(checkpoint_path) / (1024 * 1024)  # MB
    print(f"   File size: {file_size:.2f} MB")
    
    # Try to load and inspect checkpoint structure
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        print(f"\nüì¶ Checkpoint Structure:")
        print(f"   Type: {type(checkpoint).__name__}")
        
        if isinstance(checkpoint, dict):
            print(f"   Keys: {list(checkpoint.keys())[:10]}...")  # Show first 10 keys
            
            # Check for common checkpoint formats
            if 'state_dict' in checkpoint:
                print(f"   ‚úÖ Found 'state_dict' key")
                state_dict_keys = list(checkpoint['state_dict'].keys())
                print(f"   State dict has {len(state_dict_keys)} parameters")
                print(f"   First 5 keys: {state_dict_keys[:5]}")
                
                # Check for DataParallel prefix
                has_module_prefix = any(k.startswith('module.') for k in state_dict_keys)
                if has_module_prefix:
                    print(f"   ‚ö†Ô∏è  Detected 'module.' prefix (DataParallel format)")
                    settings.MODEL_PARALLEL = True
                else:
                    settings.MODEL_PARALLEL = False
                    
            elif 'model' in checkpoint:
                print(f"   ‚úÖ Found 'model' key")
                settings.MODEL_PARALLEL = False
            else:
                # Assume checkpoint is state_dict itself
                print(f"   ‚ÑπÔ∏è  Using checkpoint directly as state_dict")
                # Check for module prefix
                has_module_prefix = any(k.startswith('module.') for k in checkpoint.keys())
                settings.MODEL_PARALLEL = has_module_prefix
                if has_module_prefix:
                    print(f"   ‚ö†Ô∏è  Detected 'module.' prefix (DataParallel format)")
            
            # Check for epoch info
            if 'epoch' in checkpoint:
                print(f"   üìÖ Epoch in checkpoint: {checkpoint['epoch']}")
            if 'loss' in checkpoint:
                print(f"   üìâ Loss in checkpoint: {checkpoint['loss']}")
                
        else:
            print(f"   ‚ÑπÔ∏è  Checkpoint appears to be a model object")
            settings.MODEL_PARALLEL = False
            
        print(f"\n‚úÖ Checkpoint structure verified!")
        
    except Exception as e:
        print(f"\n‚ùå ERROR loading checkpoint: {e}")
        print("   Please check the checkpoint file format.")
        raise
        
else:
    print(f"‚ùå ERROR: Checkpoint not found at: {checkpoint_path}")
    print("\nPlease verify:")
    print("  1. Google Drive is mounted correctly")
    print("  2. The checkpoint file exists at the specified path")
    print("  3. The epoch number is correct")
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Set MODEL_FILE
settings.MODEL_FILE = checkpoint_path

# Explicitly enable GPU
settings.GPU = True
print(f"\n‚öôÔ∏è  Configuration:")
print(f"   MODEL_FILE: {settings.MODEL_FILE}")
print(f"   GPU: {settings.GPU}")
print(f"   MODEL_PARALLEL: {settings.MODEL_PARALLEL}")
print(f"   MODEL: {settings.MODEL}")
print(f"   DATASET: {settings.DATASET}")

print("\n" + "=" * 60)
print("‚úÖ Ready to run NetDissect analysis!")
print("=" * 60)


In [None]:
# Run the main script

print("Starting NetDissect analysis...")

!python main.py


Starting NetDissect analysis...
  for c, n in [re.match('^([^(]*)\(([^)]*)\)$', f).groups()
  elif re.match('^\d+$', val):
  elif re.match('^\d+\.\d*$', val):
  if re.match('^\d+$', v):
  return re.sub('[\.-/#?*!\s]+', '-', blob).strip('-')
  return re.sub('[-/#?*!\s]+', '-', blob).strip('-')

DEBUG: primary_categories_per_index() result:
  Shape: (1198,)
  Unique values: [0 1 2 3 4]
  Value counts: [ 11 549 110 468  60]
  First 20 values: [4 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1]
  All values same? False
DEBUG: labelcat after onehot:
  Shape: (1198, 5)
  Unique rows: 5
  First 5 rows:
[[0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]]
  All rows identical? False
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100% 44.7M/44.7M [00:00<00:00, 177MB/s]
file missing, loading from scratch
Extracting features batch 1/495
Extracting features batch 2/495
Extracting featu

## Step 8: Download Results


In [None]:
import os

for root, dirs, files in os.walk("result"):
    print(root)
    for f in files:
        print("  ", f)


result
result/pytorch_resnet18_imagenet_test
result/pytorch_resnet18_imagenet_test/image
result/pytorch_resnet18_imagenet
   tally.csv
   quantile.npy
   feature_size.npy
result/pytorch_resnet18_imagenet/image
result/pytorch_resnet18_imagenet/html
   layer4.html
result/pytorch_resnet18_imagenet/html/image
   layer4-0171.jpg
   layer4-0372.jpg
   layer4-0085.jpg
   layer4-0217.jpg
   layer4-0415.jpg
   layer4-0256.jpg
   layer4-0045.jpg
   layer4-0189.jpg
   layer4-0216.jpg
   layer4-0482.jpg
   layer4-0399.jpg
   layer4-0492.jpg
   layer4-0312.jpg
   layer4-0169.jpg
   layer4-0212.jpg
   layer4-0472.jpg
   layer4-0468.jpg
   layer4-0278.jpg
   layer4-0219.jpg
   layer4-0491.jpg
   layer4-0503.jpg
   layer4-0145.jpg
   layer4-0089.jpg
   layer4-0373.jpg
   layer4-0368.jpg
   layer4-0226.jpg
   layer4-0208.jpg
   layer4-0048.jpg
   layer4-0365.jpg
   layer4-0499.jpg
   layer4-0508.jpg
   layer4-0050.jpg
   layer4-0042.jpg
   layer4-0039.jpg
   layer4-0479.jpg
   layer4-0218.jpg
   layer4

In [None]:
import shutil
from google.colab import files
import os

if os.path.exists("result"):
    print("Zipping entire result directory...")
    shutil.make_archive("netdissect_results", "zip", "result")

    print("Downloading zip...")
    files.download("netdissect_results.zip")
else:
    print("No result directory found.")


Zipping entire result directory...
Downloading zip...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
import os
import shutil

# Destination directory (FIXED path)
dst = "/content/drive/MyDrive/nd_semantic_d_results/epoch0"

# Create directory if it doesn't exist
os.makedirs(dst, exist_ok=True)

# Copy results folder
src = "result"   # Network Dissection output folder
shutil.copytree(src, dst, dirs_exist_ok=True)

print(f"Results copied to Google Drive at: {dst}")



Results copied to Google Drive at: /content/drive/MyDrive/nd_semantic_d_results/epoch0


## Notes and Tips

1. **GPU**: Make sure GPU is enabled (Runtime > Change runtime type > GPU)
2. **TEST_MODE**: Starts with TEST_MODE=True for quick testing. Set to False in settings.py for full dataset
3. **Memory**: If you get OOM errors, reduce BATCH_SIZE in settings.py
4. **Time**: Full dataset analysis takes 2-3 hours, test mode takes 20-30 minutes
5. **Results**: HTML report, CSV files, and visualizations are saved in the result folder

### To run full dataset:
```python
# Edit settings.py
TEST_MODE = False  # Change this
```

### Supported Models:
- ResNet18 (default)
- ResNet50
- DenseNet161
- AlexNet

### Supported Datasets:
- ImageNet (default)
- Places365
