In [1]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
from torchsummaryX import summary

from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test

In [2]:
import torch.optim as optim
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis, parameter_count
import time

In [3]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="run patch-based HSI classification")
    parser.add_argument("--model", type=str, default='cnn3d')
    parser.add_argument("--dataset_name", type=str, default="pu")
    parser.add_argument("--dataset_dir", type=str, default="./datasets")
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--patch_size", type=int, default=12)
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--epoch", type=int, default=50)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--ratio", type=float, default=0.03)

    args, unknown = parser.parse_known_args()
    device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")

    # Print parameters
    print(f"experiments will run on GPU device {args.device}")
    print(f"model = {args.model}")
    print(f"dataset = {args.dataset_name}")
    print(f"dataset folder = {args.dataset_dir}")
    print(f"patch size = {args.patch_size}")
    print(f"batch size = {args.bs}")
    print(f"total epoch = {args.epoch}")
    print(f"{args.ratio / 2} for training, {args.ratio / 2} for validation and {1 - args.ratio} for testing")

    # Load data
    image, gt, labels = load_mat_hsi(args.dataset_name, args.dataset_dir)
    num_classes = len(labels)
    num_bands = image.shape[-1]

    # Random seeds
    seeds = [202201, 202202, 202203, 202204, 202205]

    # Empty list to store results
    results = []

    for run in range(args.num_run):
        np.random.seed(seeds[run])
        print(f"running an experiment with the {args.model} model")
        print(f"run {run + 1} / {args.num_run}")

        trainval_gt, test_gt = sample_gt(gt, args.ratio, seeds[run])
        train_gt, val_gt = sample_gt(trainval_gt, 0.5, seeds[run])

        del trainval_gt

        train_set = HSIDataset(image, train_gt, patch_size=args.patch_size, data_aug=True)
        val_set = HSIDataset(image, val_gt, patch_size=args.patch_size, data_aug=False)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, drop_last=False, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.bs, drop_last=False, shuffle=False)

        model = get_model(args.model, args.dataset_name, args.patch_size).to(device)
        print(model)

        if run == 0:
            split_info_print(train_gt, val_gt, test_gt, labels)
            print("network information:")
            # Summary of the model
            summary(model, input_size=(1, num_bands, args.patch_size, args.patch_size))

        optimizer, scheduler = load_scheduler(args.model, model)
        criterion = nn.CrossEntropyLoss()

        # Where to save checkpoint model
        model_dir = f"./checkpoints/{args.model}/{args.dataset_name}/{run}"

        # Training
        start_time = time.time()
        try:
            train(model, optimizer, criterion, train_loader, val_loader, args.epoch, model_dir, device, scheduler)
        except KeyboardInterrupt:
            print('"ctrl+c" is pressed, the training is over')
        training_time = time.time() - start_time

        # Testing
        start_time = time.time()
        probabilities = test(model, model_dir, image, args.patch_size, num_classes, device)
        testing_time = time.time() - start_time

        prediction = np.argmax(probabilities, axis=-1)
        run_results = metrics(prediction, test_gt, n_classes=num_classes)
        results.append(run_results)
        show_results(run_results, label_values=labels)

        del train_set, train_loader, val_set, val_loader

        # Calculate FLOPs and number of parameters
        dummy_input = torch.randn(1, 1, num_bands, args.patch_size, args.patch_size).to(device)
        flops = FlopCountAnalysis(model, dummy_input)
        params = parameter_count(model)

        print(f"FLOPs: {flops.total()}")
        print(f"Parameters: {params['']}")
        print(f"Training time: {training_time:.2f} seconds")
        print(f"Testing time: {testing_time:.2f}seconds")

        # Store additional metrics in results
        run_results["FLOPs"] = flops.total()
        run_results["Parameters"] = params['']
        run_results["Training time"] = training_time
        run_results["Testing time"] = testing_time

    if args.num_run > 1:
        show_results(results, label_values=labels, aggregated=True)

