In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import numpy as np
import pandas as pd 
import albumentations as A
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler

# Local
from checkpoint_config import load_checkpoint, checkpoint_dir_name, save_checkpoint
from training_and_val import initialize_model, validate_model
from fasterkan import RSF

device = torch.device("cpu") 
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# ---------------------- Check RBF's Values ----------------------

def check_rbf_parameters(model):
    print("\nRBF Parameters in Model State Dict:")
    print("=" * 40)
    state_dict = model.state_dict()
    for key, value in state_dict.items():
        if "rbf.grid" in key or "rbf.inv_denominator" in key:
            print(f"{key}: {value}")
    print("=" * 40)
    print()


In [2]:
# path_to_pth = r'Training Checkpoints Poco\Pretrained\45482\BCELoss\Adam\ReduceOnPlateau\3.0e-05\[12288,1024,7]\[4]\-2.0e+00\2.5e-01\1.5e+00\epoch_best\model_checkpoint.pth'

dataset_path = 'Dataset'
csv_path = os.path.join(dataset_path, "HAM10000_metadata.csv")
image_dir = os.path.join(dataset_path, "HAM10000_images")
image_test_dir = image_dir
df = pd.read_csv(csv_path)

unique_diagnoses = df['dx'].unique()
num_classes = len(unique_diagnoses)

nv_df = df[df['dx'] == 'nv'].reset_index()
df = df[~(df['dx'] == 'nv')].reset_index()

# Root Directory to save the training checkpoints
root_dir = r"Dataset/FX-Quantizer"
os.makedirs(root_dir, exist_ok=True)

In [3]:

# Input Dimensions for the model:
x_dim, y_dim = 64,64
channel_size = 3
batch_size = 64
seed = 45482

torch.manual_seed(seed=seed)
np.random.seed(seed)

# Hyperparameter sweep configs
grid_sizes = [[4]]
learning_rates = [3.0e-05,]
hidden_layer_configs = [[1024,]]
epochs = 10
criterion_type = 'BCELoss'
optim_type = 'Adam'
sched_type = 'ReduceOnPlateau'

grid_min_list = [-2,]
grid_max_list = [0.25,]
inv_denominator_list = [1.5,2.0,2.5]

probability = 0.25
pretrained = True

pretrained_top_dir = root_dir.replace('FX-Quantizer', 'Pretrained')

In [4]:
class SkinCancerDataset(torch.utils.data.Dataset):
    def __init__(self, root, csv_path, output_classes, transform=None):

        self.root = root
        self.transform = transform
        if isinstance(csv_path, str):
            df = pd.read_csv(csv_path)
        else:
            df = csv_path
        df = df.reindex()
        self.image_files = [os.path.splitext(f)[0] for f in os.listdir(root) if f.endswith('.jpg')]

        assert np.sum(df['image_id'].isin(self.image_files)) == len(df)
        self.image_files = df['image_id'].values.tolist()
        # Map class names to integer indices
        classes = df['dx'].unique()
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        
        self.labels = [self.class_to_idx[cls] for cls in df['dx'].values]
        self.output_classes = output_classes

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img_path = os.path.join(self.root, f'{self.image_files[index]}.jpg')
        image = np.asarray(Image.open(img_path).convert("RGB"))
        label = self.labels[index]
        if self.transform:
            image = self.transform(image = image)
        if isinstance(image, dict):
            image = image['image']
        return image.to(torch.float32), label


# Define transforms
basic_transform = A.Compose([
    A.Resize(y_dim,x_dim),
    *([A.ToGray(channel_size,p=1),] if  channel_size == 1 else []),
    A.Normalize(),
    A.ToTensorV2(),
])

