In [4]:
!git clone https://github.com/jan1na/Neural-Cellular-Automata.git
!ls
import os
print(os.getcwd())
%cd Neural-Cellular-Automata
print(os.getcwd())

Cloning into 'Neural-Cellular-Automata'...
remote: Enumerating objects: 9, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 9 (delta 2), reused 6 (delta 2), pack-reused 0 (from 0)[K
Receiving objects: 100% (9/9), 4.69 KiB | 960.00 KiB/s, done.
Resolving deltas: 100% (2/2), done.
eval.ipynb		  README.md	   train_nca.ipynb
Neural-Cellular-Automata  train_cnn.ipynb  train_utils.py
/content/Neural-Cellular-Automata
/content/Neural-Cellular-Automata/Neural-Cellular-Automata
/content/Neural-Cellular-Automata/Neural-Cellular-Automata


In [6]:
!pip install medmnist torch torchvision tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from medmnist import PathMNIST
import medmnist
from tqdm import tqdm
from train_utils import train, evaluate

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')



NameError: name 'torch' is not defined

In [None]:
medmnist.INFO['pathmnist']['download'] = True

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = PathMNIST(split='train', transform=transform, download=True)
val_dataset = PathMNIST(split='val', transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

NUM_CLASSES = 9

In [None]:
class NCA(nn.Module):
    def __init__(self, state_dim=16, num_classes=9, num_steps=8):
        super().__init__()
        self.state_dim = state_dim
        self.num_steps = num_steps
        self.perceive = nn.Conv2d(state_dim, 128, kernel_size=3, padding=1)
        self.update = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(128, state_dim, kernel_size=1)
        )
        self.readout = nn.Sequential(
            nn.Conv2d(state_dim, 64, 1),
            nn.ReLU(),
            nn.Conv2d(64, num_classes, 1)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        state = torch.zeros(B, self.state_dim, H, W, device=x.device)
        state[:, :C] = x

        for _ in range(self.num_steps):
            y = self.perceive(state)
            dx = self.update(y)
            state = state + dx

        out = self.readout(state)
        out = out.mean(dim=(2, 3))
        return out


In [None]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.squeeze().to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        total_correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.squeeze().to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item() * x.size(0)
            total_correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader.dataset), total_correct / len(loader.dataset)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NCA().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

best_acc = 0
model_path = '/content/drive/MyDrive/NCA/best_nca_pathmnist.pth'

for epoch in range(1, 51):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)
    print(f"Epoch {epoch:02d}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), model_path)
        print(f"🔥 New best model saved to {model_path}")

    if val_acc > 0.90:
        print("Converged (validation accuracy > 90%)")
        break

print("🎉 Training completed.")