In [56]:
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import logging
import os
import time
from einops import rearrange
import torch
from torchsummary import summary
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

from utils import Metric,get_model_size,test_speed, set_logger,init_weights,set_seed
from models.KANFormer import KANFormer
from utils import get_model_size
from data import ChikuseiDataset

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Data
import rasterio
import spectral

# Specify the path to the ENVI data file and the file with .hdr
file = '/mnt/c/data/chikusei_ENVI/HyperspecVNIR_Chikusei_20140729.bsq'
header_file = '/mnt/c/data/chikusei_ENVI/HyperspecVNIR_Chikusei_20140729.hdr'

# Open the ENVI image using rasterio
with rasterio.open(file) as src:
    # Read the hyperspectral data into a NumPy array
    print("Reading HS data")
    hyperspectral_data = src.read()

    # Display information about the hyperspectral data
    print('Shape of hyperspectral data:', hyperspectral_data.shape)
    print('Number of bands:', src.count)

Reading HS data
Shape of hyperspectral data: (128, 2517, 2335)
Number of bands: 128


In [3]:
#Open the image with spectral
header_spectral = spectral.open_image(header_file)

# Access the wavelengths associated with each band
w_vector = np.array(header_spectral.bands.centers)

# Display information about the hyperspectral data and wavelengths
print('Shape of hyperspectral data:', header_spectral.shape)
print('Number of bands:', header_spectral.shape[2])
print('Wavelengths:', w_vector)

