# Packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

import numpy as np
import scipy.io as sio
import os, glob

import csv
import matplotlib.pyplot as plt
import random

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.ndimage import sobel
import lpips
from sklearn.decomposition import PCA



#import matlab.engine
#eng = matlab.engine.start_matlab()



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)     


# Metrics

In [None]:
def scc(sr, hr):
    sr = sr.astype(np.float64)
    hr = hr.astype(np.float64)
    
    if sr.ndim == 2:
        sr_lap_x = sobel(sr, axis=1)
        sr_lap_y = sobel(sr, axis=0)
        sr_lap = np.sqrt(sr_lap_x**2 + sr_lap_y**2)

        hr_lap_x = sobel(hr, axis=1)
        hr_lap_y = sobel(hr, axis=0)
        hr_lap = np.sqrt(hr_lap_x**2 + hr_lap_y**2)

        scc_map = (sr_lap * hr_lap) / (np.sqrt(np.sum(sr_lap**2)) * np.sqrt(np.sum(hr_lap**2)))
    else:
        sr_lap = np.zeros(sr.shape)
        hr_lap = np.zeros(hr.shape)
        
        for idim in range(sr.shape[2]):  # Loop over spectral bands
            sr_lap_x = sobel(sr[:, :, idim], axis=1)
            sr_lap_y = sobel(sr[:, :, idim], axis=0)
            sr_lap[:, :, idim] = np.sqrt(sr_lap_x**2 + sr_lap_y**2)

            hr_lap_x = sobel(hr[:, :, idim], axis=1)
            hr_lap_y = sobel(hr[:, :, idim], axis=0)
            hr_lap[:, :, idim] = np.sqrt(hr_lap_x**2 + hr_lap_y**2)

        scc_map = np.sum(sr_lap * hr_lap, axis=2) / (np.sqrt(np.sum(sr_lap**2, axis=2)) * np.sqrt(np.sum(hr_lap**2, axis=2)))
    
    scc_value = np.sum(sr_lap * hr_lap)
    scc_value /= np.sqrt(np.sum(sr_lap**2))
    scc_value /= np.sqrt(np.sum(hr_lap**2))

    return scc_value, scc_map



loss_fn = lpips.LPIPS(net='alex') 
def calculate_lpips_bandwise(sr, hr, loss_fn):
    lpips_bandwise = []
    num_bands = sr.shape[2]  # shape (H, W, C) for image

    for band in range(num_bands):
        sr_band = torch.tensor(sr[:, :, band]).unsqueeze(0).unsqueeze(0).float() 
        hr_band = torch.tensor(hr[:, :, band]).unsqueeze(0).unsqueeze(0).float()
        
        lpips_value = loss_fn(sr_band, hr_band)
        lpips_bandwise.append(lpips_value.item())

    return np.mean(lpips_bandwise)


# Data

In [None]:

def compute_global_metrics(lr_files, hr_files):
    all_intensity_values = []

    for lr_file in lr_files:
        corresponding_hr_file = lr_file.replace('_LR4', '')
        if corresponding_hr_file in hr_files:
            lr_image = sio.loadmat(lr_file)['radiance']
            hr_image = sio.loadmat(corresponding_hr_file)['radiance']

            all_intensity_values.append(lr_image.flatten())
            all_intensity_values.append(hr_image.flatten())

    all_intensity_values = np.concatenate(all_intensity_values)

    global_min = np.min(all_intensity_values)
    global_max = np.max(all_intensity_values)
    global_mean = np.mean(all_intensity_values)
    global_median = np.median(all_intensity_values)
    global_std = np.std(all_intensity_values)  

    return all_intensity_values, global_min, global_max, global_mean, global_median, global_std

def convert_normalise_meanSTD(image, global_mean, global_std):
    min, max = image.min(), image.max()
    diff = max - min
    image = torch.tensor(image, dtype=torch.float32)
    image = (image - global_mean) / global_std
    norm_min, norm_max = image.min().item(), image.max().item()
    return image, min, max, diff, norm_min, norm_max




In [None]:

def extract_patches(image, patch_size, stride=16):
    img_h, img_w, bands = image.shape  
    patch_h, patch_w = patch_size

    patches = []
    for i in range(0, img_h - patch_h + 1, stride):
        for j in range(0, img_w - patch_w + 1, stride):
            patch = image[i:i + patch_h, j:j + patch_w, :]  
            patches.append(patch)
    return np.array(patches)


