# Chapter 4: Transfer Learning And Other Tricks

In [None]:
# Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
## Settings
images_path = '../images'
epochs = 10
# model_filename="/tmp/resnet50"
model_filename="./saved/resnet50"
lr=1e-3 # Note: Overriden below.

In [None]:

# Init run
first_run = False if "first_run" in globals() else True
print(f"First run: {first_run}")
if first_run:
    # Reset vars
    total_epochs = 0
    total_model_params = 0
        

In [None]:
# Misc functions

def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())

In [None]:
# Load Model
if first_run:
    transfer_model = models.resnet50(pretrained=True)
    total_model_params = count_parameters(transfer_model)
    print(f'{total_model_params} parameters')
    print(transfer_model)
else:
    print("Skipped after first run.")

## Freezing parameters

In [None]:
# Unfreeze all params (skip on 1st run)
if not first_run:
    for name, param in transfer_model.named_parameters():
        param.requires_grad = True
else:
    print("Skipped on first run.")

In [None]:
# Unfreeze select layers (skip on 1st run)
if not first_run:
    unfreeze_layers = [transfer_model.layer3, transfer_model.layer4]
    count_unfrozen = 0
    for layer in unfreeze_layers:
        for param in layer.parameters():
            count_unfrozen += 1
            param.requires_grad = True
    print(f'Unfrozen {count_unfrozen} parameters')
else:
    print("Skipped on first run.")

In [None]:
# Freeze named params (only in not "*bn*" layers)
count_frozen = 0
count_all = 0
for name, param in transfer_model.named_parameters():
    count_all += 1
    if("bn" not in name):
        count_frozen += 1
        param.requires_grad = False
print(f'Frozen {count_frozen} of total {count_all} named parameters ({count_frozen/count_all*100:.5f}%)')

## Replacing the classifier

In [None]:
if first_run:
    transfer_model.fc = nn.Sequential(nn.Linear(transfer_model.fc.in_features,500),
        nn.ReLU(),                                 
        nn.Dropout(), nn.Linear(500,2))
else:
    print("Skipped after first run.")

## Custom Transforms

Here we'll create a lambda transform and a custom transform class.

In [None]:
# Custom Transforms

# Convert RGB to HSV color space
def _random_colour_space(x):
    output = x.convert("HSV")
    return output
colour_transform = transforms.Lambda(lambda x: _random_colour_space(x))
random_colour_transform = transforms.RandomApply([colour_transform])

class GaussNoise():
    """Adds gaussian noise to a tensor.
    
    Example:
        >>> transforms.Compose([
        >>>     transforms.ToTensor(),
        >>>     Noise(0.1, 0.05)),
        >>> ])
    
    """
    def __init__(self, mean, stddev):
        self.mean = mean
        self.stddev = stddev

    def __call__(self, tensor):
        noise = torch.zeros_like(tensor).normal_(self.mean, self.stddev)
        return tensor.add_(noise)
    
    def __repr__(self):
        repr = f"{self.__class__.__name__  }(mean={self.mean},sttdev={self.stddev})"
        return repr
    
# custom_transform_pipeline = transforms.Compose([random_colour_transform, transforms.ToTensor(), GaussNoise(0.1, 0.05)])

## Data Prep

In [None]:
# Training Data
def check_image(path):
    try:
        _im = Image.open(path)
        return True
    except Exception as e:
        print(f'Invalid image file "{path}", error {e}')
        return False

transforms_list = []
if not first_run:
    transforms_list += [
        # Data Augmentation (PIL Image space)
        random_colour_transform,
    ]
transforms_list += [
    transforms.Resize((64,64)),    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225] ),

    # Data Augmentation:
    # GaussNoise(0.1, 0.05),
    # transforms.RandomRotation(degrees=15, interpolation=transforms.InterpolationMode.NEAREST, expand=False, center=None),
    # transforms.RandomHorizontalFlip(p=0.5),
    # transforms.RandomVerticalFlip(p=0.5),
    # transforms.RandomGrayscale(p=0.1),
]
if not first_run:
    transforms_list += [
        # Data Augmentation:
        # GaussNoise(0.1, 0.05),
        # transforms.RandomRotation(degrees=15, interpolation=transforms.InterpolationMode.NEAREST, expand=False, center=None),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        # transforms.RandomGrayscale(p=0.1),
    ]
