## import

In [5]:
import os
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.utils.data
import seaborn as sns
import imageio

from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test
import pandas as pd

## argument

In [6]:
parser = argparse.ArgumentParser(description="run patch-based HSI classification")
parser.add_argument("--model", type=str, default='speformer')
parser.add_argument("--dataset_name", type=str, default="gs")
parser.add_argument("--dataset_file_name", type=str, default="202307_downsampled_gongsan.h5")
parser.add_argument("--dataset_dir", type=str, default="/home1/jmt30269/DSNet/data/")
parser.add_argument("--device", type=str, default="1")
parser.add_argument("--patch_size", type=int, default=15)
parser.add_argument("--num_run", type=int, default=1) 
parser.add_argument("--epoch", type=int, default=200)    
parser.add_argument("--bs", type=int, default=128)  # bs = batch size  
parser.add_argument("--ratio", type=float, default=0.06)
parser.add_argument("--weights", type=str, default="/home1/jmt30269/Group-Aware-Hierarchical-Transformer/checkpoints/speformer/gs/0.05/7_0/")
parser.add_argument("--outputs", type=str, default="./results")


opts = parser.parse_args(args=[])

## main

In [7]:
device = torch.device("cuda:{}".format(opts.device))
print(device)
# print parameters
print("experiments will run on GPU device {}".format(opts.device))
print("model = {}".format(opts.model))    
print("dataset = {}".format(opts.dataset_name))
print("dataset folder = {}".format(opts.dataset_dir))
print("patch size = {}".format(opts.patch_size))
print("batch size = {}".format(opts.bs))
print("total epoch = {}".format(opts.epoch))
print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))

# load data
image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir,opts.dataset_file_name)

num_classes = len(labels)
num_bands = image.shape[-1]
model_list=['ssftt','rssan','proposed','speformer']
# random seeds
seeds = [20250411, 20250402,20250403,20250404,20250405,20250406,20250407,20250408,20250409,20250410]

cuda:1
experiments will run on GPU device 1
model = speformer
dataset = gs
dataset folder = /home1/jmt30269/DSNet/data/
patch size = 15
batch size = 128
total epoch = 200
0.03 for training, 0.03 for validation and 0.94 testing


## train

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# empty list to storing results
results = []
model_list=['ssftt','rssan','proposed','speformer']
patch_list=[9]
for opts.patch_size in patch_list:
    for opts.model in model_list:
        device = torch.device("cuda:{}".format(opts.device))
        print(device)
        # print parameters
        print("experiments will run on GPU device {}".format(opts.device))
        print("model = {}".format(opts.model))    
        print("dataset = {}".format(opts.dataset_name))
        print("dataset folder = {}".format(opts.dataset_dir))
        print("patch size = {}".format(opts.patch_size))
        print("batch size = {}".format(opts.bs))
        print("total epoch = {}".format(opts.epoch))
        print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))
        
        # load data
        image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir,opts.dataset_file_name)
        
        num_classes = len(labels)
        num_bands = image.shape[-1]
        model_list=['ssftt','rssan','proposed','speformer']
        # random seeds
        seeds = [20250411, 20250402,20250403,20250404,20250405,20250406,20250407,20250408,20250409,20250410]
        run=0
        np.random.seed(seeds[0])
        print("running an experiment with the {} model".format(opts.model))
        print("run {} / {}".format(run+1, opts.num_run))
    
        # get train_gt, val_gt and test_gt
        trainval_gt, test_gt = sample_gt(gt, opts.ratio, seeds[run])
        train_gt, val_gt = sample_gt(trainval_gt, 0.5, seeds[run])
        del trainval_gt
    
        train_set = HSIDataset(image, train_gt, patch_size=opts.patch_size, data_aug=True)
        val_set = HSIDataset(image, val_gt, patch_size=opts.patch_size, data_aug=False)
    
        train_loader = torch.utils.data.DataLoader(train_set, opts.bs, drop_last=False, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_set, opts.bs, drop_last=False, shuffle=False)
    
        # load model and loss
        model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    
        if run == 0:
            split_info_print(train_gt, val_gt, test_gt, labels)
            # print("network information:")
            # with torch.no_grad():
            #     summary(model, torch.zeros((1, 1, num_bands, opts.patch_size, opts.patch_size)))
        
        model = model.to(device)
        
        optimizer, scheduler = load_scheduler(opts.model, model)
    
        criterion = nn.CrossEntropyLoss()
    
        # where to save checkpoint model
        model_dir = "./checkpoints/" + opts.model + '/' + opts.dataset_name + '_2407/'+ str(opts.ratio)+'/' + str(opts.patch_size)+'_'+str(run)+"_1"
        print(f'model save dir : {model_dir}')
        try:
            train(model, optimizer, criterion, train_loader, val_loader, opts.epoch, model_dir, device, scheduler)
        except KeyboardInterrupt:
            print('"ctrl+c" is pused, the training is over')
    
        # test the model
        probabilities = test(model, model_dir, image, opts.patch_size, num_classes, device)
        
        prediction = np.argmax(probabilities, axis=-1)
    
        # computing metrics
        run_results = metrics(prediction, test_gt, n_classes=num_classes)  # only for test set
        results.append(run_results)
        show_results(run_results, label_values=labels)
    
        del model, train_set, train_loader, val_set, val_loader
    
    if opts.num_run > 1:
        show_results(results, label_values=labels, agregated=True)

