In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import numpy as np
import pandas as pd
import albumentations as A
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler

# Local (updated imports)
from checkpoint import load_model_checkpoint, make_checkpoint_dir_from_hyperparams, save_model_checkpoint, build_hyperparams_typedict
from training import initialize_kan_model, validate_model, train_and_validate_model, update_logs
from fasterkan import RSF
from mapper import get_optimizer, get_scheduler, get_criterion

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------- Check RBF's Values ----------------------

def check_rbf_parameters(model):
    print("\nRBF Parameters in Model State Dict:")
    print("=" * 40)
    state_dict = model.state_dict()
    for key, value in state_dict.items():
        if "rbf.grid" in key or "rbf.inv_denominator" in key:
            print(f"{key}: {value}")
    print("=" * 40)
    print()


In [None]:
# path_to_pth = r'Training Checkpoints Poco\Pretrained\45482\BCELoss\Adam\ReduceOnPlateau\3.0e-05\[12288,1024,7]\[4]\-2.0e+00\2.5e-01\1.5e+00\epoch_best\model_checkpoint.pth'

dataset_path = r'C:\Users\mrlnp\OneDrive - National and Kapodistrian University of Athens\Υπολογιστής\KANs\configs\SKINCANCER\SkinCancerDataset'
csv_path = os.path.join(dataset_path, "HAM10000_metadata.csv")
image_dir = os.path.join(dataset_path, "HAM10000_images")
image_test_dir = image_dir
df = pd.read_csv(csv_path)

unique_diagnoses = df['dx'].unique()
num_classes = len(unique_diagnoses)

nv_df = df[df['dx'] == 'nv'].reset_index()
df = df[~(df['dx'] == 'nv')].reset_index()

# Root Directory to save the training checkpoints 
root_dir = r"checkpoints"
os.makedirs(root_dir, exist_ok=True)

In [3]:

x_dim, y_dim = 64,64
channel_size = 3
batch_size = 64
seed = 45482

torch.manual_seed(seed=seed)
np.random.seed(seed)

# Hyperparameter sweep configs
grid_sizes = [[4]]
learning_rates = [3.0e-05,]
hidden_layer_configs = [[1024,]]
epochs = 10
criterion_type = 'BCELoss'
optim_type = 'Adam'
sched_type = 'ReduceOnPlateau'

grid_min_list = [-2,]
grid_max_list = [0.25,]
inv_denominator_list = [1.5,2.0,2.5]

probability = 0.25
pretrained = True

In [4]:
class SkinCancerDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_path, output_classes, transform=None):

        self.root = root
        self.transform = transform
        if isinstance(csv_path, str):
            df = pd.read_csv(csv_path)
        else:
            df = csv_path
        df = df.reindex()
        self.image_files = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]

        assert np.sum(df['image_id'].isin(self.image_files)) == len(df)
        self.image_files = df['image_id'].values.tolist()
        # Map class names to integer indices
        classes = df['dx'].unique()
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        
        self.labels = [self.class_to_idx[cls] for cls in df['dx'].values]
        self.output_classes = output_classes

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img_path = os.path.join(self.root, f'{self.image_files[index]}.jpg')
        image = np.asarray(Image.open(img_path).convert("RGB"))
        label = self.labels[index]
        if self.transform:
            image = self.transform(image = image)
        if isinstance(image, dict):
            image = image['image']
        return image.to(torch.float32), label


# Define transforms
basic_transform = A.Compose([
    A.Resize(y_dim,x_dim),
    *([A.ToGray(channel_size,p=1),] if  channel_size == 1 else []),
    A.Normalize(),
    A.ToTensorV2(),
])

# Define the augmentation pipeline
augmented_transform = A.Compose([
    A.Resize(y_dim,x_dim),
    A.RandomResizedCrop(size=(x_dim,y_dim), scale=(0.08, 1.0), p=probability),
    A.HorizontalFlip(p=probability),
    A.VerticalFlip(p=probability),
    A.RGBShift(p=probability),
    A.RandomSunFlare(p=probability),
    A.RandomBrightnessContrast(p=probability),
    A.HueSaturationValue(p=probability),
    A.ColorJitter(p=probability),
    A.RandomRotate90(p=probability),
    A.Perspective(p=probability),
    A.MotionBlur(p=probability),
    A.ChannelShuffle(p=probability),
    A.ChannelDropout(p=probability),
    *([A.ToGray(channel_size,p=1),] if  channel_size == 1 else []),
    A.Normalize(),
    A.ToTensorV2(),
])

