In [1]:
import setGPU

import sys
sys.path.append('../')

# import torch
from escnn import gspaces
from escnn import nn

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import torch
from jaxtyping import Array, Float, Int, PyTree, PRNGKeyArray  # https://github.com/google/jaxtyping

from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation
from torchvision.transforms import Pad
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from torchvision.transforms import InterpolationMode

import numpy as np

from PIL import Image

setGPU: Setting GPU to: 1




In [2]:
# download the dataset
!wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# uncompress the zip file
!unzip -n mnist_rotation_new.zip -d mnist_rotation_new

class MnistRotDataset(Dataset):
    
    def __init__(self, mode, transform=None):
        assert mode in ['train', 'test']
            
        if mode == "train":
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat"
        else:
            file = "mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat"
        
        self.transform = transform

        data = np.loadtxt(file, delimiter=' ')
            
        self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
        self.labels = data[:, -1].astype(np.int64)
        self.num_samples = len(self.labels)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image, mode='F')
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    def __len__(self):
        return len(self.labels)

# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
pad = Pad((0, 0, 1, 1), fill=0)

# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
resize1 = Resize(87)
resize2 = Resize(29)

totensor = ToTensor()

File ‘mnist_rotation_new.zip’ already there; not retrieving.

Archive:  mnist_rotation_new.zip


In [3]:
train_transform = Compose([
    pad,
    resize1,
    RandomRotation(180., interpolation=InterpolationMode.BILINEAR, expand=False),
    resize2,
    totensor,
])

mnist_train = MnistRotDataset(mode='train', transform=train_transform)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64)


test_transform = Compose([
    pad,
    totensor,
])
mnist_test = MnistRotDataset(mode='test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64)

x, y = mnist_train[0]
x = jnp.array(x[None, ...])
y = jnp.array(y[None, ...])
print(x.shape)
print(y.shape)

(1, 1, 29, 29)
(1,)


In [11]:
# %tb
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

class CNN(eqx.Module):
    layers: list
    # layer: eqx.Module
    input_type: nn.FieldType

    def __init__(self, key, n_classes=10):
        keys = jax.random.split(key, 8)
        
        super(CNN, self).__init__()
        self.layers = []
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        r2_act = gspaces.rot2dOnR2(N=8)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(r2_act, [r2_act.trivial_repr])
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        self.layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=4, padding=0, use_bias=False, key=keys[0]),
            nn.ReLU(out_type),
            nn.GroupPooling(out_type)
        ])


    def __call__(self, input: Array):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        in_type = self.input_type
        # in_type = nn.FieldType(r2_act, [r2_act.trivial_repr])
        x = nn.GeometricTensor(input, in_type)
        # x = self.layer(x)
        for layer in self.layers:
            # print(type(layer))
            # print("x", x.shape)
            x = layer(x)

        return x

r2_act = gspaces.rot2dOnR2(N=8)
in_type = nn.FieldType(r2_act, [r2_act.trivial_repr])
out_type = in_type
# out_type = nn.FieldType(r2_act, 3*[r2_act.regular_repr])

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 8)
# model = nn.Linear(in_type, out_type, key=keys[0])
# model = nn.Linear(in_type, out_type, key=keys[0])
# model = nn.R2Conv(in_type, out_type, kernel_size=4, padding=0, use_bias=False, key=keys[0])
model = CNN(key)

m = eqx.nn.Linear(1728, 512, key=keys[1])
# m = Linear(1728, 512, key=keys[1])
#
# print("jax.tree_util.tree_leaves(m)",  len(jax.tree_util.tree_leaves(m)), jax.tree_util.tree_leaves(m))
# print()
# print(hash(tuple(jax.tree_util.tree_leaves(m))))
# for e in tuple(jax.tree_util.tree_leaves(m)):
#     print("e", e)
#     print("hash(e)", hash(e))
m = eqx.filter_jit(m)
# print("jax.tree_util.tree_leaves(m)", jax.tree_util.tree_leaves(m))
# print()
input = jax.random.normal(keys[2], (1728,))
out = m(input)
# print(m(nn.GeometricTensor(input, in_type)))

# input = nn.GeometricTensor(x, in_type)

