diff --git a/tutorials/optimviz/clip/LinearProbeFacetTraining_OptimViz.ipynb b/tutorials/optimviz/clip/LinearProbeFacetTraining_OptimViz.ipynb new file mode 100644 index 0000000000..e60f5c1f74 --- /dev/null +++ b/tutorials/optimviz/clip/LinearProbeFacetTraining_OptimViz.ipynb @@ -0,0 +1,896 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "LinearProbeFacetTraining_OptimViz.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Training Linear Probes For Faceted Feature Visualization\n", + "\n", + "This tutorial demonstrates how to train linear probes for use in faceted feature visualization, as described in the Faceted Feature Visualization section of the Multimodal Neurons in Artificial Neural Networks research paper [here](https://distill.pub/2021/multimodal-neurons/#faceted-feature-visualization)." + ], + "metadata": { + "id": "Cf21lrB3QtMU" + } + }, + { + "cell_type": "code", + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import copy\n", + "import time\n", + "from collections import Counter\n", + "from typing import Dict, List, Optional, Tuple, Union\n", + "\n", + "import captum.optim as opt\n", + "import torch\n", + "import torchvision\n", + "\n", + "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")" + ], + "metadata": { + "id": "wt80XBrVGKgw" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Setup\n", + "\n", + "Before we can start training the linear probes, we'll need to do a bit of setup first. Below we define a helper function for balancing the classes of image datasets, and an optional transform that pads input images to squares for datasets requiring more spatial similarity." + ], + "metadata": { + "id": "ocLIFwn8iXGa" + } + }, + { + "cell_type": "code", + "source": [ + "def balance_training_classes(\n", + " dataloader: torch.utils.data.DataLoader, num_classes: int = 2\n", + ") -> List[float]:\n", + " \"\"\"\n", + " Calculate balancing weights for a given dataloader instance.\n", + "\n", + " Args:\n", + "\n", + " dataloader (torch.utils.data.DataLoader): A dataloader instance to count the\n", + " number of images in each class for.\n", + " num_classes (int, optional): The number of classes used in the dataset.\n", + " Default: 2\n", + "\n", + " Returns:\n", + " weights (list of float): A list of values for balancing the classes.\n", + " \"\"\"\n", + " train_class_counts = dict(\n", + " Counter(sample_tup[1] for sample_tup in dataloader.dataset)\n", + " )\n", + " train_class_counts = dict(sorted(train_class_counts.items()))\n", + " train_weights = [\n", + " 1.0 / train_class_counts[class_id] for class_id in range(num_classes)\n", + " ]\n", + " return train_weights\n", + "\n", + "\n", + "class PadToSquare(torch.nn.Module):\n", + " \"\"\"\n", + " Transform for padding rectangular shaped inputs to squares without messing up the\n", + " aspect ratio.\n", + " \"\"\"\n", + "\n", + " __constants__ = [\"padding_value\"]\n", + "\n", + " def __init__(self, padding_value: float = 0.0) -> None:\n", + " \"\"\"\n", + " Args:\n", + "\n", + " padding_value (float, optional): The value to use for the constant\n", + " padding.\n", + " Default: 0.0\n", + " \"\"\"\n", + " super().__init__()\n", + " self.padding_value = padding_value\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " assert x.dim() == 4 or x.dim() == 3\n", + " if x.dim() == 4:\n", + " C, H, W = x.shape[1:]\n", + " elif x.dim() == 3:\n", + " C, H, W = x.shape\n", + " top, left = [(max(H, W) - d) // 2 for d in [H, W]]\n", + " bottom, right = [max(H, W) - (d + pad) for d, pad in zip([H, W], [top, left])]\n", + "\n", + " padding = [left, right, top, bottom]\n", + " if x.dim() == 3:\n", + " return torch.nn.functional.pad(\n", + " x[None, :], padding, value=self.padding_value, mode=\"constant\"\n", + " )[0]\n", + " else:\n", + " return torch.nn.functional.pad(\n", + " x, padding, value=self.padding_value, mode=\"constant\"\n", + " )\n", + "\n", + "\n", + "def get_dataset_indices(dataset_path: str) -> Dict[str, int]:\n", + " \"\"\"\n", + " If you are not sure what the class indices are for your training images & the\n", + " generic natural images, then you can use this handy helper function that\n", + " replicates the ordering used by `torchvision.datasets.ImageFolder`.\n", + "\n", + " Args:\n", + "\n", + " dataset_path (str): The path to your image dataset that is using the standard\n", + " ImageFolder structure.\n", + "\n", + "\n", + " Returns\n", + " class_and_idx (dict of str and int): The folder names and corresponding class\n", + " indices.\n", + " \"\"\"\n", + " import os\n", + "\n", + " classes = [d.name for d in os.scandir(dataset_path) if d.is_dir()]\n", + " classes.sort()\n", + " return {cls_name: i for i, cls_name in enumerate(classes)}" + ], + "metadata": { + "id": "0EzmQvA9x4vt" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Dataset Setup\n", + "\n", + "\n", + "For the purpose of this tutorial we demonstrate setting up a basic dataset utilizing Torchvision's [ImageFolder](https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder). However you can use whatever dataset you like, provided of course it works with [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), otherwise you may have to modify the training function to support your dataset.\n", + "\n", + "The authors of the research paper recommend that image datasets should contain a minimum of 2 classes, where one class is composed of generic natural images and the other class or classes contain the desired themes / concepts. The basic idea behind the image dataset class structure is to train the model to separate out a theme / concept from unrelated stuff." + ], + "metadata": { + "id": "fVIzo7g4Q9ic" + } + }, + { + "cell_type": "markdown", + "source": [ + "**Spatial information in your dataset**\n", + "\n", + "In the research paper, the authors trained some of the facets on images where the features in each image in the dataset were in roughly the same locations. This is important to note only if you are trying to create similar facets where you want more spatially coherent shapes like those of the `face` facet used in other tutorials." + ], + "metadata": { + "id": "QxyGxILRMVC8" + } + }, + { + "cell_type": "code", + "source": [ + "def create_dataloaders(\n", + " dataset_path: str,\n", + " batch_size: int = 32,\n", + " val_percent: float = 0.0,\n", + " training_transforms: torch.nn.Module = None,\n", + " validation_transforms: Optional[torch.nn.Module] = None,\n", + " balance_classes: bool = False,\n", + " num_classes: int = 2,\n", + ") -> Dict[str, Union[torch.utils.data.DataLoader, List[float]]]:\n", + " \"\"\"\n", + " Create one or more dataloader instances with optional balancing weights for a\n", + " given image dataset, with Torchvision's ImageFolder directory format.\n", + "\n", + " https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder\n", + "\n", + " Args:\n", + "\n", + " dataset_path (str): The path to the image dataset to use for torchvision's\n", + " ImageFolder dataset. See above for more details.\n", + " batch_size (int, optional): The batch size to use.\n", + " Default: 32\n", + " val_percent (float, optional): The percentage of the dataset to use for\n", + " validation. If set to 0 then no validation dataset will be created.\n", + " Default: 0.0\n", + " training_transforms (nn.Module): Transforms to use for training the linear\n", + " probes.\n", + " validation_transforms (nn.Module, optional): Transforms to use for validation,\n", + " if validation is enabled.\n", + " balance_classes (bool, optional): Whether or not to calculate weights for\n", + " balancing the training classes.\n", + " Default: False\n", + " num_classes (int, optional): If balance_classes is set to True, then this\n", + " variable provides the number of classes in the dataset to use in the\n", + " balancing calculations.\n", + " Default: 2\n", + "\n", + " Returns:\n", + " dataloaders (dict of dataloader and list of float): A dictionary containing\n", + " the training dataloader, with optional validation dataloader and balancing\n", + " weights for the training dataloader.\n", + " \"\"\"\n", + " full_dataset = torchvision.datasets.ImageFolder(\n", + " root=dataset_path,\n", + " )\n", + "\n", + " if val_percent > 0.0:\n", + " assert validation_transforms is not None\n", + " n = len(full_dataset)\n", + " lengths = [round(n * (1 - val_percent)), round(n * val_percent)]\n", + "\n", + " t_data, v_data = torch.utils.data.random_split(full_dataset, lengths)\n", + " t_data = copy.deepcopy(t_data)\n", + "\n", + " t_data.dataset.transform = training_transforms\n", + " v_data.dataset.transform = validation_transforms\n", + "\n", + " t_dataloader = torch.utils.data.DataLoader(\n", + " t_data,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " )\n", + " v_dataloader = torch.utils.data.DataLoader(\n", + " v_data, batch_size=batch_size, shuffle=True\n", + " )\n", + " dataloader = {\"train\": t_dataloader, \"val\": v_dataloader}\n", + " else:\n", + " t_dataset = torch.utils.data.Subset(\n", + " copy.deepcopy(full_dataset), range(0, len(full_dataset))\n", + " )\n", + " t_dataset.dataset.transform = training_transforms\n", + " t_dataloader = torch.utils.data.DataLoader(\n", + " t_dataset, batch_size=batch_size, shuffle=True\n", + " )\n", + " dataloader = {\"train\": t_dataloader}\n", + "\n", + " if balance_classes:\n", + " train_weights = balance_training_classes(dataloader[\"train\"], num_classes)\n", + " dataloader[\"train_weights\"] = train_weights\n", + " return dataloader" + ], + "metadata": { + "id": "8zl0aQdnF7fW" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Training Function\n", + "\n", + "The model training function's `dataloaders` variable requires training dataloaders to be organized in into dictionaries containing the following keys and values:\n", + "\n", + "* `train`: The training dataloader.\n", + "* `val`: Optionally include validation dataloader. If this key doesn't exist in the dict, then no validation phase will be performed.\n", + "* `train_weights`: Optionally include a list of training weights to balance the classes during training.\n", + "\n", + "\n", + "Linear probes are implemented as [`nn.LazyLinear`](https://pytorch.org/docs/stable/generated/torch.nn.LazyLinear.html) layers with a reshaping operation between them and the target layer." + ], + "metadata": { + "id": "6gnSpoNhiRpD" + } + }, + { + "cell_type": "code", + "source": [ + "def train_linear_probes(\n", + " model: torch.nn.Module,\n", + " target_layers: List[torch.nn.Module],\n", + " dataloaders: Dict[str, Union[torch.utils.data.DataLoader, List[float]]],\n", + " out_features: int = 2,\n", + " num_epochs: int = 10,\n", + " lr: float = 1.0,\n", + " l1_weight: float = 0.0,\n", + " l2_weight: float = 0.0,\n", + " use_optimizer: str = \"lbfgs\",\n", + " device: torch.device = torch.device(\"cpu\"),\n", + " save_epoch: Optional[int] = None,\n", + " save_path: str = \"epoch_\",\n", + " verbose: bool = True,\n", + " show_progress: bool = False,\n", + ") -> Tuple[List[torch.Tensor]]:\n", + " \"\"\"\n", + " Train linear probes on target layers of a specified model, for use as faceted\n", + " feature visualization facet weights.\n", + "\n", + " Args:\n", + "\n", + " model (nn.Module): An PyTorch model instance.\n", + " target_layers (nn.Module): A list of model targets to train linear probes for.\n", + " dataloaders (dict of torch.utils.data.DataLoader): A dictionary of PyTorch\n", + " Dataloader instances for training and optionally for validation.\n", + " num_epochs (int, optional): The number of epochs to train for.\n", + " Default: 10\n", + " l1_weight (float, optional): The desired l1 penalty weight to use.\n", + " Default: 0.0\n", + " l2_weight (float, optional): The desired l2 penalty weight to use.\n", + " Default: 0.0\n", + " lr (float, optional): The desired learning rate to use with the optimizer.\n", + " Default: 1.0\n", + " use_optimizer (str, optional): The optimizer to use. Choices are: \"sgd\" or\n", + " \"lbfgs\".\n", + " Default: \"lbfgs\"\n", + " device (torch.device, optional): The device to place training inputs on before\n", + " sending them through the model.\n", + " Default: torch.device(\"cpu\")\n", + " save_epoch (int, optional): Save the best model weights every save_epoch\n", + " epochs. Set to None to not save any epochs.\n", + " Default: None\n", + " save_path (str, optional): If save_epoch is not None, save model weights with\n", + " the path / name: .\n", + " Default: \"epoch_\"\n", + " verbose (bool, optional): Whether or not to print loss and accuracy after\n", + " every epoch.\n", + " Default: True\n", + "\n", + " Returns:\n", + " weights (list of torch.Tensor): The weights of the best scoring models from\n", + " the training session. The order of the weights corresponds to\n", + " `target_layers`.\n", + " best_acc (list of float): The training accuracies for the returned weights.\n", + " The order corresponds to `weights`.\n", + " \"\"\"\n", + " assert use_optimizer in [\"lbfgs\", \"sgd\"]\n", + " assert \"train\" in dataloaders\n", + "\n", + " phases = [\"train\", \"val\"] if \"val\" in dataloaders else [\"train\"]\n", + "\n", + " # Optionally balance classes if provided with weight balancing tensor\n", + " if \"train_weights\" in dataloaders:\n", + " crit_weights = torch.FloatTensor(dataloaders[\"train_weights\"])\n", + " criterion = torch.nn.CrossEntropyLoss(weight=crit_weights).to(device)\n", + " else:\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + "\n", + " # Create Linear Probes using LazyLinear so that we don't need to specify an input size\n", + " layer_probes = [\n", + " torch.nn.LazyLinear(out_features, bias=False).to(device).train()\n", + " for _ in target_layers\n", + " ]\n", + " num_probes = len(target_layers)\n", + "\n", + " # Setup model saving\n", + " best_models = [None for _ in layer_probes]\n", + " best_accs = [0.0] * num_probes\n", + "\n", + " # Setup optimizer\n", + " parameters = []\n", + " for p in layer_probes:\n", + " parameters += list(p.parameters())\n", + " if use_optimizer == \"lbfgs\":\n", + " optimizer = torch.optim.LBFGS(\n", + " parameters, lr=lr, max_iter=1, tolerance_change=-1, tolerance_grad=-1\n", + " )\n", + " else:\n", + " optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.0, weight_decay=0.0)\n", + "\n", + " # Get dataset lengths beforehand to speed things up\n", + " val_length = 0 if \"val\" not in dataloaders else len(dataloaders[\"val\"].dataset)\n", + " dataset_length = {\"train\": len(dataloaders[\"train\"].dataset), \"val\": val_length}\n", + "\n", + " start_time = time.time()\n", + " for epoch in range(num_epochs):\n", + " if verbose:\n", + " print(\"Epoch {}/{}\".format(epoch + 1, num_epochs))\n", + " print(\"-\" * 12)\n", + "\n", + " for phase in phases:\n", + " if phase == \"train\":\n", + " [layer_probes[i].train() for i in range(num_probes)]\n", + " else:\n", + " [layer_probes[i].eval() for i in range(num_probes)]\n", + "\n", + " phase_stats = {\n", + " \"epoch_acc\": [0.0] * num_probes,\n", + " \"epoch_loss\": [0.0] * num_probes,\n", + " }\n", + "\n", + " for inputs, labels in dataloaders[phase]:\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + "\n", + " with torch.set_grad_enabled(phase == \"train\"):\n", + " if use_optimizer == \"lbfgs\":\n", + " # Training with torch.optim.LBFGS\n", + "\n", + " def closure() -> torch.Tensor:\n", + " optimizer.zero_grad()\n", + " # Collect outputs for target layers\n", + " probe_inputs = opt.models.collect_activations(\n", + " model, target_layers, inputs\n", + " )\n", + " outputs = [probe_inputs[target] for target in target_layers]\n", + "\n", + " # Send layer outputs through linear probes\n", + " outputs = [\n", + " probe(x.reshape(x.shape[0], -1))\n", + " for x, probe in zip(outputs, layer_probes)\n", + " ]\n", + "\n", + " probe_losses = [\n", + " criterion(outputs[i], labels) for i in range(num_probes)\n", + " ]\n", + " preds = [\n", + " torch.max(outputs[i], 1)[1] for i in range(num_probes)\n", + " ]\n", + " loss = sum(probe_losses)\n", + "\n", + " if phase == \"train\":\n", + "\n", + " # Apply optional L1 or L2 penalties\n", + " if l1_weight != 0.0 or l2_weight != 0.0:\n", + " if l1_weight != 0.0:\n", + " l1_penalty = sum(\n", + " [\n", + " l1_weight * p.weight.abs().sum()\n", + " for p in layer_probes\n", + " ]\n", + " )\n", + " loss = loss + l1_penalty\n", + " if l2_weight != 0.0:\n", + " l2_penalty = l2_weight * sum(\n", + " [\n", + " (p.weight**2).sum()\n", + " for p in layer_probes\n", + " ]\n", + " )\n", + " loss = loss + l2_penalty\n", + "\n", + " loss.backward()\n", + "\n", + " with torch.no_grad():\n", + " phase_stats[\"epoch_loss\"] = [\n", + " phase_stats[\"epoch_loss\"][i]\n", + " + l.detach().item() * inputs.size(0)\n", + " for i, l in enumerate(probe_losses)\n", + " ]\n", + " phase_stats[\"epoch_acc\"] = [\n", + " phase_stats[\"epoch_acc\"][i]\n", + " + torch.sum(p == labels).detach().item()\n", + " for i, p in enumerate(preds)\n", + " ]\n", + " return loss\n", + "\n", + " optimizer.step(closure)\n", + " else:\n", + " # Training with torch.optim.SGD\n", + "\n", + " optimizer.zero_grad()\n", + " # Collect outputs for target layers\n", + " probe_inputs = opt.models.collect_activations(\n", + " model, target_layers, inputs\n", + " )\n", + " outputs = [probe_inputs[target] for target in target_layers]\n", + "\n", + " # Send layer outputs through linear probes\n", + " outputs = [\n", + " probe(x.reshape(x.shape[0], -1))\n", + " for x, probe in zip(outputs, layer_probes)\n", + " ]\n", + "\n", + " probe_losses = [\n", + " criterion(outputs[i], labels)\n", + " for i in range(len(layer_probes))\n", + " ]\n", + " preds = [\n", + " torch.max(outputs[i], 1)[1]\n", + " for i in range(len(layer_probes))\n", + " ]\n", + "\n", + " loss = sum(probe_losses)\n", + "\n", + " if phase == \"train\":\n", + "\n", + " # Apply optional L1 or L2 penalties\n", + " if l1_weight != 0.0:\n", + " l1_penalty = sum(\n", + " [\n", + " l1_weight * p.weight.abs().sum()\n", + " for p in layer_probes\n", + " ]\n", + " )\n", + " loss = loss + l1_penalty\n", + " if l2_weight != 0.0:\n", + " l2_penalty = l2_weight * sum(\n", + " [(p.weight**2).sum() for p in layer_probes]\n", + " )\n", + " loss = loss + l2_penalty\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " with torch.no_grad():\n", + " phase_stats[\"epoch_loss\"] = [\n", + " phase_stats[\"epoch_loss\"][i]\n", + " + l.detach().item() * inputs.size(0)\n", + " for i, l in enumerate(probe_losses)\n", + " ]\n", + " phase_stats[\"epoch_acc\"] = [\n", + " phase_stats[\"epoch_acc\"][i]\n", + " + torch.sum(p == labels).detach().item()\n", + " for i, p in enumerate(preds)\n", + " ]\n", + "\n", + " phase_stats[\"epoch_loss\"] = [\n", + " phase_stats[\"epoch_loss\"][i] / dataset_length[phase]\n", + " for i in range(num_probes)\n", + " ]\n", + " phase_stats[\"epoch_acc\"] = [\n", + " phase_stats[\"epoch_acc\"][i] / dataset_length[phase]\n", + " for i in range(num_probes)\n", + " ]\n", + "\n", + " # Make sure we keep the best model weights\n", + " if phase == \"val\" or \"val\" not in phases:\n", + " for i, acc in enumerate(phase_stats[\"epoch_acc\"]):\n", + " if acc > best_accs[i]:\n", + " best_accs[i] = acc\n", + " best_models[i] = layer_probes[i].weight.clone().detach().cpu()\n", + "\n", + " if verbose:\n", + " print(\n", + " \"{} Loss: {:.4f} Acc: {:.4f}\".format(\n", + " phase,\n", + " sum(phase_stats[\"epoch_loss\"]) / num_probes,\n", + " sum(phase_stats[\"epoch_acc\"]) / num_probes,\n", + " )\n", + " )\n", + " print(\" Loss: \", [round(v, 4) for v in phase_stats[\"epoch_loss\"]])\n", + " print(\" Acc: \", [round(acc, 4) for acc in phase_stats[\"epoch_acc\"]])\n", + " time_elapsed = time.time() - start_time\n", + " print(\n", + " \"Time Elapsed {:.0f}m {:.0f}s\".format(\n", + " time_elapsed // 60, time_elapsed % 60\n", + " )\n", + " )\n", + " if epoch + 1 != num_epochs:\n", + " print()\n", + "\n", + " if save_epoch and (epoch + 1) % save_epoch == 0 and (epoch + 1) != num_epochs:\n", + " facet_weights = [w.clone().cpu().detach() for w in best_models]\n", + " filename = save_path + str(epoch + 1) + \".pt\"\n", + " torch.save([w.cpu() for w in facet_weights], filename)\n", + "\n", + " return best_models, best_accs" + ], + "metadata": { + "id": "0EHyeCMKiIi1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Load Model & Dataset" + ], + "metadata": { + "id": "0ds-L3I8okgX" + } + }, + { + "cell_type": "markdown", + "source": [ + "Now that we have the required classes and functions defined, we load the ResNet 50x4 image model without `RedirectedReLU`." + ], + "metadata": { + "id": "X6l71TR0fTKj" + } + }, + { + "cell_type": "code", + "source": [ + "# Load image model\n", + "clip_model = (\n", + " opt.models.clip_resnet50x4_image(\n", + " pretrained=True, replace_relus_with_redirectedrelu=False\n", + " )\n", + " .eval()\n", + " .to(device)\n", + ")" + ], + "metadata": { + "id": "BYGdvCKMFxbc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Next we load our dataset's dataloaders for training. Remember that our dataloader creation function uses Torchvision's ImageFolder, and thus different datasets may need their own setup functions." + ], + "metadata": { + "id": "8Q9i7KYBfxp4" + } + }, + { + "cell_type": "code", + "source": [ + "dataset_path = \"my_dataset\" # Path to dataset\n", + "num_classes = 2 # Number of classes in our dataset\n", + "\n", + "# Setup transforms for training\n", + "training_transforms = torchvision.transforms.Compose(\n", + " [\n", + " torchvision.transforms.ToTensor(),\n", + " # PadToSquare(1.0),\n", + " torchvision.transforms.Resize((288, 288), antialias=True),\n", + " ]\n", + ")\n", + "\n", + "dataloaders = create_dataloaders(\n", + " dataset_path,\n", + " batch_size=16,\n", + " val_percent=0.0,\n", + " training_transforms=training_transforms,\n", + " balance_classes=True,\n", + " num_classes=num_classes,\n", + ")" + ], + "metadata": { + "id": "48fVVUXmfu4E" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Training The Linear Probes" + ], + "metadata": { + "id": "CJsBWsMuUZzx" + } + }, + { + "cell_type": "markdown", + "source": [ + "We can now begin training the linear probes on the target layers! Below we train linear probes on the same 5 lower layers as the researchers did in the paper.\n", + "\n", + "Note that using the [L-BFGS optimizer](https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html) will generally produce the best quality facets, but it will also use more memory than the [SGD optimizer](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html). Memory usage can also be reduced by training fewer linear probes at once.\n", + "\n", + "Note that you may have to adjust the default parameters for training for custom datasets and models." + ], + "metadata": { + "id": "3NwqlpzkfdeB" + } + }, + { + "cell_type": "code", + "source": [ + "# Layers to train linear probes for\n", + "target_layers = [\n", + " clip_model.layer3[0].relu3,\n", + " clip_model.layer3[2].relu3,\n", + " clip_model.layer3[4].relu3,\n", + " clip_model.layer3[6].relu3,\n", + " clip_model.layer3[8].relu3,\n", + "]\n", + "\n", + "\n", + "# The L-BFGS optimizer will use more memory than the SGD optimizer\n", + "use_optimizer = \"lbfgs\" # Whether to optimize with \"lbfgs\" or \"sgd\"\n", + "\n", + "# Optimizer specific param setup\n", + "if use_optimizer == \"lbfgs\":\n", + " l2_weight = 0.0\n", + " lr = 1.0\n", + "else:\n", + " l2_weight = 0.316\n", + " lr = 0.0001\n", + "\n", + "# Train linear probes\n", + "weights, weight_accs = train_linear_probes(\n", + " model=clip_model,\n", + " target_layers=target_layers,\n", + " dataloaders=dataloaders,\n", + " # This should be the same as the number of classes in the dataset\n", + " out_features=num_classes,\n", + " num_epochs=5,\n", + " lr=lr,\n", + " l2_weight=l2_weight,\n", + " use_optimizer=use_optimizer,\n", + " device=device,\n", + ")" + ], + "metadata": { + "id": "a0yFS4JQ4zY_", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "bc4a51c3-2e69-4ab5-a265-4c3e3db9f27d" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/5\n", + "------------\n", + "train Loss: 390337.9189 Acc: 0.9715\n", + " Loss: [56043.4749, 1363915.4473, 124310.3623, 168846.0195, 238574.2905]\n", + " Acc: [0.9718, 0.966, 0.9722, 0.9705, 0.9771]\n", + "Time Elapsed 3m 14s\n", + "\n", + "Epoch 2/5\n", + "------------\n", + "train Loss: 16781.2769 Acc: 0.9976\n", + " Loss: [14076.3319, 31218.2309, 6106.3447, 19327.1426, 13178.3344]\n", + " Acc: [0.9958, 0.9979, 0.9986, 0.9969, 0.999]\n", + "Time Elapsed 6m 31s\n", + "\n", + "Epoch 3/5\n", + "------------\n", + "train Loss: 329.2152 Acc: 0.9994\n", + " Loss: [689.9083, 327.7661, 481.1846, 147.2171, 0.0]\n", + " Acc: [0.9982, 0.9997, 0.9994, 0.9994, 1.0]\n", + "Time Elapsed 9m 48s\n", + "\n", + "Epoch 4/5\n", + "------------\n", + "train Loss: 468.3097 Acc: 0.9989\n", + " Loss: [546.3372, 485.5594, 319.5212, 988.2269, 1.9037]\n", + " Acc: [0.9987, 0.999, 0.9993, 0.9978, 0.9999]\n", + "Time Elapsed 13m 5s\n", + "\n", + "Epoch 5/5\n", + "------------\n", + "train Loss: 100.6919 Acc: 0.9997\n", + " Loss: [236.6766, 138.6808, 78.6038, 49.4981, 0.0]\n", + " Acc: [0.9994, 0.9997, 0.9997, 0.9997, 1.0]\n", + "Time Elapsed 16m 21s\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "Now that we have our trained weights, we can slice out the batch dimensions that correspond to the predicted theme / concept that we are training on while ignoring the batch dimension for the generic natural images. For this tutorial we were only training 1 class in addition to the generic natural images, so we only have one index of weights to collect." + ], + "metadata": { + "id": "YIb8Swx-e0Oi" + } + }, + { + "cell_type": "code", + "source": [ + "# Uncomment to get dataset class indices for ImageFolder datasets\n", + "# print(get_dataset_indices(dataset_path))" + ], + "metadata": { + "id": "8cTCnWIPySRS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# We only need the theme / concept part of the weights\n", + "theme_idx = 0 # Class idx for the target theme / concept\n", + "facet_weights = [w[theme_idx : theme_idx + 1] for w in weights]" + ], + "metadata": { + "id": "QnX-gDLqUeq_" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The `nn.LazyLinear` layers used to train the probes require 2D inputs, and thus 4D layer targets like `nn.Conv2d` layers need to be reshaped back to their 4D output shapes after training. For this tutorial, all layer targets have an output shape of: `[N, 1280, 18, 18]`." + ], + "metadata": { + "id": "WOvE54Sk2KEJ" + } + }, + { + "cell_type": "code", + "source": [ + "# Uncomment to view the shape of each layer\n", + "# out_dict = opt.models.collect_activations(\n", + "# clip_model, target_layers, torch.zeros(1, 3, 288, 288)\n", + "# )\n", + "# print([out_dict[t].shape for t in target_layers])" + ], + "metadata": { + "id": "o9n1yOfTDyR3" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Each probe weight can be reshaped to match its corresponding model layer\n", + "facet_weights = [w.reshape(1, 1280, 18, 18) for w in facet_weights]" + ], + "metadata": { + "id": "p6nyJuLW2JW1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "We can now save our facet weights as they are ready for use in faceted feature visualization!" + ], + "metadata": { + "id": "HdCZlPxAfL5D" + } + }, + { + "cell_type": "code", + "source": [ + "# Save the trained weights\n", + "torch.save([w.cpu() for w in facet_weights], \"my_facet_weights.pt\")\n", + "\n", + "# Then the weights can be loaded like this\n", + "# facet_weights = torch.load(\"my_facet_weights.pt\")" + ], + "metadata": { + "id": "VlKn5QCJUgKA" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If you trained multiple facet themes at once, then you can save them individually like in the example code below." + ], + "metadata": { + "id": "__NXZJF9Cfl8" + } + }, + { + "cell_type": "code", + "source": [ + "# Uncomment to save multiple facets\n", + "# theme_indices = [0, 1]\n", + "# for idx in theme_indices:\n", + "# facet_weights = [w[idx : idx + 1].reshape(1, 1280, 18, 18) for w in weights]\n", + "# torch.save(\n", + "# [w.cpu() for w in facet_weights], \"my_facet_weights_{}_.pt\".format(idx)\n", + "# )" + ], + "metadata": { + "id": "kcDQ_OetHPsP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The facet weights can then be loaded and used for the `FacetLoss` objective's required `facet_weights` variable." + ], + "metadata": { + "id": "o-a5_zOaI5CT" + } + } + ] +} \ No newline at end of file