# CNN fine tuning


In this notebook, we will load the CNN , freeze the layers we want to leave untouched, and finetune the remaining ones. The image dataset was organized into an appropriate subfolder structure in the notbook called "image_sorting.ipynb".
In this dataset, each image consists in three channels: one actin, one tubulin, and one DAPI, which correspond to different structures of the cells that get coloured prior to the micriscopy image capturing. Each of these channel is stored as a single image. We therefore need to reconstitute the 3D tensors from the three channels of these images before training the CNN. 

In [35]:
#let's first import the packages we'll need
import torch
import pickle
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import multiprocessing
from PIL import Image

In [29]:
# Load the pre-trained Inception V3 model
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model.eval();

torch.set_num_threads(15)  # Set the number of intra-op threads
# torch.set_num_interop_threads(15)  # Set the number of inter-op threads

We will first try to freeze all layers except for the last fully connected layer, and finetune this one only. We will then try an alternative approach, by visualizing the features detected by the CNN, freeze layers that detect high level features, and fine tune those which detect rather low level ones. The features visualization is performed in notebook "CNN_features_visualization". 

In [54]:
#freeze all parameters
for param in model.parameters():
    param.requires_grad = False

#replace the last fc layer
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 13) 

#set parameters of last fc open for fine tunning
for param in model.fc.parameters():
    param.requires_grad = True



We will now define a class that aims at reconsituting the 3 channels of the picture (Actin, Tubulin and DAPI). In the rest of the code, we will name these channels ATD, including in the variable names, to not cause potential confusion with RGB channels.

In [58]:
class ATDImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        
        self.image_paths = []
        for cls in self.classes:
            a_path = os.path.join(root_dir, cls, 'actin')
            t_path = os.path.join(root_dir, cls, 'tubulin')
            d_path = os.path.join(root_dir, cls, 'dapi')
            for img_name in os.listdir(a_path):
                self.image_paths.append((os.path.join(a_path, img_name),
                                         os.path.join(t_path, img_name),
                                         os.path.join(d_path, img_name),
                                         cls))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        a_path, t_path, d_path, cls = self.image_paths[idx]
        a_img = Image.open(a_path)
        t_img = Image.open(t_path)
        d_img = Image.open(d_path)

        # Normalize pixel values from 0-65535 to 0-255
        a_img = a_img.point(lambda p: p * (255.0 / 65535.0))
        t_img = t_img.point(lambda p: p * (255.0 / 65535.0))
        d_img = d_img.point(lambda p: p * (255.0 / 65535.0))

        # Convert images to grayscale
        a_img = a_img.convert('L')
        t_img = t_img.convert('L')
        d_img = d_img.convert('L')
        
        img = Image.merge('RGB', (a_img, t_img, d_img))
        if self.transform:
            img = self.transform(img)
            
        return img, self.classes.index(cls)

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])


dataset = ATDImageDataset(root_dir='images/sorted_reduced/train', transform = transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)




In [59]:
# Initialize variables to store the sum and std of pixel values
mean = 0.0
std = 0.0
num_pixels = 0

# Mean: 9.687756374887613e-09
# Std: 1.1472706340498462e-08

# Iterate through the dataset
for images, _ in dataloader:
    batch_size, num_channels, height, width = images.shape
    num_pixels += batch_size * height * width
    mean += images.mean(axis=(0, 2, 3)).sum()
    std += images.std(axis=(0, 2, 3)).sum()

# Calculate the mean and std
mean /= num_pixels
std /= num_pixels

print(f'Mean: {mean}')
print(f'Std: {std}')

# Updated transform with normalization
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Reload dataset with normalization
dataset = ATDImageDataset(root_dir='images/sorted_reduced/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

Mean: 9.687756374887613e-09
Std: 1.1472706340498462e-08


In [67]:
def train_model(lr, momentum):
    print('training model with lr = {}, momentum = {}'.format(lr, momentum))
    # Define  loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

    #define the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #list to store losses
    losses = []

    #define the number of epochs
    num_epochs = 10

    # Training loop

    for epoch in range(num_epochs):  
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(dataloader):  
            inputs, labels = inputs.to(device), labels.to(device)  # Move data to the appropriate device

            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i%10 == 0:
                print('done for img {}, epoch {}'.format(i, epoch+1))
        losses.append(running_loss / len(dataloader))
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(dataloader)}')

    return model, losses

for lr in [0.001, 0.01, 0.1, 1]:
    model_name = 1
    model, losses = train_model(lr, 0.9)
    try:
        torch.save(model.state_dict(), 'models/model{}/cnn'.format(model_name))
    except:
        os.makedirs(os.path.dirname('models/model{}/cnn'.format(model_name)))
        torch.save(model.state_dict(), 'models/model{}/cnn'.format(model_name))
    with open('models/model{}/losses'.format(model_name), "wb") as fp:   #Pickling
        pickle.dump(losses, fp)
    model_name += 1


training model with lr = 0.001, momentum = 0.9
done for img 0, epoch 1
done for img 10, epoch 1
done for img 20, epoch 1
done for img 30, epoch 1
Epoch [1/10], Loss: 2.5516644653521086


In [19]:
losses

[2.3362345695495605,
 2.2832021713256836,
 2.6072685718536377,
 2.216726541519165,
 2.365002155303955,
 2.576688528060913,
 2.507329225540161,
 2.329982042312622,
 2.2601375579833984,
 2.437709331512451,
 2.433056592941284,
 2.370053768157959,
 2.40954852104187,
 2.4517765045166016,
 2.2286007404327393,
 2.378108501434326,
 2.5210256576538086,
 2.4442620277404785,
 2.568758010864258,
 2.4824845790863037,
 2.4054033756256104,
 2.383650064468384,
 2.459444522857666,
 2.612511157989502,
 2.5745246410369873,
 2.4844417572021484,
 2.4969847202301025,
 2.463207244873047,
 2.477919816970825,
 2.461578607559204,
 2.5261142253875732,
 2.5680058002471924,
 2.560798406600952,
 2.3730664253234863,
 2.3560421466827393,
 2.4384782314300537,
 2.4363255500793457,
 2.6722397804260254,
 2.5517780780792236,
 2.4916865825653076,
 2.394786834716797,
 2.641256332397461,
 2.452791690826416,
 2.399306535720825,
 2.3660695552825928,
 2.3384218215942383,
 2.221323251724243,
 2.5113577842712402,
 2.2957293987274