In [1]:
import torch
import jax
from equinox import filter_vmap
from equinox.nn import Sequential
from etils.enp.array_types.dtypes import AnyInt
from flax.nnx import display
from jax.example_libraries.optimizers import momentum
from sympy.physics.units import stefan
from sympy.physics.vector.printing import params
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torchvision import transforms
# from torchvision.datasets import ImageNet ## prefer using full size Image net(212x212)
from tiny_imagenet_torch import TinyImageNet ## Import allow TinyImage via torch methods(64x64)
from torchvision.models.resnet import conv3x3

In [2]:
#Confirm cuda is enabled on both jax and torch
print(torch.cuda.is_available())
print(jax.devices())


True


In [3]:

# Simple transformation - just convert to tensor
transform = transforms.ToTensor()

# Create dataset
train_dataset = TinyImageNet(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = TinyImageNet(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

In [5]:
import jax
# Usage example
for images, labels in train_loader:
    # Your training code here
    print(images.shape, labels.shape)
    print(jax.numpy.array(labels).shape)
    print(jax.nn.one_hot(jax.numpy.array(labels),200).shape)
    print(labels)
    # print(jax.nn.one_hot(jax.numpy.array(labels),200))
    break

torch.Size([8, 3, 64, 64]) torch.Size([8])
(8,)
(8, 200)
tensor([144, 160, 198,  64, 107,   7, 122,  44])


![Resnet Blocks](sources/resnetblocks.png)


# Building Resnet From Scratch using Equinox
Orignal Pytorch Reference : https://github.com/FrancescoSaverioZuppichini/ResNet/blob/master/ResNet.ipynb

In [146]:
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from functools import partial

In [131]:
SEED =42
key = jrandom.PRNGKey(SEED)

In [None]:
# ===== ResNet Basic Block in Equinox (abstract/final; explicit __init__; no object.__setattr__) =====

from typing import Optional, Callable
import equinox as eqx
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr


In [287]:
class Conv7x7(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, padding=3,key=key)
    def __call__(self,x):
        return self.conv(x)


In [291]:
dummy = jr.normal(key,(512,2,2)) #EUREKA MOMENT LOGGED 11/16/25
gap = eqx.nn.AdaptiveAvgPool2d(target_shape=(1,1))
y = gap(dummy)
print(y.shape)

(512, 1, 1)


In [292]:
class Conv3x3(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,downsample,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=downsample, padding=1,key=key)
    def __call__(self,x):
        return self.conv(x)


__main__.Conv3x3

In [332]:
class Conv_Norm(eqx.Module):
    block:eqx.nn.Conv2d
    bn :eqx.nn.BatchNorm

    def __init__(self,block,bn_channels=64):
        self.block = block
        self.bn = eqx.nn.BatchNorm(bn_channels,axis_name="batch",mode="batch")

    def __call__(self,x,state):
        x = self.block(x)
        x,state = self.bn(x,state)
        return x,state



In [336]:
class ResBasicBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        c1 = Conv3x3(in_channels,out_channels,downsample,key)
        c2 = Conv3x3(out_channels,out_channels,1,key)
        c3 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2,key=key)
        self.conv1 = Conv_Norm(c1)
        self.conv2 = Conv_Norm(c2)
        self.shortcut = Conv_Norm(c3)

    def __call__(self,x,state):
        residual = x
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        if self.in_channels != self.out_channels:
            residual,state = self.shortcut(residual,state)

        x = x + residual

        return jax.nn.relu(x),state

class ResBottleNeckBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    conv3: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        c1 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=key)
        c2 = Conv3x3(out_channels,out_channels,1,key)
        c3 = eqx.nn.Conv2d(out_channels,out_channels*4,kernel_size=1,stride=1,key=key)
        c4 = eqx.nn.Conv2d(in_channels,out_channels*4,kernel_size=1,stride=2,key=key)
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.conv3 = Conv_Norm(c3,out_channels*4)
        self.shortcut = Conv_Norm(c4,out_channels*4)

    def __call__(self,x,state):
        residual = x
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv3(x,state)
        if self.in_channels != self.out_channels*4:
            residual,state = self.shortcut(residual,state)

        x = x + residual

        return jax.nn.relu(x),state



# module,state = eqx.nn.make_with_state(ResBottleNeckBlock)(32,64,2,key)
# dummy = jr.normal(key,(8,32,8,8))
#
# batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
# y, state_batched = batched_forward(dummy, state)
# state = jax.tree_util.tree_map(lambda s: s[0], state_batched)
#
# print(y.shape)

(8, 256, 4, 4)


In [344]:
class ResNetLayer(eqx.Module):
    block:[ResBottleNeckBlock,ResBasicBlock]
    blocks:eqx.nn.Sequential
    def __init__(self,in_channels,out_channels,block,key,n=1):
        self.block = block
        block_expansion = 1 if isinstance(block,ResBasicBlock) else 4
        downsample = 2 if in_channels != out_channels else 1
        self.blocks = eqx.nn.Sequential([
            block(in_channels , out_channels,downsample,key),
            *[block(out_channels*block_expansion,out_channels,1,key) for _ in range(n - 1)]
        ])
    def __call__(self, x,state):
        x,state =self.blocks(x,state)
        return x,state



module,state = eqx.nn.make_with_state(ResNetLayer)(32,64,ResBasicBlock,key,n=2)
dummy = jr.normal(key,(8,32,8,8))

batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
y, state_batched = batched_forward(dummy, state)
state = jax.tree_util.tree_map(lambda s: s[0], state_batched)

print(y.shape)


TypeError: ResBasicBlock.__call__() got an unexpected keyword argument 'key'