In [8]:
import os
import random
import torch
import torch.nn as nn
from torch.nn import utils
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import HTML
import pathlib
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
import spectral.io.envi as envi
from spectral import imshow
import math
import csv

manualSeed = 999
debug = True
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

file_tracker = {
    1: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente1/Paziente1M1/Paziente1M1_nrm.hdr'},
    2: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente1/Paziente1M2/Paziente1M2_nrm.hdr'},
    3: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente2/Paziente2M1_nrm.hdr'},
    4: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente3/Paziente3M1/Paziente3M1_nrm.hdr'},
    5: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente3/Paziente3M2/Paziente3M2_nrm.hdr'},
    6: {"counter": 0, "path": 'Preliminary measurements on breast tissue/February 2025/270225/Paziente3/Paziente3M3/Paziente3M3_nrm.hdr'},
    7: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente1/Sample1_2_M1_nrm.hdr'},
    8: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente1/Sample1_2_M2_nrm.hdr'},
    9: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente2/Set1/Sample_2_1_M1_nrm.hdr'},
    10: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente2/Set1/Sample_2_1_M2_nrm.hdr'},
    11: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente2/Set2/Sample_2_2_M1_nrm.hdr'},
    12: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/170724/Paziente2/Set2/Sample_2_2_M2_nrm.hdr'},
    13: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente3/Sample3_M1_nrm.hdr'},
    14: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente3/Sample3_M2_nrm.hdr'},
    15: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente4/Set1/Sample4_latoA_M1_nrm.hdr'},
    16: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente4/Set1/Sample4_latoA_M2_nrm.hdr'},
    17: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente4/Set2/Sample4_latoB_M1_nrm.hdr'},
    18: {"counter": 0, "path": 'Preliminary measurements on breast tissue/July 2024/180724/Paziente4/Set2/Sample4_latoB_M2_nrm.hdr'},
    19: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente1/Paziente1M1/Paziente1M1_nrm.hdr'},
    20: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente1/Paziente1M2/Paziente1M2_nrm.hdr'},
    21: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente1/Paziente1M3/Paziente1M3_nrm.hdr'},
    22: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente2/Paziente2M1/Paziente2M1_nrm.hdr'},
    23: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente2/Paziente2M2/Paziente2M2_nrm.hdr'},
    24: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente2/Paziente2M3/Paziente2M3_nrm.hdr'},
    25: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente3/Paziente3M1/Paziente3M1_nrm.hdr'},
    26: {"counter": 0, "path": 'Preliminary measurements on breast tissue/November 2024/141124/141124/Paziente3/Paziente3M2/Paziente3M2_nrm.hdr'}   
}

SAMPLES_PER_IMAGE = 60
IMG_SIZE = 244
STD_THRESHOLD = 0.15
X_BOUND = 1024
Y_BOUND = 1280
CHANNELS = 3
Z_SIZE_REDUCTION = 100
Z_SIZE = 128
RGB_BANDS = [29, 17, 7]
BATCH_SIZE = 64

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Random Seed:  999


In [None]:
def get_list_img_samples(file_path, samples_per_img, img_size, x_bound, y_bound, std_threshold, debug=False):
    img = envi.open(file_path)
    data_array = img.load()
    data_array = np.nan_to_num(data_array)
    sample_list = []
    coordinates_list = []    
    while len(sample_list) < samples_per_img:
        lower_x_bound = random.randint(0, x_bound - img_size - 1)
        lower_y_bound = random.randint(0, y_bound - img_size - 1)
        sample_array = data_array[lower_x_bound: lower_x_bound + img_size, lower_y_bound:lower_y_bound + img_size, :]
        spatial_std_per_band = np.std(sample_array, axis=(2))
        if (np.mean(spatial_std_per_band) > std_threshold):
            sample_array = get_pca_sample_array(sample_array)
            sample_list.append(sample_array)
            if (debug):
                coordinates_list.append((lower_x_bound,lower_y_bound))             
    return sample_list

def get_pca_sample_array(sample_array):
    scaler = MinMaxScaler()   

    sample_array = sample_array[:, :, [RGB_BANDS]].squeeze()
    height, width, bands = sample_array.shape 
    sample_array = np.transpose(sample_array, (2, 0, 1))
    return sample_array

