In [5]:
from src.CKANs import *
from src.utils import *
import pickle as pkl
import matplotlib.pyplot as plt
import time
from tqdm.notebook import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Example usage you can choose CKANs / CKANs_InceptionBig / CKANs_BigConvs
model = CKANs(input_channels=1, num_classes=3, device=device)

In [8]:
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [9]:

with open('binned_data_20.pkl', 'rb') as f:
    all_data_params, all_images = pkl.load(f)






    

In [None]:
# test SIDM0.3

labels = np.zeros(18000)
labels[:10800] = 0
labels[10800:14400] = 1
labels[14400:] = 2


# 0:3600___CDM_low+baryons 3600:7200___CDM_hi+baryons 7200:10800___CDM+baryons 
# 10800:14400___SIDM0.1 14400:18000___SIDM0.3 18000:21600___SIDM1.0
selected_images = np.concatenate((all_images[:10800], all_images[10800:14400], all_images[18000:21600]))



In [None]:
# # test CDM_hi

# labels = np.zeros(14400)
# labels[:7200] = 0
# labels[7200:10800] = 1
# labels[10800:] = 2

# # 0:3600___CDM_low+baryons 3600:7200___CDM_hi+baryons 7200:10800___CDM+baryons 
# # 10800:14400___SIDM0.1 14400:18000___SIDM0.3 18000:21600___SIDM1.0
# selected_images = np.concatenate((all_images[:3600],all_images[7200:10800],all_images[10800:14400], all_images[18000:21600]))

In [None]:

X_train, X_val, y_train, y_val = process_data(selected_images, labels)

In [10]:
def compute_accuracy(model, data_loader, device):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [11]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, device='cpu'):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.loss_fn = nn.CrossEntropyLoss()
        self.model.to(device)

    def train(self, num_epochs=10):
        best_model_wts = copy.deepcopy(self.model.state_dict())
        best_acc = 0.0
        for epoch in range(num_epochs):
            running_loss = 0.0
            for images, labels in tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
                images, labels = images.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.loss_fn(outputs, labels.long())
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item() * images.size(0)
            epoch_loss = running_loss / len(self.train_loader.dataset)
            val_acc = compute_accuracy(self.model, self.val_loader, self.device)
            print(f'Loss: {epoch_loss:.4f} | Val Accuracy: {val_acc:.4f}')
            if val_acc > best_acc:
                best_acc = val_acc
                best_model_wts = copy.deepcopy(self.model.state_dict())
        self.model.load_state_dict(best_model_wts)
        print('Training complete.')

In [None]:
# # 设置绘图样式和参数
# plt.style.use(["science", "grid"])
model_name = "CAKNs"
epochs = 3
monte_carlo = 15
simulationNames = ['CDM', 'SIDM0.1', 'SIDM1']
nDM_Models = len(simulationNames)
monte_carlo_tests = []
monte_carlo_histories = []
all_probs = []
prob_sidm0_3 = []


for i in tqdm(range(monte_carlo)):

    set_seed(i)

    train_loader, val_loader = getGenerators(X_train, X_val, y_train, y_val)
    

    trainer = Trainer(model, train_loader, val_loader, device=device)
    trainer.train(num_epochs=epochs)
    

    monte_carlo_tests.append(val_loader)
    monte_carlo_histories.append({'epochs': epochs})
    
    iDM_model_probs = []

    for iDM_model in range(nDM_Models):
        this_dm_model_test_feat = X_val[y_val == iDM_model]
        this_dm_model_test_feat = torch.tensor(this_dm_model_test_feat, dtype=torch.float32).to(device)
        this_dm_model_test_feat = this_dm_model_test_feat.permute(0, 3, 1, 2)  

        with torch.no_grad():
            output = model(this_dm_model_test_feat)
            iDM_model_probs.append(output.cpu().numpy())
    all_probs.append(iDM_model_probs)


    SIDM0_3_images = all_images[14400:18000]
    SIDM0_3_images_transformed = []
    for img in SIDM0_3_images:
        img_transformed = transform_resize(img)
        SIDM0_3_images_transformed.append(img_transformed)
    SIDM0_3_images_transformed = torch.stack(SIDM0_3_images_transformed).to(device)
    
    batch_size = 32
    SIDM0_3_probs = []
    for i in range(0, len(SIDM0_3_images_transformed), batch_size):
        batch = SIDM0_3_images_transformed[i:i+batch_size]
        with torch.no_grad():
            output = model(batch)

            SIDM0_3_probs.append(output.cpu().numpy())

    prob_sidm0_3.append(np.concatenate(SIDM0_3_probs, axis=0))




