In [1]:
from src.CKANs import *
from src.utils import *
import pickle as pkl
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Example usage you can choose CKANs / CKANs_InceptionBig / CKANs_BigConvs
model = CKANs(input_channels=1, num_classes=3).to(device)

In [4]:
# parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")


Total number of parameters: 10743


In [5]:

with open('/home/frb/0hzy/DM/binned_data_20.pkl', 'rb') as f:
    all_data_params, all_images = pkl.load(f)


labels = np.zeros(18000)
labels[:10800] = 0
labels[10800:14400] = 1
labels[14400:] = 2

# labels = np.zeros(120)
# labels[:40] = 0
# labels[10800:10840] = 1
# labels[14400:14440] = 2


# 0:3600___CDM_low+baryons 3600:7200___CDM_hi+baryons 7200:10800___CDM+baryons 
# 10800:14400___SIDM0.1 14400:18000___SIDM0.3 18000:21600___SIDM1.0
selected_images = np.concatenate((all_images[:10800], all_images[10800:14400], all_images[18000:21600]))
# selected_images = np.concatenate((all_images[:40], all_images[10800:10840], all_images[18000:18040]))

X_train, X_val, y_train, y_val = process_data(selected_images, labels)

augmentation_factor = 1  # You can use it to multiply your data, we did not do that
train_dataset = CustomDataset(X_train, y_train, transform=transform_resize, augmentation_factor=augmentation_factor)
val_dataset = CustomDataset(X_val, y_val, transform=transform_resize, augmentation_factor=augmentation_factor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


# check the shape
for images, labels in train_loader:
    print(f"images shape: {images.shape}\n")
    print(f"labels shape: {labels.shape}\n")

    break


print(device)



images shape: torch.Size([32, 3, 100, 100])

labels shape: torch.Size([32])

cuda


In [6]:
    
class Trainer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
        self.best_val_loss = float('inf')
        self.best_model_wts = None
        self.checkpoint_path = 'your_checkpoint_path'

        # Load existing model parameters if available
        if os.path.exists(self.checkpoint_path):
            self.model.load_state_dict(torch.load(self.checkpoint_path))
            print("Loaded existing model parameters.")

    def train(self, train_loader, val_loader, epochs=100):
        start_time = time.time()
        try:
            for epoch in range(epochs):
                running_loss = 0.0
                running_corrects = 0
                # Initialize inner progress bar
                batch_progress_bar = tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", leave=False)
                for images, labels in train_loader:
                    images, labels = images.to(self.device), labels.to(self.device)                   
                    self.optimizer.zero_grad()
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels.long())
                    loss.backward()
                    self.optimizer.step()

                    running_loss += loss.item() * images.size(0)
                    running_corrects += (outputs.argmax(dim=1) == labels).sum().item()
                    # Update inner progress bar
                    batch_progress_bar.update(1)
                batch_progress_bar.close()
                epoch_loss = running_loss / len(train_loader.dataset)
                epoch_acc = running_corrects / len(train_loader.dataset)
                self.train_losses.append(epoch_loss)
                self.train_accuracies.append(epoch_acc)

                val_loss, val_acc = self.validate(val_loader)
                self.val_losses.append(val_loss)
                self.val_accuracies.append(val_acc)

                # Save the best model
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.best_model_wts = self.model.state_dict()
                    torch.save(self.best_model_wts, self.checkpoint_path)

                if epoch % 1 == 0:
                    print(f'Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
 
            self.plot_metrics()
            # self.save_metrics()

            # Load best model weights
            self.model.load_state_dict(self.best_model_wts)

        except KeyboardInterrupt:
            print('Training interrupted.')

        finally:
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f'Training finished in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s.')

    def validate(self, val_loader):
        val_loss = 0.0
        val_corrects = 0
        total = 0
        correct = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                outputs = self.model(images)
                loss = self.criterion(outputs, labels.long())

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        val_loss /= len(val_loader)
        val_acc = correct / total
        return val_loss, val_acc

    def plot_metrics(self):
        epochs = range(1, len(self.train_losses) + 1)

        plt.figure(figsize=(14, 5))

        plt.subplot(1, 2, 1)
        plt.plot(epochs, self.train_losses, label='Train Loss')
        plt.plot(epochs, self.val_losses, label='Validation Loss')
        plt.title('Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(epochs, self.train_accuracies, label='Train Accuracy')
        plt.plot(epochs, self.val_accuracies, label='Validation Accuracy')
        plt.title('Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig('your_plot_path')

    # def save_metrics(self):
    #     with open('metrics.txt', 'w') as f:
    #         f.write('Train Losses: ' + ','.join(map(str, self.train_losses)) + '\n')
    #         f.write('Validation Losses: ' + ','.join(map(str, self.val_losses)) + '\n')
    #         f.write('Train Accuracies: ' + ','.join(map(str, self.train_accuracies)) + '\n')
    #         f.write('Validation Accuracies: ' + ','.join(map(str, self.val_accuracies)) + '\n')


In [7]:

print("Go KANs!!!")

trainer = Trainer(model, device)

trainer.train(train_loader, val_loader, epochs=1)

Go KANs!!!


NameError: name 'os' is not defined

In [8]:
model.conv1[1].conv3[1].convs[0].conv.save_yi()

In [9]:
expr, r2 = model.conv1[1].conv3[1].convs[0].conv.fit_symbolic_for_each_feature()

Error fitting poly_1 for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting poly_2 for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting poly_3 for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting poly_4 for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting 1/x for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting 1/x^2 for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting exp for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting sin for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting tan for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error fitting tanh for x_0: 'KANLinear' object has no attribute 'fit_function_to_data_with_symbol'
Error

In [19]:
print(expr)

0


In [20]:
print(r2)

-inf


In [12]:
# # 定义 fit_lib 包含多项式和初等函数
# fit_lib = {
#     # 0 到 6 阶多项式
#     'poly_0': lambda x, c0: c0,
#     'poly_1': lambda x, c0, c1: c0 + c1*x,
#     'poly_2': lambda x, c0, c1, c2: c0 + c1*x + c2*x**2,
#     'poly_3': lambda x, c0, c1, c2, c3: c0 + c1*x + c2*x**2 + c3*x**3,
#     'poly_4': lambda x, c0, c1, c2, c3, c4: c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4,
#     'poly_5': lambda x, c0, c1, c2, c3, c4, c5: c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + c5*x**5,
#     'poly_6': lambda x, c0, c1, c2, c3, c4, c5, c6: c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + c5*x**5 + c6*x**6,
    
#     # 初等函数
#     '1/x': lambda x: 1/x,
#     '1/x^2': lambda x: 1/x**2,
#     '1/x^3': lambda x: 1/x**3,
#     '1/x^4': lambda x: 1/x**4,
#     'sqrt': lambda x: sp.sqrt(x),
#     '1/sqrt(x)': lambda x: 1/sp.sqrt(x),
#     'exp': lambda x: sp.exp(x),
#     'log': lambda x: sp.log(x),
#     'abs': lambda x: sp.Abs(x),
#     'sin': lambda x: sp.sin(x),
#     'tan': lambda x: sp.tan(x),
#     'tanh': lambda x: sp.tanh(x)
# }


In [13]:
# # now symbolic KANs

# # # set lib or use the default lib
# # lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs','1/x','1/x^2','1/x^3','1/x^4','1/sqrt(x)','0','gaussian','cosh']

# model.kan1.auto_symbolic()
# model.kan2.auto_symbolic()
# model.kan3.auto_symbolic()




In [14]:
# model.kan1.symbolic_formula()[0][0]


In [15]:
# model.kan2.symbolic_formula()[0][0]


In [16]:
# model.kan3.symbolic_formula()[0][0]

In [17]:
# # the convs symbolic 
# # model.conv1[0].fc.auto_symbolic(lib=lib)
# # model.conv2[0].fc.auto_symbolic(lib=lib)
# # model.conv3[0].fc.auto_symbolic(lib=lib)

# model.conv1[0].fc.auto_symbolic()
# model.conv2[0].fc.auto_symbolic()
# model.conv3[0].fc.auto_symbolic()

In [18]:
# # the InceptionCKANs symbolic 
# # model.conv1[0].fc.auto_symbolic(lib=lib)
# # model.conv2[0].fc.auto_symbolic(lib=lib)
# # model.conv3[0].fc.auto_symbolic(lib=lib)

# model.conv1[1].conv1.fc.auto_symbolic()
# model.conv2[1].conv1.fc.auto_symbolic()
# model.conv3[1].conv1.fc.auto_symbolic()

# model.conv1[1].conv3[0].fc.auto_symbolic()
# model.conv2[1].conv3[0].fc.auto_symbolic()
# model.conv3[1].conv3[0].fc.auto_symbolic()

# model.conv1[1].conv3[1].fc.auto_symbolic()
# model.conv2[1].conv3[1].fc.auto_symbolic()
# model.conv3[1].conv3[1].fc.auto_symbolic()