In [5]:
splits = [0.75, 0.09, 0.16]
full_dataset = SkinCancerDataset(root=image_dir, csv_path=df, output_classes=num_classes,transform=None)
train_dataset, val_dataset, test_dataset = random_split(full_dataset, splits)

if pretrained:
    df = pd.concat([df, nv_df]).reset_index()
    full_dataset = SkinCancerDataset(root=image_dir, csv_path=df, output_classes=num_classes,transform=None)
    
    train_dataset.dataset = full_dataset
    val_dataset.dataset = full_dataset
    test_dataset.dataset = full_dataset
    
    nv_ind = df[df['dx'] == 'nv'].index.to_list()
    tr_nv_ind, val_nv_ind, test_nv_ind = random_split(nv_ind, splits)

    train_dataset.indices += tr_nv_ind.indices
    val_dataset.indices += val_nv_ind.indices
    test_dataset.indices += test_nv_ind.indices
    

# Define the transforms for all splits
train_dataset.dataset.transform = augmented_transform
val_dataset.dataset.transform = basic_transform
test_dataset.dataset.transform = basic_transform

In [6]:
# ---------------------- Load Pre-Trained Floating Point Model ----------------------
models_fp = []
dimension_list = [x_dim * y_dim * channel_size] + hidden_layer_configs[0] + [num_classes]

# Build hyperparams dict for new checkpoint logic
hyperparams = {
    'seed': seed,
    'criterion': criterion_type,
    'optimizer': optim_type,
    'scheduler': sched_type,
    'dim_list': dimension_list,
    'learning_rate': learning_rates[0],
    'grid_size_per_layer': grid_sizes[0],
    'grid_min': grid_min_list[0],
    'grid_max': grid_max_list[0],
    'inv_denominator': inv_denominator_list[0],
}
hyperparams_typedict = build_hyperparams_typedict(
    seed=int, criterion=str, optimizer=str, scheduler=str, dim_list=list, learning_rate=float,
    grid_size_per_layer=list, grid_min=float, grid_max=float, inv_denominator=float)

for inv_denom in inv_denominator_list:
    hyperparams['inv_denominator'] = inv_denom
    checkpoint_dir = make_checkpoint_dir_from_hyperparams(hyperparams_typedict, hyperparams, root_dir)
    checkpoint_path = os.path.join(checkpoint_dir, 'epoch_best', 'model_checkpoint.pth')
    
    model_tmp, _, _, _, _ = load_model_checkpoint(
        model=initialize_kan_model(
            root_dir=root_dir,
            hyperparams=hyperparams,
            x_dim=x_dim,
            y_dim=y_dim,
            channel_size=channel_size,
            hyperparams_typedict=hyperparams_typedict,
            device=device
        )[0],
        device=device,
        checkpoint_path=checkpoint_path,
        optimizer_type=optim_type,
        optimizer_params=None
    )
    models_fp.append(model_tmp)
    check_rbf_parameters(model_tmp)
    print("State dict keys:", list(model_tmp.state_dict().keys()))