experiments will run on GPU device 0
model = cnn3d
dataset = pu
dataset folder = ./datasets
patch size = 12
batch size = 32
total epoch = 50
0.015 for training, 0.015 for validation and 0.97 for testing
running an experiment with the cnn3d model
run 1 / 1
(610, 340)
CNN3D(
  (conv1): Conv3d(1, 20, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (pool1): Conv3d(20, 20, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv2): Conv3d(20, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0))
  (pool2): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv3): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
  (conv4): Conv3d(35, 35, kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=31360, out_features=9, bias=True)
)
class train val test
Asphalt 99 100 6432
Meadows 279 280 18090
Gravel 32 31 2036
Trees 46 46 2972
Painted metal sheets 20 20

training the network:   0%|                              | 0/50 [00:00<?, ?it/s]

train at epoch 1/50, loss=1.785324


training the network:   2%|▍                     | 1/50 [00:01<00:57,  1.17s/it]

epoch = 1: best validation OA = 0.4626


training the network:   4%|▉                     | 2/50 [00:02<00:50,  1.05s/it]

epoch = 2: best validation OA = 0.6822


training the network:   6%|█▎                    | 3/50 [00:03<00:47,  1.02s/it]

epoch = 3: best validation OA = 0.7274


training the network:   8%|█▊                    | 4/50 [00:04<00:45,  1.01it/s]

epoch = 4: best validation OA = 0.7617


training the network:  12%|██▋                   | 6/50 [00:06<00:43,  1.02it/s]

epoch = 6: best validation OA = 0.7850


training the network:  14%|███                   | 7/50 [00:06<00:41,  1.03it/s]

epoch = 7: best validation OA = 0.8224


training the network:  16%|███▌                  | 8/50 [00:07<00:40,  1.03it/s]

epoch = 8: best validation OA = 0.8427


training the network:  18%|███▉                  | 9/50 [00:09<00:39,  1.04it/s]

train at epoch 10/50, loss=0.458903


training the network:  20%|████▏                | 10/50 [00:09<00:38,  1.04it/s]

epoch = 10: best validation OA = 0.8723


training the network:  24%|█████                | 12/50 [00:11<00:36,  1.04it/s]

epoch = 12: best validation OA = 0.8723


training the network:  30%|██████▎              | 15/50 [00:14<00:33,  1.04it/s]

epoch = 15: best validation OA = 0.8754


training the network:  38%|███████▉             | 19/50 [00:18<00:30,  1.01it/s]

epoch = 19: best validation OA = 0.9097


training the network:  38%|███████▉             | 19/50 [00:19<00:30,  1.01it/s]

train at epoch 20/50, loss=0.325553


training the network:  48%|██████████           | 24/50 [00:24<00:28,  1.11s/it]

epoch = 24: best validation OA = 0.9206


training the network:  58%|████████████▏        | 29/50 [00:30<00:23,  1.14s/it]

train at epoch 30/50, loss=0.210695


training the network:  60%|████████████▌        | 30/50 [00:31<00:22,  1.15s/it]

epoch = 30: best validation OA = 0.9377


training the network:  78%|████████████████▍    | 39/50 [00:42<00:12,  1.17s/it]

train at epoch 40/50, loss=0.251197


training the network:  98%|████████████████████▌| 49/50 [00:54<00:01,  1.18s/it]

train at epoch 50/50, loss=0.236717


training the network: 100%|█████████████████████| 50/50 [00:54<00:00,  1.09s/it]
inference on the HSI: 6511it [01:08, 94.80it/s]                                 

Confusion matrix :
[[ 5995     0   165     0     0     0   114   157     1]
 [    0 17718     0    28     0   343     0     1     0]
 [  196     0  1340     0     0     0     0   500     0]
 [   26    22     0  2905     0     6     0     0    13]
 [    0     0     0     0  1305     0     0     0     0]
 [    0   781     0     2    12  4079     0     4     0]
 [  188     0    43     0     0     1  1056     2     0]
 [   87    23   277     1     0     8     0  3172     3]
 [    3     0     0     2     1     0     0     0   913]]---
Accuracy : 92.75%
---
class acc :
	Asphalt: 93.21
	Meadows: 97.94
	Gravel: 65.82
	Trees: 97.75
	Painted metal sheets: 100.00
	Bare Soil: 83.62
	Bitumen: 81.86
	Self-Blocking Bricks: 88.83
	Shadows: 99.35
---
AA: 89.82%
Kappa: 90.34

FLOPs: 87971440
Parameters: 312869
Training time: 54.59 seconds
Testing time: 68.75seconds



