<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/riceTypes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [131]:
pip install --upgrade flax



In [6]:
import kagglehub
path = kagglehub.dataset_download("mssmartypants/rice-type-classification")
print(path)

/root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2


In [8]:
! mv /root/.cache/kagglehub/datasets/mssmartypants/rice-type-classification/versions/2/riceClassification.csv dataset

In [2]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

from flax import nnx
import jax.numpy as jnp

from sklearn.preprocessing import MinMaxScaler


In [3]:
df = pd.read_csv("dataset/riceClassification.csv")
df.head()

Unnamed: 0,id,Area,MajorAxisLength,MinorAxisLength,Eccentricity,ConvexArea,EquivDiameter,Extent,Perimeter,Roundness,AspectRation,Class
0,1,4537,92.229316,64.012769,0.719916,4677,76.004525,0.657536,273.085,0.76451,1.440796,1
1,2,2872,74.691881,51.400454,0.725553,3015,60.471018,0.713009,208.317,0.831658,1.453137,1
2,3,3048,76.293164,52.043491,0.731211,3132,62.296341,0.759153,210.012,0.868434,1.46595,1
3,4,3073,77.033628,51.928487,0.738639,3157,62.5513,0.783529,210.657,0.870203,1.483456,1
4,5,3693,85.124785,56.374021,0.749282,3802,68.571668,0.769375,230.332,0.874743,1.51,1


# Load Dataset

In [4]:
class CustomImageDataset(Dataset):
    def __init__(self, dataset, transform=None, target_transform=None):
        completeDF = pd.read_csv(dataset, index_col=0)
        self.labels = pd.DataFrame(completeDF['Class'])
        features_raw = pd.DataFrame(completeDF.drop('Class', axis=1))
        #self.features = pd.DataFrame(completeDF.drop('Class', axis=1))
        scaler = MinMaxScaler()
        self.features = pd.DataFrame(scaler.fit_transform(features_raw))
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        features = np.array(self.features.iloc[idx, :])
        label = self.labels.iloc[idx, 0]
        if self.transform:
            features = self.transform(features)
        if self.target_transform:
            label = self.target_transform(label)
        return features, label

In [18]:
dataset = CustomImageDataset(dataset="dataset/classification_dataset.csv")
#dataset = CustomImageDataset(dataset="dataset/riceClassification.csv")
train_set, val_set, test_set = torch.utils.data.random_split(dataset, [.7, 0.1,0.2])
data_loader = DataLoader(train_set, batch_size=4, shuffle=True)

for features, labels in data_loader:
    print("Batch of features has shape: ",features.shape)
    print("Batch of labels has shape: ", labels.shape)
    print(features)
    print(labels)
    break

Batch of features has shape:  torch.Size([4, 10])
Batch of labels has shape:  torch.Size([4])
tensor([[0.4127, 0.3476, 0.2428, 0.5258, 0.3201, 0.3007, 0.4489, 0.6791, 0.6110,
         0.7804],
        [0.5114, 0.2991, 0.3795, 0.7041, 0.4306, 0.3606, 0.5095, 0.6207, 0.2086,
         0.5082],
        [0.4826, 0.7206, 0.5078, 0.4966, 0.7598, 0.4393, 0.8265, 0.3862, 0.6264,
         0.4440],
        [0.1707, 0.1989, 0.4746, 0.3370, 0.4702, 0.2779, 0.4934, 0.9108, 0.3687,
         0.3172]], dtype=torch.float64)
tensor([1, 0, 0, 1])


# MLP model

In [7]:
class MLP(nnx.Module):
    def __init__(self, hidden_dims: list[int], num_classes: int, rngs: nnx.Rngs):
        self.layers = []
        for din, dout in zip(hidden_dims[:-1], hidden_dims[1:]):
            self.layers.append({
                'dropout': nnx.Dropout(rate=0.4, rngs=rngs),
                'linear': nnx.Linear(din, dout, rngs=rngs),
                'batch_norm': nnx.BatchNorm(dout, rngs=rngs)
            })
        # Add the final classification layer
        self.output_layer = nnx.Linear(hidden_dims[-1], num_classes, rngs=rngs)

    def __call__(self, x):
        for layer in self.layers:
            x = layer['dropout'](x)
            x = layer['linear'](x)
            x = nnx.gelu(x)
            x = layer['batch_norm'](x)
        x = self.output_layer(x)
        return nnx.sigmoid(x)

# Instantiate and test the model
#model = MLP([10, 16, 32, 16], 1, rngs=nnx.Rngs(0))
#y = model(x=jnp.ones((3, 10)))

#nnx.display(model)

model = MLP([10, 16, 32, 16], 1, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 10)))

nnx.display(y)



[[0.57413197]
 [0.49904767]
 [0.42679968]]


# Optimizer and metrics