In [None]:
def load_data_with_patches(data_path, patch_size, BAND, global_mean=None, global_std=None):
    lr_data, hr_data, global_mean, global_std = load_normalise_data(data_path, BAND, global_mean, global_std)

    lr_patches = []
    hr_patches = []
    print(f'LR Data Shape: {lr_data.shape}' )
    print(f'HR Data Shape: {hr_data.shape}')

    for lr_img, hr_img in zip(lr_data, hr_data):
        lr_img_patches = extract_patches(lr_img, patch_size)  # lr_img (spectral bands, H, W)
        hr_img_patches = extract_patches(hr_img, (patch_size[0] * 4, patch_size[1] * 4), stride=64)
        lr_patches.extend(lr_img_patches)
        hr_patches.extend(hr_img_patches)

    lr_patches = np.array(lr_patches)
    hr_patches = np.array(hr_patches)

    print(f'LR Patch Shape: {lr_patches.shape}')
    print(f'HR Patch Shape: {hr_patches.shape}')
    
    return lr_patches, hr_patches, global_mean, global_std


In [None]:


def load_normalise_data(data_dir, BAND, global_mean=None, global_std=None):
    lr_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('RADIANCE_cropped_hyper_LR4.mat') and BAND in f ])
    hr_files = sorted([os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('RADIANCE_cropped_hyper.mat') and BAND in f])

    if global_mean is None or global_std is None:
        all_intensity_values, global_min, global_max, global_mean, global_median, global_std = compute_global_metrics(lr_files, hr_files)
        """
        output_csv = os.path.join(params.save_dir, params.save_prefix + '_global_metrics.csv')
        with open(output_csv, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Metric", "Value"])
            writer.writerow(["Global Min", global_min])
            writer.writerow(["Global Max", global_max])
            writer.writerow(["Global Mean", global_mean])
            writer.writerow(["Global Median", global_median])

        #plot_global_histogram(all_intensity_values, global_min, global_max, global_mean, global_median, global_std)
        """    
    print(f"Using Global Mean: {global_mean}, Global Std: {global_std}")

    lr_data = []
    hr_data = []

    for lr_file in lr_files:
        corresponding_hr_file = lr_file.replace('_LR4', '')
        if corresponding_hr_file in hr_files:
            lr_image = sio.loadmat(lr_file)['radiance']
            hr_image = sio.loadmat(corresponding_hr_file)['radiance']

            if len(lr_image.shape) == 2:
                lr_image = lr_image[np.newaxis, :, :]
            if len(hr_image.shape) == 2:
                hr_image = hr_image[:, :, np.newaxis]

            lr_image, lr_min, lr_max, lr_diff, lr_norm_min, lr_norm_max = convert_normalise_meanSTD(lr_image, global_mean, global_std)
            hr_image, hr_min, hr_max, hr_diff, hr_norm_min, hr_norm_max = convert_normalise_meanSTD(hr_image, global_mean, global_std)

            lr_data.append(lr_image)
            hr_data.append(hr_image)

    lr_data = np.array([img.numpy() for img in lr_data])
    hr_data = np.array([img.numpy() for img in hr_data])        
    return lr_data, hr_data, global_mean, global_std

# Architecture 

In [None]:
class DSC(nn.Module):
    def __init__(self, in_channels, out_channels, num_spectral_bands, depth_multiplier=1, upsample_scale=2, mode='bilinear'):
        super(DSC, self).__init__()
        
        self.depthwise_conv = nn.Conv2d(in_channels=num_spectral_bands, 
                                        out_channels=num_spectral_bands * depth_multiplier,  
                                        kernel_size=3,  
                                        stride=1,
                                        padding=1,
                                        groups=num_spectral_bands,  
                                        bias=False)
        
        self.pointwise_conv = nn.Conv2d(in_channels=num_spectral_bands * depth_multiplier, 
                                        out_channels=out_channels,  
                                        kernel_size=1,  
                                        bias=False)
        
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # shape (batch_size, num_spectral_bands, height, width)
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        
        return x


