# AutoEncoders using Convolutional KANs

### Importing Library

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        break

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

### Install convkan
Library can be found here: https://pypi.org/project/convkan/

In [None]:
!pip install convkan

### Training and Validation Data Augmentation 

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
from convkan import ConvKAN, LayerNorm2D


# List of categories as per the directory structure
categories = [
    'Herpes HPV and other STDs Photos', 'Acne and Rosacea Photos', 'Light Diseases and Disorders of Pigmentation',
    'Scabies Lyme Disease and other Infestations and Bites', 'Poison Ivy Photos and other Contact Dermatitis',
    'Vascular Tumors', 'Psoriasis pictures Lichen Planus and related diseases', 'Vasculitis Photos',
    'Lupus and other Connective Tissue diseases', 'Urticaria Hives', 'Nail Fungus and other Nail Disease',
    'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections',
    'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
    'Warts Molluscum and other Viral Infections', 'Melanoma Skin Cancer Nevi and Moles',
    'Hair Loss Photos Alopecia and other Hair Diseases', 'Cellulitis Impetigo and other Bacterial Infections',
    'Seborrheic Keratoses and other Benign Tumors', 'Bullous Disease Photos', 'Eczema Photos',
    'Exanthems and Drug Eruptions'
]

# Map each category to an integer label
category_to_label = {category: idx for idx, category in enumerate(categories)}

# Define the transformation with grayscale conversion
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert image to grayscale
    transforms.RandomResizedCrop(112),  # Crop and resize to 112x112
    transforms.ToTensor(),  # Convert to Tensor (HWC -> CHW)
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize grayscale images
])


# Custom Dataset class
class SkinDiseaseDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Root directory with subdirectories for each category.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.annotations = self._generate_annotations()

    def _generate_annotations(self):
        annotations = []
        # Iterate through each category directory in the root directory
        for category in os.listdir(self.root_dir):
            category_path = os.path.join(self.root_dir, category)
            if os.path.isdir(category_path) and category in category_to_label:
                label = category_to_label[category]
                # Iterate through each image in the category directory
                for img_name in os.listdir(category_path):
                    img_path = os.path.join(category, img_name)
                    annotations.append((img_path, label))  # Store relative path and label
        return annotations

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path, label = self.annotations[idx]
        full_img_path = os.path.join(self.root_dir, img_path)
       
        # Load image
        image = Image.open(full_img_path).convert("RGB")
       
        # Apply transformation
        if self.transform:
            image = self.transform(image)

        return image, label

# Initialize dataset and dataloader with updated transformations
data_dir = '/kaggle/input/dermnet/train'  # Update this path to match your data directory
dataset = SkinDiseaseDataset(root_dir=data_dir, transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)
val_dir = '/kaggle/input/dermnet/test'
val_dataset = SkinDiseaseDataset(root_dir=val_dir,transform=transform)
test_loader = DataLoader(dataset,batch_size=32,shuffle=False,num_workers=2)
for image, label in train_loader:
    print(f'Image batch shape: {image.shape}')  # Check the shape (Batch size, Channels, Height, Width)
    print(f'Label batch shape: {len(label)}')  # Check the number of labels
    break

### Model Architecture


In [None]:
import torch
import torch.nn as nn
from convkan import ConvKAN, LayerNorm2D

# Simplified Encoder
class Encoder(nn.Module):
    def __init__(self, in_channels, latent_dim):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            ConvKAN(in_channels, 4, kernel_size=3, stride=2, padding=1),
            LayerNorm2D(4),
            nn.LeakyReLU(0.2, inplace=True),

            ConvKAN(4, 8, kernel_size=3, stride=2, padding=1),
            LayerNorm2D(8),
            nn.LeakyReLU(0.2, inplace=True),

            ConvKAN(8, 16, kernel_size=3, stride=2, padding=1),
            LayerNorm2D(16),
            nn.LeakyReLU(0.2, inplace=True),
            ConvKAN(16, 16, kernel_size=3, stride=2, padding=1),
            LayerNorm2D(16),
            nn.LeakyReLU(0.2, inplace=True),

            ConvKAN(16, 32, kernel_size=3, stride=2, padding=1),
            LayerNorm2D(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.flatten = nn.Flatten()

        # Assuming input image size is 112 x 112
        # Calculate output size after convolutions
        self.feature_size = 32 * 4 * 4  # 32 channels, 7x7 feature map

        self.fc = nn.Linear(self.feature_size, latent_dim)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim, out_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 32 * 4 * 4)
        self.unflatten = nn.Unflatten(1, (32, 4, 4))
        self.deconv_layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),  # 4 -> 8
            ConvKAN(32, 16, kernel_size=3, stride=1, padding=1),
            LayerNorm2D(16),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),  # 8 -> 16
            ConvKAN(16, 8, kernel_size=3, stride=1, padding=1),
            LayerNorm2D(8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),  # 16 -> 32
            ConvKAN(8, 8, kernel_size=3, stride=1, padding=1),
            LayerNorm2D(8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Upsample(scale_factor=2, mode='nearest'),  # 32 -> 64
            ConvKAN(8, 4, kernel_size=3, stride=1, padding=1),
            LayerNorm2D(4),
            nn.LeakyReLU(0.2, inplace=True),

       


            nn.Upsample(scale_factor=1.75, mode='bilinear', align_corners=True),  # Fine-tune output size
            ConvKAN(4, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x) # Debugging shape
        x = self.unflatten(x)# Debugging shape
        x = self.deconv_layers(x)
        return x





# Autoencoder
class Autoencoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=10):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(in_channels, latent_dim)
        self.decoder = Decoder(latent_dim, in_channels)

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