Shape of hyperspectral data: (2517, 2335, 128)
Number of bands: 128
Wavelengths: [0.36259 0.36775 0.3729  0.37807 0.38323 0.38839 0.39355 0.39871 0.40387
 0.40903 0.41419 0.41936 0.42452 0.42968 0.43484 0.44    0.44516 0.45032
 0.45548 0.46064 0.4658  0.47096 0.47612 0.48129 0.48645 0.49161 0.49677
 0.50193 0.50709 0.51225 0.51741 0.52257 0.52773 0.53289 0.53806 0.54321
 0.54838 0.55354 0.5587  0.56386 0.56902 0.57418 0.57934 0.5845  0.58966
 0.59483 0.59999 0.60514 0.61031 0.61547 0.62063 0.62579 0.63095 0.63611
 0.64127 0.64643 0.65159 0.65675 0.66192 0.66707 0.67224 0.6774  0.68256
 0.68772 0.69288 0.69804 0.7032  0.70836 0.71352 0.71868 0.72385 0.72901
 0.73417 0.73933 0.74449 0.74965 0.75481 0.75997 0.76513 0.77029 0.77545
 0.78061 0.78578 0.79094 0.7961  0.80126 0.80642 0.81158 0.81674 0.8219
 0.82706 0.83223 0.83738 0.84254 0.84771 0.85287 0.85803 0.86319 0.86835
 0.87351 0.87867 0.88383 0.88899 0.89416 0.89931 0.90448 0.90964 0.9148
 0.91996 0.92512 0.93028 0.93544 0.9406  0.94

In [59]:
full_image = rearrange(hyperspectral_data,'c h w -> h w c')
chikusei_data = ChikuseiDataset(full_image=full_image,training_zone=[128,128,1024,1024],wave_vector=w_vector,scale=4,gt_size=64)

In [62]:
for idx,dataload in enumerate(chikusei_data):
    print(dataload[1])

tensor([[[ 105.,  319.,  454.,  ..., 1324., 1261., 1310.],
         [1304., 1318., 1387.,  ..., 1347., 1315., 1255.],
         [1264., 1275., 1293.,  ..., 1448., 1446., 1459.],
         ...,
         [ 752.,  744.,  748.,  ...,  743.,  755.,  769.],
         [ 732.,  757.,  759.,  ...,  744.,  739.,  732.],
         [ 716.,  699.,  681.,  ...,  440.,  433.,  428.]],

        [[  72.,  178.,  210.,  ...,  387.,  374.,  404.],
         [ 402.,  405.,  437.,  ...,  465.,  467.,  466.],
         [ 486.,  508.,  531.,  ...,  576.,  576.,  582.],
         ...,
         [1180., 1155., 1174.,  ..., 1134., 1127., 1137.],
         [1105., 1121., 1122.,  ..., 1096., 1087., 1077.],
         [1052., 1027.,  998.,  ...,  682.,  683.,  690.]],

        [[  54.,  211.,  284.,  ...,  499.,  472.,  495.],
         [ 495.,  505.,  540.,  ...,  581.,  571.,  556.],
         [ 569.,  584.,  601.,  ...,  698.,  699.,  710.],
         ...,
         [1943., 1916., 1927.,  ..., 1893., 1910., 1941.],
         [

In [47]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# torch.cuda.empty_cache()

# Model
HSI_bands = full_image.shape[2]
MSI_bands = 4
chikusei_KAN = KANFormer(HSI_bands=HSI_bands,MSI_bands=MSI_bands,hidden_dim=256,scale=4,depth=4,image_size=64)
chikusei_KAN = chikusei_KAN.to(device)

# Training params
start_epoch = 0
epochs = 1000
batch_size = 1
lr = 4e-4
loss_func = torch.nn.L1Loss()
optimizer = torch.optim.Adam(lr=lr,params=chikusei_KAN.parameters())
scheduler = StepLR(optimizer=optimizer,step_size=100,gamma=0.1)

full_image = rearrange(hyperspectral_data,'c h w -> h w c')
chikusei_data = ChikuseiDataset(full_image=full_image,training_zone=[128,128,1024,1024],wave_vector=w_vector,scale=4,gt_size=64)
train_dataloader = DataLoader(chikusei_data,batch_size=batch_size,drop_last=True,shuffle=True)

model_name = 'KANFormer'
scale = 4
dataset = 'chikusei'

# Logs
# inference_time,flops,params = test_speed(chikusei_KAN,device,HSI_bands,scale=scale,channels=MSI_bands)
# now = str(datetime.now().replace(minute=0,second=0,microsecond=0)) # current date and time
# log_dir = f'./trained_models/{model_name}_x{scale}_{dataset},{now}'
# log_out = 1
# if not os.path.exists(log_dir) and log_out == 1:
#         os.mkdir(log_dir)
# logger = set_logger(model_name, log_dir, log_out)
# model_size = get_model_size(chikusei_KAN)
# logger.info(f'[model:{model_name}_x{scale},dataset:{dataset}],model_size:{params}M,inference_time:{inference_time:.6f}S,FLOPs:{flops}G')


4
4


In [44]:
chikusei_KAN.eval()
random_lr_hs = np.random.random((1,128,16,16))
random_hr_ms = np.random.random((1,4,64,64))
random_lr_hs_torch = torch.tensor(random_lr_hs, device=device, dtype=torch.float32)
random_hr_ms_torch = torch.tensor(random_hr_ms, device=device, dtype=torch.float32)


In [45]:
output = chikusei_KAN(random_lr_hs_torch,random_hr_ms_torch)

Forward pass


In [46]:
output.cpu().detach().numpy()

array([[[[0.9836007 , 0.9027434 , 0.8181177 , ..., 0.3975624 ,
          0.30005556, 0.22465308],
         [0.77070993, 0.7391971 , 0.72735035, ..., 0.4323566 ,
          0.32076403, 0.2180154 ],
         [0.5122876 , 0.5584173 , 0.63046765, ..., 0.4946113 ,
          0.3462749 , 0.23693904],
         ...,
         [0.6343538 , 0.66093355, 0.69285804, ..., 0.3999169 ,
          0.35734385, 0.3281936 ],
         [0.5770959 , 0.6115285 , 0.6665974 , ..., 0.43915442,
          0.4184043 , 0.4026086 ],
         [0.5338577 , 0.5836755 , 0.6443608 , ..., 0.44038844,
          0.4529551 , 0.45796743]],

        [[0.76427495, 0.6543199 , 0.5001998 , ..., 0.840094  ,
          0.73968923, 0.68453735],
         [0.6525148 , 0.5574255 , 0.45907897, ..., 0.7669441 ,
          0.6737646 , 0.6057033 ],
         [0.49803022, 0.46980253, 0.4323375 , ..., 0.7270578 ,
          0.5963886 , 0.50608927],
         ...,
         [0.8418702 , 0.73337084, 0.5909009 , ..., 0.54738   ,
          0.4857471 , 0.4

In [23]:
print(torch.cuda.get_device_properties(0).total_memory)
print(torch.cuda.max_memory_allocated())
print(torch.cuda.memory_allocated()) # Memory allocated
print(torch.cuda.memory_reserved())  # Total memory reserved in the CUDA pool
print(torch.cuda.memory_summary(device=0, abbreviated=False))

8585216000
944743936
48275456
1405091840
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  47144 KiB |    900 MiB |  12293 GiB |  12293 GiB |
|       from large pool |  40576 KiB |    876 MiB |  12208 GiB |  12208 GiB |
|       from small pool |   6568 KiB |     31 MiB |     84 GiB |     84 GiB |
|---------------------------------------------------------------------------|
| Active memory         |  47144 KiB |    900 MiB |  12293 GiB |  12293 GiB |
|       from large pool |  40576 KiB |    876 MiB |  12208 GiB |  12208 GiB |
|       from small pool |   6568 KiB |     31 MiB |     84 GiB |     84 GiB |
|----------------------

In [49]:
def train(epochs: int,model: torch.nn.Module):
    model.train()
    hist_loss = []
    for epoch in range(epochs):
        loss_list = []
        start_time = time.time()
        print("Epoch: ", epoch)
        for idx,loader_data in enumerate(train_dataloader):
            GT,LRHSI,HRMSI = loader_data[0].to(device),loader_data[1].to(device),loader_data[2].to(device)
            preHSI = chikusei_KAN(LRHSI,HRMSI)
            print(preHSI)
            loss = loss_func(GT,preHSI) #+chikusei_KAN.regularization_loss()
            print("Loss: ", loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
        scheduler.step()
        print(loss_list)
        logging.info(f'Epoch:{epoch},loss:{np.mean(loss_list)},time:{time.time()-start_time:.2f}s')
        


In [50]:
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 
train(1,chikusei_KAN)

Epoch:  0
Forward pass
tensor([[[[ 3.5775e+26, -1.9814e+26,  5.3473e+26,  ...,  4.0954e+26,
           -2.3707e+26, -6.3795e+26],
          [ 7.6291e+26, -2.4323e+26,  2.5767e+26,  ...,  3.5932e+26,
           -2.0919e+25, -1.4076e+27],
          [ 7.0899e+26,  1.9403e+26,  1.7866e+26,  ...,  6.2443e+26,
            1.3337e+26, -1.6107e+27],
          ...,
          [ 4.5967e+26,  2.5596e+26, -3.1017e+26,  ...,  1.0145e+27,
           -1.9279e+26, -2.4729e+27],
          [ 3.1127e+26,  3.0762e+26,  1.0939e+26,  ...,  5.1516e+26,
            4.2523e+26, -2.2931e+27],
          [-1.4302e+26,  5.1817e+26, -3.7587e+26,  ..., -7.1815e+26,
           -7.1957e+26, -1.9399e+27]],

         [[-1.4410e+27, -7.6992e+26, -1.0444e+27,  ..., -1.4627e+27,
           -1.4958e+27, -3.4951e+26],
          [-1.4639e+27, -1.0547e+27, -8.4259e+26,  ..., -1.4351e+27,
           -1.2180e+27, -1.2327e+27],
          [-1.2994e+27, -1.4384e+27, -1.1326e+27,  ..., -1.6789e+27,
           -1.3914e+27, -1.4774e+27