In [2]:
!pip install --upgrade transformers -q

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.core.xla_model as xm

# Global constants
EPOCHS = 4
WARMUP_STEPS = 2
BATCH_SIZE = 32

# Create dummy MNIST data
def create_dummy_mnist_data(num_samples=1000):
    # MNIST images are 28x28 pixels, labels are 0-9
    dummy_data = torch.randn(num_samples, 1, 28, 28)
    dummy_labels = torch.randint(0, 10, (num_samples,))
    return dummy_data, dummy_labels

class DummyDataLoader:
    def __init__(self, data, labels, batch_size):
        self.data = data
        self.labels = labels
        self.batch_size = batch_size
        self.num_samples = len(data)
        self.current_idx = 0
        
    def __iter__(self):
        self.current_idx = 0
        return self
        
    def __next__(self):
        if self.current_idx >= self.num_samples:
            raise StopIteration
            
        end_idx = min(self.current_idx + self.batch_size, self.num_samples)
        batch_data = self.data[self.current_idx:end_idx]
        batch_labels = self.labels[self.current_idx:end_idx]
        self.current_idx = end_idx
        
        return batch_data, batch_labels
        
    def __len__(self):
        return (self.num_samples + self.batch_size - 1) // self.batch_size

# Declare 3-layer MLP for MNIST dataset
class MLP(nn.Module):
    def __init__(self, input_size=28*28, output_size=10, layers=[120, 84]):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, layers[0])
        self.fc2 = nn.Linear(layers[0], layers[1])
        self.fc3 = nn.Linear(layers[1], output_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

def main():
    # Create dummy data
    dummy_data, dummy_labels = create_dummy_mnist_data()
    
    # Create dummy data loader
    train_loader = DummyDataLoader(dummy_data, dummy_labels, BATCH_SIZE)
    import pdb; pdb.set_trace()
    
    # Fix the random number generator seeds for reproducibility
    torch.manual_seed(0)
    
    # XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)
    device = 'xla'
    
    # Move model to device and declare optimizer and loss function
    model = MLP().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = torch.nn.NLLLoss()
    
    # Run the training loop
    print('----------Training ---------------')
    model.train()
    for epoch in range(EPOCHS):
        start = time.time()
        for idx, (train_x, train_label) in enumerate(train_loader):
            optimizer.zero_grad()
            train_x = train_x.view(train_x.size(0), -1)
            train_x = train_x.to(device)
            train_label = train_label.to(device)
            output = model(train_x)
            loss = loss_fn(output, train_label)
            loss.backward()
            optimizer.step()
            xm.mark_step()  # XLA: collect ops and run them in XLA runtime
            if idx < WARMUP_STEPS:  # skip warmup iterations
                start = time.time()
    
        # Compute statistics for the last epoch
        interval = idx - WARMUP_STEPS  # skip warmup iterations
        throughput = interval / (time.time() - start)
        print(f"Train throughput (iter/sec): {throughput}")
        print(f"Final loss is {loss.detach().to('cpu'):0.4f}")
    
    # Save checkpoint for evaluation
    os.makedirs("checkpoints", exist_ok=True)
    checkpoint = {'state_dict': model.state_dict()}
    xm.save(checkpoint, 'checkpoints/checkpoint.pt')
    
    print('----------End Training ---------------')

main()

> [0;32m/tmp/ipykernel_13040/1057427633.py[0m(69)[0;36mmain[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m    [0;31m# Fix the random number generator seeds for reproducibility[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 69 [0;31m    [0mtorch[0m[0;34m.[0m[0mmanual_seed[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m    [0;31m# XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  next(iter(train_loader))


> [0;32m/tmp/ipykernel_13040/1057427633.py[0m(72)[0;36mmain[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m    [0;31m# XLA: Specify XLA device (defaults to a NeuronCore on Trn1 instance)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m    [0mdevice[0m [0;34m=[0m [0;34m'xla'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m[0;34m[0m[0m
[0m[0;32m     74 [0;31m    [0;31m# Move model to device and declare optimizer and loss function[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  x = next(iter(train_loader))
ipdb>  x


(tensor([[[[-2.1191,  0.0776,  3.4171,  ...,  0.7731, -0.6935, -0.8865],
          [ 0.7186, -0.0556,  0.3989,  ..., -1.1913, -0.3478,  0.7789],
          [-0.4001,  0.1024, -0.1736,  ..., -0.9343,  0.7738,  0.6558],
          ...,
          [ 1.2527,  0.4576, -1.2948,  ..., -2.0772,  0.2891,  1.7844],
          [-1.1152, -0.8433,  0.2706,  ..., -1.4534, -1.0795, -0.4402],
          [ 1.3365,  0.7099,  1.3206,  ..., -0.8835, -1.4699,  1.8144]]],


        [[[ 0.6372,  1.1949,  2.0840,  ...,  0.4330,  0.3890, -0.6343],
          [-1.2200, -2.3117,  0.7436,  ..., -1.0024, -1.8883,  0.8620],
          [-0.9815, -1.9794,  0.1981,  ...,  0.4779, -0.8568, -0.5031],
          ...,
          [ 2.3262,  0.9625,  2.0861,  ...,  1.3275,  0.1223, -0.3024],
          [-0.9390,  0.8797, -0.6360,  ..., -0.6617, -0.6971, -0.5100],
          [-0.0382, -0.5070,  1.0390,  ..., -0.2144, -0.2371,  0.5000]]],


        [[[-1.4908, -0.3226,  1.1499,  ...,  0.4864, -1.3301,  0.6620],
          [-0.8348,  1.81

ipdb>  c


----------Training ---------------
2024-12-01 03:59:46.000119:  13040  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-12-01 03:59:46.000121:  13040  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/45b0d0d6-e9f1-4140-9abd-0955882e46a2/model.MODULE_4027976367933384269+d7517139.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/45b0d0d6-e9f1-4140-9abd-0955882e46a2/model.MODULE_4027976367933384269+d7517139.neff --verbose=35
..
Compiler status PASS
2024-12-01 04:00:08.000856:  13040  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-12-01 04:00:08.000857:  13040  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/74c80928-6ed8-440c-89b3-a6c1cec7dfa3/model.MODULE_9384584493755895831+d7517139.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/

In [3]:
import os
os.environ['USE_TORCH'] = 'True'  # To use transformers library in TPU
os.environ['XLA_USE_BF16'] = 'True'
os.environ['PJRT_DEVICE'] = 'TPU'

In [14]:
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import load_file

import logging
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)

import os
import contextlib
from dataclasses import dataclass

In [4]:
import os
import contextlib
from dataclasses import dataclass

import torch
import numpy as np
import torch.distributed as dist
import torch_xla.core.xla_model as xm
# import torch_xla.runtime as xr
# xr.use_spmd()

# # import torch_xla.distributed.spmd as xs
# from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
# from torch_xla.experimental.xla_sharding import Mesh

# # import torch_xla.distributed.xla_multiprocessing as xmp
# import torch_xla.distributed.parallel_loader as pl
# import torch_xla.test.test_utils as test_utils

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
from peft import LoraConfig, TaskType, get_peft_model
from safetensors.torch import load_file

import logging
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)

In [5]:
assert xr.is_spmd()==True

In [6]:
import sys
import importlib
sys.path.append('')
# model_partitioning = importlib.import_module('trainer_lib.model_partitioning')
# importlib.reload(model_partitioning)

In [2]:
# This notebook can be used to train any of the 7B, 8B models. Check out the 80B notebook to train bigger model.
supported_models = [
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "meta-llama/Meta-Llama-3.1-8B",
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-2-13b-hf",
    "TinyLlama/TinyLlama-1.1B-step-50K-105b",
]

# Select a supported model from above list to use!
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
HUGGINGFACE_TOKEN = input("Please provide your HUGGINGFACE_TOKEN: ") # YOUR_HF_TOKEN

Please provide your HUGGINGFACE_TOKEN:  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


# Configure LoRA config for your model.
Use the below code to configure the LoRA config for your model.

In [3]:
def apply_lora(*, model, lora_rank=None, lora_alpha=None, lora_dropout=None):
    """Applies LoRA configuration to the model."""
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8 if not lora_rank else lora_rank,
        lora_alpha=32 if not lora_alpha else lora_alpha,
        lora_dropout=0.1 if not lora_dropout else lora_dropout,
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model

In [4]:
def init_model(*, model_name, hugging_face_token):
    """Downloads and initializes the model."""
    config = AutoConfig.from_pretrained(
        model_name, 
        token=hugging_face_token)
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        token=hugging_face_token
    )

    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
        config.pad_token_id = tokenizer.pad_token_id
        
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        token=hugging_face_token,
        low_cpu_mem_usage=True
    )

    # model = apply_lora(
    #     model=model,
    #     lora_rank=TRAINER_CONFIG["lora_rank"],
    #     lora_alpha=TRAINER_CONFIG["lora_alpha"],
    #     lora_dropout=TRAINER_CONFIG["lora_dropout"],
    # )

    return model, tokenizer

In [5]:
def apply_spmd(*, model, mesh):
    # Apply on layers within model.
    model_partitioning_util.partition_model(model, mesh)

# Configure dataset pipeline for your model

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [6]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    # Tokenize the dataset.
    def _tokenize(examples):
        # Tokenized is list within list. Compute labels for causalLM by shifting input_id; 
        # consequently truncate input_id to penultimate position.
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        labels = tokenized['input_ids'].copy()
        tokenized['labels'] = [label[1:] for label in labels]
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        return tokenized

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=default_data_collator,
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=batch_size,
        collate_fn=default_data_collator,
    )

    return train_dataloader, test_dataloader

# Train the model

Now let's train the model. We are using PyTorch XLA's Fully Sharded Data Parallel (FSDP) to distribute the model across the 8 TPU cores available on TPU v3-8. This approach allows for efficient training on TPU hardware. We also utilize PyTorch/XLA's MpDeviceLoader to efficiently load data onto the TPU cores.

**NOTE:** It's important to note that the **first step of training will be slow**. This is because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using compiled+cached graph, and leveraging the full power of the all TPU cores for accelerated training.


In [8]:
def print_training_update(step,
                          loss,
                          epoch=None,
                         ):
    """Prints the training metrics at a given step."""
    if xm.is_master_ordinal():  # Only print on the master device
        update_data = [
            'Training',
            f'Epoch={epoch}' if epoch is not None else 0,
            f'Step={step}',
            f'Loss={loss:.5f}',
        ]
        print(' | '.join(item for item in update_data if item), flush=True)
        print()


In [10]:
model, tokenizer = init_model(
        model_name=MODEL_NAME, hugging_face_token=HUGGINGFACE_TOKEN
)

In [17]:
print(trainer_config)

TrainerConfig(epochs=1, batch_size=32, seq_length=64, learning_rate=0.0001, max_steps=100, max_examples=None, print_every_n_steps=5, lora_rank=8, lora_alpha=32, lora_dropout=0.1)


In [26]:
torch.manual_seed(99)
device = xm.xla_device()
model = model.to(device)


In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=trainer_config.learning_rate)

In [28]:
train_dataloader, test_dataloader = get_dataset(
    tokenizer=tokenizer,
    batch_size=trainer_config.batch_size,
    seq_length=trainer_config.seq_length,
    max_examples=trainer_config.max_examples,
)
# train_dataloader = pl.MpDeviceLoader(
#     train_dataloader, 
#     device
# )

# test_dataloader = pl.MpDeviceLoader(
#     test_dataloader, 
#     device
# )

Map:   0%|          | 0/43996 [00:00<?, ? examples/s]

Map:   0%|          | 0/7764 [00:00<?, ? examples/s]

In [29]:
batch = next(iter(train_dataloader))

In [25]:
type(labels)

torch.Tensor

In [30]:
optimizer.zero_grad()

input_ids, attention_mask, labels = (
    batch["input_ids"],
    batch["attention_mask"],
    batch["labels"],
)

In [32]:
output = model.forward(
    input_ids=input_ids.to("xla"), 
    attention_mask=attention_mask.to("xla"), 
    labels=labels.to("xla") if labels is not None else None
)

In [33]:
loss = output.loss

In [35]:
print(loss)

tensor(10.9770, device='xla:0', grad_fn=<NllLossBackward0>)


In [None]:
output = model(
    input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
loss = output.loss
# loss.backward()

In [51]:
def train(index):
    global model, tokenizer, trainer_config

    print(trainer_config)
    
    torch.manual_seed(99)
    device = xm.xla_device()
    model = model.to(device)
    
    # Create a mesh for the model partitioning.
    # num_devices = xr.global_runtime_device_count()
    # mesh_shape = (1, num_devices, 1)
    # device_ids = np.array(range(num_devices))
    # mesh = Mesh(device_ids, mesh_shape, ("dp", "fsdp", "mp"))
        
    # Partition the model using SPMD.
    # model_partitioning.partition_model(model=model, mesh=mesh)
    
    # Configure the training loop.
    optimizer = torch.optim.Adam(model.parameters(), lr=trainer_config.learning_rate)
    
    train_dataloader, test_dataloader = get_dataset(
        tokenizer=tokenizer,
        batch_size=trainer_config.batch_size,
        seq_length=trainer_config.seq_length,
        max_examples=trainer_config.max_examples,
    )
    # train_dataloader = pl.MpDeviceLoader(
    #     train_dataloader, 
    #     device
    # )
    
    # test_dataloader = pl.MpDeviceLoader(
    #     test_dataloader, 
    #     device
    # )
    
    should_break = False
    for epoch in range(trainer_config.epochs):
        print(f"Epoch {epoch} train begin..")
        tracker = xm.RateTracker()
        
        model.train()
        for step, batch in enumerate(train_dataloader):
            if trainer_config.max_steps is not None and step > trainer_config.max_steps:
                should_break = True
                break
            
            optimizer.zero_grad()
            
            input_ids, attention_mask, labels = (
                batch["input_ids"],
                batch["attention_mask"],
                batch["labels"],
            )
            input_ids=input_ids.to(device)
            attention_mask=attention_mask.to(device) 
            labels=labels.to(device) if labels is not None else None
            # xs.mark_sharding(input_ids, mesh, (0, 1))
            # xs.mark_sharding(attention_mask, mesh, (0, 1))
            # xs.mark_sharding(labels, mesh, (0, 1))
            
            output = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            loss = output.loss
            loss.backward()
            
            optimizer.step()
            # xm.mark_step()
            
            if step % trainer_config.print_every_n_steps == 0:
                loss_cpu = loss.item()
                xm.add_step_closure(
                    print_training_update,
                    args=(step, loss_cpu, epoch)
                )
        
        # UNCOMMENT BELOW TO RUN EVAL.
        # model.eval()
        # eval_loss = 0
        # with torch.no_grad():
        #     for step, batch in enumerate(test_dataloader):
        #         input_ids, attention_mask, labels = (
        #             batch["input_ids"],
        #             batch["attention_mask"],
        #             batch["labels"],
        #         )
        #         xs.mark_sharding(input_ids, mesh, (0, 1))
        #         xs.mark_sharding(attention_mask, mesh, (0, 1))
        #         xs.mark_sharding(labels, mesh, (0, 1))
        
        #         output = model(
        #             input_ids=input_ids, attention_mask=attention_mask, labels=labels
        #         )
        #         eval_loss += output.loss.item()
        # avg_eval_loss = eval_loss / len(test_dataloader)
        # xm.add_step_closure(
        #     lambda: print(f"Eval loss: {avg_eval_loss:.4f}"),
        # )
        if should_break:
            break
    result = {'device': xm.get_ordinal(), 'loss': loss.item()}
    return result

In [52]:
@dataclass
class TrainerConfig:    
    epochs: int = 1
    batch_size: int = 32
    seq_length: int = 64
    
    learning_rate: float = 1e-4

    max_steps: int | None = 100
    max_examples: int| None = None
    
    print_every_n_steps: int = 5
    
    lora_rank: int = 8
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    
trainer_config = TrainerConfig()

In [None]:
train(00)

TrainerConfig(epochs=1, batch_size=32, seq_length=64, learning_rate=0.0001, max_steps=100, max_examples=None, print_every_n_steps=5, lora_rank=8, lora_alpha=32, lora_dropout=0.1)


Map:   0%|          | 0/43996 [00:00<?, ? examples/s]

Map:   0%|          | 0/7764 [00:00<?, ? examples/s]

Epoch 0 train begin..
2024-12-01 04:39:20.000726:  13040  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-12-01 04:39:20.000745:  13040  INFO ||NEURON_CC_WRAPPER||: Using a cached neff at /var/tmp/neuron-compile-cache/neuronxcc-2.15.128.0+56dc5a86/MODULE_4793052177575340561+d7517139/model.neff. Exiting with a successfully compiled graph.
2024-12-01 04:39:38.000951:  13040  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-12-01 04:39:38.000956:  13040  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/1ad304dd-ddc9-412b-a3b0-7cbd8cf2847a/model.MODULE_15522692757083607861+d7517139.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/1ad304dd-ddc9-412b-a3b0-7cbd8cf2847a/model.MODULE_15522692757083607861+d7517139.neff --verbose=35
............................................................

In [None]:
# %%timeit
# import time; start_time = time.time()
# try:
#     xmp.spawn(train, args=(), start_method="fork")
# except Exception as e:
#     # Catch the expected error of obtaining results from multiple TPU chips when starting distributed training from a notebook.
#     print()

# end_time = time.time()
# elapsed_time = end_time - start_time
# print(f"Execution time: {elapsed_time:.4f} seconds")

# Export the model to HuggingFace Hub
Uncoment the following cell to push the model to HuggingFace Hub.

In [None]:
HUGGINGFACE_USERNAME = input("Please provide your HUGGINGFACE_USERNAME: ")

model = model.cpu()
merged_model = model.merge_and_unload()

print("Uncomment below code if you want to upload to HF.")
# print("Uploading to HF...")
# merged_model.push_to_hub(
#     f"{HUGGINGFACE_USERNAME}/felafax-llama3-finetuned",  # repo name
#     tokenizer=tokenizer,
#     private=False,
#     create_pr=False,
#     max_shard_size="2GB",
#     token=HUGGINGFACE_TOKEN,
# )