In [51]:
class customMetrics(nnx.metrics.Metric):
    def __init__(self):
        self.true_positives = jnp.array(0)
        self.false_positives = jnp.array(0)
        self.false_negatives = jnp.array(0)

    def update(self, loss, logits, labels):
        logits = jnp.argmax(logits, axis=-1)  # Assuming y_pred are probabilities
        tp = jnp.sum((labels == 1) & (logits == 1))
        fp = jnp.sum((labels == 0) & (logits == 1))
        fn = jnp.sum((labels == 1) & (logits == 0))

        self.true_positives += tp
        self.false_positives += fp
        self.false_negatives += fn

    def result(self):
        precision = self.true_positives / (self.true_positives + self.false_positives + 1e-7)
        recall = self.true_positives / (self.true_positives + self.false_negatives + 1e-7)
        f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)
        return {"f1_score": f1_score, "precision": precision, "recall": recall}

    def reset(self):
        self.true_positives = jnp.array(0)
        self.false_positives = jnp.array(0)
        self.false_negatives = jnp.array(0)

In [8]:
class CustomAccuracy(nnx.metrics.Metric):
    def __init__(self):
        self.correct_count = 0
        self.total_count = 0

    def update(self, loss, logits, labels):
        # Convert logits to binary predictions (0 or 1) based on a 0.5 threshold
        predictions = jnp.where(jnp.array(logits) > 0.5, 1, 0)
        # Flatten if necessary
        predictions = predictions.ravel()
        labels = jnp.array(labels).ravel()

        # Calculate number of correct predictions in the current batch
        self.correct_count += jnp.sum(predictions == labels)
        self.total_count += len(labels)

    def compute(self):
        # Calculate accuracy over all batches seen so far
        if self.total_count == 0:
            return 0  # Avoid division by zero if no samples are seen
        return self.correct_count / self.total_count
    def reset(self):
        # Reset counters
        self.correct_count = 0
        self.total_count = 0



In [19]:
import optax
model = MLP([10, 16], 1, rngs=nnx.Rngs(0))
learning_rate = 0.001
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adam(learning_rate))
metrics = nnx.MultiMetric(
  accuracy=CustomAccuracy(),
  loss=nnx.metrics.Average('loss'),
)
  #precision_recall_f1=customMetrics(),


#nnx.display(optimizer)

# Train steps

In [23]:
import jax.numpy as jnp
def loss_fn(model: MLP, batch):
  logits = model(batch['features'])
  loss = optax.sigmoid_binary_cross_entropy(
    logits=logits, labels=batch['labels'].reshape(-1, 1)
  ).mean()
  #loss = (logits - batch['labels'])**2
  #print(logits.pval)
  return loss, logits

@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)

  metrics.update(loss=loss, logits=logits, labels=batch['labels'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.
  return loss, logits

@nnx.jit
def eval_step(model: MLP, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['labels'])  # In-place updates.


# Train and Eval

In [20]:
def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    labels =  np.array(transposed_data[1])
    features = np.array(transposed_data[0])

    return {"features":features, "labels":labels}

train_set, val_set, test_set = torch.utils.data.random_split(dataset, [.7, 0.1,0.2])

train_ds = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)
val_set = DataLoader(val_set, batch_size=4, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)
test_ds = DataLoader(test_set, batch_size=4, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)

# test
batch_data = next(iter(train_ds))
imgs = batch_data['features']
lbls = batch_data['labels']
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)
print(lbls)

#loss = train_step(model, optimizer, metrics, batch_data)
loss, logits = loss_fn(model, batch_data)
print(loss.shape, logits.shape)
print(f'{loss = }')
#print(f'{logits = }')
print(f'{optimizer.step.value = }')


(64, 10) float64 (64,) int64
[1 0 0 1 1 0 0 1 1 1 0 1 0 0 0 1 1 1 0 0 1 0 0 0 1 1 0 0 1 0 0 1 1 0 1 1 1
 0 0 1 1 1 0 0 1 0 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 0 1 1]
() (64, 1)
loss = Array(0.68516624, dtype=float32)
optimizer.step.value = Array(0, dtype=uint32)


In [21]:
import jax
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'train_precision': [],
  'train_recall': [],
  'train_f1': [],
  'val_loss': [],
  'val_accuracy': [],
  'val_precision': [],
  'val_recall': [],
  'val_f1': [],
  'test_loss': [],
  'test_accuracy': [],
  'test_precision': [],
  'test_recall': [],
  'test_f1': [],
}