# out = model(input)
# print(out)
# print(model)
# print()
# print(model.in_type)
# print()
# print(jax.tree_util.tree_leaves(model))
# print()
# for e in tuple(jax.tree_util.tree_leaves(model)):
#     print("e", e)
#     print("hash(e)", hash(e))
# print(len(jax.tree_util.tree_leaves(model)))
# print(hash(tuple(jax.tree_util.tree_leaves(model))))
print("jax.tree_util.tree_leaves(model)", len(jax.tree_util.tree_leaves(model)), jax.tree_util.tree_leaves(model))
print()
model = eqx.filter_jit(model)#(input)
# print("jax.tree_util.tree_leaves(model)", jax.tree_util.tree_leaves(model))
# print()
# set_trace()
# hash(input)
# model(input)
# print(type(x))
# print("in_type", in_type)
model(x)


jax.tree_util.tree_leaves(model) 4 [Array([[[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]]], dtype=float32), Array([0.4864558], dtype=float32), False, [C8_on_R2[(None, 8)]: {irrep_0 (x1)}(1)]]

GroupPooling
self._contiguous {1: True}
1 True


IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [5]:
class C8SteerableCNN(eqx.Module):
    layers: list
    input_type: nn.FieldType

    def __init__(self, key, n_classes=10):
        keys = jax.random.split(key, 8)
        layers = []
        
        super(C8SteerableCNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        r2_act = gspaces.rot2dOnR2(N=8)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(r2_act, [r2_act.trivial_repr])
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(r2_act, 24*[r2_act.regular_repr])
        layers.extend([
            nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, use_bias=False, key=keys[0]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        
        # convolution 2
        # the old output type is the input type to the next layer
        # in_type = self.block1.out_type
        in_type = out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(r2_act, 48*[r2_act.regular_repr])
        layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, use_bias=False, key=keys[1]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        layers.append(nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2))
        
        # convolution 3
        # the old output type is the input type to the next layer
        # in_type = self.block2.out_type
        in_type = out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(r2_act, 48*[r2_act.regular_repr])
        layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, use_bias=False, key=keys[2]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        
        # convolution 4
        # the old output type is the input type to the next layer
        # in_type = self.block3.out_type
        in_type = out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(r2_act, 96*[r2_act.regular_repr])
        layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, use_bias=False, key=keys[3]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        layers.append(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )
        
        # convolution 5
        # the old output type is the input type to the next layer
        # in_type = self.block4.out_type
        in_type = out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(r2_act, 96*[r2_act.regular_repr])
        layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, use_bias=False, key=keys[4]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        
        # convolution 6
        # the old output type is the input type to the next layer
        # in_type = self.block5.out_type
        in_type = out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(r2_act, 64*[r2_act.regular_repr])
        layers.extend([
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, use_bias=False, key=keys[5]),
            # nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type)
        ])
        layers.append(nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0))
        
        layers.append(nn.GroupPooling(out_type))

        self.layers = layers
        
        # number of output channels
        # c = self.gpool.out_type.size
        c = out_type.size
        
        # Fully Connected
        self.fully_net = eqx.nn.Sequential(
            eqx.nn.Linear(c, 64, key=keys[6]),
            # eqx.nn.BatchNorm() .BatchNorm1d(64),
            jax.nn.elu,
            eqx.nn.Linear(64, n_classes, key=keys[7]),
        )
    
    def __call__(self, input: Array):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = nn.GeometricTensor(input, self.input_type)
        
        # apply each equivariant block
        for layer in self.layers:
            x = layer(x)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        
        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))
        
        return x

