In [None]:
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline

Transfer Learning for Computer Vision Tutorial
==============================================

**Author**: [Sasank Chilamkurthy](https://chsasank.github.io)

In this tutorial, you will learn how to train a convolutional neural
network for image classification using transfer learning. You can read
more about the transfer learning at [cs231n
notes](https://cs231n.github.io/transfer-learning/)

Quoting these notes,

> In practice, very few people train an entire Convolutional Network
> from scratch (with random initialization), because it is relatively
> rare to have a dataset of sufficient size. Instead, it is common to
> pretrain a ConvNet on a very large dataset (e.g. ImageNet, which
> contains 1.2 million images with 1000 categories), and then use the
> ConvNet either as an initialization or a fixed feature extractor for
> the task of interest.

These two major transfer learning scenarios look as follows:

-   **Finetuning the ConvNet**: Instead of random initialization, we
    initialize the network with a pretrained network, like the one that
    is trained on imagenet 1000 dataset. Rest of the training looks as
    usual.
-   **ConvNet as fixed feature extractor**: Here, we will freeze the
    weights for all of the network except that of the final fully
    connected layer. This last fully connected layer is replaced with a
    new one with random weights and only this layer is trained.


In [None]:
# License: BSD
# Author: Sasank Chilamkurthy

# Import required libraries for transfer learning
import torch                          # PyTorch core
import torch.nn as nn                 # Neural network modules
import torch.optim as optim           # Optimization algorithms (SGD, Adam, etc.)
from torch.optim import lr_scheduler  # Learning rate scheduling
import torch.backends.cudnn as cudnn  # CUDA optimization
import numpy as np                    # Numerical operations
import torchvision                    # Computer vision utilities
from torchvision import datasets, models, transforms  # Datasets, pretrained models, transforms
import matplotlib.pyplot as plt       # Plotting
import time                          # Timing training
import os                            # File operations
from PIL import Image                # Image loading
from tempfile import TemporaryDirectory  # Temporary file storage

# Enable cuDNN benchmarking for faster training (finds optimal algorithms)
cudnn.benchmark = True
# Enable interactive plotting mode
plt.ion()

TensorBoard
================

- https://www.tensorflow.org/tensorboard/get_started


In [None]:
# !tensorboard --logdir runs/transfer_learning

In [None]:
# TensorBoard setup (optional)
# TensorBoard provides visualization for training metrics, model graphs, and more
try:
    from torch.utils.tensorboard import SummaryWriter
    # Create a writer that logs to runs/transfer_learning directory
    writer = SummaryWriter(log_dir="runs/transfer_learning")
    print("TensorBoard writer created: runs/transfer_learning")
    print("To launch: tensorboard --logdir runs/transfer_learning")
    print("Then open http://localhost:6006 in your browser")
except Exception as e:
    writer = None
    print("TensorBoard not available; logging disabled.")
    print("Reason:", e)

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs/transfer_learning

Load Data
=========

We will use torchvision and torch.utils.data packages for loading the
data.

The problem we\'re going to solve today is to train a model to classify
**ants** and **bees**. We have about 120 training images each for ants
and bees. There are 75 validation images for each class. Usually, this
is a very small dataset to generalize upon, if trained from scratch.
Since we are using transfer learning, we should be able to generalize
reasonably well.

This dataset is a very small subset of imagenet.

download : https://download.pytorch.org/tutorial/hymenoptera_data.zip


In [None]:
# Download and extract hymenoptera_data without relying on shell tools
# This works cross-platform (Windows, macOS, Linux) using Python's standard library
import os, zipfile, urllib.request

# Dataset URL and local paths
url = "https://download.pytorch.org/tutorial/hymenoptera_data.zip"
zip_path = "hymenoptera_data.zip"
dst_root = "data"  # matches data_dir = 'data/hymenoptera_data'

# Create data directory if it doesn't exist
os.makedirs(dst_root, exist_ok=True)

# Download dataset if not already present
if not os.path.exists(zip_path):
    print("Downloading dataset...")
    urllib.request.urlretrieve(url, zip_path)
else:
    print("Zip file already exists, skipping download.")

# Extract to data/ directory so final path is data/hymenoptera_data/{train,val}
with zipfile.ZipFile(zip_path, "r") as z:
    print("Extracting...")
    z.extractall(dst_root)

# Verify extraction was successful
expected_root = os.path.join(dst_root, "hymenoptera_data")
print("Extracted to:", expected_root)
print("Train exists:", os.path.isdir(os.path.join(expected_root, "train")))
print("Val exists:", os.path.isdir(os.path.join(expected_root, "val")))

In [None]:
# Data augmentation and normalization for training
# Training uses augmentation to prevent overfitting and improve generalization
# Validation uses only normalization to get consistent evaluation results
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),      # Random crop to 224x224 (data augmentation)
        transforms.RandomHorizontalFlip(),      # Random horizontal flip (50% probability)
        transforms.ToTensor(),                  # Convert PIL Image to tensor [0, 1]
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet stats
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),                 # Resize shorter side to 256
        transforms.CenterCrop(224),             # Center crop to 224x224
        transforms.ToTensor(),                  # Convert to tensor
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet stats
    ]),
}

