# Finetune models
This script takes the models that have been pretrained on the superpixels approach and fine-tunes them on their respective datasets.

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import glob
import os
import argparse

from train import *
import network

base_path = '../../data/'

Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



In [2]:
class Args:
    def __init__(self, target_pos, incl_bands, satellite):
        self.target_pos = target_pos
        self.incl_bands = incl_bands
        self.satellite = satellite

# Function to freeze layers except the final convolutional blocks
def freeze_layers(model, unfrozen_layers):
    for name, param in model.named_parameters():
        param.requires_grad = False
        for layer in unfrozen_layers:
            if layer in name:
                param.requires_grad = True
                break

In [3]:
def finetune_model(model, name, loader, epochs=10, lr=0.001, model_path='../../models/'):
    """
    Function to finetune a model on a new dataset
        model: pretrained model
        name: name of the model
        loader: dataloader for the new dataset
        epochs: number of epochs to train
        lr: learning rate"""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    device = torch.device('mps')

    for epoch in range(epochs):
        model.train()
        for images, target in iter(loader):
            images = images.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # Calculate validation loss
        model = model.eval()

        valid_loss = 0
        for images, target in iter(loader):
            images = images.to(device)
            target = target.to(device)

            output = model(images)

            loss = criterion(output, target)

            valid_loss += loss.item()

        valid_loss /= len(loader)
        print(f"Epoch {epoch}: {round(valid_loss, 5)}")

    # Save model
    torch.save(model.state_dict(), f'{model_path}{name}.pth')

# SWED

In [4]:
target_pos=-1
incl_bands=[0, 1, 2, 3, 4, 5, 6,7,8,9,10,11]
satellite="sentinel"

swed_finetune_file = base_path + 'SWED/finetune/'
swed_finetune_paths = glob.glob(swed_finetune_file + '*.npy')
print(len(swed_finetune_paths))
    
args = Args(target_pos, incl_bands, satellite)

finetune_data = TrainDataset(swed_finetune_paths, args)
finetune_loader = DataLoader(finetune_data, batch_size=10, shuffle=False)

# Sense check the data
for i, (X, y) in enumerate(finetune_loader):
    print(X.shape)
    print(y.shape)
    break

100
torch.Size([10, 12, 256, 256])
torch.Size([10, 2, 256, 256])


In [8]:
# Load rough model
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))
swed_superpixel = "SWED_SUPERPIXELS_26JUL2024.pth"

model = network.U_Net(12,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{swed_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

Using device: mps



U_Net(
  (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): conv_block(
    (conv): Sequential(
      (0): Conv2d(12, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv2): conv_block(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    

In [9]:
# Freeze all layers except the final convolutional blocks
unfrozen_layers = ['Up5','Up_conv5','Up4','Up_conv4','Up3','Up_conv3','Up2','Up_conv2', 'Conv_1x1']
freeze_layers(model, unfrozen_layers)

In [10]:
finetune_model(model, 'SWED-FINETUNE-30JUL24', finetune_loader, epochs=10, lr=0.001)

Epoch 0: 0.18848
Epoch 1: 0.15869
Epoch 2: 0.13616
Epoch 3: 0.1251
Epoch 4: 0.11853
Epoch 5: 0.11535
Epoch 6: 0.11391
Epoch 7: 0.11138
Epoch 8: 0.10938
Epoch 9: 0.1081


# LICS

In [7]:
target_pos=-1
incl_bands=[0, 1, 2, 3, 4, 5, 6]
satellite="landsat"

lics_finetune_file = base_path + 'LICS/finetune/'
lics_finetune_paths = glob.glob(lics_finetune_file + '*.npy')
    
args = Args(target_pos, incl_bands, satellite)

finetune_data = TrainDataset(lics_finetune_paths, args)
finetune_loader = DataLoader(finetune_data, batch_size=10, shuffle=False)


# Sense check the data
for i, (X, y) in enumerate(finetune_loader):
    print(X.shape)
    print(y.shape)
    break

torch.Size([10, 7, 256, 256])
torch.Size([10, 2, 256, 256])


In [8]:
# Load rough model
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))
lics_superpixel = "LICS_SUPERPIXELS_26JUL2024.pth"

model = network.U_Net(7,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{lics_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

Using device: mps



U_Net(
  (Maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv1): conv_block(
    (conv): Sequential(
      (0): Conv2d(7, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (Conv2): conv_block(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )

In [9]:
# Freeze all layers except the final convolutional blocks
unfrozen_layers = ['Up5','Up_conv5','Up4','Up_conv4','Up3','Up_conv3','Up2','Up_conv2', 'Conv_1x1']
freeze_layers(model, unfrozen_layers)

In [10]:
finetune_model(model, 'LICS_FINETUNE_26JUL24', finetune_loader, epochs=10, lr=0.001)

Epoch 0: 0.06655
Epoch 1: 0.05389
Epoch 2: 0.04447
Epoch 3: 0.03856
Epoch 4: 0.0347
Epoch 5: 0.03203
Epoch 6: 0.03009
Epoch 7: 0.02853
Epoch 8: 0.02721
Epoch 9: 0.02604


# Archive 

In [None]:
lics_finetune_file = base_path + 'LICS/finetune/'
lics_augment_file = base_path + 'LICS/finetune_augmentation/'
lics_finetune_paths = glob.glob(lics_finetune_file + '*.npy')

for path in lics_finetune_paths:
    file = np.load(path)
    name = os.path.basename(path)

    # Save file to new directory
    np.save(lics_augment_file + name, file)

    # image shape is [7, 256, 256]
    # rotate 90 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot90_' + name, file)

    # rotate 180 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot180_' + name, file)

    # rotate 270 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot270_' + name, file)

    # flip horizontally
    file = np.fliplr(file)
    np.save(lics_augment_file + 'fliph_' + name, file)

    # flip vertically
    file = np.flipud(file)
    np.save(lics_augment_file + 'fliv_' + name, file)
