In [1]:
import json
import time
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import utils as lora_utils
from mlx.utils import tree_flatten
from models import LoRALinear
from mlx_lm import load, generate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "mlx-community/gemma-2-9b-it-4bit"
data_folder = "data/"
lora_layers = 4
batch_size = 8
iters = 100
steps_per_report = 2
steps_per_eval = 20
val_batches = 8
learning_rate = 1e-4
seed = 0
save_every = 10

In [3]:
adapter_file = f"{time.strftime('%Y%m%d-%H%M%S')}-adapters-{model_path.split('/')[-1]}.npz"
adapter_file

'20240719-223132-adapters-gemma-2-9b-it-4bit.npz'

In [4]:
print("Loading pretrained model")
model, tokenizer = load(model_path)

Loading pretrained model


Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 81180.08it/s]


In [5]:
example_prompt = "Hasta:Ankilozan Spondilit ve omurilik ve göğüs kafesi kemikleri birbirine girdi ve kaynamış biz bu hastalığın tedavisi var mı? Lütfen türkçe cevap ver, tercüme etme"
verbose = True
top_p = 0.8
temperature = 0.7
repetition_penalty = 1.05
max_tokens = 512

In [6]:
# response = generate(model, tokenizer, prompt=example_prompt, verbose=True, top_p=0.8, temp=0.7, repetition_penalty=1.05, max_tokens=512)
response = generate(model, tokenizer, prompt=example_prompt, verbose=True, top_p=top_p, temp=temperature, repetition_penalty=repetition_penalty, max_tokens=max_tokens)

Prompt: Hasta:Ankilozan Spondilit ve omurilik ve göğüs kafesi kemikleri birbirine girdi ve kaynamış biz bu hastalığın tedavisi var mı? Lütfen türkçe cevap ver, tercüme etme
yin.

Ankilozan spondilit ciddi bir romatizmal hastalıktır ve kemiklerin birbirine yapışmasına neden olur.  Omurganın, göğüs kafesinin ve pelvisin kemikleri etkilenir. Bu durum hareket kısıtlılığına, ağrı ve diğer problemlere yol açabilir.

**İyi haber şu ki, ankilozan spondilit için etkili tedavi yöntemleri mevcuttur.** 

Tedavi hedefi semptomları yönetmek, ağrıyı azaltmak, eklem hareketini korumak ve hastalığın ilerlemesini yavaşlatmaktır. 

**Tedavi planı şunları içerebilir:**

* **İlaçlar:** Ağrı kesiciler, anti-inflamatuar ilaçlar (AINS), bisfosfonatlar ve immün baskılayıcılar gibi ilaçlar kullanılır.
* **Fizik Tedavi:** Egzersizler, esneme hareketleri ve su terapisi ağrıyı azaltır, kas gücünü artırır ve hareket kabiliyetini geliştirir.
* **Diğer Tedaviler:** Obezite varsa kilo kontrolü, düzenli uyku ve stres y

In [7]:
# Freeze all layers other than LORA linears
model.freeze()