# Define the augmentation pipeline
augmented_transform = A.Compose([
    A.Resize(y_dim,x_dim),
    A.RandomResizedCrop(size=(x_dim,y_dim), scale=(0.08, 1.0), p=probability),
    A.HorizontalFlip(p=probability),
    A.VerticalFlip(p=probability),
    A.RGBShift(p=probability),
    A.RandomSunFlare(p=probability),
    A.RandomBrightnessContrast(p=probability),
    A.HueSaturationValue(p=probability),
    A.ColorJitter(p=probability),
    A.RandomRotate90(p=probability),
    A.Perspective(p=probability),
    A.MotionBlur(p=probability),
    A.ChannelShuffle(p=probability),
    A.ChannelDropout(p=probability),
    *([A.ToGray(channel_size,p=1),] if  channel_size == 1 else []),
    A.Normalize(),
    A.ToTensorV2(),
])

In [5]:
splits = [0.75, 0.09, 0.16]
full_dataset = SkinCancerDataset(root=image_dir, csv_path=df, output_classes=num_classes,transform=None)
train_dataset, val_dataset, test_dataset = random_split(full_dataset, splits)

if pretrained:
    df = pd.concat([df, nv_df]).reset_index()
    full_dataset = SkinCancerDataset(root=image_dir, csv_path=df, output_classes=num_classes,transform=None)
    
    train_dataset.dataset = full_dataset
    val_dataset.dataset = full_dataset
    test_dataset.dataset = full_dataset
    
    nv_ind = df[df['dx'] == 'nv'].index.to_list()
    tr_nv_ind, val_nv_ind, test_nv_ind = random_split(nv_ind, splits)

    train_dataset.indices += tr_nv_ind.indices
    val_dataset.indices += val_nv_ind.indices
    test_dataset.indices += test_nv_ind.indices
    

# Define the transforms for all splits
train_dataset.dataset.transform = augmented_transform
val_dataset.dataset.transform = basic_transform
test_dataset.dataset.transform = basic_transform

In [6]:
# ---------------------- Load Pre-Trained Floating Point Model ----------------------
models_fp = []
checkpoints_fp = []
dimension_list = [x_dim * y_dim * channel_size] + hidden_layer_configs[0] + [num_classes]

for inv_denom in inv_denominator_list:
    path_to_pth = checkpoint_dir_name(criterion=criterion_type,
                                    optimizer=optim_type,
                                    scheduler=sched_type,
                                    seed=seed,
                                    dim_list=dimension_list,
                                    grid_size = grid_sizes[0],
                                    grid_min=grid_min_list[0],
                                    grid_max=grid_max_list[0],
                                    inv_denominator=inv_denom,
                                    root_dir=pretrained_top_dir,
                                    learning_rate=learning_rates[0],
                                    )
    
    model_tmp, _ = initialize_model(
        root_dir=None,
        dimension=dimension_list,
        grid_size=grid_sizes[0][0],
        lr=learning_rates[0],
        sched=sched_type,
        optim=optim_type,
        criterion=criterion_type,
        grid_min=grid_min_list[0],
        grid_max=grid_max_list[0],
        inv_denominator=inv_denom,
        x_dim=x_dim,
        y_dim=y_dim,
        channel_size=channel_size,
        seed=seed
    )
    path_to_pth = os.path.join(path_to_pth, 'epoch_best','model_checkpoint.pth')
    model_tmp, *_ = load_checkpoint(model_tmp, optimizer_name=optim_type, checkpoint_path=path_to_pth, device='cpu')
    models_fp.append(model_tmp)
    checkpoint_dir = os.path.dirname(path_to_pth).replace(pretrained_top_dir,root_dir)
    checkpoints_fp.append(checkpoint_dir)
    check_rbf_parameters(model_tmp)

Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best/model_checkpoint.pth. Next epoch is 498.

RBF Parameters in Model State Dict:
layers.0.rbf.grid: tensor([-1.8899, -1.1920, -0.5019,  0.2129])
layers.0.rbf.inv_denominator: 1.586333990097046
layers.1.rbf.grid: tensor([-1.8616, -1.3367, -0.5240,  0.1764])
layers.1.rbf.inv_denominator: 1.4424347877502441

Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best/model_checkpoint.pth. Next epoch is 488.

RBF Parameters in Model State Dict:
layers.0.rbf.grid: tensor([-1.8921, -1.1864, -0.4808,  0.2277])
layers.0.rbf.inv_denominator: 2.073580026626587
layers.1.rbf.grid: tensor([-1.8617, -1.2484, -0.5892,  0.1851])
layers.1.rbf.inv_denominator: 1.9153474569320679

Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00

