In [8]:
!git clone https://github.com/jan1na/Neural-Cellular-Automata.git

%cd Neural-Cellular-Automata

Cloning into 'Neural-Cellular-Automata'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (32/32), done.[K
remote: Total 44 (delta 23), reused 19 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (44/44), 29.37 KiB | 4.89 MiB/s, done.
Resolving deltas: 100% (23/23), done.
/content/Neural-Cellular-Automata/Neural-Cellular-Automata


In [9]:
!pip install -q medmnist scikit-learn

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist import PathMNIST
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score, mean_absolute_error
import seaborn as sns
from models import NCA, CNNBaseline

from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda


In [10]:
from google.colab import drive
import os
import shutil

drive_folder = "/content/drive/MyDrive/NCA/DATA"
cache_dir = os.path.expanduser("~/.medmnist")
os.makedirs(cache_dir, exist_ok=True)

resolutions = ["", "_64", "_128", "_224"]
for res in resolutions:
    filename = f"pathmnist{res}.npz"
    src = os.path.join(drive_folder, filename)
    dst = os.path.join(cache_dir, filename)

    if os.path.exists(src):
        shutil.copyfile(src, dst)
        print(f"Copied {filename} to cache.")
    else:
        print(f"File not found in Drive: {filename}")

Copied pathmnist.npz to cache.
Copied pathmnist_64.npz to cache.
Copied pathmnist_128.npz to cache.
Copied pathmnist_224.npz to cache.


In [11]:
nca = NCA().to(device)
nca.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_nca_pathmnist.pth"))
nca.eval()

cnn = CNNBaseline().to(device)
cnn.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_cnn_pathmnist.pth"))
cnn.eval()

CNNBaseline(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=128, out_features=9, bias=True)
)

In [12]:
def get_loader(size, batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = PathMNIST(split="test", size=size, download=False, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [13]:
@torch.no_grad()
def evaluate(model, loader, name="Model", size=28, save_dir="/content/drive/MyDrive/NCA/results", is_NCA=False):
    all_preds, all_labels = [], []

    for x, y in loader:
        x, y = x.to(device), y.squeeze()
        if is_NCA:
            out, rgb_steps = model(x, is_NCA=True)
            print("rgb_steps:", rgb_steps)
        else:
            out = model(x)
        pred = out.argmax(dim=1).cpu().numpy()
        label = y.numpy()
        all_preds.extend(pred)
        all_labels.extend(label)

    # Metrics
    cm = confusion_matrix(all_labels, all_preds)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)

    # Logging
    print(f"\n{name} @ {size}x{size}")
    print("Balanced Accuracy:", f"{bal_acc:.4f}")
    print("Mean Absolute Error (MAE):", f"{mae:.4f}")
    print(report)

    # Save confusion matrix
    os.makedirs(save_dir, exist_ok=True)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix: {name} @ {size}x{size}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    fname = f"{save_dir}/cm_{name.replace(' ', '_')}_{size}x{size}.png"
    plt.savefig(fname)
    plt.close()
    print(f"Confusion matrix saved to: {fname}")

In [14]:
for size in [28, 64, 128, 224]:
    print(f"\n==============================")
    print(f"Resolution: {size}x{size}")
    loader = get_loader(size)

    print("CNN:")
    evaluate(cnn, loader, name="CNN", size=size)

    print("NCA:")
    evaluate(nca, loader, name="NCA", size=size)


Resolution: 28x28
CNN:

CNN @ 28x28
Balanced Accuracy: 0.7675
Mean Absolute Error (MAE): 0.5617
              precision    recall  f1-score   support

           0     0.9378    0.9686    0.9529      1338
           1     0.8255    1.0000    0.9044       847
           2     0.4222    0.7847    0.5490       339
           3     0.9677    0.7082    0.8179       634
           4     0.9144    0.7845    0.8445      1035
           5     0.8065    0.5912    0.6823       592
           6     0.6764    0.7530    0.7126       741
           7     0.6529    0.4513    0.5337       421
           8     0.8613    0.8662    0.8637      1233

    accuracy                         0.8128      7180
   macro avg     0.7850    0.7675    0.7623      7180
weighted avg     0.8318    0.8128    0.8136      7180

Confusion matrix saved to: /content/drive/MyDrive/NCA/results/cm_CNN_28x28.png
NCA:

NCA @ 28x28
Balanced Accuracy: 0.7784
Mean Absolute Error (MAE): 0.5340
              precision    recall  f1-sco

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



NCA @ 224x224
Balanced Accuracy: 0.2798
Mean Absolute Error (MAE): 1.8765
              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000      1338
           1     0.3238    0.9988    0.4890       847
           2     0.8778    0.4661    0.6089       339
           3     0.0000    0.0000    0.0000       634
           4     0.5315    0.1710    0.2588      1035
           5     0.2188    0.0236    0.0427       592
           6     0.0016    0.0013    0.0015       741
           7     0.1682    0.6105    0.2637       421
           8     0.1655    0.2466    0.1980      1233

    accuracy                         0.2447      7180
   macro avg     0.2541    0.2798    0.2070      7180
weighted avg     0.2127    0.2447    0.1769      7180

Confusion matrix saved to: /content/drive/MyDrive/NCA/results/cm_NCA_224x224.png
