In [8]:
!git clone https://github.com/jan1na/Neural-Cellular-Automata.git

import os

%cd Neural-Cellular-Automata

Cloning into 'Neural-Cellular-Automata'...
remote: Enumerating objects: 21, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 21 (delta 8), reused 12 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (21/21), 9.99 KiB | 9.99 MiB/s, done.
Resolving deltas: 100% (8/8), done.
/content/Neural-Cellular-Automata/Neural-Cellular-Automata/Neural-Cellular-Automata/Neural-Cellular-Automata


In [9]:
!pip install -q medmnist scikit-learn

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from medmnist import PathMNIST
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score, mean_absolute_error
import seaborn as sns
from models import NCA, CNNBaseline

from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda


In [None]:
from google.colab import files
uploaded = files.upload() 

In [None]:
import os
import shutil

# Create MedMNIST cache directory
cache_dir = os.path.expanduser("~/.medmnist")
os.makedirs(cache_dir, exist_ok=True)

# List of expected filenames
files_needed = [
    "pathmnist.npz",
    "pathmnist_64.npz",
    "pathmnist_128.npz",
    "pathmnist_224.npz",
]

# Move uploaded files into the cache directory
for fname in files_needed:
    if fname in uploaded:
        dest_path = os.path.join(cache_dir, fname)
        if os.path.exists(dest_path):
            os.remove(dest_path)  # Replace if already exists
        shutil.move(fname, dest_path)
        print(f"{fname} moved to {dest_path}")
    else:
        print(f"File {fname} was not uploaded.")

In [10]:
nca = NCA().to(device)
nca.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_nca_pathmnist.pth"))
nca.eval()

cnn = CNNBaseline().to(device)
cnn.load_state_dict(torch.load("/content/drive/MyDrive/NCA/best_cnn_pathmnist.pth"))
cnn.eval()

CNNBaseline(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=128, out_features=9, bias=True)
)

In [11]:
def get_loader(size, batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = PathMNIST(split="test", size=size, download=False, transform=transforms)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [12]:
@torch.no_grad()
def evaluate(model, loader, name="Model", size=28, save_dir="/content/drive/MyDrive/NCA/results"):
    all_preds, all_labels = [], []

    for x, y in loader:
        x, y = x.to(device), y.squeeze()
        out = model(x)
        pred = out.argmax(dim=1).cpu().numpy()
        label = y.numpy()
        all_preds.extend(pred)
        all_labels.extend(label)

    # Metrics
    cm = confusion_matrix(all_labels, all_preds)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)

    # Logging
    print(f"\n{name} @ {size}x{size}")
    print("Balanced Accuracy:", f"{bal_acc:.4f}")
    print("Mean Absolute Error (MAE):", f"{mae:.4f}")
    print(report)

    # Save confusion matrix
    os.makedirs(save_dir, exist_ok=True)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix: {name} @ {size}x{size}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    fname = f"{save_dir}/cm_{name.replace(' ', '_')}_{size}x{size}.png"
    plt.savefig(fname)
    plt.close()
    print(f"Confusion matrix saved to: {fname}")

In [13]:
for size in [28, 64, 128, 224]:
    print(f"\n==============================")
    print(f"Resolution: {size}x{size}")
    loader = get_loader(size)

    print("CNN:")
    evaluate(cnn, loader, name="CNN", size=size)

    print("NCA:")
    evaluate(nca, loader, name="NCA", size=size)


🖼️  Resolution: 28x28


100%|██████████| 206M/206M [01:38<00:00, 2.08MB/s]


🔍 CNN:

CNN @ 28 Results
[[1003    1    0    0   37    0    0    0    0]
 [  60  973    2    0   21    0    0    1    0]
 [   2    5  818   10    8   90   17  156   46]
 [   1    0    3 1121    1    0   12    5   13]
 [  10    4    2    1  768    2   44   46   13]
 [  30    2   86    1   14  985    5  225    6]
 [   3    0   11   11  108    5  628   36   75]
 [   0    1   94    2    7   91    3  822   25]
 [   3    1   53    7   21   13   61   83 1190]]
              precision    recall  f1-score   support

           0     0.9020    0.9635    0.9317      1041
           1     0.9858    0.9205    0.9521      1057
           2     0.7652    0.7101    0.7366      1152
           3     0.9722    0.9697    0.9710      1156
           4     0.7797    0.8629    0.8192       890
           5     0.8305    0.7275    0.7756      1354
           6     0.8156    0.7161    0.7626       877
           7     0.5983    0.7866    0.6796      1045
           8     0.8699    0.8310    0.8500      1432



100%|██████████| 1.07G/1.07G [05:12<00:00, 3.43MB/s]


🔍 CNN:

CNN @ 64 Results
[[ 977   21    0    0   42    1    0    0    0]
 [  50  984    2    0   19    0    0    2    0]
 [   2    6  514    0   23  245   35  302   25]
 [   1    1  204  735   12    0  184   13    6]
 [   3    5    3    0  807    1   37   22   12]
 [  12    2   38    0   22 1096   11  173    0]
 [   0    1   12    0  420    0  344   13   87]
 [   0    0   30    0  109  188   21  671   26]
 [   1    2  133    0   33   12  198  207  846]]
              precision    recall  f1-score   support

           0     0.9340    0.9385    0.9363      1041
           1     0.9628    0.9309    0.9466      1057
           2     0.5491    0.4462    0.4923      1152
           3     1.0000    0.6358    0.7774      1156
           4     0.5427    0.9067    0.6790       890
           5     0.7103    0.8095    0.7566      1354
           6     0.4145    0.3922    0.4030       877
           7     0.4783    0.6421    0.5482      1045
           8     0.8443    0.5908    0.6952      1432



 58%|█████▊    | 2.46G/4.26G [15:06<11:00, 2.72MB/s]   


RuntimeError: 
                Automatic download failed! Please download pathmnist_128.npz manually.
                1. [Optional] Check your network connection: 
                    Go to https://github.com/MedMNIST/MedMNIST/ and find the Zenodo repository
                2. Download the npz file from the Zenodo repository or its Zenodo data link: 
                    https://zenodo.org/records/10519652/files/pathmnist_128.npz?download=1
                3. [Optional] Verify the MD5: 
                    ac42d08fb904d92c244187169d1fd1d9
                4. Put the npz file under your MedMNIST root folder: 
                    /root/.medmnist
                