From 91e6e9e494f38b17fe3c4a8711fcfeb43a2f511b Mon Sep 17 00:00:00 2001 From: bingyic <107590227+bingyic@users.noreply.github.com> Date: Thu, 14 May 2026 13:35:00 -0700 Subject: [PATCH] Revert "Add decoder inference Colab (segmentation, depth, normals)" --- pytorch/TIPS_decoder_inference.ipynb | 468 --------------------------- 1 file changed, 468 deletions(-) delete mode 100644 pytorch/TIPS_decoder_inference.ipynb diff --git a/pytorch/TIPS_decoder_inference.ipynb b/pytorch/TIPS_decoder_inference.ipynb deleted file mode 100644 index b2cab29..0000000 --- a/pytorch/TIPS_decoder_inference.ipynb +++ /dev/null @@ -1,468 +0,0 @@ -{ - "cells": [ - { - "id": "94b870ed", - "cell_type": "code", - "source": [ - "# Copyright 2025 Google LLC.\n", - "#\n", - "# SPDX-License-Identifier: Apache-2.0" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "05dc14da", - "cell_type": "markdown", - "source": [ - "# TIPSv2: Segmentation, Depth \u0026 Surface Normals Inference\n", - "\n", - "This notebook demonstrates how to run **dense prediction** inference using\n", - "the TIPSv2 DPT decoders:\n", - "\n", - "1. **Semantic Segmentation** (ADE20K, 150 classes)\n", - "2. **Monocular Depth Estimation** (classification-based, 256 bins)\n", - "3. **Surface Normal Estimation**\n", - "\n", - "The DPT decoder heads take intermediate ViT features from the TIPS vision\n", - "encoder and produce pixel-level predictions.\n", - "\n", - "**Requirements:** GPU runtime recommended for faster inference." - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "3dcd30e6", - "cell_type": "markdown", - "source": [ - "## Setup" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "16cac140", - "cell_type": "code", - "source": [ - "# @title Install dependencies and clone TIPS repo.\n", - "import os\n", - "import sys\n", - "\n", - "ROOT_DIR = os.getcwd()\n", - "TIPS_DIR = os.path.join(ROOT_DIR, 'tips')\n", - "\n", - "# Install required packages.\n", - "!pip install -q torch torchvision torchaudio\n", - "!pip install -q tensorflow_text scikit-learn\n", - "\n", - "# Clone the TIPS repository.\n", - "if not os.path.exists(TIPS_DIR):\n", - " !git clone https://github.com/google-deepmind/tips.git {TIPS_DIR}\n", - "\n", - "# Add the root directory to PYTHONPATH so that `tips.*` imports work.\n", - "if ROOT_DIR not in sys.path:\n", - " sys.path.insert(0, ROOT_DIR)\n", - "\n", - "print(f'ROOT_DIR: {ROOT_DIR}')\n", - "print(f'TIPS_DIR: {TIPS_DIR}')\n", - "print('Installation complete!')" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "985f995d", - "cell_type": "code", - "source": [ - "# @title Download checkpoints and sample images.\n", - "import urllib.request\n", - "import zipfile\n", - "\n", - "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", - "\n", - "VISION_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", - "DPT_CKPT_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/scenic'\n", - "NYU_URL = 'http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip'\n", - "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", - "\n", - "CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')\n", - "os.makedirs(CKPT_DIR, exist_ok=True)\n", - "\n", - "# Checkpoint naming maps.\n", - "V2_CKPT_BASENAME_MAP = {\n", - " 'B': 'tips_v2_oss_b14', 'L': 'tips_v2_oss_l14',\n", - " 'So': 'tips_v2_oss_so14', 'g': 'tips_v2_oss_g14',\n", - "}\n", - "V2_DPT_BASENAME_MAP = {\n", - " 'B': 'tips_v2_b14', 'L': 'tips_v2_l14',\n", - " 'So': 'tips_v2_so400m14', 'g': 'tips_v2_g14',\n", - "}\n", - "ckpt_basename = V2_CKPT_BASENAME_MAP[variant]\n", - "dpt_basename = V2_DPT_BASENAME_MAP[variant]\n", - "\n", - "# Download vision encoder checkpoint.\n", - "vision_ckpt_name = f'{ckpt_basename}_vision.npz'\n", - "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", - "if not os.path.exists(image_encoder_checkpoint):\n", - " print(f'Downloading vision encoder...')\n", - " urllib.request.urlretrieve(f'{VISION_CKPT_URL}/{vision_ckpt_name}', image_encoder_checkpoint)\n", - "\n", - "# Download DPT checkpoints (Segmentation, Depth, Normals).\n", - "# These use Scenic-format checkpoints (Flax .npy arrays in a zip).\n", - "dpt_tasks = ['segmentation', 'depth', 'normals']\n", - "dpt_checkpoint_paths = {}\n", - "for task in dpt_tasks:\n", - " dpt_zip_name = f'{dpt_basename}_{task}_dpt.zip'\n", - " dpt_zip_path = os.path.join(CKPT_DIR, dpt_zip_name)\n", - " if not os.path.exists(dpt_zip_path):\n", - " print(f'Downloading DPT {task} checkpoint...')\n", - " try:\n", - " urllib.request.urlretrieve(f'{DPT_CKPT_URL}/{dpt_zip_name}', dpt_zip_path)\n", - " except Exception as e:\n", - " print(f' Failed: {e}')\n", - " dpt_checkpoint_paths[task] = dpt_zip_path\n", - "\n", - "# Download NYU depth dataset (for depth \u0026 normals demo).\n", - "NYU_IMG_DIR = os.path.join(ROOT_DIR, 'nyu_images')\n", - "if not os.path.isdir(NYU_IMG_DIR):\n", - " print('Downloading NYU dataset...')\n", - " nyu_tmp = os.path.join(ROOT_DIR, 'bedrooms_part6.zip')\n", - " urllib.request.urlretrieve(NYU_URL, nyu_tmp)\n", - " os.makedirs(NYU_IMG_DIR, exist_ok=True)\n", - " with zipfile.ZipFile(nyu_tmp, 'r') as z:\n", - " z.extractall(NYU_IMG_DIR)\n", - " os.remove(nyu_tmp)\n", - "\n", - "# Download ADE20K dataset (for segmentation demo).\n", - "ADE20K_DIR = os.path.join(ROOT_DIR, 'ADEChallengeData2016')\n", - "if not os.path.isdir(ADE20K_DIR):\n", - " print('Downloading ADE20K dataset...')\n", - " ade_tmp = os.path.join(ROOT_DIR, 'ADEChallengeData2016.zip')\n", - " urllib.request.urlretrieve(ADE20K_URL, ade_tmp)\n", - " with zipfile.ZipFile(ade_tmp, 'r') as z:\n", - " z.extractall(ROOT_DIR)\n", - " os.remove(ade_tmp)\n", - "\n", - "print('All downloads complete!')" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "d1d3260a", - "cell_type": "markdown", - "source": [ - "## Load Models" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "1462be26", - "cell_type": "code", - "source": [ - "# @title Load the TIPS vision encoder and all three DPT decoders.\n", - "import numpy as np\n", - "import torch\n", - "from tips.pytorch import image_encoder\n", - "from tips.pytorch.decoders import (\n", - " Decoder, SegmentationDecoder, DepthDecoder, NormalsDecoder,\n", - " load_decoder_weights,\n", - ")\n", - "\n", - "image_size = 448 # @param {type: \"number\"}\n", - "PATCH_SIZE = 14\n", - "\n", - "# Model configs per variant.\n", - "MODEL_CONSTRUCTOR_MAP = {\n", - " 'B': 'vit_base', 'L': 'vit_large', 'So': 'vit_so400m', 'g': 'vit_giant2',\n", - "}\n", - "EMBED_DIM_MAP = {'B': 768, 'L': 1024, 'So': 1152, 'g': 1536}\n", - "INTERMEDIATE_LAYERS_MAP = {\n", - " 'B': [2, 5, 8, 11], 'L': [5, 11, 17, 23],\n", - " 'So': [6, 13, 20, 26], 'g': [9, 19, 29, 39],\n", - "}\n", - "\n", - "vit_constructor = getattr(image_encoder, MODEL_CONSTRUCTOR_MAP[variant])\n", - "embed_dim = EMBED_DIM_MAP[variant]\n", - "intermediate_layers = INTERMEDIATE_LAYERS_MAP[variant]\n", - "post_process_channels = (embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim)\n", - "ffn_layer = 'swiglu' if variant == 'g' else 'mlp'\n", - "\n", - "# --- Vision Encoder ---\n", - "weights_image = dict(np.load(image_encoder_checkpoint, allow_pickle=False))\n", - "for key in weights_image:\n", - " weights_image[key] = torch.tensor(weights_image[key])\n", - "\n", - "with torch.no_grad():\n", - " model_image = vit_constructor(\n", - " img_size=image_size, patch_size=PATCH_SIZE, ffn_layer=ffn_layer,\n", - " block_chunks=0, init_values=1.0,\n", - " interpolate_antialias=True, interpolate_offset=0.0,\n", - " )\n", - " model_image.load_state_dict(weights_image)\n", - " model_image.eval()\n", - "print(f'✓ Vision encoder loaded ({variant})')\n", - "\n", - "# --- Segmentation Decoder ---\n", - "with torch.no_grad():\n", - " seg_model = SegmentationDecoder(\n", - " num_classes=150, input_embed_dim=embed_dim,\n", - " post_process_channels=post_process_channels,\n", - " )\n", - " load_decoder_weights(seg_model, dpt_checkpoint_paths['segmentation'])\n", - " seg_model.eval()\n", - "\n", - "# --- Depth Decoder ---\n", - "with torch.no_grad():\n", - " depth_model = DepthDecoder(\n", - " input_embed_dim=embed_dim,\n", - " post_process_channels=post_process_channels,\n", - " )\n", - " load_decoder_weights(depth_model, dpt_checkpoint_paths['depth'])\n", - " depth_model.eval()\n", - "\n", - "# --- Normals Decoder ---\n", - "with torch.no_grad():\n", - " normals_model = NormalsDecoder(\n", - " input_embed_dim=embed_dim,\n", - " post_process_channels=post_process_channels,\n", - " )\n", - " load_decoder_weights(normals_model, dpt_checkpoint_paths['normals'])\n", - " normals_model.eval()\n", - "\n", - "print('✓ All decoders loaded')" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "06b04c61", - "cell_type": "code", - "source": [ - "# @title Define helper: extract ViT features.\n", - "import torchvision.transforms as TVT\n", - "import PIL.Image\n", - "\n", - "transform = TVT.Compose([TVT.Resize((image_size, image_size)), TVT.ToTensor()])\n", - "\n", - "def extract_features(img_path):\n", - " \"\"\"Load image and extract intermediate ViT features.\"\"\"\n", - " img = PIL.Image.open(img_path).convert(\"RGB\")\n", - " tensor = transform(img).unsqueeze(0)\n", - " device = next(model_image.parameters()).device\n", - " tensor = tensor.to(device)\n", - " with torch.no_grad():\n", - " features = model_image.get_intermediate_layers(\n", - " tensor, n=intermediate_layers, reshape=True,\n", - " return_class_token=True, norm=True,\n", - " )\n", - " # Reorder: (feat, cls) -\u003e (cls, feat)\n", - " features = [(cls, feat) for feat, cls in features]\n", - " return img, features" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "a4f824d8", - "cell_type": "code", - "source": [ - "# @title Define ADE20K class names and color palette.\n", - "import colorsys\n", - "\n", - "ADE20K_CLASSES = (\n", - " 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road',\n", - " 'bed', 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person',\n", - " 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',\n", - " 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea',\n", - " 'mirror', 'rug', 'field', 'armchair', 'seat', 'fence', 'desk',\n", - " 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',\n", - " 'base', 'box', 'column', 'signboard', 'chest of drawers',\n", - " 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',\n", - " 'refrigerator', 'grandstand', 'path', 'stairs', 'runway',\n", - " 'case', 'pool table', 'pillow', 'screen door', 'stairway',\n", - " 'river', 'bridge', 'bookcase', 'blind', 'coffee table',\n", - " 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop',\n", - " 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair',\n", - " 'boat', 'bar', 'arcade machine', 'hovel', 'bus',\n", - " 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning',\n", - " 'streetlight', 'booth', 'television', 'airplane', 'dirt track',\n", - " 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman',\n", - " 'bottle', 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain',\n", - " 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',\n", - " 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag',\n", - " 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',\n", - " 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',\n", - " 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',\n", - " 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier',\n", - " 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower',\n", - " 'radiator', 'glass', 'clock', 'flag',\n", - ")\n", - "\n", - "def _generate_ade20k_palette(n=150):\n", - " palette = np.zeros((n, 3), dtype=np.uint8)\n", - " for i in range(n):\n", - " hue = i / n\n", - " saturation = 0.7 + 0.3 * ((i * 7) % 10) / 10\n", - " value = 0.6 + 0.4 * ((i * 3) % 10) / 10\n", - " r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)\n", - " palette[i] = [int(r * 255), int(g * 255), int(b * 255)]\n", - " return palette\n", - "\n", - "ADE20K_PALETTE = _generate_ade20k_palette()\n", - "print(f'Defined {len(ADE20K_CLASSES)} ADE20K classes')" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "9d2e893b", - "cell_type": "markdown", - "source": [ - "## Run Inference\n", - "\n", - "### Semantic Segmentation (ADE20K)" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "51628b05", - "cell_type": "code", - "source": [ - "# @title Run segmentation on an ADE20K sample image.\n", - "import matplotlib.pyplot as plt\n", - "\n", - "ade_img_dir = os.path.join(ADE20K_DIR, 'images', 'validation')\n", - "ade_images = sorted([\n", - " os.path.join(ade_img_dir, f)\n", - " for f in os.listdir(ade_img_dir) if f.endswith('.jpg')\n", - "])\n", - "\n", - "image_path = ade_images[1] # A castle scene\n", - "print(f'Image: {image_path}')\n", - "\n", - "img, features = extract_features(image_path)\n", - "\n", - "with torch.no_grad():\n", - " seg_logits = seg_model(features, image_size=(image_size, image_size))\n", - " seg_map = seg_logits.argmax(dim=1).squeeze(0).cpu().numpy()\n", - "\n", - "colored_seg = ADE20K_PALETTE[seg_map]\n", - "\n", - "# Print top classes found.\n", - "unique_classes, counts = np.unique(seg_map, return_counts=True)\n", - "top_idx = np.argsort(-counts)[:5]\n", - "print('Top classes:')\n", - "for idx in top_idx:\n", - " cls_id = unique_classes[idx]\n", - " pct = 100 * counts[idx] / seg_map.size\n", - " print(f' {ADE20K_CLASSES[cls_id]:20s} ({pct:.1f}%)')\n", - "\n", - "plt.figure(figsize=(12, 5))\n", - "plt.subplot(1, 2, 1)\n", - "plt.imshow(img.resize((image_size, image_size)))\n", - "plt.title('Input Image')\n", - "plt.axis('off')\n", - "plt.subplot(1, 2, 2)\n", - "plt.imshow(colored_seg)\n", - "plt.title('Semantic Segmentation')\n", - "plt.axis('off')\n", - "plt.tight_layout()\n", - "plt.show()" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "d296d2e5", - "cell_type": "markdown", - "source": [ - "### Depth Estimation \u0026 Surface Normals (NYU)" - ], - "metadata": {}, - "execution_count": null - }, - { - "id": "7796c341", - "cell_type": "code", - "source": [ - "# @title Run depth and normals inference on NYU images.\n", - "import torch.nn.functional as F\n", - "\n", - "# Collect NYU sample images.\n", - "valid_extensions = ('.ppm', '.jpg', '.jpeg', '.png')\n", - "nyu_images = []\n", - "for root, dirs, files in os.walk(NYU_IMG_DIR):\n", - " for file in files:\n", - " if file.lower().endswith(valid_extensions):\n", - " nyu_images.append(os.path.join(root, file))\n", - "\n", - "selected_images = nyu_images[:3]\n", - "\n", - "for i, image_path in enumerate(selected_images):\n", - " print(f'Processing image {i+1}/{len(selected_images)}: {os.path.basename(image_path)}')\n", - " img, features = extract_features(image_path)\n", - "\n", - " with torch.no_grad():\n", - " # --- Depth ---\n", - " depth_map = depth_model(features, image_size=(image_size, image_size))\n", - " depth_np = depth_map.squeeze().cpu().numpy()\n", - " # Normalize for visualization.\n", - " depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)\n", - "\n", - " # --- Normals ---\n", - " # Get raw low-res output first.\n", - " normals_map = normals_model(features)\n", - " # L2 normalize.\n", - " normals_map = F.normalize(normals_map, dim=1)\n", - " # Upsample with bicubic for smooth results.\n", - " normals_map = F.interpolate(\n", - " normals_map, size=(image_size, image_size),\n", - " mode='bicubic', align_corners=False,\n", - " )\n", - " # Re-normalize after upsampling.\n", - " normals_map = F.normalize(normals_map, dim=1)\n", - " normals_np = normals_map.squeeze(0).cpu().numpy().transpose(1, 2, 0)\n", - " # Map [-1, 1] -\u003e [0, 1] for display.\n", - " normals_np = np.clip((normals_np + 1.0) / 2.0, 0.0, 1.0)\n", - "\n", - " # Visualize.\n", - " plt.figure(figsize=(15, 5))\n", - " plt.subplot(1, 3, 1)\n", - " plt.imshow(img.resize((image_size, image_size)))\n", - " plt.title(f'Input ({i+1})')\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1, 3, 2)\n", - " plt.imshow(depth_np, cmap='turbo')\n", - " plt.title(f'Depth ({i+1})')\n", - " plt.axis('off')\n", - "\n", - " plt.subplot(1, 3, 3)\n", - " plt.imshow(normals_np)\n", - " plt.title(f'Surface Normals ({i+1})')\n", - " plt.axis('off')\n", - "\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "metadata": {}, - "execution_count": null - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat_minor": 5, - "nbformat": 4 -}