In [1]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data
from torchsummaryX import summary

from utils.dataset import load_mat_hsi, sample_gt, HSIDataset
from utils.utils import split_info_print, metrics, show_results
from utils.scheduler import load_scheduler
from models.get_model import get_model
from train import train, test

In [2]:
import torch.optim as optim
from torchsummary import summary
from fvcore.nn import FlopCountAnalysis, parameter_count
import time

In [3]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="run patch-based HSI classification")
    parser.add_argument("--model", type=str, default='cnn3d')
    parser.add_argument("--dataset_name", type=str, default="sa")
    parser.add_argument("--dataset_dir", type=str, default="./datasets")
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--patch_size", type=int, default=12)
    parser.add_argument("--num_run", type=int, default=1)
    parser.add_argument("--epoch", type=int, default=50)
    parser.add_argument("--bs", type=int, default=32)
    parser.add_argument("--ratio", type=float, default=0.03)

    args, unknown = parser.parse_known_args()
    device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")

    # Print parameters
    print(f"experiments will run on GPU device {args.device}")
    print(f"model = {args.model}")
    print(f"dataset = {args.dataset_name}")
    print(f"dataset folder = {args.dataset_dir}")
    print(f"patch size = {args.patch_size}")
    print(f"batch size = {args.bs}")
    print(f"total epoch = {args.epoch}")
    print(f"{args.ratio / 2} for training, {args.ratio / 2} for validation and {1 - args.ratio} for testing")

    # Load data
    image, gt, labels = load_mat_hsi(args.dataset_name, args.dataset_dir)
    num_classes = len(labels)
    num_bands = image.shape[-1]

    # Random seeds
    seeds = [202201, 202202, 202203, 202204, 202205]

    # Empty list to store results
    results = []

    for run in range(args.num_run):
        np.random.seed(seeds[run])
        print(f"running an experiment with the {args.model} model")
        print(f"run {run + 1} / {args.num_run}")

        trainval_gt, test_gt = sample_gt(gt, args.ratio, seeds[run])
        train_gt, val_gt = sample_gt(trainval_gt, 0.5, seeds[run])

        del trainval_gt

        train_set = HSIDataset(image, train_gt, patch_size=args.patch_size, data_aug=True)
        val_set = HSIDataset(image, val_gt, patch_size=args.patch_size, data_aug=False)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, drop_last=False, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.bs, drop_last=False, shuffle=False)

        model = get_model(args.model, args.dataset_name, args.patch_size).to(device)
        print(model)

        if run == 0:
            split_info_print(train_gt, val_gt, test_gt, labels)
            print("network information:")
            # Summary of the model
            summary(model, input_size=(1, num_bands, args.patch_size, args.patch_size))

        optimizer, scheduler = load_scheduler(args.model, model)
        criterion = nn.CrossEntropyLoss()

        # Where to save checkpoint model
        model_dir = f"./checkpoints/{args.model}/{args.dataset_name}/{run}"

        # Training
        start_time = time.time()
        try:
            train(model, optimizer, criterion, train_loader, val_loader, args.epoch, model_dir, device, scheduler)
        except KeyboardInterrupt:
            print('"ctrl+c" is pressed, the training is over')
        training_time = time.time() - start_time

        # Testing
        start_time = time.time()
        probabilities = test(model, model_dir, image, args.patch_size, num_classes, device)
        testing_time = time.time() - start_time

        prediction = np.argmax(probabilities, axis=-1)
        run_results = metrics(prediction, test_gt, n_classes=num_classes)
        results.append(run_results)
        show_results(run_results, label_values=labels)

        del train_set, train_loader, val_set, val_loader

        # Calculate FLOPs and number of parameters
        dummy_input = torch.randn(1, 1, num_bands, args.patch_size, args.patch_size).to(device)
        flops = FlopCountAnalysis(model, dummy_input)
        params = parameter_count(model)

        print(f"FLOPs: {flops.total()}")
        print(f"Parameters: {params['']}")
        print(f"Training time: {training_time:.2f} seconds")
        print(f"Testing time: {testing_time:.2f}seconds")

        # Store additional metrics in results
        run_results["FLOPs"] = flops.total()
        run_results["Parameters"] = params['']
        run_results["Training time"] = training_time
        run_results["Testing time"] = testing_time

    if args.num_run > 1:
        show_results(results, label_values=labels, aggregated=True)

experiments will run on GPU device 0
model = cnn3d
dataset = sa
dataset folder = ./datasets
patch size = 12
batch size = 32
total epoch = 50
0.015 for training, 0.015 for validation and 0.97 for testing
running an experiment with the cnn3d model
run 1 / 1
CNN3D(
  (conv1): Conv3d(1, 20, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (pool1): Conv3d(20, 20, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv2): Conv3d(20, 35, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0))
  (pool2): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (conv3): Conv3d(35, 35, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0))
  (conv4): Conv3d(35, 35, kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=58240, out_features=16, bias=True)
)
class train val test
Brocoli_green_weeds_1 30 30 1949
Brocoli_green_weeds_2 56 56 3614
Fallow 30 29 1917
Fallow_rough_plow 21 21 1352


