In [1]:
!pip install tensorflow tensorflow-model-optimization



In [2]:
import tensorflow as tf
import numpy as np

print("TensorFlow version:", tf.__version__)
print("Num GPUs:", len(tf.config.list_physical_devices('GPU')))

TensorFlow version: 2.19.0
Num GPUs: 1


In [3]:
class LoRADense(tf.keras.layers.Layer):
    def __init__(self, units, rank=8, alpha=32):
        super().__init__()
        self.units = units
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

    def build(self, input_shape):
        input_dim = input_shape[-1]

        # Frozen base weight
        self.W = self.add_weight(
            shape=(input_dim, self.units),
            initializer="glorot_uniform",
            trainable=False,
            name="base_weight"
        )

        # LoRA matrices
        self.A = self.add_weight(
            shape=(input_dim, self.rank),
            initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
            trainable=True,
            name="lora_A"
        )

        self.B = self.add_weight(
            shape=(self.rank, self.units),
            initializer="zeros",
            trainable=True,
            name="lora_B"
        )

        self.bias = self.add_weight(
            shape=(self.units,),
            initializer="zeros",
            trainable=True,
            name="bias"
        )

    def call(self, inputs):
        base = tf.matmul(inputs, self.W)
        lora = tf.matmul(tf.matmul(inputs, self.A), self.B)
        return base + self.scaling * lora + self.bias

In [4]:
inputs = tf.keras.Input(shape=(128,))
x = LoRADense(256, rank=8)(inputs)
x = tf.keras.layers.ReLU()(x)
outputs = tf.keras.layers.Dense(10)(x)

model = tf.keras.Model(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

model.summary()

In [7]:
x_train = np.random.rand(1000, 128).astype(np.float32)
y_train = np.random.randint(0, 10, size=(1000,))

model.fit(x_train, y_train, epochs=5, batch_size=32)
print(model)

Epoch 1/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.0965 - loss: 2.3582
Epoch 2/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.1074 - loss: 2.3238
Epoch 3/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.1359 - loss: 2.3098
Epoch 4/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.1403 - loss: 2.2741
Epoch 5/5
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - accuracy: 0.1578 - loss: 2.2548
<Functional name=functional, built=True>


In [8]:
model.summary()

In [10]:
for w in model.trainable_weights:
    print(w.name, w.shape)

lora_A (128, 8)
lora_B (8, 256)
bias (256,)
kernel (256, 10)
bias (10,)


In [11]:
model.count_params()

38666