### Testing the shape of the model

In [None]:
model = Autoencoder(in_channels=1, latent_dim=10)
sample_input = torch.randn(1, 1, 112, 112)  # Batch size 1, 3 channels, 112x112
output = model(sample_input)

print("Input shape:", sample_input.shape)
print("Output shape:", output.shape)


### Optimization and Loss Functions

In [None]:

from torch.cuda.amp import GradScaler,autocast
from torch import optim
# Instantiate the autoencoder with reduced latent dimension
latent_dim = 130  # Reduced latent dimension
in_channels = 1  # RGB images
model = Autoencoder(in_channels=in_channels, latent_dim=latent_dim)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# Use mixed precision training
scaler = torch.amp.GradScaler('cuda')

### Training and Validation Loop

In [None]:
from torch.amp import autocast  # Import autocast from torch.amp
import matplotlib.pyplot as plt
# Training loop with mixed precision and memory optimizations
num_epochs = 50
for epoch in range(num_epochs):
    print('training_in_session')
    model.train()
    running_loss = 0.0
    batch_count = len(train_loader)
    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        optimizer.zero_grad()
        # print('autocasting...')
        with autocast("cuda"):  # Use "cuda" for autocast
            outputs = model(images)
            loss = criterion(outputs, images)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        batch_loss = loss.item()
        # print('lossing...')
        running_loss += batch_loss * images.size(0)

        # Print batch status every 100 batches
        if (batch_idx + 1) % 50 == 0:
            # print('calculating...')
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{batch_count}], Loss: {batch_loss:.4f}')

        del images, outputs, loss
        torch.cuda.empty_cache()
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {epoch_loss:.4f}')

    # Validation phase
    model.eval()
    val_loss = 0.0
    images_plot = None
    with torch.no_grad():
        batch_count = len(test_loader)
        for idx, (images, _) in enumerate(test_loader):
            images = images.to(device, non_blocking=True)
            with autocast("cuda"):  # Use "cuda" for autocast
                outputs = model(images)
                loss = criterion(outputs, images)
            val_loss += loss.item() * images.size(0)

            if images_plot is None and idx == 0:
                images_plot = images.cpu()
                outputs_plot = outputs.cpu()

            del images, outputs, loss
            torch.cuda.empty_cache()
    val_loss = val_loss / len(test_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')

    # Plot reconstructed images after each epoch
    if images_plot is not None:
        # Denormalize images for plotting
        images_plot = images_plot * 0.5 + 0.5
        outputs_plot = outputs_plot * 0.5 + 0.5

        num_images = min(images_plot.size(0), 6)
        fig, axes = plt.subplots(2, num_images, figsize=(num_images * 2, 4))
        for i in range(num_images):
            axes[0, i].imshow(images_plot[i].permute(1, 2, 0).numpy())
            axes[0, i].axis('off')
            axes[1, i].imshow(outputs_plot[i].permute(1, 2, 0).numpy())
            axes[1, i].axis('off')
        axes[0, 0].set_ylabel('Original')
        axes[1, 0].set_ylabel('Reconstructed')
        plt.suptitle(f'Epoch {epoch+1}')
        plt.tight_layout()
        plt.show()

        del images_plot, outputs_plot
        torch.cuda.empty_cache()
print('Training completed.')

### Run Garbage Collection in case failure

In [None]:
import gc
gc.collect()