all_probs_nested = [[np.array(prob) for prob in probs] for probs in all_probs]
prob_sidm0_3 = np.array(prob_sidm0_3)


with open("final_model_%s.pkl" % model_name, "wb") as f:
    pkl.dump([all_probs_nested, prob_sidm0_3, monte_carlo_histories], f)

# 保存最终结果
pkl.dump([all_probs, prob_sidm0_3, monte_carlo_histories], open("final_model_%s.pkl" % model_name, "wb"))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from matplotlib.gridspec import GridSpec

plt.figure(figsize=(5, 5))
gs = GridSpec(20, 1)

ax = plt.subplot(gs[:14, 0])
c = ['r', 'b', 'g', 'c']
stds = []
cross_sections = np.array([0., 0.1, 1.0])

for iTelescope, n_samples_per_subset in enumerate([10, 100, 1000]):
    these_probs = np.exp(prob_sidm0_3)
    these_probs = these_probs / np.sum(these_probs, axis=-1, keepdims=True)
    prediction, err = get_predictions_per_subset(these_probs, n_samples_per_subset, cross_sections=cross_sections)
    mean = np.nanmedian(prediction)
    std = np.nanmean(err)
    stds.append(std)
    x = np.linspace(0., 0.7, 1000)
    pdf = norm.pdf(x, mean, std)
    
    pdf_sum = np.sum(pdf) * (x[1] - x[0])

    if pdf_sum != 0:
        pdf /= pdf_sum
    else:
        pdf = np.ones_like(pdf) / len(pdf)
        pdf /= np.sum(pdf) * (x[1] - x[0])
    
    ax.plot(x, pdf, color=c[iTelescope])
    if iTelescope == 0:
        ylims = ax.get_ylim()
    ax.fill_between(x[(x > mean - std) & (x < mean + std)], 
                    np.zeros(len(x[(x > mean - std) & (x < mean + std)])),
                    pdf[(x > mean - std) & (x < mean + std)], color=c[iTelescope], alpha=0.1)
    ax.plot([mean - std, mean - std], [0, norm.pdf(mean - std, mean, std)], '-', color=c[iTelescope], label='Sample Size: %i' % n_samples_per_subset)
    ax.plot([mean + std, mean + std], [0, norm.pdf(mean + std, mean, std)], '-', color=c[iTelescope])

subsets = np.logspace(1, 2.5, 20)
stds = [np.nanmean(get_predictions_per_subset(these_probs, int(i), cross_sections=cross_sections)[1]) for i in subsets]
ax1 = plt.subplot(gs[16:, 0])
ax1.set_yscale('log')
ax1.plot(subsets, stds)
ax1.set_xlabel('Sample Size', fontsize=12)
ax1.set_ylabel('Error in $\sigma_{\\rm DM}$', fontsize=12)
ax1.set_xscale('log')

# 确保 y 轴限制设置合理
ax.set_ylim(0, max(ylims[1], max([np.max(norm.pdf(x, np.nanmedian(prediction), np.nanmean(err))) for prediction, err in zip(prediction, err)])))
ax.set_xlim(0.1, 0.7)
ax.legend(loc=1)
ax.set_xlabel(r'$\sigma_{\rm DM}/m$ [cm$^2$/g]', fontsize=12)
ax.set_ylabel(r'$p(\sigma_{\rm DM}/m)$', fontsize=12)

filename = "sidm_0p3_blind_test.pdf"
plt.savefig(filename)

plt.show()
os.system("pdfcrop %s %s" % (filename, filename))