# Finetune models
This script takes the models that have been pretrained on the superpixels approach and fine-tunes them on their respective datasets.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import glob
import os
import argparse

from train import *
import network

base_path = '../../data/'

In [None]:
class Args:
    def __init__(self, target_pos, incl_bands, satellite):
        self.target_pos = target_pos
        self.incl_bands = incl_bands
        self.satellite = satellite

# Function to freeze layers except the final convolutional blocks
def freeze_layers(model, unfrozen_layers):
    for name, param in model.named_parameters():
        param.requires_grad = False
        for layer in unfrozen_layers:
            if layer in name:
                param.requires_grad = True
                break

In [None]:
def finetune_model(model, name, loader, epochs=10, lr=0.001, model_path='../../models/'):
    """
    Function to finetune a model on a new dataset
        model: pretrained model
        name: name of the model
        loader: dataloader for the new dataset
        epochs: number of epochs to train
        lr: learning rate"""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    device = torch.device('mps')

    for epoch in range(epochs):
        model.train()
        for images, target in iter(loader):
            images = images.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # Calculate validation loss
        model = model.eval()

        valid_loss = 0
        for images, target in iter(loader):
            images = images.to(device)
            target = target.to(device)

            output = model(images)

            loss = criterion(output, target)

            valid_loss += loss.item()

        valid_loss /= len(loader)
        print(f"Epoch {epoch}: {round(valid_loss, 5)}")

    # Save model
    torch.save(model.state_dict(), f'{model_path}{name}.pth')

# SWED

In [None]:
target_pos=-1
incl_bands=[0, 1, 2, 3, 4, 5, 6,7,8,9,10,11]
satellite="sentinel"

swed_finetune_file = base_path + 'SWED/finetune/'
swed_finetune_paths = glob.glob(swed_finetune_file + '*.npy')
print(len(swed_finetune_paths))
    
args = Args(target_pos, incl_bands, satellite)

finetune_data = TrainDataset(swed_finetune_paths, args)
finetune_loader = DataLoader(finetune_data, batch_size=10, shuffle=False)

# Sense check the data
for i, (X, y) in enumerate(finetune_loader):
    print(X.shape)
    print(y.shape)
    break

In [None]:
# Load rough model
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))
swed_superpixel = "SWED_SUPERPIXELS_26JUL2024.pth"

model = network.U_Net(12,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{swed_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Freeze all layers except the final convolutional blocks
unfrozen_layers = ['Up5','Up_conv5','Up4','Up_conv4','Up3','Up_conv3','Up2','Up_conv2', 'Conv_1x1']
freeze_layers(model, unfrozen_layers)

In [None]:
finetune_model(model, 'SWED-FINETUNE-30JUL24', finetune_loader, epochs=10, lr=0.001)

# LICS

In [None]:
target_pos=-1
incl_bands=[0, 1, 2, 3, 4, 5, 6]
satellite="landsat"

lics_finetune_file = base_path + 'LICS/finetune/'
lics_finetune_paths = glob.glob(lics_finetune_file + '*.npy')
    
args = Args(target_pos, incl_bands, satellite)

finetune_data = TrainDataset(lics_finetune_paths, args)
finetune_loader = DataLoader(finetune_data, batch_size=10, shuffle=False)


# Sense check the data
for i, (X, y) in enumerate(finetune_loader):
    print(X.shape)
    print(y.shape)
    break

In [None]:
# Load rough model
device = torch.device('mps')  #UPDATE
print("Using device: {}\n".format(device))
lics_superpixel = "LICS_SUPERPIXELS_26JUL2024.pth"

model = network.U_Net(7,2).to(device)

# Load saved model 
#model = torch.load('../models/LANDSAT-UNET-20JUL23.pth', map_location=torch.device('cpu') )
state_dict = torch.load(f'../../models/{lics_superpixel}', map_location=torch.device('cpu') )
model.load_state_dict(state_dict)
model.eval()
model.to(device)

In [None]:
# Freeze all layers except the final convolutional blocks
unfrozen_layers = ['Up5','Up_conv5','Up4','Up_conv4','Up3','Up_conv3','Up2','Up_conv2', 'Conv_1x1']
freeze_layers(model, unfrozen_layers)

In [None]:
finetune_model(model, 'LICS_FINETUNE_26JUL24', finetune_loader, epochs=10, lr=0.001)

# Archive 

In [None]:
lics_finetune_file = base_path + 'LICS/finetune/'
lics_augment_file = base_path + 'LICS/finetune_augmentation/'
lics_finetune_paths = glob.glob(lics_finetune_file + '*.npy')

for path in lics_finetune_paths:
    file = np.load(path)
    name = os.path.basename(path)

    # Save file to new directory
    np.save(lics_augment_file + name, file)

    # image shape is [7, 256, 256]
    # rotate 90 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot90_' + name, file)

    # rotate 180 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot180_' + name, file)

    # rotate 270 degrees
    file = np.rot90(file, axes=(0,1))
    np.save(lics_augment_file + 'rot270_' + name, file)

    # flip horizontally
    file = np.fliplr(file)
    np.save(lics_augment_file + 'fliph_' + name, file)

    # flip vertically
    file = np.flipud(file)
    np.save(lics_augment_file + 'fliv_' + name, file)
