# CNN fine tuning


In this notebook, we will load the CNN , freeze the layers we want to leave untouched, and finetune the remaining ones. 

In [9]:
#let's first import the packages we'll need
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import multiprocessing
from PIL import Image

In [5]:
print(torch.cuda.is_available())

False


In [12]:
multiprocessing.cpu_count()

20

In [13]:
# Load the pre-trained Inception V3 model
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.eval();

We will first try to freeze all layers except for the last fully connected layer, and finetune this one only. We will then try an alternative approach, by visualizing the features detected by the CNN, freeze layers that detect high level features, and fine tune those which detect rather low level ones. The features visualization is performed in notebook "CNN_features_visualization". 

In [14]:
#freeze all parameters
for param in model.parameters():
    param.requires_grad = False

#replace the last fc layer
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 12) 

#set parameters of last fc open for fine tunning
for param in model.fc.parameters():
    param.requires_grad = True

# Define  loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

#define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training loop
for epoch in range(num_epochs):  # num_epochs is the number of epochs you want to train for
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:  # dataloader is your data loader
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to the appropriate device

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader)}')