In [2]:
%%capture
try:
    import jaxtyping
except ImportError:
    !pip install jaxtyping

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms.functional
import torch.optim as optim
import warnings

from copy import deepcopy
from jaxtyping import Float, Int, Bool
from matplotlib.gridspec import GridSpec
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from tqdm.auto import trange
from typing import Optional, Any, Dict, Union, Tuple, List, Callable

# Helper function to convert between numpy arrays and tensors
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
to_t = lambda array: torch.tensor(array, device=device, dtype=dtype)
from_t = lambda tensor: tensor.to("cpu").detach().numpy()

# Helper to round integer to odd number
to_odd = lambda x: int(np.ceil(x) // 2 * 2 + 1)



In [1]:
!wget -nc https://github.com/slinderman/ml4nd/raw/refs/heads/main/data/03_pose_tracking/lab3_data.pt

--2025-10-20 20:38:50--  https://github.com/slinderman/ml4nd/raw/refs/heads/main/data/03_pose_tracking/lab3_data.pt
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://media.githubusercontent.com/media/slinderman/ml4nd/refs/heads/main/data/03_pose_tracking/lab3_data.pt [following]
--2025-10-20 20:38:51--  https://media.githubusercontent.com/media/slinderman/ml4nd/refs/heads/main/data/03_pose_tracking/lab3_data.pt
Resolving media.githubusercontent.com (media.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8000::154, 2606:50c0:8001::154, ...
Connecting to media.githubusercontent.com (media.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 712328325 (679M) [application/octet-stream]
Saving to: ‘lab3_data.pt’


2025-10-20 20:39:24 (20.4 MB/s) - ‘lab3_data.pt’ saved [712328325/71232

In [3]:
data = torch.load("lab3_data.pt")
train_images = data["train_images"]
train_coords = data["train_coords"]
train_targets = data["train_targets"]
val_images = data["val_images"]
val_coords = data["val_coords"]
val_targets = data["val_targets"]
test_images = data["test_images"]
keypoint_names = data["keypoint_names"]
num_keypoints = len(keypoint_names)

In [4]:
#@title
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import animation
from IPython.display import HTML
from tempfile import NamedTemporaryFile
import base64

# Set some plotting defaults
sns.set_context("talk")

# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
                            "red",
                            "medium green",
                            "dusty purple",
                            "orange",
                            "amber",
                            "clay",
                            "pink",
                            "greyish"])

def gradient_cmap(colors: List,
                  nsteps: int = 256,
                  bounds: Optional[np.ndarray] = None):
    # Make a colormap that interpolates between a set of colors
    ncolors = len(colors)
    if bounds is None:
        bounds = np.linspace(0,1,ncolors)

    reds = []
    greens = []
    blues = []
    alphas = []
    for b,c in zip(bounds, colors):
        reds.append((b, c[0], c[0]))
        greens.append((b, c[1], c[1]))
        blues.append((b, c[2], c[2]))
        alphas.append((b, c[3], c[3]) if len(c) == 4 else (b, 1., 1.))

    cdict = {'red': tuple(reds),
             'green': tuple(greens),
             'blue': tuple(blues),
             'alpha': tuple(alphas)}

    cmap = LinearSegmentedColormap('grad_colormap', cdict, nsteps)
    return cmap

cmaps = [
    gradient_cmap([np.array([1, 1, 1, 0]),
                   np.concatenate([color, [1]])])
    for color in palette
]


_VIDEO_TAG = """<video controls>
 <source src="data:video/x-m4v;base64,{0}" type="video/mp4">
 Your browser does not support the video tag.
</video>"""

def _anim_to_html(anim: animation.Animation,
                  fps: int = 20):
    # todo: todocument
    if not hasattr(anim, '_encoded_video'):
        with NamedTemporaryFile(suffix='.mp4') as f:
            anim.save(f.name, fps=fps, extra_args=['-vcodec', 'libx264'])
            video = open(f.name, "rb").read()
        anim._encoded_video = base64.b64encode(video)

    return _VIDEO_TAG.format(anim._encoded_video.decode('ascii'))

def _display_animation(anim: animation.Animation,
                       fps: int = 30,
                       start: int = 0,
                       stop: Optional[int] = None):
    plt.close(anim._fig)
    return HTML(_anim_to_html(anim, fps=fps))

def play(movie: Int[Tensor, "num_frames height width 3"],
         fps: int = 30,
         speedup: int = 1,
         fig_height: int = 6,
         show_time: bool = False):

    # First set up the figure, the axis, and the plot element we want to animate
    T, Py, Px = movie.shape[:3]
    fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px/Py, fig_height))
    im = plt.imshow(movie[0], interpolation='None', cmap=plt.cm.gray)

    if show_time:
        tx = plt.text(0.75, 0.05, 't={:.3f}s'.format(0),
                    color='white',
                    fontdict=dict(size=12),
                    horizontalalignment='left',
                    verticalalignment='center',
                    transform=ax.transAxes)
    plt.axis('off')

    def animate(i):
        im.set_data(movie[i * speedup])
        if show_time:
            tx.set_text("t={:.3f}s".format(i * speedup / fps))
        return im,

    # call the animator.  blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate,
                                   frames=T // speedup,
                                   interval=1,
                                   blit=True)
    plt.close(anim._fig)

    # return an HTML video snippet
    print("Preparing animation. This may take a minute...")
    return HTML(_anim_to_html(anim, fps=30))