Model(
  (model): GemmaModel(
    (embed_tokens): QuantizedEmbedding(256000, 3584, group_size=64, bits=4)
    (layers.0): TransformerBlock(
      (self_attn): Attention(
        (q_proj): QuantizedLinear(input_dims=3584, output_dims=4096, bias=False,group_size=64, bits=4)
        (k_proj): QuantizedLinear(input_dims=3584, output_dims=2048, bias=False,group_size=64, bits=4)
        (v_proj): QuantizedLinear(input_dims=3584, output_dims=2048, bias=False,group_size=64, bits=4)
        (o_proj): QuantizedLinear(input_dims=4096, output_dims=3584, bias=False,group_size=64, bits=4)
        (rope): RoPE(256, traditional=False)
      )
      (mlp): MLP(
        (gate_proj): QuantizedLinear(input_dims=3584, output_dims=14336, bias=False,group_size=64, bits=4)
        (down_proj): QuantizedLinear(input_dims=14336, output_dims=3584, bias=False,group_size=64, bits=4)
        (up_proj): QuantizedLinear(input_dims=3584, output_dims=14336, bias=False,group_size=64, bits=4)
      )
      (input_layerno

In [8]:
for l in model.model.layers[len(model.model.layers) - lora_layers :]:
    l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
    l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
    if hasattr(l, "block_sparse_moe"):
        l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate)

p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

Total parameters 1444.954M
Trainable parameters 0.426M


In [9]:
class Dataset:
    """
    Light-weight wrapper to hold lines from a jsonl file
    """

    def __init__(self, path: Path, key: str = "text"):
        if not path.exists():
            self._data = None
        else:
            with open(path, "r") as fid:
                self._data = [json.loads(l) for l in fid]
        self._key = key

    def __getitem__(self, idx: int):
        return self._data[idx][self._key]

    def __len__(self):
        return len(self._data)

In [10]:
def load(data_folder: str, training: bool = False, validation: bool = False, testing: bool = False):
    def load_and_check(name):
        dataset_path = Path(data_folder) / f"{name}.jsonl"
        try:
            return Dataset(dataset_path)
        except Exception as e:
            print(f"Unable to build dataset {dataset_path} ({e})")
            raise

    names = ("train", "valid", "test")
    train, valid, test = (load_and_check(n) for n in names)

    if training and len(train) == 0:
        raise ValueError(
            "Training set not found or empty. Must provide training set for fine-tuning."
        )
    if validation and len(valid) == 0:
        raise ValueError(
            "Validation set not found or empty. Must provide validation set for fine-tuning."
        )
    if testing and len(test) == 0:
        raise ValueError(
            "Test set not found or empty. Must provide test set for evaluation."
        )
    return train, valid, test

In [11]:
print("Loading datasets")
train_set, valid_set, test_set = load(data_folder, training=True)
print(f"Training set: {len(train_set)}, Validation set: {len(valid_set)}, Test set: {len(test_set)}")

Loading datasets
Training set: 1000, Validation set: 100, Test set: 100


In [12]:
def iterate_batches(dset, tokenizer, batch_size, train=False):
    # Shuffle indices
    while True:
        indices = np.arange(len(dset))
        if train:
            indices = np.random.permutation(indices)

        # Collect batches from dataset
        for i in range(0, len(indices) - batch_size + 1, batch_size):
            # Encode batch
            batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
            lengths = [len(x) for x in batch]

            # Check if any sequence is longer than 2048 tokens
            if max(lengths) > 2048:
                print(
                    "[WARNING] Some sequences are longer than 2048 tokens. "
                    "Consider pre-splitting your data to save memory."
                )

            # Pad to the max length
            batch_arr = np.zeros((batch_size, max(lengths)), np.int32)

            for j in range(batch_size):
                batch_arr[j, : lengths[j]] = batch[j]
            batch = mx.array(batch_arr)
            yield batch[:, :-1], batch[:, 1:], mx.array(lengths)

        if not train:
            break


In [13]:
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
    all_losses = []
    ntokens = 0
    for it, batch in zip(
        range(num_batches),
        iterate_batches(dataset, tokenizer, batch_size),
    ):
        losses, toks = loss(model, *batch)
        all_losses.append((losses * toks).item())
        ntokens += toks.item()

    return np.sum(all_losses) / ntokens

In [14]:
def train(model, train_set, val_set, optimizer, loss, tokenizer):
    # Create value and grad function for loss
    loss_value_and_grad = nn.value_and_grad(model, loss)

    losses = []
    n_tokens = 0

    # Main training loop
    start = time.perf_counter()
    for it, batch in zip(
        range(iters),
        iterate_batches(train_set, tokenizer, batch_size, train=True),
    ):
        # Forward and backward pass
        (lvalue, toks), grad = loss_value_and_grad(model, *batch)

        # Model update
        optimizer.update(model, grad)
        mx.eval(model.parameters(), optimizer.state, lvalue)

        # Record loss
        losses.append(lvalue.item())
        n_tokens += toks.item()

        # Report training loss if needed
        if (it + 1) % steps_per_report == 0:
            train_loss = np.mean(losses)

            stop = time.perf_counter()
            print(
                f"Iter {it + 1}: Train loss {train_loss:.3f}, "
                f"It/sec {steps_per_report / (stop - start):.3f}, "
                f"Tokens/sec {float(n_tokens) / (stop - start):.3f}"
            )
            losses = []
            n_tokens = 0
            start = time.perf_counter()

        # Report validation loss if needed
        if it == 0 or (it + 1) % steps_per_eval == 0:
            stop = time.perf_counter()
            val_loss = evaluate(
                model, val_set, loss, tokenizer, batch_size, val_batches
            )
            print(
                f"Iter {it + 1}: "
                f"Val loss {val_loss:.3f}, "
                f"Val took {(time.perf_counter() - stop):.3f}s"
            )

            start = time.perf_counter()

        # Save adapter weights if needed
        if (it + 1) % save_every == 0:
            mx.savez(
                adapter_file, **dict(tree_flatten(model.trainable_parameters()))
            )
            print(f"Iter {it + 1}: Saved adapter weights to {adapter_file}.")


In [15]:
def loss(model, inputs, targets, lengths):
    # Run model on inputs
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    # Mask padding tokens
    length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]

    # Calculate the loss
    ce = nn.losses.cross_entropy(logits, targets) * length_mask
    ntoks = length_mask.sum()
    ce = ce.sum() / ntoks
    return ce, ntoks

In [16]:
print("Training")

np.random.seed(seed)

opt = optim.Adam(learning_rate=learning_rate)

# Train model
train(model, train_set, valid_set, opt, loss, tokenizer)

# Save adapter weights
mx.savez(adapter_file, **dict(tree_flatten(model.trainable_parameters())))

Training


ValueError: too many values to unpack (expected 2)