num_epochs = 20
for epoch in range(num_epochs):
    for batch in train_ds:
      loss, logits = train_step(model, optimizer, metrics, batch)

    for metric, value in metrics.compute().items():
      metrics_history[f'train_{metric}'].append(value)

    metrics.reset()
    for val_batch in val_set:
      eval_step(model, metrics, val_batch)
    for metric, value in metrics.compute().items():
      metrics_history[f'val_{metric}'].append(value)
    metrics.reset()
    print(
      f"[train] epoch: {epoch}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
      #f"accuracy: {metrics_history['train_precision'][-1] * 100}"
      #f"accuracy: {metrics_history['train_recall'][-1] * 100}"
      #f"accuracy: {metrics_history['train_f1'][-1] * 100}"

    )
    """print(
      f"[test] epoch: {epoch}, "
      f"loss: {metrics_history['val_loss'][-1]}, "
      f"accuracy: {metrics_history['val_accuracy'][-1] * 100}"
      #f"accuracy: {metrics_history['val_precision'][-1] * 100}"
      #f"accuracy: {metrics_history['val_recall'][-1] * 100}"
      #f"accuracy: {metrics_history['val_f1'][-1] * 100}"
    )"""
"""for test_batch in test_ds:
      eval_step(model, metrics, test_batch)
"""

[train] epoch: 0, loss: 0.713580846786499, accuracy: 57.8125
[train] epoch: 1, loss: 0.7034299373626709, accuracy: 54.6875
[train] epoch: 2, loss: 0.737416684627533, accuracy: 54.6875
[train] epoch: 3, loss: 0.7496273517608643, accuracy: 57.03125
[train] epoch: 4, loss: 0.7125489711761475, accuracy: 69.53125
[train] epoch: 5, loss: 0.7265267372131348, accuracy: 53.125
[train] epoch: 6, loss: 0.7226094007492065, accuracy: 53.125
[train] epoch: 7, loss: 0.7251901626586914, accuracy: 54.6875
[train] epoch: 8, loss: 0.7349185943603516, accuracy: 49.21875
[train] epoch: 9, loss: 0.7223374247550964, accuracy: 46.09375
[train] epoch: 10, loss: 0.7486499547958374, accuracy: 57.8125
[train] epoch: 11, loss: 0.7377750873565674, accuracy: 55.46875
[train] epoch: 12, loss: 0.7167431116104126, accuracy: 53.90625
[train] epoch: 13, loss: 0.7198973298072815, accuracy: 50.78125
[train] epoch: 14, loss: 0.7022953629493713, accuracy: 58.59375
[train] epoch: 15, loss: 0.7176570892333984, accuracy: 54.687

'for test_batch in test_ds:\n      eval_step(model, metrics, test_batch)\n'

In [13]:
print(metrics_history['train_accuracy'])

[Array(0.71898675, dtype=float32), Array(0.79506, dtype=float32), Array(0.8662405, dtype=float32), Array(0.91035354, dtype=float32), Array(0.92889833, dtype=float32), Array(0.94278723, dtype=float32), Array(0.947601, dtype=float32), Array(0.94996846, dtype=float32), Array(0.9566761, dtype=float32), Array(0.9551768, dtype=float32), Array(0.9579388, dtype=float32), Array(0.96030617, dtype=float32), Array(0.962279, dtype=float32), Array(0.96417296, dtype=float32), Array(0.9656724, dtype=float32), Array(0.9661458, dtype=float32), Array(0.96953917, dtype=float32), Array(0.97159094, dtype=float32), Array(0.9699337, dtype=float32), Array(0.96859217, dtype=float32)]


In [14]:
model.eval() # Switch to evaluation mode.

@nnx.jit
def pred_step(model: MLP, batch):
  logits = model(batch['features'])
  return logits

test_ds = DataLoader(test_set, batch_size=32, shuffle=False, drop_last=True, collate_fn=custom_collate_fn)

ypred = []
label = []
for test_batch in test_ds:
  logits = pred_step(model, test_batch)
  #print(np.ravel(logits))
  #print(np.ravel(logits))
  #break
  ypred.extend(np.ravel(logits))
  label.extend(np.ravel(test_batch["labels"]))
  #print(logits, test_batch["labels"])
  #break
binary_ypred = np.where(np.array(ypred) > 0.5, 1, 0)
print(len(label), sum(label))
accuracy = sum([1 for pred, true in zip(binary_ypred, label) if pred == true]) / len(label)
print(f"Accuracy: {accuracy:.4f}")

3616 1992
Accuracy: 0.8999


In [16]:
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
# Generate a dataset with 100 samples, 10 features, 5 informative, 5 redundant, and 2 classes
X, y = make_classification(n_samples=256, n_features=10, n_informative=8,
                          n_classes=2, random_state=42)
df = pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
df['Class'] = y

# Save the DataFrame to a CSV file
mlp = MLPClassifier(hidden_layer_sizes=(64,), max_iter=500, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the MLP on the data
mlp.fit(X_train, y_train)

# Predict on the test set
y_pred = mlp.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")
df.to_csv('dataset/classification_dataset.csv', index=True)

Accuracy: 0.8654


In [122]:
logits = jax.random.normal(jax.random.key(0), (5, 2))
logits

Array([[-0.3721109 ,  0.26423115],
       [-0.18252768, -0.7368197 ],
       [-0.44030377, -0.1521442 ],
       [-0.67135346, -0.5908641 ],
       [ 0.73168886,  0.5673026 ]], dtype=float32)