img_transforms = transforms.Compose(transforms_list)
train_data_path = os.path.join(images_path, "train")
test_data_path = os.path.join(images_path, "test")
val_data_path = os.path.join(images_path, "val")
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)
test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image)
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)
batch_size=64
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)

print(f'Training data length: {len(train_data_loader.dataset)}')
print(f'Test data length: {len(test_data_loader.dataset)}') # UNUSED!
print(f'Validation data length: {len(val_data_loader.dataset)}')

## Training

In [None]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.3f}, Validation Loss: {:.3f}, accuracy = {:.3f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [None]:
# Select CUDA vs CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA")
else:
    device = torch.device("cpu")
    print("Using CPU")

transfer_model.to(device)

In [None]:
# Adjust LR
lr = 1e-3
if not first_run:
    lr = 1e-6
    # lr = 2e-7
print(f"Set Learn Rate lr={lr}")

In [None]:
optimizer1 = optim.Adam(transfer_model.parameters(), lr=lr)

optimizer2 = optim.Adam([
{ 'params': transfer_model.layer4.parameters(), 'lr': lr /3},
{ 'params': transfer_model.layer3.parameters(), 'lr': lr /9},
], lr=lr)

In [None]:
optimizer = optimizer1 if first_run else optimizer2
train(transfer_model, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, epochs=epochs, device=device)
total_epochs += epochs
print(f'DONE {epochs} epochs, total epochs: {total_epochs}')

## Test Inference

In [None]:
def test(model, loss_fn, test_loader, device="cpu"):
    test_loss = 0.0
    model.eval()
    num_correct = 0 
    num_examples = 0
    for batch in test_loader:
        inputs, targets = batch
        inputs = inputs.to(device)
        output = model(inputs)
        targets = targets.to(device)
        loss = loss_fn(output,targets) 
        test_loss += loss.data.item() * inputs.size(0)
        correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
        num_correct += torch.sum(correct).item()
        num_examples += correct.shape[0]
    test_loss /= len(test_loader.dataset)
    print('Test Loss: {:.3f}, accuracy = {:.3f}'.format(test_loss, num_correct / num_examples))

In [None]:
test(transfer_model, torch.nn.CrossEntropyLoss(), test_data_loader, device=device)

In [None]:
# Stop Run Here
first_run = False
assert False, "Stopping the Run. Nothing to auto-run below"

## LR Finder

In [None]:
def find_lr(model, loss_fn, optimizer, train_loader, init_value=1e-8, final_value=10.0, device="cpu"):
    number_in_epoch = len(train_loader) - 1
    update_step = (final_value / init_value) ** (1 / number_in_epoch)
    lr = init_value
    optimizer.param_groups[0]["lr"] = lr
    best_loss = 0.0
    batch_num = 0
    losses = []
    log_lrs = []
    for data in train_loader:
        batch_num += 1
        inputs, targets = data
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        # Crash out if loss explodes

        if batch_num > 1 and loss > 4 * best_loss:
            if(len(log_lrs) > 20):
                return log_lrs[10:-5], losses[10:-5]
            else:
                return log_lrs, losses

        # Record the best loss

        if loss < best_loss or batch_num == 1:
            best_loss = loss

        # Store the values
        losses.append(loss.item())
        log_lrs.append((lr))

        # Do the backward pass and optimize

        loss.backward()
        optimizer.step()

        # Update the lr for the next step and store

        lr *= update_step
        optimizer.param_groups[0]["lr"] = lr
    if(len(log_lrs) > 20):
        return log_lrs[10:-5], losses[10:-5]
    else:
        return log_lrs, losses


In [None]:
(lrs, losses) = find_lr(transfer_model, torch.nn.CrossEntropyLoss(), optimizer, train_data_loader, init_value=1e-7, final_value=1e-3, device=device)
plt.plot(lrs, losses)

plt.xscale("log")
plt.xlabel("Learning rate")
plt.ylabel("Loss")
plt.show()

## Ensembles

Given a list of models, we can produce predictions for each model and then make an average to make a final prediction.

In [None]:
models_ensemble = [models.resnet50().to(device), models.resnet50().to(device)]
predictions = [F.softmax(m(torch.rand(1,3,224,244).to(device))) for m in models_ensemble] 
avg_prediction = torch.stack(predictions).mean(0).argmax()

In [None]:
avg_prediction

In [None]:
torch.stack(predictions)