class S5_DSCR_S(nn.Module):
    def __init__(self, in_channels, out_channels, num_spectral_bands, depth_multiplier=1, upsample_scale=2, mode='bilinear'):
        super(S5_DSCR_S, self).__init__()
        self.interpolation = nn.Upsample(scale_factor=upsample_scale, mode='bicubic', align_corners=False)
        self.dsc_block = DSC(in_channels, out_channels, num_spectral_bands, depth_multiplier)

    def forward(self, x, target_size=None):
        interpolated = self.interpolation(x)
        refined = self.dsc_block(interpolated)
        
        if target_size is not None:
            interpolated = F.interpolate(interpolated, size=target_size, mode='bicubic', align_corners=False)
            refined = F.interpolate(refined, size=target_size, mode='bicubic', align_corners=False)
        else:
            refined = F.interpolate(refined, size=interpolated.shape[2:], mode='bicubic', align_corners=False)
        output = refined + interpolated
        return output


In [None]:
class ImprovedDSC_2(nn.Module):
    def __init__(self, in_channels, out_channels, num_spectral_bands, depth_multiplier=1, num_layers=3, kernel_size=3):
        super(ImprovedDSC_2, self).__init__()
        
        layers = []
        for _ in range(num_layers):
            depthwise_conv = nn.Conv2d(
                in_channels=num_spectral_bands,  
                out_channels=num_spectral_bands * depth_multiplier,  
                kernel_size=kernel_size,  
                stride=1,
                padding=kernel_size // 2, 
                groups=num_spectral_bands,  
                bias=False
            )
            layers.append(depthwise_conv)
            
            pointwise_conv = nn.Conv2d(
                in_channels=num_spectral_bands * depth_multiplier,  
                out_channels=out_channels,  
                kernel_size=1,  
                bias=False
            )
            layers.append(pointwise_conv)
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU())

        self.conv_layers = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.conv_layers(x)


class S5_DSCR(nn.Module):
    def __init__(self, in_channels, out_channels, num_spectral_bands, depth_multiplier=1, num_layers=3, kernel_size=3, upsample_scale=2):
        super(S5_DSCR, self).__init__()
        self.interpolation = nn.Upsample(scale_factor=upsample_scale, mode='bicubic', align_corners=False)
        self.dsc_block = ImprovedDSC_2(in_channels, out_channels, num_spectral_bands, depth_multiplier, num_layers, kernel_size)

    def forward(self, x):
        interpolated = self.interpolation(x)
        refined = self.dsc_block(interpolated)
        output = refined + interpolated
        return output


# Model training


In [None]:
class Arguments:
    def __init__(self,
                nepochs = 100,
                report_step = 5,
                batch_size = 1,
                validation = 0.3,
                net_loss = 'MSE',
                net_opt = 'Adam',
                net_lr = 1e-2,
                net_name = 'S5Net',
                save_dir = '',
                ftrain = '',
                ftest = '',
                fvalid = '',
                pretrain='',
                save_prefix=''):
        self.nepochs = nepochs
        self.report_step = report_step
        self.batch_size = batch_size
        self.validation = validation
        self.net_loss = net_loss
        self.net_opt = net_opt
        self.net_lr = net_lr
        self.net_name = net_name
        self.save_dir = save_dir
        self.ftrain = ftrain
        self.ftest = ftest
        self.fvalid = fvalid
        self.pretrain=pretrain
        self.save_prefix=save_prefix

In [None]:
BASE_PATH = os.path.expanduser("~/data_sets/")
SAVE_DIR  = os.path.join(BASE_PATH, 'outputs')
net_optimizer = 'Adam' # 'adadelta', 'adadelta'
net_lr = 1e-3
net_loss = 'MSE' #'L1norm'
net_name = 'BAND4_hyper'
SAVE_PREFIX = net_optimizer +'_lr'+ str(net_lr) +'_' + str(net_loss)+'_'+ net_name
train_image_names = os.path.join(BASE_PATH, 'train_hyper')
valid_image_names  = os.path.join(BASE_PATH, 'valid_hyper')
test_image_names  = os.path.join(BASE_PATH, 'test_hyper')


params = Arguments(
    nepochs=50,
    report_step=5,
    batch_size=1,
    validation=0.2,
    net_loss = net_loss,
    net_opt = net_optimizer,
    net_lr = net_lr,
    net_name= net_name,
    save_dir=SAVE_DIR,
    ftrain=train_image_names,
    ftest=test_image_names,
    fvalid=valid_image_names,
    pretrain='',
    save_prefix = SAVE_PREFIX
)

SAVE_PREFIX

# Testing

In [None]:
args = params


In [None]:
csv_file = os.path.join(args.save_dir, args.save_prefix, "results.csv")

