In [1]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
from torchsummaryX import summary

from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test

In [2]:
import torch.optim as optim
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis, parameter_count
import time

In [3]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="run patch-based HSI classification")
    parser.add_argument("--model", type=str, default='cnn3d')
    parser.add_argument("--dataset_name", type=str, default="hsn")
    parser.add_argument("--dataset_dir", type=str, default="./datasets")
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--patch_size", type=int, default=12)
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--epoch", type=int, default=50)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--ratio", type=float, default=0.1)

    args, unknown = parser.parse_known_args()
    device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")

    # Print parameters
    print(f"experiments will run on GPU device {args.device}")
    print(f"model = {args.model}")
    print(f"dataset = {args.dataset_name}")
    print(f"dataset folder = {args.dataset_dir}")
    print(f"patch size = {args.patch_size}")
    print(f"batch size = {args.bs}")
    print(f"total epoch = {args.epoch}")
    print(f"{args.ratio / 2} for training, {args.ratio / 2} for validation and {1 - args.ratio} for testing")

    # Load data
    image, gt, labels = load_mat_hsi(args.dataset_name, args.dataset_dir)
    num_classes = len(labels)
    num_bands = image.shape[-1]

    # Random seeds
    seeds = [202201, 202202, 202203, 202204, 202205]

    # Empty list to store results
    results = []

    for run in range(args.num_run):
        np.random.seed(seeds[run])
        print(f"running an experiment with the {args.model} model")
        print(f"run {run + 1} / {args.num_run}")

        trainval_gt, test_gt = sample_gt(gt, args.ratio, seeds[run])
        train_gt, val_gt = sample_gt(trainval_gt, 0.5, seeds[run])

        del trainval_gt

        train_set = HSIDataset(image, train_gt, patch_size=args.patch_size, data_aug=True)
        val_set = HSIDataset(image, val_gt, patch_size=args.patch_size, data_aug=False)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, drop_last=False, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.bs, drop_last=False, shuffle=False)

        model = get_model(args.model, args.dataset_name, args.patch_size).to(device)
        print(model)

        if run == 0:
            split_info_print(train_gt, val_gt, test_gt, labels)
            print("network information:")
            # Summary of the model
            summary(model, input_size=(1, num_bands, args.patch_size, args.patch_size))

        optimizer, scheduler = load_scheduler(args.model, model)
        criterion = nn.CrossEntropyLoss()

        # Where to save checkpoint model
        model_dir = f"./checkpoints/{args.model}/{args.dataset_name}/{run}"

        # Training
        start_time = time.time()
        try:
            train(model, optimizer, criterion, train_loader, val_loader, args.epoch, model_dir, device, scheduler)
        except KeyboardInterrupt:
            print('"ctrl+c" is pressed, the training is over')
        training_time = time.time() - start_time

        # Testing
        start_time = time.time()
        probabilities = test(model, model_dir, image, args.patch_size, num_classes, device)
        testing_time = time.time() - start_time

        prediction = np.argmax(probabilities, axis=-1)
        run_results = metrics(prediction, test_gt, n_classes=num_classes)
        results.append(run_results)
        show_results(run_results, label_values=labels)

        del train_set, train_loader, val_set, val_loader

        # Calculate FLOPs and number of parameters
        dummy_input = torch.randn(1, 1, num_bands, args.patch_size, args.patch_size).to(device)
        flops = FlopCountAnalysis(model, dummy_input)
        params = parameter_count(model)

        print(f"FLOPs: {flops.total()}")
        print(f"Parameters: {params['']}")
        print(f"Training time: {training_time:.2f} seconds")
        print(f"Testing time: {testing_time:.2f}seconds")

        # Store additional metrics in results
        run_results["FLOPs"] = flops.total()
        run_results["Parameters"] = params['']
        run_results["Training time"] = training_time
        run_results["Testing time"] = testing_time

    if args.num_run > 1:
        show_results(results, label_values=labels, aggregated=True)

experiments will run on GPU device 0
model = cnn3d
dataset = hsn
dataset folder = ./datasets
patch size = 12
batch size = 32
total epoch = 50
0.05 for training, 0.05 for validation and 0.9 for testing
running an experiment with the cnn3d model
run 1 / 1
CNN3D(
  (conv1): Conv3d(1, 20, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (pool1): Conv3d(20, 20, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv2): Conv3d(20, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0))
  (pool2): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv3): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
  (conv4): Conv3d(35, 35, kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=42560, out_features=15, bias=True)
)
class train val test
Healthy grass 63 62 1126
Stressed grass 63 62 1129
Synthetic grass 35 35 627
Trees 62 62 1120
Soil 62 62 1118
Water

