<a href="https://colab.research.google.com/github/h40300965/deep-learnin/blob/main/nfnet-%20tensofrflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow-addons



📁 1. Install Required Packages

In [None]:
!pip install sam-tf --quiet

[31mERROR: Could not find a version that satisfies the requirement sam-tf (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for sam-tf[0m[31m
[0m

🧠 2. Import Libraries

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers, models, mixed_precision

# Enable Mixed Precision (optional)
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

import numpy as np

🔧 3. Scaled Weight Standardization Layer

In [None]:
class ScaledWSConv2D(layers.Conv2D):
    def build(self, input_shape):
        super().build(input_shape)
        # Compute fan-in
        kernel_shape = self.kernel.shape.as_list()
        self.fan_in = np.prod(kernel_shape[:-1])

    def call(self, inputs):
        mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
        weight = (self.kernel - mean) * tf.math.rsqrt(var + 1e-10)
        scale = tf.math.sqrt(2. / self.fan_in)
        x = tf.nn.conv2d(
            inputs,
            filters=weight * scale,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format="NHWC"
        )
        if self.use_bias:
            x = tf.nn.bias_add(x, self.bias)
        return x

⚙️ 4. NFBlock with SkipInit

In [None]:
class NFBlock(layers.Layer):
    def __init__(self, channels, expansion=2, stride=1, **kwargs):
        super().__init__(**kwargs)
        mid_channels = channels // expansion
        self.conv1 = ScaledWSConv2D(mid_channels, 1, strides=stride, use_bias=False)
        self.act1 = layers.Activation('gelu')

        self.conv2 = ScaledWSConv2D(mid_channels, 3, strides=1, padding='same', use_bias=False)
        self.act2 = layers.Activation('gelu')

        self.conv3 = ScaledWSConv2D(channels, 1, strides=1, use_bias=False)
        self.act3 = layers.Activation('gelu')

        if stride != 1 or inputs.shape[-1] != channels:
            self.shortcut = ScaledWSConv2D(channels, 1, strides=stride, use_bias=False)
        else:
            self.shortcut = tf.identity

        self.skip_gain = tf.Variable(0., trainable=True, name="skip_gain")

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.act2(x)

        x = self.conv3(x)
        x = x * self.skip_gain

        shortcut = self.shortcut(inputs)
        x = x + shortcut
        x = self.act3(x)
        return x

🧱 5. Stem Network

In [None]:
def get_stem(filters):
    return models.Sequential([
        layers.Input((32, 32, 3)),
        ScaledWSConv2D(filters, 3, strides=1, padding='same'),
        layers.Activation('gelu'),
        ScaledWSConv2D(filters, 3, strides=1, padding='same'),
        layers.Activation('gelu'),
        ScaledWSConv2D(filters, 3, strides=1, padding='same'),
        layers.Activation('gelu'),
        layers.MaxPool2D(3, strides=2, padding='same')
    ], name="stem")

🧬 6. Build Full NFNet Model

In [None]:
def build_nfnet(num_classes=10):
    stem = get_stem(64)
    blocks = [
        NFBlock(256),
        NFBlock(256),
        NFBlock(256),
        NFBlock(512, stride=2),
        NFBlock(512),
        NFBlock(512),
        NFBlock(1024, stride=2),
        NFBlock(1024),
        NFBlock(1024)
    ]

    head = models.Sequential([
        layers.GlobalAveragePooling2D(),
        layers.Dense(num_classes)
    ])

    model = models.Sequential([
        stem,
        *blocks,
        head
    ])
    return model

🔄 7. Data Augmentation & CIFAR-10 Dataset

In [None]:
(train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()

# Normalize
train_images = train_images.astype("float32") / 255.0
test_images = test_images.astype("float32") / 255.0

# Augmentation
augmenter = models.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.2)
])

# DataLoader
batch_size = 512
num_epochs = 20

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.map(lambda x, y: (augmenter(x), y)).shuffle(10000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(batch_size)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 0us/step


⚙️ 8. SAM Optimizer + Loss Function

  from sam_tf import SAMModel


```
# This is formatted as code
```


model = build_nfnet(num_classes=10)

# Wrap with SAM
sam_model = SAM(model, distance=0.05)

# Compile
sam_model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=3e-4, weight_decay=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

 Manually Implement SAM

In [None]:
import tensorflow as tf
from tensorflow import keras

class SAMModel(keras.Model):
    def __init__(self, base_model, rho=0.05):
        super().__init__()
        self.base_model = base_model
        self.rho = rho  # Perturbation radius (called "distance" in your code)

    def train_step(self, data):
        (images, labels) = data

        # First forward-backward pass
        with tf.GradientTape() as tape:
            logits = self.base_model(images, training=True)
            loss = self.compiled_loss(labels, logits)
        gradients = tape.gradient(loss, self.base_model.trainable_variables)

        # Apply SAM perturbation
        grad_norm = tf.linalg.global_norm(gradients)
        scale = self.rho / (grad_norm + 1e-12)
        perturbations = [g * scale for g in gradients]

        # Save original weights and apply perturbations
        original_weights = [tf.identity(w) for w in self.base_model.weights]
        for w, p in zip(self.base_model.trainable_variables, perturbations):
            w.assign_add(p)

        # Second forward-backward pass
        with tf.GradientTape() as tape:
            logits = self.base_model(images, training=True)
            loss = self.compiled_loss(labels, logits)
        gradients = tape.gradient(loss, self.base_model.trainable_variables)

        # Restore original weights
        for w, orig in zip(self.base_model.weights, original_weights):
            w.assign(orig)

        # Update weights with the gradients from the second pass
        self.optimizer.apply_gradients(
            zip(gradients, self.base_model.trainable_variables)
        )

        # Update metrics
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

Update Your Code   Manually Implement SAM

```
# This is formatted as code
```



In [None]:
from tensorflow import keras

# Replace this with your actual NFNet builder
def build_nfnet(input_shape=(224, 224, 3), num_classes=10):
    return keras.Sequential([
        keras.layers.Input(shape=input_shape),
        keras.layers.Conv2D(64, 3, activation='relu'),
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dense(num_classes)
    ])

# Build model and wrap with SAM
base_model = build_nfnet(num_classes=10)
sam_model = SAMModel(base_model, rho=0.05)  # Use rho=0.05 instead of distance

# Compile and train
sam_model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=3e-4, weight_decay=1e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# sam_model.fit(x_train, y_train, epochs=10, ...)

🏋️‍♂️ 9. Train the Model

In [None]:
history = sam_model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=num_epochs
)


Epoch 1/20


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 933ms/step - accuracy: 0.1014 - loss: -0.0081 - val_accuracy: 0.1160 - val_loss: 2.2678
Epoch 2/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 956ms/step - accuracy: 0.1286 - loss: -0.0084 - val_accuracy: 0.1537 - val_loss: 2.2374
Epoch 3/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 903ms/step - accuracy: 0.1547 - loss: -2.9738e-04 - val_accuracy: 0.1782 - val_loss: 2.2049
Epoch 4/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 895ms/step - accuracy: 0.1675 - loss: 4.2788e-04 - val_accuracy: 0.1933 - val_loss: 2.1742
Epoch 5/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m98s[0m 944ms/step - accuracy: 0.1764 - loss: -0.0025 - val_accuracy: 0.2034 - val_loss: 2.1495
Epoch 6/20
[1m98/98[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 987ms/step - accuracy: 0.1840 - loss: -0.0109 - val_accuracy: 0.2064 - val_loss: 2.1305
Epoch 7/20
[1m98

📊 10. Evaluate and Save

In [None]:
test_loss, test_acc = sam_model.evaluate(test_dataset)
print(f"Test Accuracy: {test_acc:.2%}")




[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 139ms/step - accuracy: 0.2646 - loss: 2.0527
Test Accuracy: 25.92%