In [6]:
class CNN(eqx.Module):
    layers: list
    input_type: nn.FieldType
    fully_net: list

    def __init__(self, key, n_classes=10):
        keys = jax.random.split(key, 8)
        self.layers = []
        
        super(CNN, self).__init__()
        
        # the model is equivariant under rotations by 45 degrees, modelled by C8
        r2_act = gspaces.rot2dOnR2(N=8)
        
        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(r2_act, [r2_act.trivial_repr])
        
        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type
        
        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(r2_act, 3*[r2_act.regular_repr])
        self.layers.extend([
            # nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=4, padding=0, use_bias=False, key=keys[0]),
            nn.ReLU(out_type),
            nn.GroupPooling(out_type)
        ])

        # block1 = nn.SequentialModule(*[
        #     nn.MaskModule(in_type, 29, margin=1),
        #     nn.R2Conv(in_type, out_type, kernel_size=4, padding=0, use_bias=False, key=keys[0]),
        #     nn.ReLU(out_type),
        #     nn.GroupPooling(out_type)
        # ])
        # self.layers.append(block1)
      
        # # number of output channels
        # # c = self.gpool.out_type.size
        # c = block1.out_type.size
        # print("c", c)
        c = out_type.size
        print("c", c)
        c = 2028
        print("c", c)
        
        # Fully Connected
        # self.fully_net = eqx.nn.Sequential([
        self.fully_net = [
            eqx.nn.Linear(c, 64, key=keys[6]),
            # eqx.nn.BatchNorm() .BatchNorm1d(64),
            jax.nn.elu,
            eqx.nn.Linear(64, n_classes, key=keys[7]),
            jax.nn.log_softmax
        ]
    
    def __call__(self, input: Array):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = nn.GeometricTensor(input, self.input_type)
        # print("x", x.shape)
        
        # apply each equivariant block
        for layer in self.layers:
            # print(type(layer))
            # print("x", x.shape)
            x = layer(x)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor
        # print("x", x.shape)
        x = x.reshape(x.shape[0], -1)
        
        # classify with the final fully connected layers)
        # x = self.fully_net(x.reshape(x.shape[0], -1))
        for layer in self.fully_net:
            # print(type(layer))
            # print("x", x.shape)
            x = jax.vmap(layer)(x)
        
        return x

In [7]:
BATCH_SIZE = 64
LEARNING_RATE = 3e-4
STEPS = 300
PRINT_EVERY = 30
SEED = 5678

key = jax.random.PRNGKey(SEED)

key, subkey = jax.random.split(key, 2)
# model = C8SteerableCNN(subkey)
model = CNN(subkey)
# model = jax.jit(model)
print(model)

x, y = mnist_train[0]
x = jnp.array(x[None, ...])
y = jnp.array(y[None, ...])
print(x.shape)
print(y.shape)
eqx.filter_jit(model)(x)