# Load dataset from directory structure: data/hymenoptera_data/{train,val}/{ants,bees}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

# Create data loaders for batching and shuffling
# batch_size=4: Process 4 images at a time
# shuffle=True: Randomize order each epoch (train only)
# num_workers=4: Use 4 subprocesses for data loading (speeds up I/O)
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

# Store dataset sizes for metrics calculation
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
# Get class names automatically from folder names
class_names = image_datasets['train'].classes

# Automatic device selection: Use GPU/MPS if available, otherwise CPU
# Supports CUDA (NVIDIA), MPS (Apple Silicon), MTIA, XPU
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")
print(f"Dataset sizes: {dataset_sizes}")
print(f"Classes: {class_names}")

Visualize a few images
======================

Let\'s visualize a few training images so as to understand the data
augmentations.


In [None]:
def imshow(inp, title=None):
    """
    Display image from tensor.
    Reverses normalization to show original image colors.
    """
    # Convert tensor to numpy and transpose from (C, H, W) to (H, W, C)
    inp = inp.numpy().transpose((1, 2, 0))
    
    # Reverse normalization using ImageNet mean and std
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean  # De-normalize
    
    # Clip values to valid range [0, 1]
    inp = np.clip(inp, 0, 1)
    
    # Display image
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data to visualize
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch (arranges multiple images in a grid)
out = torchvision.utils.make_grid(inputs)

# Show the grid with class labels
imshow(out, title=[class_names[x] for x in classes])

Training the model
==================

Now, let\'s write a general function to train a model. Here, we will
illustrate:

-   Scheduling the learning rate
-   Saving the best model

In the following, parameter `scheduler` is an LR scheduler object from
`torch.optim.lr_scheduler`.


