<a href="https://colab.research.google.com/github/kiplangatkorir/Hierarchial-Compression-With-KANs/blob/main/kan_compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [46]:
pip install torch transformers



In [62]:
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertConfig


In [73]:
class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, p=100):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.p = p

        # Initialize psi and phi functions with float tensors
        self.psi = nn.Parameter(torch.randn(5, p))
        self.phi1 = nn.Parameter(torch.randn(5, input_dim))
        self.phi2 = nn.Parameter(torch.randn(5, output_dim))

    def forward(self):
        i = torch.arange(self.input_dim).view(1, -1, 1)
        j = torch.arange(self.output_dim).view(1, 1, -1)

        phi_sum = self.phi1[:, :, None] + self.phi2[:, None, :]
        psi_input = torch.clamp((phi_sum * (self.p - 1)).long(), 0, self.p - 1)

        # Reshape psi_input to have 2 dimensions to match self.psi
        # This assumes you want to gather along dimension 1 of self.psi
        psi_input = psi_input.view(psi_input.size(0), -1)

        weights = self.psi.gather(1, psi_input).sum(dim=0)

        # Reshape weights back to the desired output shape (input_dim, output_dim)
        weights = weights.view(self.input_dim, self.output_dim)

        return weights

In [74]:


class KANLinear(nn.Module):
    def __init__(self, input_dim, output_dim, p=100):
        super().__init__()
        self.kan_layer = KANLayer(input_dim, output_dim, p)
        self.weight = None

    def forward(self, x):
        if self.weight is None:
            self.weight = self.kan_layer()
        return torch.matmul(x, self.weight)

def replace_linear_with_kan(module, p=100):
    for name, child in module.named_children():
        if isinstance(child, nn.Linear):
            setattr(module, name, KANLinear(child.in_features, child.out_features, p))
        else:
            replace_linear_with_kan(child, p)


In [75]:

model = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [76]:
replace_linear_with_kan(model)

In [77]:
# Example usage
input_ids = torch.randint(0, 1000, (1, 10))  # Batch size 1, sequence length 10
outputs = model(input_ids)

In [78]:
print(f"Output shape: {outputs.last_hidden_state.shape}")


Output shape: torch.Size([1, 10, 768])


In [79]:
# Fine-tuning setup (example for binary classification)
class DistilBertForSequenceClassification(nn.Module):
    def __init__(self, pretrained_model, num_labels=2):
        super().__init__()
        self.distilbert = pretrained_model
        self.classifier = nn.Linear(768, num_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.distilbert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # Use [CLS] token
        return self.classifier(pooled_output)

In [80]:
# Create the classification model
classification_model = DistilBertForSequenceClassification(model)

In [81]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(classification_model.parameters(), lr=1e-5)