def samples_dict_init():
    return {
    1: {"counter": 0, "sample_list": []},
    2: {"counter": 0, "sample_list": []},
    3: {"counter": 0, "sample_list": []},
    4: {"counter": 0, "sample_list": []},
    5: {"counter": 0, "sample_list": []},
    6: {"counter": 0, "sample_list": []},
    7: {"counter": 0, "sample_list": []},
    8: {"counter": 0, "sample_list": []},
    9: {"counter": 0, "sample_list": []},
    10: {"counter": 0, "sample_list": []},
    11: {"counter": 0, "sample_list": []},
    12: {"counter": 0, "sample_list": []},
    13: {"counter": 0, "sample_list": []},
    14: {"counter": 0, "sample_list": []},
    15: {"counter": 0, "sample_list": []},
    16: {"counter": 0, "sample_list": []},
    17: {"counter": 0, "sample_list": []},
    18: {"counter": 0, "sample_list": []},
    19: {"counter": 0, "sample_list": []},
    20: {"counter": 0, "sample_list": []},
    21: {"counter": 0, "sample_list": []},
    22: {"counter": 0, "sample_list": []},
    23: {"counter": 0, "sample_list": []},
    24: {"counter": 0, "sample_list": []},
    25: {"counter": 0, "sample_list": []},
    26: {"counter": 0, "sample_list": []}
} 
        
    

class HSIGANDataset(Dataset):
    def __init__(self, file_tracker, samples_per_img, img_size, std_threshold, x_bound, y_bound, transforms=None):
        self.file_tracker = file_tracker 
        self.samples_per_img = samples_per_img
        self.img_size = img_size
        self.x_bound = x_bound
        self.y_bound = y_bound
        self.file_tracker_len = len(self.file_tracker)
        self.total_len =  self.file_tracker_len * samples_per_img
        self.std_threshold = std_threshold
        self.samples_dict = samples_dict_init()    

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        img_to_sample = (idx % self.file_tracker_len) + 1
        counter = self.samples_dict[img_to_sample]["counter"]
        if (len(self.samples_dict[img_to_sample]["sample_list"]) == 0):
            file_path = self.file_tracker[img_to_sample]["path"]
            self.samples_dict[img_to_sample]["sample_list"] = get_list_img_samples(file_path, self.samples_per_img, self.img_size, self.x_bound, self.y_bound, self.std_threshold)
        
        sample = torch.tensor(self.samples_dict[img_to_sample]["sample_list"][counter])
        #sample = sample.permute(2,0,1).contiguous()
        self.samples_dict[img_to_sample]["counter"] = (counter + 1 ) % self.samples_per_img
        return sample

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance


def evaluate_and_save(epoch, dim, generator, real_samples, fid_metric, 
                      num_plot_samples=4, num_fid_samples=2000, fid_batch_size=32,
                      output_dir="training_images_FID_244", log_file="fid_scores.csv"):    
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    generator.eval()
    device = next(generator.parameters()).device    
   
    d = real_samples.shape[0]
    initial_sample_num = random.randint(0, max(0, d - num_plot_samples)) 
    final_sample_num = initial_sample_num + num_plot_samples
    
    real_to_plot = real_samples[initial_sample_num:final_sample_num].detach().cpu().numpy().transpose(0, 2, 3, 1)
    
    with torch.no_grad():
        noise = torch.randn(num_plot_samples, dim).to(device)
        fake_to_plot = generator(noise).detach().cpu().numpy().transpose(0, 2, 3, 1)

    print(f'Epoch {epoch+1} - Plotting and Calculating FID...')
    fig, axes = plt.subplots(2, num_plot_samples, figsize=(12, 8))
    fig.suptitle(f'Epoch {epoch+1}', fontsize=16)

    if num_plot_samples == 1:
        axes = np.expand_dims(axes, axis=1)

    for i in range(num_plot_samples):
        ax_real = axes[0, i]
        ax_fake = axes[1, i]
        
        ax_real.imshow(np.clip(real_to_plot[i], 0, 1))
        ax_real.set_title("Real")
        ax_real.axis('off')

        ax_fake.imshow(np.clip(fake_to_plot[i], 0, 1))
        ax_fake.set_title("Fake")
        ax_fake.axis('off')

    filename = os.path.join(output_dir, f"epoch_{epoch+1:03d}.png")
    fig.savefig(filename)
    plt.close(fig)
    
    fid_metric.reset()
    
    actual_fid_samples = min(num_fid_samples, real_samples.shape[0])
    num_batches = math.ceil(actual_fid_samples / fid_batch_size)

    print(f"Calculating FID on {actual_fid_samples} samples (Batch size: {fid_batch_size})...")

    with torch.no_grad():
        for i in range(num_batches):
            start_idx = i * fid_batch_size
            end_idx = min((i + 1) * fid_batch_size, actual_fid_samples)
            
            real_batch = real_samples[start_idx:end_idx].to(device)
            
            real_batch_uint8 = (real_batch * 255).byte()
            
            fid_metric.update(real_batch_uint8, real=True)
            
            current_batch_size = end_idx - start_idx 
            
            noise_batch = torch.randn(current_batch_size, dim).to(device)
            fake_batch = generator(noise_batch)
            
            fake_batch_uint8 = (fake_batch * 255).clamp(0, 255).byte()
            
            fid_metric.update(fake_batch_uint8, real=False)

    fid_score = fid_metric.compute().item()
    print(f"Epoch {epoch+1} FID Score: {fid_score:.4f}")

    log_path = os.path.join(output_dir, log_file)
    file_exists = os.path.isfile(log_path)
    
    with open(log_path, mode='a', newline='') as f:
        writer = csv.writer(f)
        if not file_exists:
            writer.writerow(['Epoch', 'FID'])
        writer.writerow([epoch + 1, fid_score])

    generator.train()

  _torch_pytree._register_pytree_node(


In [None]:
def plot_samples(epoch, dim, generator, real_samples, num_samples=4, output_dir="training_images_FID_244"):
   
    d, _, _, _ = real_samples.shape
    initial_sample_num = random.randint(0, d -  num_samples) 
    final_sample_num = initial_sample_num + num_samples
    generator.eval()
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    real_to_plot = real_samples[initial_sample_num:final_sample_num, :, :, :].cpu().numpy().transpose(0, 2, 3, 1)
    print(f'Epoch {epoch+1} - Real vs. Generated Samples')
    
    with torch.no_grad():
        noise = torch.randn(num_samples, dim).to(device)
        fake_to_plot = generator(noise)[:num_samples, :, :, :].cpu().numpy().transpose(0, 2, 3, 1)
    

    print(f"Fake shape: {fake_to_plot.shape}")
    fig, axes = plt.subplots(2, num_samples, figsize=(12, 8))
    fig.suptitle(f'Epoch {epoch+1}', fontsize=16)   

    for i in range(num_samples):
        if num_samples > 1:
            ax_real = axes[0, i]
            ax_fake = axes[1, i]
        else: 
            ax_real = axes[0]
            ax_fake = axes[1]

        
        img_real = real_to_plot[i][:, :, :]
        img_fake = fake_to_plot[i][:, :, :]
        
        ax_real.imshow(np.clip(img_real, 0, 1))
        ax_real.set_title("Real")
        ax_real.axis('off') # Hide the empty "graph" axes

        ax_fake.imshow(np.clip(img_fake, 0, 1))
        ax_fake.set_title("Fake")
        ax_fake.axis('off')    
   
    
    filename = os.path.join(output_dir, f"epoch_{epoch+1:03d}.png")
    fig.savefig(filename)
    plt.close(fig)
    
    generator.train() 

In [12]:
dataset = HSIGANDataset(file_tracker=file_tracker, samples_per_img=SAMPLES_PER_IMAGE, img_size=IMG_SIZE, std_threshold=STD_THRESHOLD, x_bound=X_BOUND, y_bound=Y_BOUND)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0)