In [7]:
classes = full_dataset.class_to_idx.keys()

# Calculate class weights using the training set 
train_indices = train_dataset.indices # Get indices of the training samples
train_labels = df.loc[train_indices, 'dx']

class_sample_counts = train_labels.value_counts().reindex(classes, fill_value=0).values
class_weights = 1.0 / torch.tensor(class_sample_counts, dtype=torch.float)
train_labels = train_labels.tolist()

dataset_len = int(len(classes) / np.sqrt(np.mean((1 / class_sample_counts) ** 2)))

# Map class names to indices
label_to_idx = full_dataset.class_to_idx
train_label_indices = [label_to_idx[label] for label in train_labels]

num_workers = os.cpu_count()
# num_workers = 0

torch.manual_seed(seed=seed)
sample_weights = [class_weights[label_idx].item() for label_idx in train_label_indices]
sampler = WeightedRandomSampler(sample_weights, num_samples=dataset_len, replacement=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [8]:
# 89.69%
calibration_set_size = 1000
calibration_indices = list(range(calibration_set_size))
calibration_subset = torch.utils.data.Subset(train_dataset, calibration_indices)
calibration_loader = DataLoader(calibration_subset, batch_size=64, shuffle=False)

# 89.57%
# calibration_set_size = 1000
# torch.manual_seed(seed=seed)
# calibration_sampler = WeightedRandomSampler(sample_weights, num_samples=calibration_set_size, replacement=False) # Use a WeightedRandomSampler to select calibration indices
# calibration_loader = DataLoader(train_dataset, batch_size=64, sampler=calibration_sampler, num_workers=num_workers)

# 89.51%
# # Create a calibration subset with equal parts per class (stratified sampling)
# from collections import defaultdict

# calibration_set_size = 1000  # total calibration samples
# num_classes = len(classes)
# samples_per_class = calibration_set_size // num_classes

# # Find indices for each class in the training set
# class_to_indices = defaultdict(list)
# for idx in train_dataset.indices:
#     label = df.loc[idx, 'dx']
#     class_to_indices[label].append(idx)

# # Sample equal number of indices per class
# calibration_indices = []
# for cls in classes:
#     indices = class_to_indices[cls]
#     calibration_indices.extend(indices[:samples_per_class])

# # Create the calibration subset and loader
# calibration_subset = torch.utils.data.Subset(train_dataset.dataset, calibration_indices)
# calibration_loader = DataLoader(calibration_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 

# ---------------------- Quantization of RSF Module ----------------------
import copy
from torch.ao.quantization import quantize_fx, get_default_qconfig_mapping

class RSFQuant(nn.Module):
    def __init__(
        self,
        rsf_module
    ):
        super(RSFQuant, self).__init__()
        self.grid = torch.nn.Parameter(rsf_module.grid.data.clone().detach(), requires_grad=False)
        self.inv_denominator = torch.nn.Parameter(torch.tensor(rsf_module.inv_denominator.data.clone().detach(), dtype=torch.float32), requires_grad=False)  # Cache the inverse of the denominator
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # Compute the forward pass
        diff_mul = (x[..., None] - self.grid) * self.inv_denominator
        tanh_diff = self.tanh(diff_mul)
        tanh_diff_deriviative = 1. - tanh_diff ** 2  # sech^2(x) = 1 - tanh^2(x)
        
        return tanh_diff_deriviative

def convert_to_quantizable(model):
    for name, module in model.named_children():
        if isinstance(module, RSF):
            setattr(model, name, RSFQuant(module)) # Replace RSF with RSFQuant and copy parameters
        else:
            convert_to_quantizable(module) # Recursively process submodules
    return model

# ---------------------- Quantization Pipeline ----------------------

def quantize_model_and_calibrate(model_fp, calibration_loader, checkpoint_dir, device='cpu'):
    """
    Perform quantization on a given model.
    
    Parameters:
    - model_fp: The original full-precision model.
    - calibration_loader: DataLoader for calibration.
    - device: Device to perform quantization on.
    """

    # Create a copy and convert to a quantizable model
    model_to_quantize = copy.deepcopy(model_fp.cpu())
    model_to_quantize = convert_to_quantizable(model_to_quantize)  # Switch RSF with RSFQuant
    
    model_fp.eval()
    model_to_quantize.eval()
    
    # Orepare, Calibrate, Convert
    model_prepared = quantize_fx.prepare_fx(model_to_quantize, get_default_qconfig_mapping("fbgemm"), calibration_loader)  
    model_prepared.to(device)
    
    with torch.no_grad():
        for inputs, _ in calibration_loader:
            inputs = inputs.to(device)
            _ = model_prepared(inputs)
    
    model_quantized_static = quantize_fx.convert_fx(model_prepared.cpu())
    save_checkpoint(checkpoint_dir, model_quantized_static, optimizer=None, epoch='FX-Quantize', loss=None, best_val_loss=None, device=device)
    
    return model_quantized_static.to(device)


In [None]:
def model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    size = os.path.getsize("tmp.pt") / 1e6
    os.remove("tmp.pt")
    return f"{size:.2f} MB"

# Quantize all models in models_fp
model_quantized_static_list = [
    quantize_model_and_calibrate(model_fp=mdl, checkpoint_dir=checkpoint_dir, calibration_loader=calibration_loader, device=device)
    for mdl,checkpoint_dir in zip(models_fp,checkpoints_fp)
]


In [10]:
# Evaluate model sizes
print("\nModel Sizes")
print("=" * 40)
for i, (fp, q) in enumerate(zip(models_fp, model_quantized_static_list)):
    print(f"Model {i} size (floating point): {model_size(fp)}")
    print(f"Model {i} size (quantized): {model_size(q)}")
print("=" * 40)

# Evaluate accuracy for all models
print("\nModel Evaluation Results")
print("=" * 40)
for i, (fp, q, cdir) in enumerate(zip(models_fp, model_quantized_static_list,checkpoints_fp)):
    pth = os.path.join(cdir, 'model_checkpoint.pth')
    val_loss_fp, accuracy_fp = validate_model(model=fp, val_loader=test_loader, criterion=criterion_type, checkpoint_path = pth.replace(root_dir,pretrained_top_dir), device=device, metrics_flag=True)
    val_loss_q, accuracy_q = validate_model(model=q, val_loader=test_loader, criterion=criterion_type, checkpoint_path = pth, device='cpu', metrics_flag=True)
    print('Model path :', cdir)
    print(f"Model {i} Floating Point:\n  - Validation Loss: {val_loss_fp:.4f}\n  - Accuracy: {accuracy_fp:.2f}%")
    print(f"Model {i} Quantized:\n  - Validation Loss: {val_loss_q:.4f}\n  - Accuracy: {accuracy_q:.2f}%")
    print("-" * 40)
print("=" * 40)


  self.inv_denominator = torch.nn.Parameter(torch.tensor(rsf_module.inv_denominator.data.clone().detach(), dtype=torch.float32), requires_grad=False)  # Cache the inverse of the denominator



Model Sizes
Model 0 size (floating point): 201.44 MB
Model 0 size (quantized): 50.39 MB
Model 1 size (floating point): 201.44 MB
Model 1 size (quantized): 50.39 MB
Model 2 size (floating point): 201.44 MB
Model 2 size (quantized): 50.39 MB

Model Evaluation Results
Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best/model_checkpoint.pth. Next epoch is 498.
Model loaded from checkpoint: Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best/model_checkpoint.pth


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.8982
Recall: 0.8982
Accuracy: 89.82%
Confusion Matrix:
[[289   2  23   1  12  11  11]
 [  4  28   2   0   4   0   0]
 [ 30   2 324   5   4   5   1]
 [  2   0   0  34   0   0   0]
 [  5   3   4   0 157   2   1]
 [  3   2   4   0   7  80   1]
 [  5   3   2   0   0   2 526]]
Checkpoint loaded from Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best/model_checkpoint.pth. Next epoch is FX-Quantize.
Model loaded from checkpoint: Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best/model_checkpoint.pth


  device=storage.device,


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.8968
Recall: 0.8969
Accuracy: 89.69%
Confusion Matrix:
[[294   0  27   0  10   7  11]
 [  4  28   3   0   3   0   0]
 [ 31   1 324   5   3   5   2]
 [  1   0   1  34   0   0   0]
 [  9   1   4   0 155   2   1]
 [  5   2   4   0   7  78   1]
 [  8   3   2   0   0   2 523]]
Model path : Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/1.5e+00/epoch_best
Model 0 Floating Point:
  - Validation Loss: 0.0882
  - Accuracy: 89.82%
Model 0 Quantized:
  - Validation Loss: 0.0907
  - Accuracy: 89.69%
----------------------------------------
Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best/model_checkpoint.pth. Next epoch is 488.
Model loaded from checkpoint: Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best/model_checkpoint.pth


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9064
Recall: 0.9063
Accuracy: 90.63%
Confusion Matrix:
[[301   0  24   0   7   7  10]
 [  6  28   1   0   3   0   0]
 [ 27   2 328   3   3   5   3]
 [  2   0   0  33   1   0   0]
 [  4   3   4   0 156   4   1]
 [  1   2   3   0   7  83   1]
 [  7   3   2   0   1   3 522]]
Checkpoint loaded from Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best/model_checkpoint.pth. Next epoch is FX-Quantize.
Model loaded from checkpoint: Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best/model_checkpoint.pth


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9053
Recall: 0.9051
Accuracy: 90.51%
Confusion Matrix:
[[303   0  26   0   7   3  10]
 [  6  28   1   0   3   0   0]
 [ 27   0 336   3   1   3   1]
 [  3   0   0  32   0   0   1]
 [  6   1   4   0 156   4   1]
 [  1   2   5   0   6  82   1]
 [ 13   3   3   0   3   4 512]]
Model path : Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.0e+00/epoch_best
Model 1 Floating Point:
  - Validation Loss: 0.0820
  - Accuracy: 90.63%
Model 1 Quantized:
  - Validation Loss: 0.0851
  - Accuracy: 90.51%
----------------------------------------
Checkpoint loaded from Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.5e+00/epoch_best/model_checkpoint.pth. Next epoch is 487.
Model loaded from checkpoint: Dataset/Pretrained/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.5e+00/epoch_best/model_checkpoint.pth


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9123
Recall: 0.9126
Accuracy: 91.26%
Confusion Matrix:
[[303   0  23   0   8   1  14]
 [  6  28   1   0   3   0   0]
 [ 23   0 336   3   3   3   3]
 [  5   0   0  30   0   0   1]
 [  2   1   4   0 160   5   0]
 [  1   2   4   0   7  82   1]
 [  6   3   1   0   3   3 522]]
Checkpoint loaded from Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.5e+00/epoch_best/model_checkpoint.pth. Next epoch is FX-Quantize.
Model loaded from checkpoint: Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.5e+00/epoch_best/model_checkpoint.pth


Validating:   0%|          | 0/26 [00:00<?, ?batch/s]

F1 Score: 0.9064
Recall: 0.9063
Accuracy: 90.63%
Confusion Matrix:
[[310   0  24   0   8   1   6]
 [  6  28   1   0   3   0   0]
 [ 22   0 344   0   2   2   1]
 [  2   0   3  26   2   0   3]
 [  2   1   5   0 162   2   0]
 [  3   2   5   0   6  80   1]
 [ 17   1   8   0   7   4 501]]
Model path : Dataset/FX-Quantizer/45482/BCELoss/Adam/ReduceOnPlateau/3.0e-05/[12288,1024,7]/[4]/-2.0e+00/2.5e-01/2.5e+00/epoch_best
Model 2 Floating Point:
  - Validation Loss: 0.0727
  - Accuracy: 91.26%
Model 2 Quantized:
  - Validation Loss: 0.0806
  - Accuracy: 90.63%
----------------------------------------