## eval

In [8]:
def color_results(arr2d, palette):
    arr_3d = np.zeros((arr2d.shape[0], arr2d.shape[1], 3), dtype=np.uint8)
    for c, i in palette.items():
        m = arr2d == c
        arr_3d[m] = i
    return arr_3d

In [None]:
model_list=['ssftt','rssan','proposed','speformer']
for opts.model in model_list:
    palette = {0: (0, 0, 0)}
    for k, color in enumerate(sns.color_palette("hls", num_classes + 1)):
        palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
    opts.weights= f'/home1/jmt30269/Group-Aware-Hierarchical-Transformer/checkpoints/{opts.model}/gs_2407/0.03/{opts.patch_size}_0_1/'
# load model and weights
    model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    print('loading weights from %s' % opts.weights + '/model_best.pth')
    model = model.to(device)
    model.load_state_dict(torch.load(os.path.join(opts.weights, 'model_best.pth')))
    model.eval()
    
    # testing model: metric for the whole HSI, including train, val, and test
    probabilities = test(model, opts.weights, image, opts.patch_size, num_classes, device=device)
    prediction = np.argmax(probabilities, axis=-1)
    
    run_results = metrics(prediction, gt, n_classes=num_classes)
    
    prediction[gt < 0] = -1
    
    # color results
    colored_gt = color_results(gt+1, palette)
    colored_pred = color_results(prediction+1, palette)
    
    outfile = os.path.join(opts.outputs, opts.dataset_name,  opts.model)
    os.makedirs(outfile, exist_ok=True)
    
    # imageio.imsave(os.path.join(outfile, opts.dataset_name + '_gt.png'), colored_gt)  # eps or png
    # imageio.imsave(os.path.join(outfile, opts.dataset_name+'_' + opts.model + '_out.png'), colored_pred)  # or png
    
    prod = probabilities.reshape(-1,2)
    gtt= gt.reshape(-1)
    predd = prediction.reshape(-1)
    gtt=gtt +1
    predd=predd +1
    df = pd.DataFrame({
        'gt': gtt,
        'prob_akk': prod[:, 0],  # 첫 번째 열
        'prob_back': prod[:, 1],  # 두 번째 열
        'pred': predd
    })
    df.to_csv(os.path.join(outfile, opts.dataset_name+'_2307_'+str(opts.patch_size)+'_' + opts.model+"_result.csv"),index=False)
    
    show_results(run_results, label_values=labels)
    del model


## train size test

