In [1]:
import torch
import jax

from torch.utils.data import DataLoader
from torchvision import transforms
from tiny_imagenet_torch import TinyImageNet ## Import allow TinyImage via torch methods(64x64)

In [2]:
#Confirm cuda is enabled on both jax and torch
print(torch.cuda.is_available())
print(jax.devices())


False
[CpuDevice(id=0)]


In [2]:

# Simple transformation - just convert to tensor
transform = transforms.ToTensor()

# Create dataset
train_dataset = TinyImageNet(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = TinyImageNet(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

In [4]:
import jax
# Usage example
for images, labels in train_loader:
    # Your training code here
    print(images.shape, labels.shape)
    print(jax.numpy.array(labels).shape)
    print(jax.nn.one_hot(jax.numpy.array(labels),200).shape)
    print(labels)
    # print(jax.nn.one_hot(jax.numpy.array(labels),200))
    break

torch.Size([8, 3, 64, 64]) torch.Size([8])
(8,)
(8, 200)
tensor([ 75,  49, 192,  62,  28,  16,  30, 158])


![Resnet Blocks](sources/ResnetPaper.png)


# Building Resnet From Scratch using Equinox
Orignal Pytorch Reference : https://github.com/FrancescoSaverioZuppichini/ResNet/blob/master/ResNet.ipynb

In [5]:
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from functools import partial
eqx.clear_caches()

In [6]:
SEED =42
key = jrandom.PRNGKey(SEED)

In [7]:
# ===== ResNet Basic Block in Equinox (abstract/final; explicit __init__; no object.__setattr__) =====

from typing import Optional, Callable
import equinox as eqx
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr


In [8]:
class Conv7x7(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, padding=3,key=key)
    def __call__(self,x):
        return self.conv(x)


In [9]:
dummy = jr.normal(key,(512,2,2)) #EUREKA MOMENT LOGGED 11/16/25
gap = eqx.nn.AdaptiveAvgPool2d(target_shape=(1,1))
y = gap(dummy)
print(y.shape)

(512, 1, 1)


In [10]:
class Conv3x3(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,downsample,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=downsample, padding=1,key=key)
    def __call__(self,x):
        return self.conv(x)


In [11]:
class Conv_Norm(eqx.Module):
    block:eqx.nn.Conv2d
    bn :eqx.nn.BatchNorm

    def __init__(self,block,bn_channels):
        self.block = block
        self.bn = eqx.nn.BatchNorm(bn_channels,axis_name="batch",mode="batch")

    def __call__(self,x,state):
        x = self.block(x)
        x,state = self.bn(x,state)
        return x,state



![Resnet Blocks](sources/resnetblocks.png)

In [12]:
#Residual Blocks
class ResBasicBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    shortcut: Conv_Norm
    in_channels: int
    out_channels: int
    downsample: int

    def __init__(self, in_channels, out_channels, downsample, key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample

        k1, k2, k3, _ = jax.random.split(key, 4)

        # Main path
        c1 = Conv3x3(in_channels, out_channels, downsample, k1)
        c2 = Conv3x3(out_channels, out_channels, 1,        k2)

        self.conv1 = Conv_Norm(c1, out_channels)
        self.conv2 = Conv_Norm(c2, out_channels)

        # Shortcut path: only *needed* when channels or stride change,
        # but it's fine to always create it and only sometimes use it.
        c3 = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=downsample,
            padding=0,
            key=k3,
        )
        self.shortcut = Conv_Norm(c3, out_channels)

    def __call__(self, x, state):
        residual = x

        x, state = self.conv1(x, state)
        x = jax.nn.relu(x)
        x, state = self.conv2(x, state)

        jax.debug.print("Block in channels -> {x}", x=self.in_channels)
        jax.debug.print("Block out channels -> {x}", x=self.out_channels)
        jax.debug.print("downsample -> {x}", x=self.downsample)

        # Use shortcut if shape would differ
        if self.in_channels != self.out_channels or self.downsample == 2:
            jax.debug.print("here")
            residual, state = self.shortcut(residual, state)

        x = x + residual
        return jax.nn.relu(x), state


class ResBottleNeckBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    conv3: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jr.split(key,3)
        c1 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(out_channels,out_channels*4,kernel_size=1,stride=1,key=keys[2])
        c4 = eqx.nn.Conv2d(in_channels,out_channels*4,kernel_size=1,stride=downsample,key=keys[3])
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.conv3 = Conv_Norm(c3,out_channels*4)
        self.shortcut = Conv_Norm(c4,out_channels*4)


    def __call__(self,x,state):
        residual = x
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv3(x,state)
        # jax.debug.print("in_channel -> {x}",x=self.in_channels)
        # jax.debug.print("out_channel -> {x}",x=self.out_channels)
        # jax.debug.print("downsample -> {x}",x=self.downsample)
        if self.in_channels != self.out_channels*4 or self.downsample==2:
            residual,state = self.shortcut(residual,state)

        x = x + residual

        return jax.nn.relu(x),state

In [13]:

#Non Residual Blocks
class BasicBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jax.random.split(key, 3)
        c1 = Conv3x3(in_channels,out_channels,downsample,keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=keys[2])
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.shortcut = Conv_Norm(c3,out_channels)

    def __call__(self,x,state):
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        return jax.nn.relu(x),state

class BottleNeckBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    conv3: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jr.split(key,4)
        c1 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(out_channels,out_channels*4,kernel_size=1,stride=1,key=keys[2])
        c4 = eqx.nn.Conv2d(in_channels,out_channels*4,kernel_size=1,stride=downsample,key=keys[3])
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.conv3 = Conv_Norm(c3,out_channels*4)
        self.shortcut = Conv_Norm(c4,out_channels*4)

    def __call__(self,x,state):
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv3(x,state)
        return jax.nn.relu(x),state

eqx.tree_pprint(BottleNeckBlock)

# module,state = eqx.nn.make_with_state(BottleNeckBlock)(64,64,2,key)
# dummy = jr.normal(key,(8,64,8,8))
#
# batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
# y, state_batched = batched_forward(dummy, state)
# state = jax.tree_util.tree_map(lambda s: s[0], state_batched)
#
# print(y.shape)


__main__.BottleNeckBlock


In [14]:
from typing import Union

class ResNetLayer(eqx.Module):
    block: Union[ResBottleNeckBlock,ResBasicBlock,BottleNeckBlock,BasicBlock]
    layer:tuple
    def __init__(self,in_channels,out_channels,block,key,n=1):
        self.block = block
        block_expansion = 1 if issubclass(block,(ResBasicBlock,BasicBlock))  else 4 #bottleneck has expansion = 4
        # jax.debug.print("Block Expansion -> {x}",x=block_expansion)
        downsample = 2 if in_channels != out_channels*block_expansion else 1
        keys = jr.split(key,n)

        jax.debug.print("Layer Downsample -> {x}",x=downsample)
        blocks = [block(in_channels,out_channels,downsample,keys[0])]

        for i in range(1,n):
            blocks.append(block(out_channels*block_expansion,out_channels,1,keys[i]))
        self.layer = tuple(blocks)

    def __call__(self, x,state):
        for blk in self.layer:
            x, state = blk(x, state)
        return x,state


# eqx.tree_pprint(ResNetLayer)
# module,state = eqx.nn.make_with_state(ResNetLayer)(64,64,BottleNeckBlock,key,n=3)
# dummy = jr.normal(key,(8,64,8,8))
#
# batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
# y, state_batched = batched_forward(dummy, state)
# state = jax.tree_util.tree_map(lambda s: s[0], state_batched)
#
# print(y.shape)
# print(state)

In [15]:

class ResNet(eqx.Module):
    input_size:int
    num_classes:int
    layer_size:tuple
    layers:tuple
    block: Union[ResBasicBlock,ResBottleNeckBlock,BottleNeckBlock,BasicBlock]
    maxpool: eqx.nn.MaxPool2d
    avgpool: eqx.nn.AdaptiveAvgPool2d
    conv1: Conv7x7
    fc: eqx.nn.Linear



    base_channel: int = 64


    def __init__(self,input_size=3,num_classes=200,layer_size=(1,1,1,1),block=ResBasicBlock,key=jax.random.key(0)):
        self.input_size = input_size
        self.num_classes = num_classes
        self.layer_size = layer_size
        self.block = block
        keys = jr.split(key,6)
        block_expansion = 1 if (issubclass(block, ResBasicBlock) or issubclass(block, BasicBlock)) else 4
        self.conv1 = Conv7x7(input_size,self.base_channel,keys[0])
        self.maxpool = eqx.nn.MaxPool2d(kernel_size=(3,3),stride=(2,2),padding=1)

        in_channels = self.base_channel
        out_channels = self.base_channel
        jax.debug.print("Layer0 in channels -> {x}",x=in_channels)
        jax.debug.print("Layer0 out channels -> {x}",x=out_channels)
        layers= [ResNetLayer(in_channels,out_channels,block,keys[1],layer_size[0])]
        # jax.debug.print("Layer0 in channels -> {x}",x=in_channels)
        # jax.debug.print("Layer0 out channels -> {x}",x=out_channels)
        for ii in range(1,len(layer_size)):
            # jax.debug.print("in channels -> {x}",x=in_channels)
            # jax.debug.print("out channels -> {x}",x=out_channels)
            out_channels = 2*in_channels
            jax.debug.print("Layer{i} in channels -> {x}",x=in_channels,i=ii)
            jax.debug.print("Layer{i} out channels -> {x}",x=out_channels,i=ii)
            layers.append(ResNetLayer(in_channels,out_channels,block,keys[ii+1],layer_size[ii]))
            in_channels = out_channels*block_expansion

            # jax.debug.print("Layer{i} in channels -> {x}",x=in_channels,i=ii)
            # jax.debug.print("Layer{i} out channels -> {x}",x=out_channels,i=ii)


            #jax.debug.print("Loop Number-> {x}",x=ii)

        self.layers = tuple(layers)

        self.avgpool = eqx.nn.AdaptiveAvgPool2d(target_shape=(1,1))
        self.fc = eqx.nn.Linear(out_channels,num_classes,key=keys[-1])

    def __call__(self, x,state):
        x = self.conv1(x)
        x = self.maxpool(x)

        for i,layer in enumerate(self.layers):
            # if i >= 1:
                # jax.debug.print("iteration -> {x}",x=i)
                # jax.debug.print("layer -> {x}",x=layer)
            x,state = layer(x,state)
            jax.debug.print("x shape -> {x}",x=x.shape)



        x = self.avgpool(x)
        x = jnp.ravel(x)
        x = self.fc(x)

        return x,state





In [17]:

module,state = eqx.nn.make_with_state(ResNet)(3,200,(2,2,2,2),ResBottleNeckBlock,key)
dummy = jr.normal(key,(1,3,64,64))

batched_forward = eqx.filter_vmap(module, in_axes=(0,None),axis_name="batch")
y, state_batched = batched_forward(dummy, state)
state = jax.tree_util.tree_map(lambda s: s[0], state_batched)

print(y.shape)






Layer0 in channels -> 64
Layer0 out channels -> 64
Layer Downsample -> 2
Layer1 in channels -> 64
Layer1 out channels -> 128
Layer Downsample -> 2
Layer2 in channels -> 512
Layer2 out channels -> 1024
Layer Downsample -> 2
Layer3 in channels -> 4096
Layer3 out channels -> 8192
Layer Downsample -> 2
x shape -> (Array(256, dtype=int32, weak_type=True), Array(8, dtype=int32, weak_type=True), Array(8, dtype=int32, weak_type=True))


ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 256 // 1 != 64.