<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/riceTypes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
pip install --upgrade flax

Collecting flax
  Downloading flax-0.10.1-py3-none-any.whl.metadata (11 kB)
Downloading flax-0.10.1-py3-none-any.whl (419 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m419.3/419.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flax
  Attempting uninstall: flax
    Found existing installation: flax 0.8.5
    Uninstalling flax-0.8.5:
      Successfully uninstalled flax-0.8.5
Successfully installed flax-0.10.1


In [None]:
import kagglehub
path = kagglehub.dataset_download("mssmartypants/rice-type-classification")

In [12]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

from flax import nnx
import jax.numpy as jnp


In [8]:
df = pd.read_csv("dataset/riceClassification.csv")
df.head()

Unnamed: 0,id,Area,MajorAxisLength,MinorAxisLength,Eccentricity,ConvexArea,EquivDiameter,Extent,Perimeter,Roundness,AspectRation,Class
0,1,4537,92.229316,64.012769,0.719916,4677,76.004525,0.657536,273.085,0.76451,1.440796,1
1,2,2872,74.691881,51.400454,0.725553,3015,60.471018,0.713009,208.317,0.831658,1.453137,1
2,3,3048,76.293164,52.043491,0.731211,3132,62.296341,0.759153,210.012,0.868434,1.46595,1
3,4,3073,77.033628,51.928487,0.738639,3157,62.5513,0.783529,210.657,0.870203,1.483456,1
4,5,3693,85.124785,56.374021,0.749282,3802,68.571668,0.769375,230.332,0.874743,1.51,1


# Load Dataset

In [9]:
class CustomImageDataset(Dataset):
    def __init__(self, dataset, transform=None, target_transform=None):
        completeDF = pd.read_csv(dataset, index_col=0)
        self.labels = pd.DataFrame(completeDF['Class'])
        self.features = pd.DataFrame(completeDF.drop('Class', axis=1))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        features = np.array(self.features.iloc[idx, :])
        label = self.labels.iloc[idx, 0]
        if self.transform:
            features = self.transform(features)
        if self.target_transform:
            label = self.target_transform(label)
        return features, label

In [10]:
dataset = CustomImageDataset(
    dataset="dataset/riceClassification.csv"
)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

for features, labels in data_loader:
    print("Batch of features has shape: ",features.shape)
    print("Batch of labels has shape: ", labels.shape)
    print(features)
    print(labels)
    break

Batch of features has shape:  torch.Size([4, 10])
Batch of labels has shape:  torch.Size([4])
tensor([[7.0200e+03, 1.7620e+02, 5.1844e+01, 9.5574e-01, 7.1950e+03, 9.4542e+01,
         5.6668e-01, 3.8011e+02, 6.1055e-01, 3.3987e+00],
        [6.7170e+03, 1.6200e+02, 5.3661e+01, 9.4355e-01, 6.9130e+03, 9.2479e+01,
         5.7030e-01, 3.6401e+02, 6.3702e-01, 3.0190e+00],
        [5.8750e+03, 1.5418e+02, 4.9532e+01, 9.4699e-01, 6.0280e+03, 8.6489e+01,
         4.5332e-01, 3.3939e+02, 6.4095e-01, 3.1127e+00],
        [8.9620e+03, 1.5679e+02, 7.3572e+01, 8.8308e-01, 9.1460e+03, 1.0682e+02,
         5.8529e-01, 3.7555e+02, 7.9850e-01, 2.1312e+00]], dtype=torch.float64)
tensor([1, 1, 1, 0])


# MLP model

In [29]:
class MLP(nnx.Module):
    def __init__(self, hidden_dims: list[int], num_classes: int, rngs: nnx.Rngs):
        self.layers = []
        for din, dout in zip(hidden_dims[:-1], hidden_dims[1:]):
            self.layers.append({
                'dropout': nnx.Dropout(rate=0.4, rngs=rngs),
                'linear': nnx.Linear(din, dout, rngs=rngs),
                'batch_norm': nnx.BatchNorm(dout, rngs=rngs)
            })
        # Add the final classification layer
        self.output_layer = nnx.Linear(hidden_dims[-1], num_classes, rngs=rngs)

    def __call__(self, x):
        for layer in self.layers:
            x = layer['dropout'](x)
            x = layer['linear'](x)
            x = nnx.gelu(x)
            x = layer['batch_norm'](x)
        x = self.output_layer(x)
        return nnx.sigmoid(x)

# Instantiate and test the model
model = MLP([10, 16, 32, 16], 1, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 10)))

nnx.display(model)


MLP(
  layers=[{'batch_norm': BatchNorm(
    mean=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    var=BatchStat(
      value=Array(shape=(16,), dtype=float32)
    ),
    scale=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(16,), dtype=float32)
    ),
    num_features=16,
    use_running_average=False,
    axis=-1,
    momentum=0.99,
    epsilon=1e-05,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    use_bias=True,
    use_scale=True,
    bias_init=<function zeros at 0x7bc9981b5bd0>,
    scale_init=<function ones at 0x7bc9981b5d80>,
    axis_name=None,
    axis_index_groups=None,
    use_fast_variance=True
  ), 'dropout': Dropout(rate=0.4, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
    default=RngStream(
      key=RngKey(
        value=Array((), dtype=key<fry>) overlaying:
        [0 0],
        tag='default'
      ),
      count=RngCount(
        value=Array(17, d

# Optimizer and metrics

In [31]:
class customMetrics(nnx.metrics.Metric):
    def __init__(self):
        self.true_positives = 0
        self.false_positives = 0
        self.false_negatives = 0

    def update(self, y_pred, y_true):
        y_pred = jnp.argmax(y_pred, axis=-1)  # Assuming y_pred are probabilities
        self.true_positives += jnp.sum((y_true == 1) & (y_pred == 1))
        self.false_positives += jnp.sum((y_true == 0) & (y_pred == 1))
        self.false_negatives += jnp.sum((y_true == 1) & (y_pred == 0))

    def result(self):
        precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
        recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
        return {"f1_score": f1_score, "precision": precision, "recall": recall}

    def reset(self):
        self.true_positives = 0
        self.false_positives = 0
        self.false_negatives = 0

In [32]:
import optax

learning_rate = 0.005
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
  precision_recall_f1=customMetrics(),
)

nnx.display(optimizer)

Optimizer(
  step=OptState(
    value=Array(0, dtype=uint32)
  ),
  model=MLP(
    layers=[{'batch_norm': BatchNorm(
      mean=BatchStat(
        value=Array(shape=(16,), dtype=float32)
      ),
      var=BatchStat(
        value=Array(shape=(16,), dtype=float32)
      ),
      scale=Param(
        value=Array(shape=(16,), dtype=float32)
      ),
      bias=Param(
        value=Array(shape=(16,), dtype=float32)
      ),
      num_features=16,
      use_running_average=False,
      axis=-1,
      momentum=0.99,
      epsilon=1e-05,
      dtype=None,
      param_dtype=<class 'jax.numpy.float32'>,
      use_bias=True,
      use_scale=True,
      bias_init=<function zeros at 0x7bc9981b5bd0>,
      scale_init=<function ones at 0x7bc9981b5d80>,
      axis_name=None,
      axis_index_groups=None,
      use_fast_variance=True
    ), 'dropout': Dropout(rate=0.4, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
      default=RngStream(
        key=RngKey(
          v