In [21]:
from typing import Callable

# Core dependencies
import jax
import jax.numpy as jnp

# pcax
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.utils as pxu


class BasicBlock(pxc.EnergyModule):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, act_fn=Callable[[jax.Array], jax.Array]) -> None:
        super().__init__()

        self.conv1 = pxnn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = pxnn.BatchNorm(out_channels)
        self.act_fn = px.static(act_fn)
        self.conv2 = pxnn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = pxnn.BatchNorm(out_channels)

    def __call__(self, x: jax.Array) -> jax.Array:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act_fn(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.act_fn(out)

        return out


class ResNet(pxc.EnergyModule):
    def __init__(
        self,
        block,
        layers,
        num_classes=1000,
        act_fn=Callable[[jax.Array], jax.Array]
    ) -> None:
        super().__init__()

        self.in_channels = 64
        self.act_fn = px.static(act_fn)
        self.conv1 = pxnn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
        self.bn1 = pxnn.BatchNorm(64)
        self.maxpool = pxnn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = pxnn.AvgPool2d((1, 1))
        self.fc = pxnn.Linear(512, num_classes)

        self.vodes = [
            pxc.Vode((64, 56, 56)), 
            pxc.Vode((64, 56, 56)), 
            pxc.Vode((128, 28, 28)),
            pxc.Vode((256, 14, 14)), 
            pxc.Vode((512, 7, 7)), 
            pxc.Vode((512, 1, 1)), 
            pxc.Vode((num_classes,), energy_fn=pxc.ce_energy)
        ]

        self.vodes[-1].h.frozen = True

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = pxnn.Layer(
                pxnn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride),
                pxnn.BatchNorm(out_channels)
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample, self.act_fn))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels, act_fn=self.act_fn))

    def __call__(self, x: jax.Array, y: jax.Array):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act_fn(x)
        x = self.maxpool(x)
        x = self.vodes[0](x)

        x = self.layer1(x)
        x = self.vodes[1](x)
        x = self.layer2(x)
        x = self.vodes[2](x)
        x = self.layer3(x)
        x = self.vodes[3](x)
        x = self.layer4(x)
        x = self.vodes[4](x)

        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.vodes[5](x)
        x = self.fc(x)
        x = self.vodes[6](x)

        if y is not None:
            self.vodes[-1].set("h", y)

        return self.vodes[-1].get("u")



In [14]:

import torch
import numpy as np

# This is a simple collate function that stacks numpy arrays used to interface
# the PyTorch dataloader with JAX. In the future we hope to provide custom dataloaders
# that are independent of PyTorch.

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


# The dataloader assumes cuda is being used, as such it sets 'pin_memory = True' and
# 'prefetch_factor = 2'. Note that the batch size should be constant during training, so
# we set 'drop_last = True' to avoid having to deal with variable batch sizes. 
class TorchDataloader(torch.utils.data.DataLoader):
    def __init__(
        self,
        dataset,
        batch_size=1,
        shuffle=None,
        sampler=None,
        batch_sampler=None,
        num_workers=1,
        pin_memory=True,
        timeout=0,
        worker_init_fn=None,
        persistent_workers=True,
        prefetch_factor=2,
    ):
        super(self.__class__, self).__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=True if batch_sampler is None else None,
            timeout=timeout,
            worker_init_fn=worker_init_fn,
            persistent_workers=persistent_workers,
            prefetch_factor=prefetch_factor,
        )



def resnet18(num_classes=1000, act_fn=jax.nn.leaky_relu):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, act_fn)


In [22]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
from torch.utils.data import Dataset

# Define a simple synthetic dataset
class RandomDataset(Dataset):
    def __init__(self, num_samples, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.data = np.random.rand(num_samples, 3, 64, 64).astype(np.float32)
        self.labels = np.random.randint(0, num_classes, size=num_samples).astype(np.int64)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Generate synthetic data
num_samples = 100
num_classes = 10
dataset = RandomDataset(num_samples, num_classes)
dataloader = TorchDataloader(dataset, batch_size=16, shuffle=True)

# Define the model and optimizer
model = resnet18(num_classes=num_classes)
optimizer = optax.adam(learning_rate=0.001)

# Initialize model parameters and optimizer state
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 3, 64, 64)), jnp.ones((1, num_classes)))
opt_state = optimizer.init(params)

# Define the loss function
def loss_fn(params, x, y):
    logits = model.apply(params, x, y)
    one_hot = jax.nn.one_hot(y, num_classes)
    loss = optax.softmax_cross_entropy(logits, one_hot).mean()
    return loss

# Define the training step
@jax.jit
def train_step(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for x_batch, y_batch in dataloader:
        x_batch = jnp.array(x_batch)
        y_batch = jnp.array(y_batch)
        params, opt_state = train_step(params, opt_state, x_batch, y_batch)
    print(f"Epoch {epoch + 1} completed.")

print("Training completed.")


TypeError: pcax.nn._layer.Layer.__init__() got multiple values for keyword argument 'key'

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np

In [2]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [3]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


0.7%

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100.0%


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(f'Train: Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

def test(epoch):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print(f'Test: Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}% ({correct}/{total})')

    # Save checkpoint.
    torch.save(net.state_dict(), './checkpoint.pth')

for epoch in range(0, 10):
    train(epoch)
    test(epoch)


Train: Loss: 2.377 | Acc: 10.156% (13/128)
Train: Loss: 2.529 | Acc: 9.375% (24/256)
Train: Loss: 3.116 | Acc: 12.760% (49/384)
Train: Loss: 3.306 | Acc: 13.281% (68/512)
Train: Loss: 3.547 | Acc: 14.219% (91/640)
Train: Loss: 3.485 | Acc: 13.281% (102/768)
Train: Loss: 3.429 | Acc: 12.500% (112/896)
Train: Loss: 3.377 | Acc: 12.109% (124/1024)
Train: Loss: 3.268 | Acc: 12.240% (141/1152)
Train: Loss: 3.206 | Acc: 12.266% (157/1280)
Train: Loss: 3.143 | Acc: 12.500% (176/1408)
Train: Loss: 3.112 | Acc: 12.891% (198/1536)
Train: Loss: 3.119 | Acc: 13.281% (221/1664)
Train: Loss: 3.093 | Acc: 13.672% (245/1792)
Train: Loss: 3.075 | Acc: 14.010% (269/1920)
Train: Loss: 3.042 | Acc: 13.965% (286/2048)
Train: Loss: 3.075 | Acc: 13.603% (296/2176)
Train: Loss: 3.029 | Acc: 13.976% (322/2304)
Train: Loss: 3.008 | Acc: 14.186% (345/2432)
Train: Loss: 2.986 | Acc: 14.492% (371/2560)
Train: Loss: 2.958 | Acc: 14.546% (391/2688)
Train: Loss: 2.951 | Acc: 14.631% (412/2816)
Train: Loss: 2.931 | Ac