In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, writer=None):
    """
    Train a PyTorch model with validation and checkpointing.
    
    Args:
        model: PyTorch model to train
        criterion: Loss function (e.g., CrossEntropyLoss)
        optimizer: Optimization algorithm (e.g., SGD, Adam)
        scheduler: Learning rate scheduler
        num_epochs: Number of training epochs
        writer: TensorBoard SummaryWriter (optional)
    
    Returns:
        model: Trained model with best weights loaded
        history: Dictionary containing training/validation loss and accuracy per epoch
    """
    since = time.time()

    # Track epoch-wise metrics for plotting later
    history = {
        'train': {'loss': [], 'acc': []},
        'val':   {'loss': [], 'acc': []}
    }

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        # Save initial model state
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0  # Track best validation accuracy

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode (enables dropout, batchnorm training)
                else:
                    model.eval()   # Set model to evaluate mode (disables dropout, batchnorm eval)

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data batches
                for inputs, labels in dataloaders[phase]:
                    # Move data to device (GPU/CPU)
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # Zero the parameter gradients (clear from previous batch)
                    optimizer.zero_grad()

                    # Forward pass
                    # Only track gradients during training phase
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)  # Get predicted class
                        loss = criterion(outputs, labels)  # Calculate loss

                        # Backward pass + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()      # Compute gradients
                            optimizer.step()     # Update weights

                    # Accumulate statistics
                    running_loss += loss.item() * inputs.size(0)  # Total loss
                    running_corrects += torch.sum(preds == labels.data).item()  # Correct predictions
                
                # Step learning rate scheduler after training phase
                if phase == 'train':
                    scheduler.step()

                # Calculate epoch metrics
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects / dataset_sizes[phase]

                # Save history for plotting
                history[phase]['loss'].append(epoch_loss)
                history[phase]['acc'].append(epoch_acc)

                # Log to TensorBoard if available
                if writer is not None:
                    writer.add_scalar(f'Loss/{phase}', epoch_loss, epoch)
                    writer.add_scalar(f'Accuracy/{phase}', epoch_acc, epoch)

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # Save model if it has best validation accuracy so far
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        # Training complete
        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        # Load best model weights (from epoch with best validation accuracy)
        model.load_state_dict(torch.load(best_model_params_path, weights_only=True))

    # Ensure TensorBoard events are written to disk
    if writer is not None:
        writer.flush()

    return model, history

In [None]:
# Helper function to plot learning curves
import matplotlib.pyplot as plt

