In [1]:
%%capture
!pip install ipywidgets -q
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q
!pip install --upgrade jax -q 
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
!pip install "flax[all]" -q
!pip install --upgrade optax==0.2.2
!pip install --upgrade einops
!pip install --no-cache-dir transformers==4.43.3
!pip install --no-cache-dir datasets==2.18.0
!pip install --upgrade tqdm
!pip install --upgrade requests
!pip install --upgrade typing-extensions
!pip install --upgrade mlxu>=0.1.13
!pip install --upgrade sentencepiece
!pip install --upgrade pydantic
!pip install --upgrade fastapi
!pip install --upgrade uvicorn
!pip install --upgrade gradio


In [2]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [None]:
# Test if transformers lib is working correctly.
# from transformers import AutoModelForCausalLM
# AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", token="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")

In [3]:
import os
import sys
import importlib
import sys
import os
from types import SimpleNamespace
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

# Imports felafax trainer_engine
convert_hf_to_easylm = import_local_module("EasyLM.models.llama.convert_hf_to_easylm")

In [4]:
llama_model = import_local_module("EasyLM.models.llama.llama_model")

In [5]:
# Set up the arguments
args = SimpleNamespace(
    hf_model="meta-llama/Meta-Llama-3-8B",
    output_file="/mnt/persistent-disk/easy/easylm_format.easylm",
    streaming=False,
    float_dtype="bf16"
)


In [6]:
# Set up the FLAGS
convert_hf_to_easylm.FLAGS = args

# Set up the llama configuration
convert_hf_to_easylm.FLAGS.llama = llama_model.LLaMAConfigurator.get_default_config()
convert_hf_to_easylm.FLAGS.llama.base_model = "llama3_8b"

In [9]:
# Call the main function
convert_hf_to_easylm.main([])

INPUT: Please provide your HUGGINGFACE_TOKEN:  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


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Start convert weight to easylm format...
Convert weight to easylm format finished...
Start to save...
Save finished!!! take time: 233.940571308136 save path: /mnt/persistent-disk/easy/easylm_format.easylm


In [10]:
llama_train = import_local_module("EasyLM.models.llama.llama_train")

In [11]:
data = import_local_module("EasyLM.data")
optimizers = import_local_module("EasyLM.optimizers")
checkpoint = import_local_module("EasyLM.checkpoint")
jax_utils = import_local_module("EasyLM.jax_utils")

In [12]:
llama_config = llama_model.LLaMAConfigurator.get_default_config()
llama_config.base_model = "llama3_8b"

In [None]:
# {
#     'path': 'glue',         # specify the dataset path
#     'name': 'sst2',         # specify the dataset name
#     'split': 'train',       # specify the dataset split
#     'seq_length': 128,      # sequence length
#     'batch_size': 16        # batch size for training
# }

In [13]:
# Define default configurations for training and evaluation datasets
default_train_dataset_config = data.DatasetFactory.get_default_config()


In [None]:
# default_train_dataset_config.huggingface_dataset.path = "c4-en-10k"# "tiny_shakespeare"
# default_train_dataset_config.huggingface_dataset.seq_length = 64

In [14]:
default_train_dataset_config.text_processor.fields = "text"
default_train_dataset_config

huggingface_dataset:
  always_start_with_bos: false
  batch_size: 4
  batch_token_dtype: i4
  name: 20220301.en
  path: wikipedia
  seq_length: 512
  split: train
  streaming: true
json_dataset:
  always_start_with_bos: false
  batch_size: 8
  example_index_at_start: 0
  path: ''
  seq_length: 1024
  start_seek_loc: 0
  throughput_average_window_size: 200
  tokenizer_parallel_batch_size: 1024
  tokenizer_parallel_chunk_size: 32
  tokenizer_processes: 1
  tokens_count_at_start: 0
text_processor:
  add_bos_token: true
  add_eos_token: true
  base64_token_dtype: i4
  fields: text
  fields_from_example: ''
  prepend_text: ''
  subfield_separator: ' '