training the network:   0%|                              | 0/50 [00:01<?, ?it/s]

train at epoch 1/50, loss=2.417992


training the network:   2%|▍                     | 1/50 [00:01<01:22,  1.67s/it]

epoch = 1: best validation OA = 0.3968


training the network:   4%|▉                     | 2/50 [00:03<01:15,  1.58s/it]

epoch = 2: best validation OA = 0.4687


training the network:   6%|█▎                    | 3/50 [00:04<01:12,  1.55s/it]

epoch = 3: best validation OA = 0.4780


training the network:   8%|█▊                    | 4/50 [00:06<01:10,  1.53s/it]

epoch = 4: best validation OA = 0.5766


training the network:  10%|██▏                   | 5/50 [00:07<01:08,  1.53s/it]

epoch = 5: best validation OA = 0.6831


training the network:  12%|██▋                   | 6/50 [00:09<01:06,  1.52s/it]

epoch = 6: best validation OA = 0.7403


training the network:  18%|███▉                  | 9/50 [00:13<01:02,  1.53s/it]

epoch = 9: best validation OA = 0.7683


training the network:  18%|███▉                  | 9/50 [00:15<01:02,  1.53s/it]

train at epoch 10/50, loss=0.808479


training the network:  20%|████▏                | 10/50 [00:15<01:02,  1.57s/it]

epoch = 10: best validation OA = 0.7790


training the network:  22%|████▌                | 11/50 [00:17<01:02,  1.60s/it]

epoch = 11: best validation OA = 0.7883


training the network:  26%|█████▍               | 13/50 [00:20<01:03,  1.70s/it]

epoch = 13: best validation OA = 0.8415


training the network:  30%|██████▎              | 15/50 [00:24<01:02,  1.78s/it]

epoch = 15: best validation OA = 0.8535


training the network:  38%|███████▉             | 19/50 [00:31<00:56,  1.84s/it]

epoch = 19: best validation OA = 0.8615


training the network:  38%|███████▉             | 19/50 [00:33<00:56,  1.84s/it]

train at epoch 20/50, loss=0.546931


training the network:  42%|████████▊            | 21/50 [00:35<00:53,  1.85s/it]

epoch = 21: best validation OA = 0.8655


training the network:  50%|██████████▌          | 25/50 [00:43<00:46,  1.86s/it]

epoch = 25: best validation OA = 0.8802


training the network:  52%|██████████▉          | 26/50 [00:44<00:44,  1.86s/it]

epoch = 26: best validation OA = 0.9108


training the network:  58%|████████████▏        | 29/50 [00:51<00:38,  1.85s/it]

train at epoch 30/50, loss=0.450068


training the network:  78%|████████████████▍    | 39/50 [01:10<00:20,  1.86s/it]

train at epoch 40/50, loss=0.364301


training the network:  88%|██████████████████▍  | 44/50 [01:18<00:11,  1.87s/it]

epoch = 44: best validation OA = 0.9281


training the network:  98%|████████████████████▌| 49/50 [01:29<00:01,  1.85s/it]

train at epoch 50/50, loss=0.301371


training the network: 100%|█████████████████████| 50/50 [01:29<00:00,  1.79s/it]
inference on the HSI: 20847it [05:08, 67.49it/s]                                


Confusion matrix :
[[1094    0    0   32    0    0    0    0    0    0    0    0    0    0
     0]
 [  22 1104    0    3    0    0    0    0    0    0    0    0    0    0
     0]
 [   0    0  623    0    0    0    0    0    0    0    0    0    0    0
     4]
 [   1    3    0 1116    0    0    0    0    0    0    0    0    0    0
     0]
 [   0    0    0    0 1109    0    0    0    0    0    1    8    0    0
     0]
 [   4    0    0    0    0  248   25    0    0    7    8    0    0    0
     0]
 [   0    3    4    0    0    5 1019   78   15    0   17    0    0    0
     0]
 [   0    0    0    0    9    0   21  954    2   13   14   98    1    0
     8]
 [   0    8    0    0    0    2   10   43  985   12   28   39    0    0
     0]
 [   0    0    0    0    0    0    0   21   11  909   65   98    0    0
     0]
 [   0    2    2    0    0    0   25   13   33   24  974   23   13    3
     0]
 [   0    0    1    0    0    0   11    2    5    0    9 1082    0    0
     0]
 [   0    0    0    0