[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.


FileNotFoundError: Checkpoint file not found at path: checkpoints\45482\BCELoss\Adam\ReduceOnPlateau\[12288, 1024, 7]\3e-05\[4]\-2.0\0.25\1.5\epoch_best\model_checkpoint.pth

In [None]:
classes = full_dataset.class_to_idx.keys()

# Calculate class weights using the training set 
train_indices = train_dataset.indices # Get indices of the training samples
train_labels = df.loc[train_indices, 'dx']

class_sample_counts = train_labels.value_counts().reindex(classes, fill_value=0).values
class_weights = 1.0 / torch.tensor(class_sample_counts, dtype=torch.float)
train_labels = train_labels.tolist()

dataset_len = int(len(classes) / np.sqrt(np.mean((1 / class_sample_counts) ** 2)))

# Map class names to indices
label_to_idx = full_dataset.class_to_idx
train_label_indices = [label_to_idx[label] for label in train_labels]

num_workers = os.cpu_count()
# num_workers = 0

torch.manual_seed(seed=seed)
sample_weights = [class_weights[label_idx].item() for label_idx in train_label_indices]
sampler = WeightedRandomSampler(sample_weights, num_samples=dataset_len, replacement=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [28]:
# 89.69%
calibration_set_size = 1000
calibration_indices = list(range(calibration_set_size))
calibration_subset = torch.utils.data.Subset(train_dataset, calibration_indices)
calibration_loader = DataLoader(calibration_subset, batch_size=64, shuffle=False)

# 89.57%
# calibration_set_size = 1000
# torch.manual_seed(seed=seed)
# calibration_sampler = WeightedRandomSampler(sample_weights, num_samples=calibration_set_size, replacement=False) # Use a WeightedRandomSampler to select calibration indices
# calibration_loader = DataLoader(train_dataset, batch_size=64, sampler=calibration_sampler, num_workers=num_workers)

# 89.51%
# # Create a calibration subset with equal parts per class (stratified sampling)
# from collections import defaultdict

# calibration_set_size = 1000  # total calibration samples
# num_classes = len(classes)
# samples_per_class = calibration_set_size // num_classes

# # Find indices for each class in the training set
# class_to_indices = defaultdict(list)
# for idx in train_dataset.indices:
#     label = df.loc[idx, 'dx']
#     class_to_indices[label].append(idx)

# # Sample equal number of indices per class
# calibration_indices = []
# for cls in classes:
#     indices = class_to_indices[cls]
#     calibration_indices.extend(indices[:samples_per_class])

# # Create the calibration subset and loader
# calibration_subset = torch.utils.data.Subset(train_dataset.dataset, calibration_indices)
# calibration_loader = DataLoader(calibration_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 

# ---------------------- Quantization of RSF Module ----------------------
import copy
from torch.ao.quantization import quantize_fx, get_default_qconfig_mapping

class RSFQuant(nn.Module):
    def __init__(
        self,
        rsf_module
    ):
        super(RSFQuant, self).__init__()
        self.grid = torch.nn.Parameter(rsf_module.grid.detach().clone(), requires_grad=False)
        self.inv_denominator = torch.nn.Parameter(torch.tensor(rsf_module.inv_denominator.detach().clone(), dtype=torch.float32, device=device), requires_grad=False)  # Cache the inverse of the denominator
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # Compute the forward pass
        diff_mul = (x[..., None] - self.grid) * self.inv_denominator
        tanh_diff = self.tanh(diff_mul)
        tanh_diff_deriviative = 1. - tanh_diff ** 2  # sech^2(x) = 1 - tanh^2(x)
        
        return tanh_diff_deriviative

def convert_to_quantizable(model):
    for name, module in model.named_children():
        if isinstance(module, RSF):
            setattr(model, name, RSFQuant(module)) # Replace RSF with RSFQuant and copy parameters
        else:
            convert_to_quantizable(module) # Recursively process submodules
    return model

# ---------------------- Quantization Pipeline ----------------------

def quantize_model_and_calibrate(model_fp, calibration_loader, device=device):
    """
    Perform quantization on a given model.
    
    Parameters:
    - model_fp: The original full-precision model.
    - calibration_loader: DataLoader for calibration.
    - device: Device to perform quantization on.
    """

    # Create a copy and convert to a quantizable model
    model_to_quantize = copy.deepcopy(model_fp)
    model_to_quantize = convert_to_quantizable(model_to_quantize)  # Switch RSF with RSFQuant
    
    model_fp.eval()
    model_to_quantize.eval()
    
    # Orepare, Calibrate, Convert
    model_prepared = quantize_fx.prepare_fx(model_to_quantize, get_default_qconfig_mapping("fbgemm"), calibration_loader)  
    
    with torch.no_grad():
        for inputs, _ in calibration_loader:
            _ = model_prepared(inputs.to(device))
    
    model_quantized_static = quantize_fx.convert_fx(model_prepared)
    torch.save(model_quantized_static.state_dict(), "model_checkpoint_quantized.pth")
    
    return model_quantized_static


In [29]:
def model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    size = os.path.getsize("tmp.pt") / 1e6
    os.remove("tmp.pt")
    return f"{size:.2f} MB"

# Quantize all models in models_fp
model_quantized_static_list = [
    quantize_model_and_calibrate(model_fp=mdl, calibration_loader=calibration_loader)
    for mdl in models_fp
]

# Evaluate model sizes
print("\nModel Sizes")
print("=" * 40)
for i, (fp, q) in enumerate(zip(models_fp, model_quantized_static_list)):
    print(f"Model {i} size (floating point): {model_size(fp)}")
    print(f"Model {i} size (quantized): {model_size(q)}")
print("=" * 40)

# Evaluate accuracy for all models
print("\nModel Evaluation Results")
print("=" * 40)
for i, (fp, q) in enumerate(zip(models_fp, model_quantized_static_list)):
    val_loss_fp, accuracy_fp = validate_model(model=fp, val_loader=test_loader, criterion=criterion_type, device=device, metrics_flag=True)
    val_loss_q, accuracy_q = validate_model(model=q, val_loader=test_loader, criterion=criterion_type, device=device, metrics_flag=True)
    print(f"Model {i} Floating Point:\n  - Validation Loss: {val_loss_fp:.4f}\n  - Accuracy: {accuracy_fp:.2f}%")
    print(f"Model {i} Quantized:\n  - Validation Loss: {val_loss_q:.4f}\n  - Accuracy: {accuracy_q:.2f}%")
    print("-" * 40)
print("=" * 40)


  self.inv_denominator = torch.nn.Parameter(torch.tensor(rsf_module.inv_denominator.detach().clone(), dtype=torch.float32, device=device), requires_grad=False)  # Cache the inverse of the denominator



Model Sizes
Model 0 size (floating point): 201.45 MB
Model 0 size (quantized): 50.39 MB
Model 1 size (floating point): 201.45 MB
Model 1 size (quantized): 50.39 MB
Model 2 size (floating point): 201.45 MB
Model 2 size (quantized): 50.39 MB

Model Evaluation Results


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.8982
Recall: 0.8982
Accuracy: 89.82%
Confusion Matrix:
[[289   2  23   1  12  11  11]
 [  4  28   2   0   4   0   0]
 [ 30   2 324   5   4   5   1]
 [  2   0   0  34   0   0   0]
 [  5   3   4   0 157   2   1]
 [  3   2   4   0   7  80   1]
 [  5   3   2   0   0   2 526]]


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.8968
Recall: 0.8969
Accuracy: 89.69%
Confusion Matrix:
[[294   0  27   0  10   7  11]
 [  4  28   3   0   3   0   0]
 [ 31   1 324   5   3   5   2]
 [  1   0   1  34   0   0   0]
 [  9   1   4   0 155   2   1]
 [  5   2   4   0   7  78   1]
 [  8   3   2   0   0   2 523]]
Model 0 Floating Point:
  - Validation Loss: 0.0882
  - Accuracy: 89.82%
Model 0 Quantized:
  - Validation Loss: 0.0908
  - Accuracy: 89.69%
----------------------------------------


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9064
Recall: 0.9063
Accuracy: 90.63%
Confusion Matrix:
[[301   0  24   0   7   7  10]
 [  6  28   1   0   3   0   0]
 [ 27   2 328   3   3   5   3]
 [  2   0   0  33   1   0   0]
 [  4   3   4   0 156   4   1]
 [  1   2   3   0   7  83   1]
 [  7   3   2   0   1   3 522]]


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9065
Recall: 0.9063
Accuracy: 90.63%
Confusion Matrix:
[[303   0  26   0   8   2  10]
 [  6  28   1   0   3   0   0]
 [ 27   0 336   3   1   3   1]
 [  3   0   0  32   0   0   1]
 [  6   1   4   0 158   2   1]
 [  2   2   5   0   5  82   1]
 [ 13   3   3   0   3   4 512]]
Model 1 Floating Point:
  - Validation Loss: 0.0820
  - Accuracy: 90.63%
Model 1 Quantized:
  - Validation Loss: 0.0850
  - Accuracy: 90.63%
----------------------------------------


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9123
Recall: 0.9126
Accuracy: 91.26%
Confusion Matrix:
[[303   0  23   0   8   1  14]
 [  6  28   1   0   3   0   0]
 [ 23   0 336   3   3   3   3]
 [  5   0   0  30   0   0   1]
 [  2   1   4   0 160   5   0]
 [  1   2   4   0   7  82   1]
 [  6   3   1   0   3   3 522]]


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9058
Recall: 0.9057
Accuracy: 90.57%
Confusion Matrix:
[[309   0  25   0   8   1   6]
 [  6  28   1   0   3   0   0]
 [ 22   0 345   0   2   2   0]
 [  3   0   2  26   2   0   3]
 [  2   1   5   0 162   2   0]
 [  3   2   5   0   6  80   1]
 [ 17   2   9   0   7   3 500]]
Model 2 Floating Point:
  - Validation Loss: 0.0727
  - Accuracy: 91.26%
Model 2 Quantized:
  - Validation Loss: 0.0806
  - Accuracy: 90.57%
----------------------------------------
