# CNN fine tuning


In this notebook, we will load the CNN , freeze the layers we want to leave untouched, and finetune the remaining ones. The image dataset was organized into an appropriate subfolder structure in the notbook called "image_sorting.ipynb".
In this dataset, each image consists in three channels: one actin, one tubulin, and one DAPI, which correspond to different structures of the cells that get coloured prior to the micriscopy image capturing. Each of these channel is stored as a single image. We therefore need to reconstitute the 3D tensors from the three channels of these images before training the CNN. 

In [1]:
#let's first import the packages we'll need
import torch

import torch.nn as nn
import os
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import multiprocessing
from PIL import Image

In [3]:
# Load the pre-trained Inception V3 model
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.eval();

torch.set_num_threads(15)  # Set the number of intra-op threads
torch.set_num_interop_threads(15)  # Set the number of inter-op threads

We will first try to freeze all layers except for the last fully connected layer, and finetune this one only. We will then try an alternative approach, by visualizing the features detected by the CNN, freeze layers that detect high level features, and fine tune those which detect rather low level ones. The features visualization is performed in notebook "CNN_features_visualization". 

In [4]:
#freeze all parameters
for param in model.parameters():
    param.requires_grad = False

#replace the last fc layer
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 12) 

#set parameters of last fc open for fine tunning
for param in model.fc.parameters():
    param.requires_grad = True



We will now define a class that aims at reconsituting the 3 channels of the picture (Actin, Tubulin and DAPI). In the rest of the code, we will name these channels ATD, including in the variable names, to not cause potential confusion with RGB channels.

In [5]:
class ATDImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = []
        for cls in self.classes:
            a_path = os.path.join(root_dir, cls, 'actin')
            t_path = os.path.join(root_dir, cls, 'tubulin')
            d_path = os.path.join(root_dir, cls, 'dapi')
            for img_name in os.listdir(a_path):
                self.image_paths.append((os.path.join(a_path, img_name),
                                         os.path.join(t_path, img_name),
                                         os.path.join(d_path, img_name),
                                         cls))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        a_path, t_path, d_path, cls = self.image_paths[idx]
        a_img = Image.open(a_path).convert('L')
        t_img = Image.open(t_path).convert('L')
        d_img = Image.open(d_path).convert('L')
        img = Image.merge('RGB', (a_img, t_img, d_img))
        if self.transform:
            img = self.transform(img)
        return img, self.classes.index(cls)

transform = transforms.ToTensor()

dataset = ATDImageDataset(root_dir='sorted/train/', transform = transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [6]:
# Define  loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

#define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#list to store losses
losses = []

#define the number of epochs
num_epochs = 10




# Training loop
for epoch in range(num_epochs):  
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:  
        inputs, labels = inputs.to(device), labels.to(device)  # Move data to the appropriate device

        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, labels)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader)}')