<a href="https://colab.research.google.com/github/gg5d/Diffusion_models/blob/main/Stop_Gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Stop Gradient example


A stop gradient operation is a technique used in machine learning—especially in deep learning frameworks like TensorFlow and PyTorch—to prevent gradients from flowing backward through a specific part of a computational graph during backpropagation.

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet18_Weights

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        # Load pretrained ResNet18 using the new weights API
        resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)

        # Use only the convolutional base (excluding the fully connected layers)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])

        # Freeze the feature extractor
        for param in self.feature_extractor.parameters():
            param.requires_grad = False

        # Custom classifier head
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(512, 10)  # Assuming 10 output classes
        )

    def forward(self, x):
        with torch.no_grad():  # Stop gradient here
            features = self.feature_extractor(x)

        output = self.classifier(features)
        return output

# Example usage
model = CustomModel()
input_tensor = torch.randn(8, 3, 224, 224)
output = model(input_tensor)
print(output.shape)  # Should be [8, 10]


torch.Size([8, 10])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader

# 1. Load a pretrained ResNet18 model with default weights
resnet = resnet18(weights=ResNet18_Weights.DEFAULT)

# 2. Freeze all layers so they won't be updated during training
for param in resnet.parameters():
    param.requires_grad = False

# 3. Unfreeze only layer2 and layer3 so they can be trained
for param in resnet.layer2.parameters():
    param.requires_grad = True
for param in resnet.layer3.parameters():
    param.requires_grad = True

# 4. Replace the final fully connected (fc) layer to match our number of classes (e.g., 10 for CIFAR-10)
num_classes = 10
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)

# 5. Unfreeze the new fc layer so it can be trained
for param in resnet.fc.parameters():
    param.requires_grad = True

# 6. Define image transformations: resize and convert to tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 (ResNet input size)
    transforms.ToTensor()           # Convert images to PyTorch tensors
])

# 7. Load the CIFAR-10 training dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  # Load data in batches of 32

# 8. Define the loss function (CrossEntropyLoss for classification)
criterion = nn.CrossEntropyLoss()

# 9. Create an optimizer that only updates parameters that require gradients
trainable_params = filter(lambda p: p.requires_grad, resnet.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-4)  # Adam optimizer with learning rate 0.0001

# 10. Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet.to(device)

# 11. Training loop
num_epochs = 5  # Number of times to go through the entire dataset
for epoch in range(num_epochs):
    resnet.train()  # Set model to training mode
    running_loss = 0.0  # Track loss for this epoch

    for images, labels in train_loader:
        # Move data to the same device as the model (GPU or CPU)
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()        # Clear previous gradients
        outputs = resnet(images)     # Forward pass: compute predictions
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()              # Backward pass: compute gradients
        optimizer.step()             # Update weights

        running_loss += loss.item()  # Accumulate loss

    # Print average loss for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


100%|██████████| 170M/170M [00:02<00:00, 85.1MB/s]