In [None]:
device = torch.device("cuda:{}".format(opts.device))
print("model = {}".format(opts.model))    
print("dataset = {}".format(opts.dataset_name))
print("dataset folder = {}".format(opts.dataset_dir))
print("patch size = {}".format(opts.patch_size))
print("batch size = {}".format(opts.bs))
print("total epoch = {}".format(opts.epoch))
opts.epoch=100
for i in np.arange(0.05, 0.2, 0.02):
    opts.ratio=i
    print("{} for training, {} for validation and {} testing".format(opts.ratio / 2, opts.ratio / 2, 1 - opts.ratio))

    # load data
    image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir,opts.dataset_file_name)
    
    num_classes = len(labels)
    num_bands = image.shape[-1]
    
    # random seeds
    seeds = [202201, 202202, 202203, 202204, 202205]
    
    # empty list to storing results
    results = []
    
    for run in range(opts.num_run):
        np.random.seed(seeds[run])
        print("running an experiment with the {} model".format(opts.model))
        print("run {} / {}".format(run+1, opts.num_run))
    
        # get train_gt, val_gt and test_gt
        trainval_gt, test_gt = sample_gt(gt, opts.ratio, seeds[run])
        train_gt, val_gt = sample_gt(trainval_gt, 0.5, seeds[run])
        del trainval_gt
    
        train_set = HSIDataset(image, train_gt, patch_size=opts.patch_size, data_aug=True)
        val_set = HSIDataset(image, val_gt, patch_size=opts.patch_size, data_aug=False)
    
        train_loader = torch.utils.data.DataLoader(train_set, opts.bs, drop_last=False, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_set, opts.bs, drop_last=False, shuffle=False)
    
        # load model and loss
        model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    
        if run == 0:
            split_info_print(train_gt, val_gt, test_gt, labels)
            # print("network information:")
            # with torch.no_grad():
            #     summary(model, torch.zeros((1, 1, num_bands, opts.patch_size, opts.patch_size)))
        
        model = model.to(device)
        
        optimizer, scheduler = load_scheduler(opts.model, model)
    
        criterion = nn.CrossEntropyLoss()
    
        # where to save checkpoint model
        model_dir = "./checkpoints/" + opts.model + '/' + opts.dataset_name + '/'+ str(opts.ratio)+'/' + str(opts.patch_size)+'_'+str(run)
        print(f'model save dir : {model_dir}')
        try:
            train(model, optimizer, criterion, train_loader, val_loader, opts.epoch, model_dir, device, scheduler)
        except KeyboardInterrupt:
            print('"ctrl+c" is pused, the training is over')
    
        # test the model
        probabilities = test(model, model_dir, image, opts.patch_size, num_classes, device)
        
        prediction = np.argmax(probabilities, axis=-1)
    
        # computing metrics
        run_results = metrics(prediction, test_gt, n_classes=num_classes)  # only for test set
        results.append(run_results)
        show_results(run_results, label_values=labels)
    
        del model, train_set, train_loader, val_set, val_loader
    
    if opts.num_run > 1:
        show_results(results, label_values=labels, agregated=True)

## test per size

In [4]:
def color_results(arr2d, palette):
    arr_3d = np.zeros((arr2d.shape[0], arr2d.shape[1], 3), dtype=np.uint8)
    for c, i in palette.items():
        m = arr2d == c
        arr_3d[m] = i
    return arr_3d

In [8]:
import os
base_path='/home1/jmt30269/Group-Aware-Hierarchical-Transformer/checkpoints/rssan/bs/0.06/'
list_dir=os.listdir(base_path)
list_dir.sort()
print(list_dir)

['7_0', '7_1', '7_2', '7_3', '7_4', '7_5', '7_6', '7_7', '7_8', '7_9']


In [None]:
for ratio in list_dir:
    opts.weights=base_path+str(ratio)
    print(opts.weights)
    # load data
    image, gt, labels = load_mat_hsi(opts.dataset_name, opts.dataset_dir,opts.dataset_file_name)
    
    num_classes = len(labels)
    num_bands = image.shape[-1]

    # empty list to storing results
    results = []
    
    palette = {0: (0, 0, 0)}
    for k, color in enumerate(sns.color_palette("hls", num_classes + 1)):
        palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
    
    # load model and weights
    model = get_model(opts.model, opts.dataset_name, opts.patch_size)
    print('loading weights from %s' % opts.weights + '/model_best.pth')
    model = model.to(device)
    model.load_state_dict(torch.load(os.path.join(opts.weights, 'model_best.pth')))
    model.eval()
    
    # testing model: metric for the whole HSI, including train, val, and test
    probabilities = test(model, opts.weights, image, opts.patch_size, num_classes, device=device)
    prediction = np.argmax(probabilities, axis=-1)
    
    run_results = metrics(prediction, gt, n_classes=num_classes)
    
    prediction[gt < 0] = -1
    
    # color results
    colored_gt = color_results(gt+1, palette)
    colored_pred = color_results(prediction+1, palette)
    
    outfile = os.path.join(opts.outputs, opts.dataset_name,  opts.model)
    os.makedirs(outfile, exist_ok=True)
    
    # imageio.imsave(os.path.join(outfile, opts.dataset_name + '_gt.png'), colored_gt)  # eps or png
    imageio.imsave(os.path.join(outfile, opts.dataset_name+'_' + opts.model+"_"+str(round(float(ratio),2))+ '_out.png'), colored_pred)  # or png
    
    prod = probabilities.reshape(-1,2)
    gtt= gt.reshape(-1)
    predd = prediction.reshape(-1)
    gtt=gtt +1
    predd=predd +1
    df = pd.DataFrame({
        'gt': gtt,
        'prob_akk': prod[:, 0],  # 첫 번째 열
        'prob_back': prod[:, 1],  # 두 번째 열
        'pred': predd
    })
    df.to_csv(os.path.join(outfile, opts.dataset_name+'_' + opts.model+"_"+str(ratio)+"_result.csv"),index=False)
    
    show_results(run_results, label_values=labels)
    del model