c 24
c 2028
CNN(
  layers=[
    R2Conv(
      in_type=[C8_on_R2[(None, 8)]: {irrep_0 (x1)}(1)],
      out_type=[C8_on_R2[(None, 8)]: {regular (x3)}(24)],
      space=C8_on_R2[(None, 8)],
      d=2,
      kernel_size=(4, 4),
      stride=(1, 1),
      dilation=(1, 1),
      padding=((0, 0), (0, 0)),
      padding_mode='zeros',
      groups=1,
      _reversed_padding_repeated_twice=(0, 0, 0, 0),
      use_bias=False,
      _basisexpansion=<escnn.nn.modules.basismanager.basisexpansion_blocks.BlocksBasisExpansion object at 0x7f9366dd11f0>,
      bias=None,
      filter=f32[24,1,4,4],
      weights=f32[15],
      bias_expansion=None,
      expanded_bias=None,
      inference=False,
      _rings=[0.0, 1.0],
      _sigma=[0.005, 0.4],
      _maximum_frequency=2
    ),
    ReLU(
      in_type=[C8_on_R2[(None, 8)]: {regular (x3)}(24)],
      out_type=[C8_on_R2[(None, 8)]: {regular (x3)}(24)],
      space=C8_on_R2[(None, 8)]
    ),
    GroupPooling(
      in_type=[C8_on_R2[(None, 8)]: {regular (

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [8]:
def test_model(model: eqx.Module, x: Image):
    np.set_printoptions(linewidth=10000)
    
    # evaluate the `model` on 8 rotated versions of the input image `x`
    # model = model.eval()
    
    x = resize1(pad(x))
    
    print()
    print('##########################################################################################')
    header = 'angle |  ' + '  '.join(["{:6d}".format(d) for d in range(10)])
    print(header)
    # with torch.no_grad():
    for r in range(8):
        x_transformed = totensor(resize2(x.rotate(r*45., Image.BILINEAR))).reshape(1, 1, 29, 29)
        # print("x_transformed", type(x_transformed), x_transformed)
        x_transformed = jnp.array(x_transformed.numpy())#.to(device)

        y = model(x_transformed)
        # y = y.to('cpu').numpy().squeeze()
        y = np.array(y).squeeze()
        
        angle = r * 45
        print("{:5d} : {}".format(angle, y))
    print('##########################################################################################')
    print()

In [9]:
# build the test set    
raw_mnist_test = MnistRotDataset(mode='test')
# retrieve the first image from the test set
x, y = next(iter(raw_mnist_test))

# evaluate the model
test_model(model, x)


##########################################################################################
angle |       0       1       2       3       4       5       6       7       8       9
    0 : [-2.3039029 -2.3758793 -2.261747  -2.417113  -2.3186145 -2.2233348 -2.3939812 -2.445771  -2.1761634 -2.155946 ]
   45 : [-2.3090162 -2.330779  -2.20519   -2.4289494 -2.3360045 -2.3187938 -2.3918345 -2.3799646 -2.188079  -2.173234 ]
   90 : [-2.2690861 -2.374911  -2.2361343 -2.4105105 -2.3555703 -2.2416685 -2.4006789 -2.4494424 -2.1421843 -2.1945066]
  135 : [-2.3045998 -2.3854792 -2.2522273 -2.4654298 -2.288461  -2.1792316 -2.3682797 -2.4039762 -2.1915684 -2.2281508]
  180 : [-2.3038316 -2.4193363 -2.1862879 -2.431365  -2.294912  -2.28298   -2.3901792 -2.4280572 -2.1476018 -2.193202 ]
  225 : [-2.2785158 -2.3578658 -2.2640278 -2.4308782 -2.2774334 -2.2697644 -2.4253724 -2.395699  -2.1898189 -2.1743972]
  270 : [-2.2624114 -2.3587322 -2.2226477 -2.4149845 -2.3286276 -2.2724388 -2.393867  -2.4467957 -2.

In [121]:
def loss(
    # model, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
    model, x: Float[Array, "batch 1 29 29"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    # pred_y = jax.vmap(model)(x)
    pred_y = model(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, x, y)
print(loss_value.shape, loss_value)  # scalar loss
# Example inference
# output = jax.vmap(model)(x)
# output = jax.jit(model)(x)
output = model(x)


# This will work!
# params, static = eqx.partition(model, eqx.is_array)


# def loss2(params, static, x, y):
#     model = eqx.combine(params, static)
#     return loss(model, x, y)


# loss_value, grads = jax.value_and_grad(loss2)(params, static, x, y)
# print(loss_value)

# This will work too!
value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
print(value)

() 2.4092078
2.4092078


In [124]:
# loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


# @eqx.filter_jit
def compute_accuracy(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    # pred_y = jax.vmap(model)(x)
    pred_y = model(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = jnp.array(x.numpy())
        y = jnp.array(y.numpy())
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

evaluate(model, test_loader)

(Array(2.3148353, dtype=float32), Array(0.09828565, dtype=float32))

In [74]:
import optax
optim = optax.adamw(LEARNING_RATE)

def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    # @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        # x: Float[Array, "batch 1 28 28"],
        x: Float[Array, "batch 1 29 29"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from train_loader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = jnp.array(x.numpy())
        y = jnp.array(y.numpy())
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            print(
                f"{step=}, train_loss={train_loss.item()}, "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [75]:
model = train(model, train_loader, test_loader, optim, STEPS, PRINT_EVERY)

step=0, train_loss=2.292130947113037, test_loss=2.3077950477600098, test_accuracy=0.12965552508831024
step=30, train_loss=2.1808724403381348, test_loss=2.1475913524627686, test_accuracy=0.27635470032691956
step=60, train_loss=2.042396068572998, test_loss=1.99031400680542, test_accuracy=0.34325048327445984
step=90, train_loss=1.8445414304733276, test_loss=1.8483741283416748, test_accuracy=0.4116847813129425
step=120, train_loss=1.801900863647461, test_loss=1.7456316947937012, test_accuracy=0.4030730426311493
step=150, train_loss=1.606579303741455, test_loss=1.6744093894958496, test_accuracy=0.4048313498497009
step=180, train_loss=1.609127163887024, test_loss=1.6039823293685913, test_accuracy=0.4306465685367584
step=210, train_loss=1.4931309223175049, test_loss=1.5789670944213867, test_accuracy=0.44057703018188477
step=240, train_loss=1.5692706108093262, test_loss=1.5438400506973267, test_accuracy=0.44037723541259766
step=270, train_loss=1.3343498706817627, test_loss=1.529144525527954, t