training the network:   0%|                              | 0/50 [00:01<?, ?it/s]

train at epoch 1/50, loss=2.256667


training the network:   2%|▍                     | 1/50 [00:02<02:00,  2.45s/it]

epoch = 1: best validation OA = 0.3719


training the network:   4%|▉                     | 2/50 [00:04<01:52,  2.35s/it]

epoch = 2: best validation OA = 0.6182


training the network:   6%|█▎                    | 3/50 [00:06<01:48,  2.31s/it]

epoch = 3: best validation OA = 0.6429


training the network:   8%|█▊                    | 4/50 [00:09<01:45,  2.29s/it]

epoch = 4: best validation OA = 0.6958


training the network:  10%|██▏                   | 5/50 [00:11<01:42,  2.28s/it]

epoch = 5: best validation OA = 0.7512


training the network:  12%|██▋                   | 6/50 [00:13<01:39,  2.27s/it]

epoch = 6: best validation OA = 0.8300


training the network:  18%|███▉                  | 9/50 [00:20<01:37,  2.37s/it]

epoch = 9: best validation OA = 0.8559


training the network:  18%|███▉                  | 9/50 [00:22<01:37,  2.37s/it]

train at epoch 10/50, loss=0.443547


training the network:  20%|████▏                | 10/50 [00:23<01:38,  2.47s/it]

epoch = 10: best validation OA = 0.8793


training the network:  28%|█████▉               | 14/50 [00:34<01:37,  2.70s/it]

epoch = 14: best validation OA = 0.9027


training the network:  34%|███████▏             | 17/50 [00:43<01:30,  2.75s/it]

epoch = 17: best validation OA = 0.9138


training the network:  38%|███████▉             | 19/50 [00:50<01:25,  2.76s/it]

train at epoch 20/50, loss=0.349699


training the network:  40%|████████▍            | 20/50 [00:51<01:23,  2.77s/it]

epoch = 20: best validation OA = 0.9236


training the network:  42%|████████▊            | 21/50 [00:54<01:20,  2.77s/it]

epoch = 21: best validation OA = 0.9261


training the network:  46%|█████████▋           | 23/50 [00:59<01:14,  2.78s/it]

epoch = 23: best validation OA = 0.9397


training the network:  58%|████████████▏        | 29/50 [01:18<00:58,  2.77s/it]

train at epoch 30/50, loss=0.235512


training the network:  62%|█████████████        | 31/50 [01:21<00:52,  2.78s/it]

epoch = 31: best validation OA = 0.9446


training the network:  78%|████████████████▍    | 39/50 [01:45<00:30,  2.75s/it]

train at epoch 40/50, loss=0.221279


training the network:  80%|████████████████▊    | 40/50 [01:46<00:27,  2.75s/it]

epoch = 40: best validation OA = 0.9495


training the network:  98%|████████████████████▌| 49/50 [02:11<00:02,  2.75s/it]

epoch = 49: best validation OA = 0.9532


training the network:  98%|████████████████████▌| 49/50 [02:13<00:02,  2.75s/it]

train at epoch 50/50, loss=0.169025


training the network: 100%|█████████████████████| 50/50 [02:14<00:00,  2.68s/it]
inference on the HSI: 3495it [01:10, 49.72it/s]                                 


Confusion matrix :
[[1943    6    0    0    0    0    0    0    0    0    0    0    0    0
     0    0]
 [   0 3614    0    0    0    0    0    0    0    0    0    0    0    0
     0    0]
 [   0    0 1764    0   19    0    0    0   26    0   93   15    0    0
     0    0]
 [   0    0    0 1341   11    0    0    0    0    0    0    0    0    0
     0    0]
 [   0    0    5   15 2568    0    0    0    5    1    4    0    0    0
     0    0]
 [   0    0    0    0    0 3840    0    0    0    0    0    0    0    0
     0    0]
 [   0    0    0    0    0    1 3471    0    0    0    0    0    0    0
     0    0]
 [   0    0    0    0    0    1    0 9488    0   24    6    0    0    0
  1414    0]
 [   0    0    0    0    0    0    0    0 5986    0   31    0    0    0
     0    0]
 [   0    0    2    0    0    0    0   39   13 3012   86   21    0    7
     0    0]
 [   0    0    0    0    0    0    0    0    0    0 1029    7    0    0
     0    0]
 [   0    0    0    0    0    0    0    0    0