def plot_history(history, title_prefix=""):
    """
    Plot training and validation loss/accuracy curves.
    
    Args:
        history: Dictionary with 'train' and 'val' keys, each containing 'loss' and 'acc' lists
        title_prefix: Prefix for plot titles (e.g., "Finetune: ")
    """
    epochs = range(1, len(history['train']['loss']) + 1)
    plt.figure(figsize=(10,4))
    
    # Loss subplot
    plt.subplot(1,2,1)
    plt.plot(epochs, history['train']['loss'], 'b-', label='train loss')
    plt.plot(epochs, history['val']['loss'], 'r-', label='val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{title_prefix}Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Accuracy subplot
    plt.subplot(1,2,2)
    plt.plot(epochs, history['train']['acc'], 'b-', label='train acc')
    plt.plot(epochs, history['val']['acc'], 'r-', label='val acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'{title_prefix}Accuracy')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Evaluation metrics: Confusion Matrix and Classification Report
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def evaluate_model(model, phase='val', class_names=None):
    """
    Evaluate model performance with detailed metrics.
    
    Args:
        model: Trained PyTorch model
        phase: 'train' or 'val' - which dataset to evaluate on
        class_names: List of class names for display
    
    Returns:
        y_true: Ground truth labels
        y_pred: Predicted labels
    """
    model.eval()
    y_true = []
    y_pred = []
    
    print(f"\n{'='*50}")
    print(f"Evaluating on {phase} set")
    print(f"{'='*50}")
    
    with torch.no_grad():
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {phase.capitalize()} Set')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()
    
    # Classification Report
    print(f"\nClassification Report - {phase.capitalize()} Set:")
    print("-" * 50)
    report = classification_report(y_true, y_pred, 
                                   target_names=class_names, 
                                   digits=4)
    print(report)
    
    # Overall Accuracy
    accuracy = (np.array(y_true) == np.array(y_pred)).sum() / len(y_true)
    print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    return y_true, y_pred

Visualizing the model predictions
=================================

Generic function to display predictions for a few images


In [None]:
def visualize_model(model, num_images=6):
    """
    Visualize model predictions on validation set.
    
    Args:
        model: Trained PyTorch model
        num_images: Number of images to display (default: 6)
    """
    was_training = model.training
    model.eval()  # Set to evaluation mode
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():  # Disable gradient computation for inference
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Get predictions
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Display images with predictions
            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)  # Restore original mode
                    return
        model.train(mode=was_training)

1 Finetuning the ConvNet
======================

Load a pretrained model and reset final fully connected layer.
`

In [None]:
# Strategy 1: Full Finetuning - Train all layers
# Load pretrained ResNet18 with ImageNet weights
model_ft = models.resnet18(weights='IMAGENET1K_V1')

# Get the number of input features for the final fc layer
num_ftrs = model_ft.fc.in_features  # 512 for ResNet18

# Replace the final fully connected layer
# Original fc: nn.Linear(512, 1000) for ImageNet's 1000 classes
# New fc: nn.Linear(512, 2) for our 2 classes (ants and bees)
model_ft.fc = nn.Linear(num_ftrs, 2)

# Move model to device (GPU/CPU)
model_ft = model_ft.to(device)

# Define loss function for multi-class classification
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
# This means we'll update weights in all layers during backpropagation
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler: Multiply LR by gamma=0.1 every step_size=7 epochs
# Epoch 0-6: lr=0.001, Epoch 7-13: lr=0.0001, Epoch 14-20: lr=0.00001, etc.
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

Train and evaluate
==================

It should take around 15-25 min on CPU. On GPU though, it takes less
than a minute.


In [None]:
# Train the full finetuning model
# All layers will be updated during training
epoch = 30
model_ft, history_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=epoch, writer=writer)

In [None]:
# Plot learning curves for finetuned model
plot_history(history_ft, title_prefix="Finetune: ")

In [None]:
# Evaluate the finetuned model with confusion matrix and classification report
y_true_ft, y_pred_ft = evaluate_model(model_ft, phase='val', class_names=class_names)

In [None]:
visualize_model(model_ft)

2 ConvNet as fixed feature extractor
==================================

Here, we need to freeze all the network except the final layer. We need
to set `requires_grad = False` to freeze the parameters so that the
gradients are not computed in `backward()`.

You can read more about this in the documentation
[here](https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward).


In [None]:
# Strategy 2: Feature Extractor - Freeze all layers except fc
# Load pretrained ResNet18 model
model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')

# Freeze all convolutional layers (no gradient computation = no weight updates)
for param in model_conv.parameters():
    param.requires_grad = False

# Replace the final fc layer
# Parameters of newly constructed modules have requires_grad=True by default
# So only this layer will be trained
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

# Move to device
model_conv = model_conv.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized
# model_conv.fc.parameters() returns only the fc layer parameters
# All other layers are frozen and won't be updated
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate
==================

On CPU this will take about half the time compared to previous scenario.
This is expected as gradients don\'t need to be computed for most of the
network. However, forward does need to be computed.


In [None]:
# Train the feature extractor model (only fc layer is trainable)
epoch = 30
model_conv, history_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=epoch, writer=writer)

In [None]:
# Plot learning curves for feature extractor model
plot_history(history_conv, title_prefix="Feature extractor: ")

In [None]:
# Evaluate the feature extractor model with confusion matrix and classification report
y_true_conv, y_pred_conv = evaluate_model(model_conv, phase='val', class_names=class_names)

In [None]:
visualize_model(model_conv)

plt.ioff()
plt.show()

Inference on custom images
==========================

Use the trained model to make predictions on custom images and visualize
the predicted class labels along with the images.


In [None]:
def visualize_model_predictions(model, img_path):
    """
    Make predictions on a single custom image and display result.
    
    Args:
        model: Trained PyTorch model
        img_path: Path to image file
    """
    was_training = model.training
    model.eval()  # Set to evaluation mode

    # Load and preprocess image
    img = Image.open(img_path)
    img = data_transforms['val'](img)  # Apply validation transforms
    img = img.unsqueeze(0)  # Add batch dimension: (C, H, W) -> (1, C, H, W)
    img = img.to(device)

    # Make prediction
    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)

        # Display image with prediction
        ax = plt.subplot(2,2,1)
        ax.axis('off')
        ax.set_title(f'Predicted: {class_names[preds[0]]}')
        imshow(img.cpu().data[0])

        model.train(mode=was_training)  # Restore original mode

3 Finetuning the ConvNet - last 2 CNN layers + fc
======================

Load a pretrained model and unfreeze the last 2 convolutional blocks (layer4) plus the final fc layer.
- ResNet18 architecture: conv1 → layer1 → layer2 → layer3 → layer4 → fc
- We freeze: conv1, layer1, layer2, layer3
- We train: layer4 (last 2 residual blocks) + fc

In [None]:
# Strategy 3: Partial Finetuning - Train last 2 CNN blocks (layer4) + fc
# This is a middle ground between full finetuning and feature extraction
# Load pretrained ResNet18 model
model_partial = torchvision.models.resnet18(weights='IMAGENET1K_V1')

# Freeze all parameters first
for param in model_partial.parameters():
    param.requires_grad = False

# Unfreeze layer4 (last 2 CNN blocks) - these are the last convolutional layers before fc
# ResNet18 layer4 contains 2 BasicBlocks, each with 2 conv layers = 4 conv layers total
for param in model_partial.layer4.parameters():
    param.requires_grad = True

# Replace and unfreeze the final fc layer (this is automatic since it's newly created)
num_ftrs = model_partial.fc.in_features
model_partial.fc = nn.Linear(num_ftrs, 2)

# Move to device
model_partial = model_partial.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimize layer4 + fc with potentially different learning rates
# Use a smaller lr for layer4 (pretrained features) and normal lr for fc (random init)
# This is a common strategy: fine-tune pretrained layers slowly, train new layers faster
optimizer_partial = optim.SGD([
    {'params': model_partial.layer4.parameters(), 'lr': 0.0001},  # Lower LR for fine-tuning CNN
    {'params': model_partial.fc.parameters(), 'lr': 0.001}        # Higher LR for new fc
], momentum=0.9)

# Learning rate scheduler (applies to all parameter groups)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_partial, step_size=7, gamma=0.1)

# Print trainable parameter counts for verification
print("Trainable parameters:")
print("- layer4 (last 2 CNN blocks):", sum(p.numel() for p in model_partial.layer4.parameters() if p.requires_grad))
print("- fc (final classifier):", sum(p.numel() for p in model_partial.fc.parameters() if p.requires_grad))
print("- Total trainable:", sum(p.numel() for p in model_partial.parameters() if p.requires_grad))
print("- Total parameters:", sum(p.numel() for p in model_partial.parameters()))

Train and evaluate
==================

In [None]:
epoch = 30
model_partial, history_partial = train_model(model_partial, criterion, optimizer_partial,
                         exp_lr_scheduler, num_epochs=epoch, writer=writer)

In [None]:
# Plot learning curves for partial fine-tuning (layer4 + fc)
plot_history(history_partial, title_prefix="Partial finetune (layer4+fc): ")

In [None]:
# Evaluate the partial finetuned model (layer4 + fc) with confusion matrix and classification report
y_true_partial, y_pred_partial = evaluate_model(model_partial, phase='val', class_names=class_names)

In [None]:
visualize_model(model_partial)

plt.ioff()
plt.show()

Inference on custom images
==========================

Use the trained model to make predictions on custom images and visualize
the predicted class labels along with the images.


In [None]:
visualize_model_predictions(
    model_partial,
    img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

Further Learning
================

If you would like to learn more about the applications of transfer
learning, checkout our [Quantized Transfer Learning for Computer Vision
Tutorial](https://pytorch.org/tutorials/intermediate/quantized_transfer_learning_tutorial.html).


In [None]:
# Close TensorBoard writer (optional but good practice)
# This ensures all logged events are written to disk
if writer is not None:
    writer.close()
    print("TensorBoard writer closed.")
    print("\nTo view results, run: tensorboard --logdir runs/transfer_learning")
    print("Then open http://localhost:6006 in your browser")