#@title Helper functions to play movies with overlays { display-mode: "form" }
def play_overlay(movie: Int[Tensor, "num_frames height width 3"],
                 probs: Float[Tensor, "num_frames num_keypoints height width"],
                 fps: int = 30,
                 speedup: int = 1,
                 fig_height: int = 6,
                 show_time: bool = False,
                 prob_clip: float = 0.1):

    # First set up the figure, the axis, and the plot element we want to animate
    T, Py, Px = movie.shape[:3]
    assert probs.shape[0] == T
    num_keypoints = probs.shape[1]

    fig, ax = plt.subplots(1, 1, figsize=(fig_height * Px/Py, fig_height))
    im = plt.imshow(movie[0], interpolation='None', cmap=plt.cm.gray)

    p_ims = []
    for k, (name, color, cmap) in enumerate(zip(keypoint_names, palette, cmaps)):
        p_ims.append(plt.imshow(probs[0, k],
                     extent=(0, Px, Py, 0),
                     cmap=cmap, vmin=0, vmax=prob_clip))


    plt.axis('off')

    def animate(i):
        im.set_data(movie[i * speedup])
        for k, p_im in enumerate(p_ims):
            p_im.set_data(probs[i * speedup, k])
        return [im] + p_ims

    # call the animator.  blit=True means only re-draw the parts that have changed.
    anim = animation.FuncAnimation(fig, animate,
                                   frames=T // speedup,
                                   interval=1,
                                   blit=True)
    plt.close(anim._fig)

    # return an HTML video snippet
    print("Preparing animation. This may take a minute...")
    return HTML(_anim_to_html(anim, fps=30))

def plot_example(image: Int[Tensor, "height width 3"],
                 coords: Float[Tensor, "num_keypoints 2"],
                 targets: Bool[Tensor, "num_keypoints height width"],
                 title: Optional[str] = None,
                 downsample: int = 1,
                 dilation: int = 1,
                 panel_size: int = 5,
                 plot_grid: bool = False):

    from scipy.ndimage import binary_dilation
    figsize = (panel_size * (num_keypoints + 1), panel_size)
    fig, axs = plt.subplots(1, 1 + num_keypoints, figsize=figsize)

    if image.ndim == 2:
        axs[0].imshow(image, cmap="Greys_r")
    elif image.ndim == 3:
        axs[0].imshow(image)
    else:
        raise Exception("image must be 2d grayscale or 3d rgb")
    for k in range(num_keypoints):
        axs[0].plot(coords[k, 0] / downsample,
                    coords[k, 1] / downsample,
                    marker='o', color=palette[k], ls='none',
                    label=keypoint_names[k])

    axs[0].legend(loc="lower right", fontsize=8)
    axs[0].set_axis_off()
    if title: axs[0].set_title(title)

    for k, target in enumerate(targets):
        axs[k + 1].imshow(binary_dilation(target, np.ones((dilation, dilation))))
        if plot_grid:
            height, width = target.shape
            axs[k + 1].hlines(-0.5 + np.arange(height + 1), -0.5, width, 'w', lw=0.5)
            axs[k + 1].vlines(-0.5 + np.arange(width + 1), -0.5, height, 'w', lw=0.5)
        axs[k + 1].set_title("Target: {}".format(keypoint_names[k]))
        axs[k + 1].set_axis_off()

def plot_image_and_probs(image: Int[Tensor, "height width 3"],
                         probs: Float[Tensor, "num_keypoints height width"],
                         probs_max: float = 1.0,
                         panel_size: int = 4):
    # Plot the image, the predictions for each keypoint, and a colorbar
    num_keypoints = probs.shape[0]
    if image.ndim == 3: image = image.mean(0)

    figsize = (panel_size * (1 + num_keypoints + .1), panel_size)
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(nrows=1, ncols=1 + num_keypoints + 1, figure=fig,
                  width_ratios=[1] * (1 + num_keypoints) + [0.1],)

    # Plot the (grayscale) image
    ax = plt.subplot(gs[0])
    ax.imshow(image, cmap="Greys_r")
    ax.axis('off')
    ax.set_title("input image $X$")

    # Plot the probabilities for each keypoint
    for k in range(num_keypoints):
        ax = plt.subplot(gs[k + 1])
        im = ax.imshow(probs[k], vmin=0, vmax=probs_max)
        ax.axis('off')
        ax.set_title("{} probs".format(keypoint_names[k]))

    # Plot a colorbar
    cbax = plt.subplot(gs[-1])
    plt.colorbar(im, cax=cbax)
    plt.tight_layout()

def plot_model_weights(basic_model: nn.Module,
                       panel_size: int = 4):
    weights = basic_model.final_conv.weight
    num_keypoints = weights.shape[0]
    vlim = from_t(abs(weights).max())

    figsize = (panel_size * (num_keypoints + .1), panel_size)
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(nrows=1, ncols=num_keypoints + 1, figure=fig,
                  width_ratios=[1] * (num_keypoints) + [0.1])

    for k in range(num_keypoints):
        ax = plt.subplot(gs[k])
        im = ax.imshow(from_t(weights[k, 0]), vmin=-vlim, vmax=vlim, cmap="RdBu")
        ax.set_title("weights: {}".format(keypoint_names[k]))
        ax.set_axis_off()

    cbax = plt.subplot(gs[-1])
    plt.colorbar(im, cax=cbax)

    plt.tight_layout()