In [1]:
!git clone https://github.com/jan1na/Neural-Cellular-Automata.git

%cd Neural-Cellular-Automata

Cloning into 'Neural-Cellular-Automata'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects: 100% (63/63), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 63 (delta 36), reused 30 (delta 15), pack-reused 0 (from 0)[K
Receiving objects: 100% (63/63), 35.98 KiB | 11.99 MiB/s, done.
Resolving deltas: 100% (36/36), done.
/content/Neural-Cellular-Automata


In [2]:
!pip install -q medmnist scikit-learn

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist import PathMNIST
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score, mean_absolute_error
import seaborn as sns
from models import NCA, CNNBaseline

from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m124.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m97.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m65.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36

In [4]:
from google.colab import drive
import os
import shutil

drive_folder = "/content/drive/MyDrive/NCA/DATA"
cache_dir = os.path.expanduser("~/.medmnist")
os.makedirs(cache_dir, exist_ok=True)

resolutions = ["", "_64", "_128", "_224"]
for res in resolutions:
    filename = f"pathmnist{res}.npz"
    src = os.path.join(drive_folder, filename)
    dst = os.path.join(cache_dir, filename)

    if os.path.exists(src):
        shutil.copyfile(src, dst)
        print(f"Copied {filename} to cache.")
    else:
        print(f"File not found in Drive: {filename}")

Copied pathmnist.npz to cache.


In [5]:
nca = NCA().to(device)
nca.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_nca_pathmnist.pth"))
nca.eval()

cnn = CNNBaseline().to(device)
cnn.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_cnn_pathmnist.pth"))
cnn.eval()

CNNBaseline(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=128, out_features=9, bias=True)
)

In [6]:
def get_loader(size, batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = PathMNIST(split="test", size=size, download=False, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [10]:
@torch.no_grad()
def evaluate(model, loader, name="Model", size=28, save_dir="/content/drive/MyDrive/NCA/results", is_NCA=False):
    all_preds, all_labels = [], []

    for x, y in loader:
        x, y = x.to(device), y.squeeze()
        if is_NCA:
            out, rgb_steps = model(x, True)
        else:
            out = model(x)
        pred = out.argmax(dim=1).cpu().numpy()
        label = y.numpy()
        all_preds.extend(pred)
        all_labels.extend(label)

    # Metrics
    cm = confusion_matrix(all_labels, all_preds)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)

    # Logging
    print(f"\n{name} @ {size}x{size}")
    print("Balanced Accuracy:", f"{bal_acc:.4f}")
    print("Mean Absolute Error (MAE):", f"{mae:.4f}")
    print(report)

    # Save confusion matrix
    os.makedirs(save_dir, exist_ok=True)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix: {name} @ {size}x{size}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    fname = f"{save_dir}/cm_{name.replace(' ', '_')}_{size}x{size}.png"
    plt.savefig(fname)
    plt.close()
    print(f"Confusion matrix saved to: {fname}")

In [11]:
for size in [28, 64, 128, 224]:
    print(f"\n==============================")
    print(f"Resolution: {size}x{size}")
    loader = get_loader(size)

    print("CNN:")
    evaluate(cnn, loader, name="CNN", size=size)

    print("NCA:")
    evaluate(nca, loader, name="NCA", size=size, is_NCA=True)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
          [-5.4330e-01,  3.5454e-01,  4.6556e-01,  ...,  1.6234e-01,
            2.9816e-01, -1.4971e-01]]],


        [[[ 5.1639e-02, -6.6937e-02, -3.8539e-01,  ...,  6.9746e-01,
            5.3436e-02,  5.8709e-01],
          [ 1.7725e+00,  1.9043e+00,  8.7730e-01,  ..., -1.8960e-01,
            6.8751e-02,  4.2593e-01],
          [ 2.3244e+00,  2.1020e+00,  9.8420e-01,  ..., -2.7689e-01,
            1.2482e+00, -4.5644e-01],
          ...,
          [ 3.2430e-01, -5.0794e-02, -6.6766e-01,  ..., -1.5465e-01,
            7.9860e-01, -2.5822e-01],
          [ 3.7333e-01,  1.5617e+00, -1.4523e+00,  ...,  5.6562e-01,
            1.6840e+00,  1.0899e+00],
          [-3.8288e-02,  7.8865e-01, -8.2769e-01,  ...,  5.2177e-01,
            1.1576e+00,  1.1373e+00]],

         [[-1.0423e+00, -4.9073e-01, -1.2292e+00,  ...,  5.5374e-02,
            1.3386e-01, -3.4926e-01],
          [-6.7228e-01,  4.1750e-01,  2.9171e-01,  ...,  1

KeyboardInterrupt: 