In [1]:
# Installing last versions of everything
!pip3 install zstandard jsonlines datasets sentencepiece langchain
!pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -q
!pip install git+https://github.com/huggingface/transformers.git -q
!pip install git+https://github.com/deepmind/optax.git -q
!pip install jax jax-smi einops flax optax fire mypy wandb matplotlib tqdm types-tqdm pdoc3 tf-nightly accelerate sentencepiece protobuf~=3.20.0 -q
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q # It's actually a nightly version. JAX doesn't have a stable version since it's so new

Collecting zstandard
  Downloading zstandard-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting jsonlines
  Downloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)
Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain
  Downloading langchain-0.1.5-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.7/806.7 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
Collec

In [2]:
!git clone https://github.com/defdet/falcon-jax
!cp -R /kaggle/working/falcon-2-jax/lib .

Cloning into 'llama-2-jax'...
remote: Enumerating objects: 928, done.[K
remote: Counting objects: 100% (361/361), done.[K
remote: Compressing objects: 100% (175/175), done.[K
remote: Total 928 (delta 269), reused 188 (delta 186), pack-reused 567[K
Receiving objects: 100% (928/928), 275.73 KiB | 7.66 MiB/s, done.
Resolving deltas: 100% (536/536), done.


In [3]:
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb_api")
!wandb login '11111111111111111' # Use your own token

from typing import NamedTuple, Tuple

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [4]:
from lib.proc_init_utils import initialise_tpu

import einops as op
from functools import partial
import jax
from jax import Array
from jax.experimental.multihost_utils import process_allgather
import jax.numpy as jnp
import jax.random as rand
import jax_smi
import math
import optax
import signal
import time
from transformers import FalconTokenizer
from tqdm import tqdm
from typing import Any, Callable
import wandb

from lib.data import TrainData
from lib.dataloader import FalconDataLoader
from lib.gsm_data import GSMDataset, gsm_collate_fn_train
from lib.falcon import Falcon, RotaryValues, forward_falcon, init_falcon, make_rotary_values
# from lib.falcon import model_config_dummy as model_config
from lib.falcon import model_config_falcon_7B as model_config
from lib.loss import cross_entropy_loss
from lib.multihost_utils import shard_model_params
from lib.param_utils import load_params, save_params

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def load_params_from_disk():
    cpu_device = jax.devices('cpu')[0]
    with jax.default_device(cpu_device):
        params = load_params('/kaggle/input/pure-falcon-7b-jax/Falcon-7B-JAX.pickle')
        params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
    params = shard_model_params(params)
    return params

def set_save_params_signal():
    signal.signal(signal.SIGINT, save_params_signal_handler)
    signal.signal(signal.SIGTERM, save_params_signal_handler)

def unset_save_params_signal():
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    signal.signal(signal.SIGTERM, signal.SIG_IGN)

def save_params_to_disk() -> None:
    unset_save_params_signal()
    gathered_params = process_allgather(params)
    if is_process_0:
        save_params(gathered_params, f'{wandb.run.name}.pickle')
    set_save_params_signal()

def save_params_signal_handler(signum, frame):
    save_params_to_disk()
    print(f'Signal {signum} received. Model params have been successfully saved to disk.')
    exit(-1)

@jax.value_and_grad
def train_forward(params: Falcon, rotary_values: RotaryValues, data_batch: TrainData, *, key: Array):
    seq, seq_mask, labels = data_batch
    qk_mask = op.rearrange(jnp.tril(op.einsum(seq_mask, seq_mask, 'B L1, B L2 -> B L1 L2')), 'B L1 L2 -> B 1 1 L1 L2')  # causal QK mask
    logits, _ = forward_falcon(params, seq, qk_mask, rotary_values=rotary_values, key=key, model_config=model_config)
    loss = cross_entropy_loss(logits, labels, mask=seq_mask)
    return loss

@jax.jit
def train_step(params: Falcon, opt_state: Any, rotary_values: RotaryValues, total_loss: Array, data_batch: TrainData, key: Array) -> tuple[Falcon, Any, Array, Array, Array]:
    key, subkey = rand.split(key)
    loss, grads = train_forward(params, rotary_values, data_batch, key=subkey)
    total_loss += loss
    updates, opt_state = optimize(grads, opt_state, params)  # type: ignore
    params = optax.apply_updates(params, updates)
    return params, opt_state, total_loss, loss, key

In [6]:
import pandas as pd
import torch
from datasets import Dataset, load_dataset, concatenate_datasets
from transformers import DataCollatorForLanguageModeling
lr = 0.00005
batch_size = 1
n_accumulation_steps = 1
max_len = 512
n_epochs = 7
seed = 3407

df = pd.read_csv('/kaggle/input/corpus-of-russian-news-articles-from-lenta/lenta-ru-news.csv', low_memory=False)[:100000].drop(columns=['url', 'tags', 'title', 'topic', 'date']).astype('str')
def filter_word_count(row):
    words = row.split()
    word_count = len(words)
    for word in words:
        if "https" in word:
            return False
    if word_count <= 32:
        return True
    else:
        return False

df = df[df['text'].apply(filter_word_count)].dropna().reset_index(drop=True)
dataset_sum = Dataset.from_pandas(df).train_test_split(test_size=0.1)

jax_smi.initialise_tracking()

tokenizer = FalconTokenizer.from_pretrained('tiiuae/falcon-7b', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

def preprocess_function(example):
    text_tokens = tokenizer(example["text"], truncation=True, max_length=max_len, padding='max_length', return_tensors='jax')
    return {
        "input_ids": text_tokens.input_ids.flatten(),
        "attention_mask": text_tokens.attention_mask.flatten(),
    }

dataset_sum['train'] = dataset_sum['train'].map(preprocess_function, batched=False, num_proc=1).remove_columns(['text'])
training_loader = torch.utils.data.DataLoader(dataset_sum['train'], batch_size=1, collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors='np'))

    
    
wandb.init(project='falcon-finetuning', config=dict(learning_rate=lr, batch_size=batch_size * n_accumulation_steps, n_epochs=n_epochs, optimiser='adamw'))




E0202 01:25:17.014160888     649 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-02-02T01:25:17.014140284+00:00"}
tokenizer_config.json: 100%|██████████| 967/967 [00:00<00:00, 4.31MB/s]
tokenizer.model: 100%|██████████| 493k/493k [00:00<00:00, 4.17MB/s]
special_tokens_map.json: 100%|██████████| 72.0/72.0 [00:00<00:00, 293kB/s]
tokenizer.json: 100%|██████████| 1.80M/1.80M [00:00<00:00, 4.21MB/s]
Map: 100%|██████████| 444/444 [00:03<00:00, 120.73 examples/s]
[34m[1mwandb[0m: Currently logged in as: [33mbossprocool[0m ([33mmemers[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
key = rand.key(seed, impl='rbg')
params = load_params_from_disk()
print('Loaded parameters')
set_save_params_signal()

Loaded parameters


In [None]:
from lib.falcon import Falcon, FalconModel
from lib.falcon.attention import Attention
from lib.falcon.decoder import Decoder
n_steps = math.ceil(len(training_loader) / n_accumulation_steps)
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.,
    peak_value=lr,
    warmup_steps=n_steps,
    decay_steps=n_steps + 1,
    end_value=lr,
)
optimizer = optax.adamw(learning_rate=schedule)
optimizer = optax.multi_transform(
                    {'train': optimizer, 
                     'freeze': optax.set_to_zero()
                     },

                    Falcon(
                        model=FalconModel(
                            embedding='freeze',
                            decoder=Decoder(
                                input_norm='freeze',
                                attention=Attention(q_proj='freeze', k_proj='freeze', v_proj='freeze', out_proj='train'),
                                post_attn_norm='freeze',
                                gate_proj='freeze',
                                up_proj='freeze',
                                down_proj='train',
                            ),
                            norm='freeze',
                        ),
                        lm_head='freeze'))
optimize = optimizer.update
opt_state = optimizer.init(params)

rotary_values = make_rotary_values(None, batch_size, max_len, model_config=model_config)

for _ in range(n_epochs):
    step_loss = 0.0
    total_loss = jnp.zeros(())

    def report_to_wandb(start_time, opt_state, loss):
        wandb.log({'train loss': loss.item() / n_accumulation_steps, 'time': time.time() - start_time})

    for step, data_batch in enumerate(training_loader):
        data_batch = (data_batch.input_ids.astype(jnp.uint16), data_batch.attention_mask.astype(jnp.bool_), data_batch.labels.astype(jnp.uint16))
        start_time = time.time()
        params, opt_state, total_loss, loss, key = train_step(params, opt_state, rotary_values, total_loss, data_batch, key)
        jax.debug.callback(report_to_wandb, start_time, opt_state, loss)
    wandb.log({'epoch loss': total_loss.item() / (step + 1)})
    print('epoch loss, ', total_loss / (step + 1))




  9%|▉         | 41/444 [00:18<03:05,  2.17it/s]

  0%|          | 1/444 [00:00<02:55,  2.52it/s][A
  0%|          | 2/444 [00:00<01:44,  4.24it/s][A
  1%|          | 3/444 [00:00<01:21,  5.42it/s][A
  1%|          | 4/444 [00:00<01:10,  6.23it/s][A
  1%|          | 5/444 [00:00<01:04,  6.81it/s][A
  1%|▏         | 6/444 [00:01<01:00,  7.22it/s][A
  2%|▏         | 7/444 [00:01<00:58,  7.51it/s][A
  2%|▏         | 8/444 [00:01<00:56,  7.73it/s][A
  2%|▏         | 9/444 [00:01<00:55,  7.89it/s][A
  2%|▏         | 10/444 [00:01<00:54,  8.00it/s][A
  2%|▏         | 11/444 [00:01<00:53,  8.08it/s][A
  3%|▎         | 12/444 [00:01<00:53,  8.10it/s][A
  3%|▎         | 13/444 [00:01<00:52,  8.14it/s][A
  3%|▎         | 14/444 [00:01<00:52,  8.17it/s][A
  3%|▎         | 15/444 [00:02<00:52,  8.19it/s][A
  4%|▎         | 16/444 [00:02<00:52,  8.21it/s][A
  4%|▍         | 17/444 [00:02<00:51,  8.23it/s][A
  4%|▍         | 18/444 [00:02<00:51,  8.24it/s][A
  4%|▍         | 19/44

epoch loss,  5.910086


100%|██████████| 444/444 [00:54<00:00,  8.14it/s]
100%|██████████| 444/444 [00:53<00:00,  8.33it/s]

epoch loss,  5.3959656



100%|██████████| 444/444 [00:53<00:00,  8.26it/s]

  0%|          | 1/444 [00:00<00:53,  8.26it/s][A
  0%|          | 2/444 [00:00<00:53,  8.24it/s][A
  1%|          | 3/444 [00:00<00:53,  8.23it/s][A
  1%|          | 4/444 [00:00<00:53,  8.23it/s][A
  1%|          | 5/444 [00:00<00:53,  8.23it/s][A
  1%|▏         | 6/444 [00:00<00:53,  8.22it/s][A
  2%|▏         | 7/444 [00:00<00:53,  8.22it/s][A
  2%|▏         | 8/444 [00:00<00:53,  8.21it/s][A
  2%|▏         | 9/444 [00:01<00:53,  8.20it/s][A
  2%|▏         | 10/444 [00:01<00:52,  8.21it/s][A
  2%|▏         | 11/444 [00:01<00:52,  8.22it/s][A
  3%|▎         | 12/444 [00:01<00:52,  8.21it/s][A
  3%|▎         | 13/444 [00:01<00:52,  8.22it/s][A
  3%|▎         | 14/444 [00:01<00:52,  8.22it/s][A
  3%|▎         | 15/444 [00:01<00:52,  8.23it/s][A
  4%|▎         | 16/444 [00:01<00:52,  8.23it/s][A
  4%|▍         | 17/444 [00:02<00:52,  8.21it/s][A
  4%|▍         | 18/444 [00:02<00:51,  8.20it/s][A
  4%|▍         | 19/4

epoch loss,  4.5851655


100%|██████████| 444/444 [00:54<00:00,  8.18it/s]
 66%|██████▌   | 293/444 [00:35<00:18,  8.32it/s]

In [None]:
save_params_to_disk()

In [None]:
import gc
gc.collect()