In [13]:
class HyperSpectralDiscriminator(nn.Module):
    def __init__(self, channels, img_size, num_classes):
        super().__init__()

        self.main = nn.Sequential(
            nn.Conv2d(channels, 32, kernel_size=6, stride=2, padding=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.25),
            
            nn.Conv2d(32, 64, kernel_size=6, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.25),
            
            nn.Conv2d(64, 128, kernel_size=6, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            nn.Dropout(0.25),
            
            nn.Conv2d(128, 256, kernel_size=6, stride=2, padding=2, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
            
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Flatten()           
        )
        
        with torch.no_grad():
            dummy_input = torch.zeros(1, channels, img_size, img_size)
            encoder_output_dim = self.main(dummy_input).shape[1]

        self.final = nn.Sequential(
            nn.Linear(encoder_output_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
            )
        
    def forward(self, x):
        x = self.main(x)
        x = self.final(x)
        return x

In [None]:
class HyperSpectralGenerator(nn.Module):
    def __init__(self, channels, img_size, z_dim=128):
        super().__init__()
        
        self.init_size = (img_size - 20) // 16      

        self.fc = nn.Linear(z_dim, 256 * self.init_size * self.init_size)
        
        self.main = nn.Sequential(
            nn.Unflatten(1, (256, self.init_size, self.init_size)),
            
            nn.ConvTranspose2d(256, 128, kernel_size=6, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=6, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2, padding=2, output_padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),            
            
            nn.ConvTranspose2d(32, channels, kernel_size=10, stride=2, padding=0, output_padding=0),
            nn.Sigmoid() 
        )

    def forward(self, x):
        x = self.fc(x)
        x = self.main(x)
        return x

In [None]:
real_label = 1
fake_label = 0
g_lr = 0.0002
d_lr = 0.0002
noise_size = Z_SIZE
LOG_INTERVAL = 20    
PLOT_INTERVAL = 100 #to test   

netD = HyperSpectralDiscriminator(channels=CHANNELS, img_size=IMG_SIZE, num_classes=1).to(device)
netD.apply(weights_init)
netG = HyperSpectralGenerator(channels=CHANNELS, img_size=IMG_SIZE, z_dim=noise_size).to(device)
netG.apply(weights_init)

fixed_noise = torch.randn(64, noise_size, device=device)

NUM_FID_SAMPLES = 2048
all_real_samples = []


temp_loader = iter(dataloader) 
count = 0

while count < NUM_FID_SAMPLES:
    try:       
        batch = next(temp_loader)
        
        if isinstance(batch, list) or isinstance(batch, tuple):
            batch = batch[0]
            
        all_real_samples.append(batch)
        count += batch.shape[0]
        
    except StopIteration:
        temp_loader = iter(dataloader)

real_samples_large = torch.cat(all_real_samples, dim=0)

real_samples_large = real_samples_large[:NUM_FID_SAMPLES]
criterion = nn.BCEWithLogitsLoss()

optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=g_lr, betas=(0.5, 0.999))



In [None]:
# Training Loop

img_list = []
G_losses = []
D_losses = []
D_accuracies = []
iters = 0
num_epochs =  2000 

fid = FrechetInceptionDistance(feature=2048).to(device)

print("Starting Training Loop...")

for epoch in range(num_epochs):
    netD.train()
    for i, samples in enumerate(dataloader):        
        samples = samples.to(device)        
        batch_size = samples.size(0)
        real_labels_D = torch.full((batch_size, 1), 0.9, device=device)
        real_labels_G = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        optimizerD.zero_grad()

        real_output = netD(samples) 

        d_loss_real = criterion(real_output, real_labels_D)
        real_acc = ((real_output > 0.5).float() == real_labels_G).float().mean()
       
        noise = torch.randn(batch_size, noise_size, device=device)
        fake_spectra = netG(noise).squeeze()
        
        fake_output = netD(fake_spectra.detach())        
        d_loss_fake = criterion(fake_output, fake_labels)

        fake_acc = (fake_output < 0.5).float().mean()

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        
        optimizerD.step()

        d_accuracy = (real_acc + fake_acc) / 2

        optimizerG.zero_grad()
        fake_spectra_for_g = netG(noise).squeeze()
        output = netD(fake_spectra_for_g)

        g_loss = criterion(output, real_labels_G)

        g_loss.backward()
        optimizerG.step()

        if (i + 1) % LOG_INTERVAL == 0:
            print(
                f"[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(dataloader)}] "
                f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}] "
                f"[D Acc: {d_accuracy.item():.2%}]"
            )
            D_losses.append(d_loss.item())
            G_losses.append(g_loss.item())
            D_accuracies.append(d_accuracy.item())
        
    if (epoch + 1) % PLOT_INTERVAL == 0:
        print(f"--- Generating plot for epoch {epoch+1} ---")
        #plot_samples(epoch=epoch, generator=netG, real_samples=fixed_real_samples,dim=noise_size)
        evaluate_and_save(epoch=epoch,dim=noise_size,generator=netG,real_samples=real_samples_large,fid_metric=fid, num_plot_samples=4, num_fid_samples=1000 # Use at least 1000 for decent FID estimation
)


print("--- Training Finished ---")




Starting Training Loop...
[Epoch 1/2000] [Batch 20/24] [D loss: 0.3499] [G loss: 4.6844] [D Acc: 100.00%]
[Epoch 2/2000] [Batch 20/24] [D loss: 1.0296] [G loss: 5.3108] [D Acc: 51.56%]
[Epoch 3/2000] [Batch 20/24] [D loss: 0.4012] [G loss: 5.1869] [D Acc: 99.22%]
[Epoch 4/2000] [Batch 20/24] [D loss: 0.4375] [G loss: 3.2470] [D Acc: 99.22%]
[Epoch 5/2000] [Batch 20/24] [D loss: 0.7624] [G loss: 3.0543] [D Acc: 67.19%]
[Epoch 6/2000] [Batch 20/24] [D loss: 0.6227] [G loss: 4.1071] [D Acc: 100.00%]
[Epoch 7/2000] [Batch 20/24] [D loss: 0.4792] [G loss: 4.6388] [D Acc: 90.62%]
[Epoch 8/2000] [Batch 20/24] [D loss: 0.9260] [G loss: 0.3810] [D Acc: 60.16%]
[Epoch 9/2000] [Batch 20/24] [D loss: 0.5044] [G loss: 4.1600] [D Acc: 99.22%]
[Epoch 10/2000] [Batch 20/24] [D loss: 0.5260] [G loss: 5.8380] [D Acc: 100.00%]
[Epoch 11/2000] [Batch 20/24] [D loss: 1.4065] [G loss: 5.9342] [D Acc: 64.84%]
[Epoch 12/2000] [Batch 20/24] [D loss: 0.4772] [G loss: 3.0320] [D Acc: 100.00%]
[Epoch 13/2000] [Ba