In [1]:
import torch
import jax
from torch.utils.data import DataLoader
from torchvision import transforms
from tiny_imagenet_torch import TinyImageNet ## Import allow TinyImage via torch methods(64x64)

In [2]:
#Confirm cuda is enabled on both jax and torch
print(torch.cuda.is_available())
print(jax.devices())


True
[CudaDevice(id=0)]


In [3]:
import os
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] ="True"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] =".75"
os.environ["TF_GPU_ALLOCATOR"] ="cuda_malloc_async"


In [4]:

# Simple transformation - just convert to tensor
transform = transforms.ToTensor()

# Create dataset
train_dataset = TinyImageNet(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = TinyImageNet(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

Files already downloaded and verified
Files already downloaded and verified


In [5]:

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

In [6]:
import jax
# Usage example
for images, labels in train_loader:
    # Your training code here
    print(images.shape, labels.shape)
    print(jax.numpy.array(labels).shape)
    print(jax.nn.one_hot(jax.numpy.array(labels),200).shape)
    print(labels)
    # print(jax.nn.one_hot(jax.numpy.array(labels),200))
    break

  self.pid = os.fork()


torch.Size([8, 3, 64, 64]) torch.Size([8])


  self.pid = os.fork()


(8,)
(8, 200)
tensor([  5,  66, 149,  14,  73, 123, 110, 198])


![Resnet Blocks](sources/ResnetPaper.png)


# Building Resnet From Scratch using Equinox
Orignal Pytorch Reference : https://github.com/FrancescoSaverioZuppichini/ResNet/blob/master/ResNet.ipynb

In [7]:
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
from functools import partial
eqx.clear_caches()

In [8]:
SEED =42
key = jrandom.PRNGKey(SEED)

In [9]:
# ===== ResNet Basic Block in Equinox (abstract/final; explicit __init__; no object.__setattr__) =====

from typing import Optional, Callable
import equinox as eqx
import jax.numpy as jnp
import jax.nn as jnn
import jax.random as jr


In [10]:
class Conv7x7(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, padding=3,key=key)
    def __call__(self,x):
        return self.conv(x)


In [11]:
dummy = jr.normal(key,(512,2,2)) #EUREKA MOMENT LOGGED 11/16/25
gap = eqx.nn.AdaptiveAvgPool2d(target_shape=(1,1))
y = gap(dummy)
print(y.shape)

(512, 1, 1)


In [12]:
class Conv3x3(eqx.Module):
    conv:eqx.nn.Conv2d

    def __init__(self,in_channels,out_channels,downsample,key):
        self.conv = eqx.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=downsample, padding=1,key=key)
    def __call__(self,x):
        return self.conv(x)


In [13]:
class Conv_Norm(eqx.Module):
    block:eqx.nn.Conv2d
    bn :eqx.nn.BatchNorm

    def __init__(self,block,bn_channels=8):
        self.block = block
        self.bn = eqx.nn.BatchNorm(bn_channels,axis_name="batch")

    def __call__(self,x,state):
        x = self.block(x)
        x,state = self.bn(x,state)
        return x,state



![Resnet Blocks](sources/resnetblocks.png)

In [14]:
#Residual Blocks
class ResBasicBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jax.random.split(key, 3)
        c1 = Conv3x3(in_channels,out_channels,downsample,keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2,key=keys[2])
        self.conv1 = Conv_Norm(c1)
        self.conv2 = Conv_Norm(c2)
        self.shortcut = Conv_Norm(c3)

    def __call__(self,x,state):
        residual = x
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        if self.in_channels != self.out_channels or self.downsample==2:
            residual,state = self.shortcut(residual,state)

        x = x + residual

        return jax.nn.relu(x),state

class ResBottleNeckBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    conv3: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jr.split(key,4)
        c1 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(out_channels,out_channels*4,kernel_size=1,stride=1,key=keys[2])
        c4 = eqx.nn.Conv2d(in_channels,out_channels*4,kernel_size=1,stride=1,key=keys[3])
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.conv3 = Conv_Norm(c3,out_channels*4)
        self.shortcut = Conv_Norm(c4,out_channels*4)

    def __call__(self,x,state):
        residual = x
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv3(x,state)
        # jax.debug.print("in_channel -> {x}",x=self.in_channels)
        # jax.debug.print("out_channel -> {x}",x=self.out_channels)
        # jax.debug.print("downsample -> {x}",x=self.downsample)
        if self.in_channels != self.out_channels*4 or self.downsample==2:
            residual,state = self.shortcut(residual,state)

        x = x + residual

        return jax.nn.relu(x),state

#Non Residual Blocks
class BasicBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jax.random.split(key, 3)
        c1 = Conv3x3(in_channels,out_channels,downsample,keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2,key=keys[2])
        self.conv1 = Conv_Norm(c1)
        self.conv2 = Conv_Norm(c2)
        self.shortcut = Conv_Norm(c3)

    def __call__(self,x,state):
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        return jax.nn.relu(x),state

class BottleNeckBlock(eqx.Module):
    conv1: Conv_Norm
    conv2: Conv_Norm
    conv3: Conv_Norm
    shortcut:Conv_Norm
    in_channels:int
    out_channels:int
    downsample:int

    def __init__(self,in_channels,out_channels,downsample,key):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.downsample = downsample
        keys = jr.split(key,4)
        c1 = eqx.nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=downsample,key=keys[0])
        c2 = Conv3x3(out_channels,out_channels,1,keys[1])
        c3 = eqx.nn.Conv2d(out_channels,out_channels*4,kernel_size=1,stride=1,key=keys[2])
        c4 = eqx.nn.Conv2d(in_channels,out_channels*4,kernel_size=1,stride=1,key=keys[3])
        self.conv1 = Conv_Norm(c1,out_channels)
        self.conv2 = Conv_Norm(c2,out_channels)
        self.conv3 = Conv_Norm(c3,out_channels*4)
        self.shortcut = Conv_Norm(c4,out_channels*4)

    def __call__(self,x,state):
        x,state = self.conv1(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv2(x,state)
        x = jax.nn.relu(x)
        x,state = self.conv3(x,state)
        return jax.nn.relu(x),state


# module,state = eqx.nn.make_with_state(BottleNeckBlock)(64,64,2,key)
# dummy = jr.normal(key,(8,64,8,8))
#
# batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
# y, state_batched = batched_forward(dummy, state)
# state = jax.tree_util.tree_map(lambda s: s[0], state_batched)
#
# print(y.shape)


__main__.BottleNeckBlock


In [29]:
from typing import Union

class ResNetLayer(eqx.Module):
    block: Union[ResBottleNeckBlock,ResBasicBlock,BottleNeckBlock,BasicBlock]
    layer:tuple
    def __init__(self,in_channels,out_channels,block,key,n=1):
        self.block = block
        block_expansion = 1 if (issubclass(block,ResBasicBlock) or issubclass(block,BasicBlock)) else 4 #bottleneck has expansion = 4
        # jax.debug.print("Block Expansion -> {x}",x=block_expansion)
        downsample = 2 if in_channels != out_channels else 1
        keys = jr.split(key,n)


        blocks = [block(in_channels,out_channels,downsample,keys[0])]

        for i in range(1,n):
            blocks.append(block(out_channels*block_expansion,out_channels,1,keys[i]))
        self.layer = tuple(blocks)

    def __call__(self, x,state):
        for blk in self.layer:
            x, state = blk(x, state)
        return x,state


module,state = eqx.nn.make_with_state(ResNetLayer)(64,64,BasicBlock,key,n=2)
print(sum(x.size for x in jax.tree_util.tree_leaves(eqx.filter(module, eqx.is_array))))
dummy = jr.normal(key,(8,64,8,8))

batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
y, state_batched = batched_forward(dummy, state)
state = jax.tree_util.tree_map(lambda s: s[0], state_batched)

print(y.shape)


156128


TypeError: add got incompatible shapes for broadcasting: (64,), (8,).

In [20]:

class ResNet(eqx.Module):
    input_size:int
    num_classes:int
    layer_size:tuple
    layers:tuple
    block: callable
    maxpool: eqx.nn.MaxPool2d
    avgpool: eqx.nn.AdaptiveAvgPool2d
    conv1: Conv7x7
    fc: eqx.nn.Linear



    base_channel: int = 64


    def __init__(self,input_size=3,num_classes=200,layer_size=(1,1,1,1),block=ResBasicBlock,key=jax.random.key(0)):
        self.input_size = input_size
        self.num_classes = num_classes
        self.layer_size = layer_size
        self.block = block
        keys = jr.split(key,6)
        self.conv1 = Conv7x7(input_size,self.base_channel,keys[0])
        self.maxpool = eqx.nn.MaxPool2d(kernel_size=(3,3),stride=(2,2),padding=1)

        layers= []
        in_channels = self.base_channel
        out_channels = self.base_channel
        for ii in range(len(layer_size)):
            # jax.debug.print("Loop Number-> {x}",x=ii)
            layers.append(ResNetLayer(in_channels,out_channels,block,keys[ii+1],layer_size[ii]))
            out_channels = in_channels
            out_channels = out_channels**(2*ii)
        self.layers = tuple(layers)

        self.avgpool = eqx.nn.AdaptiveAvgPool2d(target_shape=(1,1))
        self.fc = eqx.nn.Linear(out_channels,num_classes,key=keys[-1])

    def __call__(self, x,state):
        x = self.conv1(x)
        x = self.maxpool(x)
        for layer in self.layers:
            x,state = layer(x,state)
        x = self.avgpool(x)
        x = jnp.ravel(x)
        x = self.fc(x)

        return x,state






In [27]:

module,state = eqx.nn.make_with_state(ResNet)(3,200,(1,1,1,1),ResBasicBlock,key)


W1118 04:01:43.179512    1592 bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 36.00GiB (rounded to 38654705664)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
W1118 04:01:43.179976    1592 bfc_allocator.cc:512] *************************************************************************************_______________
E1118 04:01:43.180002    1592 pjrt_stream_executor_client.cc:2974] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 38654705664 bytes. [tf-allocator-allocation-error='']


JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 38654705664 bytes.

In [17]:
print(module)
dummy = jr.normal(key,(1,3,64,64))
batched_forward = eqx.filter_vmap(module, in_axes=(0, None), axis_name="batch")
y, state_batched = batched_forward(dummy, state)
state = jax.tree_util.tree_map(lambda s: s[0], state_batched)

print(y.shape)






E1118 03:53:31.365453    1592 gpu_hlo_schedule.cc:817] The byte size of input/output arguments (38654705680) exceeds the base limit (9658466304). This indicates an error in the calculation!
W1118 03:53:31.366958    1592 hlo_rematerialization.cc:3204] Can't reduce memory use below 36.00GiB (38654705664 bytes) by rematerialization; only reduced to 36.00GiB (38654706208 bytes), down from 36.00GiB (38654706208 bytes) originally
W1118 03:53:41.477556    1592 bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 36.00GiB (rounded to 38654705664)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
W1118 03:53:41.477815    1592 bfc_allocator.cc:512] ********************________________________________________________________________________________
E1118 03:53:41.477840    1592 pjrt_stream_executor_c

JaxRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 38654705664 bytes.