type: huggingface

In [15]:
import mlxu
train_args = SimpleNamespace(
    seed=42,
    mesh_dim='1,-1,1',
    dtype='fp32',
    param_dtype='fp32',
    total_steps=100,
    load_llama_config='',
    update_llama_config='',
    
    load_checkpoint='flax_params::/mnt/persistent-disk/easy/easylm_format.easylm',
    
    load_dataset_state='',
    log_freq=50,
    save_model_freq=0,
    save_milestone_freq=0,
    eval_steps=0,
    tokenizer='openlm-research/open_llama_3b_v2',
    train_dataset=default_train_dataset_config, # data.DatasetFactory.get_default_config(),
    # eval_dataset=data.DatasetFactory.get_default_config(),
    optimizer=optimizers.OptimizerFactory.get_default_config(),
    checkpointer=checkpoint.StreamingCheckpointer.get_default_config(),
    llama=llama_config,
    logger=mlxu.WandBLogger.get_default_config(),
    log_all_worker=False,
    jax_distributed=jax_utils.JaxDistributedConfig.get_default_config(),
)

In [16]:
llama_train.FLAGS = train_args

In [17]:
llama_train.main([])

tokenizer_config.json:   0%|          | 0.00/593 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/512k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/330 [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggin

Downloading builder script:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

  0% 0/100 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (9629 > 2048). Running this sequence through the model will result in indexing errors


step 0


  1% 1/100 [03:27<5:42:08, 207.36s/it]

loss 9.366291999816895
step 1


  2% 2/100 [06:54<5:38:07, 207.02s/it]

loss 9.418638229370117
step 2


  3% 3/100 [06:54<3:02:15, 112.74s/it]

loss 9.135272979736328
step 3


  4% 4/100 [06:55<1:49:30, 68.45s/it] 

loss 8.720643997192383
step 4


  5% 5/100 [06:55<1:09:37, 43.97s/it]

loss 9.008569717407227
step 5


  6% 6/100 [06:56<45:45, 29.21s/it]  

loss 8.753011703491211
step 6


  7% 7/100 [06:56<30:45, 19.84s/it]

loss 8.761460304260254
step 7


  8% 8/100 [06:57<21:00, 13.70s/it]

loss 9.192926406860352
step 8


  9% 9/100 [06:58<14:32,  9.59s/it]

loss 8.271177291870117
step 9


 10% 10/100 [06:58<10:12,  6.80s/it]

loss 8.075187683105469
step 10


 11% 11/100 [06:59<07:14,  4.89s/it]

loss 7.776384353637695
step 11


 12% 12/100 [06:59<05:14,  3.57s/it]

loss 7.786738395690918
step 12


 13% 13/100 [07:00<03:50,  2.65s/it]

loss 7.938724517822266
step 13


 14% 14/100 [07:00<02:54,  2.03s/it]

loss 8.161386489868164
step 14


 15% 15/100 [07:01<02:14,  1.59s/it]

loss 7.284797191619873
step 15


 16% 16/100 [07:01<01:46,  1.27s/it]

loss 7.311324119567871
step 16


 17% 17/100 [07:02<01:27,  1.06s/it]

loss 7.541593074798584
step 17


 18% 18/100 [07:03<01:14,  1.11it/s]

loss 6.971421241760254
step 18


 19% 19/100 [07:03<01:04,  1.26it/s]

loss 6.7012457847595215
step 19


 20% 20/100 [07:04<00:57,  1.39it/s]

loss 7.0383710861206055
step 20


 21% 21/100 [07:04<00:52,  1.49it/s]

loss 6.748800277709961
step 21


 22% 22/100 [07:05<00:49,  1.58it/s]

loss 6.932526588439941
step 22


 23% 23/100 [07:05<00:47,  1.62it/s]

loss 7.0236053466796875
step 23


 24% 24/100 [07:06<00:45,  1.68it/s]

loss 7.780677795410156
step 24


 25% 25/100 [07:06<00:43,  1.72it/s]

loss 7.769660949707031
step 25


 26% 26/100 [07:07<00:42,  1.75it/s]

loss 7.995835781097412
step 26


 27% 27/100 [07:07<00:41,  1.77it/s]

loss 7.894532680511475
step 27


 28% 28/100 [07:08<00:40,  1.79it/s]

loss 7.25313663482666
step 28


 29% 29/100 [07:09<00:40,  1.74it/s]

loss 7.151447296142578
step 29


 30% 30/100 [07:09<00:39,  1.77it/s]

loss 7.104001522064209
step 30


 31% 31/100 [07:10<00:38,  1.79it/s]

loss 7.240172863006592
step 31


 32% 32/100 [07:10<00:37,  1.80it/s]

loss 7.0530924797058105
step 32


 33% 33/100 [07:11<00:37,  1.81it/s]

loss 7.373514175415039
step 33


 34% 34/100 [07:11<00:36,  1.81it/s]

loss 7.1229047775268555
step 34


 35% 35/100 [07:12<00:35,  1.82it/s]

loss 6.960782527923584
step 35


 36% 36/100 [07:12<00:35,  1.82it/s]

loss 6.654844284057617
step 36


 37% 37/100 [07:13<00:34,  1.82it/s]

loss 6.991496562957764
step 37


 38% 38/100 [07:14<00:34,  1.78it/s]

loss 7.057541847229004
step 38


 39% 39/100 [07:14<00:34,  1.79it/s]

loss 7.590811729431152
step 39


 40% 40/100 [07:15<00:33,  1.80it/s]

loss 7.433245658874512
step 40


 41% 41/100 [07:15<00:32,  1.81it/s]

loss 7.381690979003906
step 41


 42% 42/100 [07:16<00:31,  1.81it/s]

loss 7.5280351638793945
step 42


 43% 43/100 [07:16<00:31,  1.82it/s]

loss 7.300105094909668
step 43


 44% 44/100 [07:17<00:30,  1.82it/s]

loss 7.3164825439453125
step 44


 45% 45/100 [07:17<00:30,  1.82it/s]

loss 7.269549369812012
step 45


 46% 46/100 [07:18<00:29,  1.82it/s]

loss 7.555338382720947
step 46


 47% 47/100 [07:19<00:29,  1.82it/s]

loss 7.693605899810791
step 47


 48% 48/100 [07:19<00:29,  1.79it/s]

loss 7.17784309387207
step 48


 49% 49/100 [07:20<00:28,  1.80it/s]

loss 7.658641815185547
step 49


 50% 50/100 [07:20<00:27,  1.81it/s]

loss 7.560274124145508
step 50


 51% 51/100 [07:21<00:27,  1.81it/s]

loss 7.559930801391602
step 51


 52% 52/100 [07:21<00:26,  1.82it/s]

loss 7.7722272872924805
step 52


 53% 53/100 [07:22<00:25,  1.81it/s]

loss 7.751843452453613
step 53


 54% 54/100 [07:22<00:25,  1.81it/s]

loss 9.026155471801758
step 54


 55% 55/100 [07:23<00:24,  1.82it/s]

loss 6.9355292320251465
step 55


 56% 56/100 [07:24<00:24,  1.82it/s]

loss 9.333514213562012
step 56


 57% 57/100 [07:24<00:23,  1.82it/s]

loss 7.858311176300049
step 57


 58% 58/100 [07:25<00:23,  1.79it/s]

loss 9.67912769317627
step 58


 59% 59/100 [07:25<00:22,  1.80it/s]

loss 9.07669448852539
step 59


 60% 60/100 [07:26<00:22,  1.81it/s]

loss 8.69036865234375
step 60


 61% 61/100 [07:26<00:21,  1.82it/s]

loss 9.00741195678711
step 61


 62% 62/100 [07:27<00:20,  1.82it/s]

loss 8.739542007446289
step 62


 63% 63/100 [07:27<00:20,  1.79it/s]

loss 8.325194358825684
step 63


 64% 64/100 [07:28<00:19,  1.80it/s]

loss 8.12378978729248
step 64


 65% 65/100 [07:28<00:19,  1.81it/s]

loss 7.98336124420166
step 65


 66% 66/100 [07:29<00:18,  1.82it/s]

loss 7.905200481414795
step 66


 67% 67/100 [07:30<00:18,  1.82it/s]

loss 7.280104637145996
step 67


 68% 68/100 [07:30<00:17,  1.81it/s]

loss 7.963953018188477
step 68


 69% 69/100 [07:31<00:17,  1.77it/s]

loss 7.769622802734375
step 69


 70% 70/100 [07:31<00:16,  1.79it/s]

loss 8.638200759887695
step 70


 71% 71/100 [07:32<00:16,  1.80it/s]

loss 8.499345779418945
step 71


 72% 72/100 [07:32<00:15,  1.81it/s]

loss 7.809869766235352
step 72


 73% 73/100 [07:33<00:14,  1.81it/s]

loss 7.954878807067871
step 73


 74% 74/100 [07:33<00:14,  1.82it/s]

loss 7.94611120223999
step 74


 75% 75/100 [07:34<00:13,  1.82it/s]

loss 7.655534744262695
step 75


 76% 76/100 [07:35<00:13,  1.82it/s]

loss 7.470763206481934
step 76


 77% 77/100 [07:35<00:12,  1.82it/s]

loss 7.673531532287598
step 77


 78% 78/100 [07:36<00:12,  1.83it/s]

loss 7.763275623321533
step 78


 79% 79/100 [07:36<00:11,  1.81it/s]

loss 8.179311752319336
step 79


 80% 80/100 [07:37<00:11,  1.81it/s]

loss 8.064218521118164
step 80


 81% 81/100 [07:37<00:10,  1.82it/s]

loss 7.785511493682861
step 81


 82% 82/100 [07:38<00:10,  1.78it/s]

loss 7.789140701293945
step 82


 83% 83/100 [07:38<00:09,  1.80it/s]

loss 8.009418487548828
step 83


 84% 84/100 [07:39<00:08,  1.81it/s]

loss 8.09515380859375
step 84


 85% 85/100 [07:40<00:08,  1.81it/s]

loss 7.935939788818359
step 85


 86% 86/100 [07:40<00:07,  1.82it/s]

loss 11.263519287109375
step 86


 87% 87/100 [07:41<00:07,  1.81it/s]

loss 9.048345565795898
step 87


 88% 88/100 [07:41<00:06,  1.79it/s]

loss 8.921339988708496
step 88


 89% 89/100 [07:42<00:06,  1.80it/s]

loss 9.389142990112305
step 89


 90% 90/100 [07:42<00:05,  1.81it/s]

loss 8.924957275390625
step 90


 91% 91/100 [07:43<00:04,  1.81it/s]

loss 8.497274398803711
step 91


 92% 92/100 [07:43<00:04,  1.82it/s]

loss 8.87980842590332
step 92


 93% 93/100 [07:44<00:03,  1.82it/s]

loss 7.975639343261719
step 93


 94% 94/100 [07:45<00:03,  1.82it/s]

loss 8.706695556640625
step 94


 95% 95/100 [07:45<00:02,  1.82it/s]

loss 8.070947647094727
step 95


 96% 96/100 [07:46<00:02,  1.80it/s]

loss 8.223447799682617
step 96


 97% 97/100 [07:46<00:01,  1.81it/s]

loss 9.000728607177734
step 97


 98% 98/100 [07:47<00:01,  1.81it/s]

loss 8.879977226257324
step 98


 99% 99/100 [07:47<00:00,  1.82it/s]

loss 8.699119567871094
step 99


100% 100/100 [07:48<00:00,  4.68s/it]

loss 8.712419509887695