def metric_s5net(model, test_loader, device, network_name, csv_filename= csv_file):
    model.eval()
    psnr_values, scc_values = [], []
    ssim_values, lpips_values = [], []
    lr_images, hr_images, sr_images = [], [], []
    ii =0
    with torch.no_grad():
        for lr, hr in test_loader:
            lr, hr = lr.to(device), hr.to(device)
            model = model.to(device)  
            output = model(lr)  
            
            for i in range(output.shape[0]):
                sr = output[i].cpu().numpy().squeeze()
                hr_img = hr[i].cpu().numpy().squeeze()
                lr_img = lr[i].cpu().numpy().squeeze()

                # LPIPS (normalized to [-1, 1])
                sr_lpips = 2 * sr - 1
                hr_lpips = 2 * hr_img - 1
                lpips_value = calculate_lpips_bandwise(sr_lpips, hr_lpips, loss_fn)

                # normalize to [0, 1]
                sr = (sr - sr.min()) / (sr.max() - sr.min())
                hr_img = (hr_img - hr_img.min()) / (hr_img.max() - hr_img.min())
                
                psnr_value = psnr(hr_img, sr, data_range=1)  # [0, 1] range
                scc_value, _ = scc(sr, hr_img)
                ssim_value = ssim(hr_img, sr, data_range=1)  # [0, 1] range

                psnr_values.append(psnr_value)
                scc_values.append(scc_value)
                ssim_values.append(ssim_value)
                lpips_values.append(lpips_value)

                lr_images.append(lr_img)
                hr_images.append(hr_img)
                sr_images.append(sr)

                ii= ii+1

                hr_reshaped = hr_img.reshape(hr_img.shape[0], -1).T  
                pca = PCA(n_components=3)
                pca.fit(hr_reshaped)  
                
                hr_pca = pca.transform(hr_reshaped).T.reshape(3, hr_img.shape[1], hr_img.shape[2])  
                lr_reshaped = lr_img.reshape(lr_img.shape[0], -1).T
                lr_pca = pca.transform(lr_reshaped).T.reshape(3, lr_img.shape[1], lr_img.shape[2])  
                sr_reshaped = sr.reshape(sr.shape[0], -1).T
                sr_pca = pca.transform(sr_reshaped).T.reshape(3, sr.shape[1], sr.shape[2])  
                
                lr_pca = (lr_pca - lr_pca.mean()) / lr_pca.std()
                hr_pca = (hr_pca - hr_pca.mean()) / hr_pca.std()
                sr_pca = (sr_pca - sr_pca.mean()) / sr_pca.std()

                plot_hyperspectral_images_false_color_global2(lr_pca, hr_pca, sr_pca, ii, network_name, bands=[1,0, 2], cmap='viridis') 

    avg_psnr = np.mean(psnr_values)
    avg_scc = np.mean(scc_values)
    avg_ssim = np.mean(ssim_values)
    avg_lpips = np.mean(lpips_values)

    print(f"PSNR: {avg_psnr:.4f}")
    print(f"SCC: {avg_scc:.4f}")
    print(f"SSIM: {avg_ssim:.4f}")
    print(f"LPIPS: {avg_lpips:.4e}")

    rounded_values = [
    network_name,
    round(avg_psnr, 4),
    round(avg_scc, 4),
    round(avg_ssim, 4),
    f"{avg_lpips:.4e}"]

    file_exists = os.path.isfile(csv_filename)
    with open(csv_filename, mode='a', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        if not file_exists:
            csv_writer.writerow(["Network Name", "PSNR", "SCC", "SSIM", "LPIPS"])
        csv_writer.writerow(rounded_values)
    
    return avg_psnr, avg_scc,  avg_ssim, avg_lpips, lr_images, hr_images, sr_images


## loader and paramters 

In [None]:
args = params

lr_patches, hr_patches, global_mean, global_std  = load_data_with_patches(args.ftrain, (64, 64), 'BAND4')


lr_patches = lr_patches.transpose(0, 3, 1, 2)  # to (batch_size, channels, height, width)
hr_patches = hr_patches.transpose(0, 3, 1, 2)

train_data = [(torch.tensor(lr, dtype=torch.float32), torch.tensor(hr, dtype=torch.float32)) for lr, hr in zip(lr_patches, hr_patches)]
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
    
num_bands = lr_patches.shape[1]  # num of spectral channels


In [None]:
lr_patches, hr_patches,_ , _  = load_data_with_patches(params.fvalid, (64, 64), 'BAND4', global_mean=global_mean, global_std=global_std)
    
lr_patches = lr_patches.transpose(0, 3, 1, 2)  
hr_patches = hr_patches.transpose(0, 3, 1, 2)

valid_data = [(torch.tensor(lr, dtype=torch.float32), torch.tensor(hr, dtype=torch.float32)) for lr, hr in zip(lr_patches, hr_patches)]
valid_loader = DataLoader(valid_data, batch_size=params.batch_size, shuffle=True)

num_bands = lr_patches.shape[1]

In [None]:
lr_patches, hr_patches,_ , _  = load_data_with_patches(params.ftest, (64, 64), 'BAND4', global_mean=global_mean, global_std=global_std)
    
lr_patches = lr_patches.transpose(0, 3, 1, 2)  
hr_patches = hr_patches.transpose(0, 3, 1, 2)

test_data = [(torch.tensor(lr, dtype=torch.float32), torch.tensor(hr, dtype=torch.float32)) for lr, hr in zip(lr_patches, hr_patches)]
test_loader = DataLoader(test_data, batch_size=params.batch_size, shuffle=False)

num_bands = lr_patches.shape[1]

# S5_DSCR_S

In [None]:
def S5_DSCR_S_train(args):
    model = S5_DSCR_S(in_channels=num_bands, 
                            out_channels=497, 
                            num_spectral_bands=num_bands, 
                            depth_multiplier=1, 
                            upsample_scale=4, 
                            mode='convtranspose').to(device)  
    
    model = model.to(device)     
    summary(model, input_size=(num_bands, 64, 64))

    log_dir = os.path.join(args.save_dir, args.save_prefix, 'DSC')
    writer_tensor = SummaryWriter(log_dir=log_dir)

    criterion = nn.L1Loss() if args.net_loss == 'L1norm' else nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=args.net_lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

    train_losses, val_losses = [], []
    
    for epoch in range(args.nepochs):
        model.train()
        epoch_loss = 0
        for batch_idx, (lr, hr) in enumerate(train_loader):
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad()
            output = model(lr)
            loss = criterion(output, hr)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            
            if epoch % 2 == 0 and batch_idx == 0:  
                for i in range(min(5, len(lr))):  
                    lr_img = lr.cpu().numpy()[i]
                    hr_img = hr.cpu().numpy()[i]
                    pred_img = output.cpu().detach().numpy()[i]
                    fig = plot_hyperspectral_images_false_color_train(lr_img, hr_img, pred_img, idx=epoch * len(train_loader) + batch_idx * len(lr) + i)
                    writer_tensor.add_figure(f'Predictions vs Actuals/Epoch_{epoch}', fig, global_step=epoch * len(train_loader) + batch_idx * len(lr) + i)
        
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for lr, hr in valid_loader:
                lr, hr = lr.to(device), hr.to(device)
                output = model(lr)
                val_loss += criterion(output, hr).item()

        train_losses.append(epoch_loss / len(train_loader))
        val_losses.append(val_loss / len(valid_loader))
        
        print(f"Epoch {epoch+1}/{args.nepochs}, Train Loss: {train_losses[-1]}, Validation Loss: {val_losses[-1]}")
        writer_tensor.add_scalar('Loss/Train', train_losses[-1], epoch)
        writer_tensor.add_scalar('Loss/Validation', val_losses[-1], epoch)
        scheduler.step(val_losses[-1])

    writer_tensor.close()

    try:
        torch.save(model.state_dict(), os.path.join(args.save_dir, f"{args.save_prefix}_DSC2_updated_hyperspectral_model.pth"))
        print('Model saved successfully.')
    except Exception as e:
        print(f"Error saving model: {e}")
    
    try:
        with open(os.path.join(args.save_dir, f"{args.save_prefix}_DSC2_updated_losses.csv"), mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Epoch", "Train Loss", "Validation Loss"])
            for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
                writer.writerow([epoch, train_loss, val_loss])
        print('Loss file saved successfully.')
    except Exception as e:
        print(f"Error saving loss file: {e}")




In [None]:
S5_DSCR_S_train(params)

In [None]:
def S5_DSCR_S_test(params):
    model = S5_DSCR_S(in_channels=num_bands, 
                            out_channels=497, 
                            num_spectral_bands=num_bands, 
                            depth_multiplier=1, 
                            upsample_scale=4, 
                            mode='convtranspose').to(device)  
    
    model.load_state_dict(torch.load(os.path.join(params.save_dir, f"{params.save_prefix}_DSC2_updated_hyperspectral_model.pth")))

    avg_psnr, avg_scc, avg_ssim, avg_lpips, lr_images, hr_images, sr_images = metric_s5net(model, test_loader, device, 'S5_DSCR_S')
    return lr_images, hr_images, sr_images
    


In [None]:
_,_,_ = S5_DSCR_S_test(params)

# S5_DSCR

In [None]:

def S5_DSCR_train(args):
    model = S5_DSCR(
        in_channels=497,
        out_channels=497,
        num_spectral_bands=497,
        depth_multiplier=3,
        num_layers=5,
        kernel_size=5,
        upsample_scale=4)

    model = model.to(device) 
    summary(model, input_size=(num_bands, 64, 64))

    log_dir = os.path.join(args.save_dir, args.save_prefix, 'DSC_residual2')
    writer_tensor = SummaryWriter(log_dir=log_dir)

    criterion = nn.L1Loss() if args.net_loss == 'L1norm' else nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=args.net_lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)

    train_losses, val_losses = [], []

    
    for epoch in range(args.nepochs):
        model.train()
        epoch_loss = 0
        for batch_idx, (lr, hr) in enumerate(train_loader):
            lr, hr = lr.to(device), hr.to(device)
            optimizer.zero_grad()
            output = model(lr)
            loss = criterion(output, hr)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            
            if epoch % 2 == 0 and batch_idx == 0:  
                for i in range(min(5, len(lr))):  
                    lr_img = lr.cpu().numpy()[i]
                    hr_img = hr.cpu().numpy()[i]
                    pred_img = output.cpu().detach().numpy()[i]
                    fig = plot_hyperspectral_images_false_color_train(lr_img, hr_img, pred_img, idx=epoch * len(train_loader) + batch_idx * len(lr) + i)
                    writer_tensor.add_figure(f'Predictions vs Actuals/Epoch_{epoch}', fig, global_step=epoch * len(train_loader) + batch_idx * len(lr) + i)
        
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for lr, hr in valid_loader:
                lr, hr = lr.to(device), hr.to(device)
                output = model(lr)
                val_loss += criterion(output, hr).item()

        train_losses.append(epoch_loss / len(train_loader))
        val_losses.append(val_loss / len(valid_loader))
        
        print(f"Epoch {epoch+1}/{args.nepochs}, Train Loss: {train_losses[-1]}, Validation Loss: {val_losses[-1]}")
        writer_tensor.add_scalar('Loss/Train', train_losses[-1], epoch)
        writer_tensor.add_scalar('Loss/Validation', val_losses[-1], epoch)
        scheduler.step(val_losses[-1])

    writer_tensor.close()

    try:
        torch.save(model.state_dict(), os.path.join(args.save_dir, f"{args.save_prefix}_DSC_residual2_updated_hyperspectral_model.pth"))
        print('Model saved successfully.')
    except Exception as e:
        print(f"Error saving model: {e}")
    
    try:
        with open(os.path.join(args.save_dir, f"{args.save_prefix}_DSC_residual2_updated_losses.csv"), mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["Epoch", "Train Loss", "Validation Loss"])
            for epoch, (train_loss, val_loss) in enumerate(zip(train_losses, val_losses), 1):
                writer.writerow([epoch, train_loss, val_loss])
        print('Loss file saved successfully.')
    except Exception as e:
        print(f"Error saving loss file: {e}")



In [None]:
S5_DSCR_train(params)

In [None]:

def S5_DSCR_test(params):
    model = S5_DSCR(
        in_channels=497,
        out_channels=497,
        num_spectral_bands=497,
        depth_multiplier=3,
        num_layers=5,
        kernel_size=5,
        upsample_scale=4)

    model.load_state_dict(torch.load(os.path.join(params.save_dir, f"{params.save_prefix}_DSC_residual2_updated_hyperspectral_model.pth")))
    avg_psnr, avg_scc, avg_ssim, avg_lpips, lr_images, hr_images, sr_images = metric_s5net(model, test_loader, device,'S5_DSCR')
    return lr_images, hr_images, sr_images


In [None]:
_,_,_ = S5_DSCR_test(params)