In [2]:
# import necessary packages
import sys
import torch 
import numpy as np
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import tqdm
from peft import get_peft_model, LoraConfig, TaskType
from functools import partial
from importlib import reload
from transformers import pipeline as pipe
from transformers import (
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from IPython.display import clear_output

sys.path.append('../')

# custom modules
import utils.preprocessing as pp

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B"
dataset_path = "allenai/peS2o"

# Training

In [4]:
def train_model(model_path, dataset_path):

    # for distributed training
    accelerator = Accelerator()

    # load dataset
    raw_dataset = load_dataset(dataset_path, "v2", streaming=True, trust_remote_code=True)

    ## MODEL LOADING
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
    )

    # load tokenizer and model
    pipeline = pipe('text-generation', 
                        model=model_path,
                        model_kwargs={'torch_dtype': torch.bfloat16},
                        device_map = accelerator.device 
                        )

    pipeline.model = get_peft_model(pipeline.model, peft_config)
    pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
    pipeline.tokenizer.pad_token_id = pipeline.tokenizer.eos_token_id
    pipeline.model.generation_config.pad_token_id = pipeline.tokenizer.eos_token_id


    ## PREPROCESSING
    # add special tokens to tokenizer
    pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
    pipeline.model.resize_token_embeddings(len(pipeline.tokenizer))

    tokenize_fn = partial(pp.tokenize_data, 
                        type = "nextchar",
                        pipeline_name = pipeline,
                        max_length = 100)

    tokenized_dataset = raw_dataset.map(tokenize_fn,
                                        batched=True,
                                        remove_columns=raw_dataset['train'].column_names,)
    tokenized_dataset.with_format("torch")

    # instantiate data collator
    data_collator = DataCollatorWithPadding(tokenizer=pipeline.tokenizer)

    train_dataloader = DataLoader(tokenized_dataset['train'],
                                batch_size=8, 
                                collate_fn=data_collator,
                                num_workers=20)

    val_dataloader = DataLoader(tokenized_dataset['validation'],
                            batch_size=8,
                            collate_fn=data_collator,
                            num_workers=2)
    
    ## TRAINING

    # options
    num_batches = 10
    num_epochs = 1
    checkpoint_path = '../checkpoints/checkpoint_8b_{0}epochs.pt'
    log_path = '../logs/log.csv'

    # init optimizer
    optimizer = AdamW(pipeline.model.parameters(), lr=1e-5)

    # init scheduler
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=1000,
        num_training_steps=num_epochs * num_batches,
    )

    pipeline.model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        pipeline.model, optimizer, train_dataloader, val_dataloader, lr_scheduler)

    # init parameters
    best_val_loss = np.inf

    with open(log_path, 'w') as f: 
        f.write(f'epoch,iter_num,train_loss,val_loss\n')

    # loop
    for epoch in range(num_epochs):
        
        clear_output(wait=False)

        running_train_loss = 0.0
        running_val_loss = 0.0

        accelerator.print("=====================")
        accelerator.print(f"Epoch {epoch + 1}")
        accelerator.print("=====================")

        # loop through train data
        accelerator.print("Training...")
        with tqdm(range(num_batches), disable=not accelerator.is_local_main_process) as pbar:
            for i, (train_batch, val_batch) in enumerate(zip(train_dataloader, val_dataloader)):
                
                ## training
                # set model to train mode
                pipeline.model.train()

                # grab batch and map to device
                train_batch = {k: v.to(accelerator.device) for k, v in train_batch.items()}

                # forward pass
                outputs = pipeline.model(train_batch['input_ids'], 
                                        labels=train_batch['input_ids'],
                                        attention_mask=train_batch['attention_mask'])
                train_loss = outputs.loss

                running_train_loss += train_loss.item()

                # backward pass
                # train_loss.backward()
                accelerator.backward(train_loss)

                # clip gradients
                torch.nn.utils.clip_grad_norm_(pipeline.model.parameters(), 1.0)

                # update optimizer, scheduler
                optimizer.step()
                lr_scheduler.step()

                # zero gradients
                optimizer.zero_grad()
                
                ## validation
                # set model to eval mode
                pipeline.model.eval()
                # loop through val data
                val_batch = {k: v.to(accelerator.device) for k, v in val_batch.items()}
                with torch.no_grad():
                    outputs = pipeline.model(val_batch['input_ids'], 
                                            labels=val_batch['input_ids'],
                                            attention_mask=val_batch['attention_mask'])
                    val_loss = outputs.loss
                    running_val_loss += val_loss.item()
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                
                accelerator.print(f"Train Batch Loss: {train_loss:.4f} | Val Batch Loss: {val_loss:.4f} | Best Val. Loss: {best_val_loss:.4f}\r", end="")

                pbar.update(1)
                
                # write to log
                with open(log_path, 'a') as f: 
                    f.write(f'{epoch},{i},{train_loss},{val_loss}\n')
                
                if i == num_batches:
                    accelerator.print(f"Reached {num_batches} batches; starting next epoch...")
                    
                    # break out of batching loop
                    break
        
        train_loss = running_train_loss / num_batches
        val_loss = running_val_loss / num_batches
        train_loss = running_train_loss / num_batches
        accelerator.print(f"Avg. Train Loss: {train_loss:.4f}, Avg. Val Loss: {val_loss:.4f}")

    accelerator.print(f"Saving model checkpoint to {checkpoint_path.format(epoch)}")
    # save model checkpoint
    checkpoint = {'model': pipeline.model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_val_loss': best_val_loss,
                }
    accelerator.save(checkpoint, checkpoint_path.format(epoch))

    accelerator.print("Training Complete!")

    # save model to hub after training
    pipeline.model.push_to_hub(f"Semantic-Scholar-{model_path.split('/')[-1]}")
    pipeline.tokenizer.push_to_hub(f"Semantic-Scholar-{model_path.split('/')[-1]}")

In [5]:
# accelerate notebook launcher to run
args = (model_path, dataset_path)
notebook_launcher(train_model, args = args, num_processes=2)

Epoch 1
Training...


  0%|          | 0/10 [00:00<?, ?it/s]

W1010 15:22:31.582478 140610585851712 torch/multiprocessing/spawn.py:146] Terminating process 393007 via signal SIGTERM
W1010 15:23:01.614235 140610585851712 torch/multiprocessing/spawn.py:154] Unable to shutdown process 393007 via SIGTERM , forcefully exiting via SIGKILL
E1010 15:23:01.620146 140610585851712 torch/distributed/elastic/multiprocessing/api.py:702] failed (exitcode: 1) local_rank: 1 (pid: 393009) of fn: train_model (start_method: fork)
E1010 15:23:01.620146 140610585851712 torch/distributed/elastic/multiprocessing/api.py:702] Traceback (most recent call last):
E1010 15:23:01.620146 140610585851712 torch/distributed/elastic/multiprocessing/api.py:702]   File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 659, in _poll
E1010 15:23:01.620146 140610585851712 torch/distributed/elastic/multiprocessing/api.py:702]     self._pc.join(-1)
E1010 15:23:01.620146 140610585851712 torch/distributed/elastic/multiproc

ChildFailedError: 
============================================================
train_model FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-10_15:22:31
  host      : bsd-a100
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 393009)
  error_file: /tmp/torchelastic_5ol151nh/none_dn7ruawi/attempt_0/1/error.json
  traceback : Traceback (most recent call last):
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
      return f(*args, **kwargs)
    File "/tmp/ipykernel_373195/2832777836.py", line 109, in train_model
      outputs = pipeline.model(train_batch['input_ids'],
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
      return forward_call(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1636, in forward
      else self._run_ddp_forward(*inputs, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1454, in _run_ddp_forward
      return self.module(*inputs, **kwargs)  # type: ignore[index]
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
      return forward_call(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/peft/peft_model.py", line 1577, in forward
      return self.base_model(
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
      return forward_call(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
      return self.model.forward(*args, **kwargs)
    File "/mnt/DGX01/Personal/krusepi/.venv/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1161, in forward
      logits = logits.float()
  torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 392.00 MiB. GPU 1 has a total capacity of 39.50 GiB of which 317.81 MiB is free. Process 305467 has 17.94 GiB memory in use. Including non-PyTorch memory, this process has 21.23 GiB memory in use. Of the allocated memory 19.82 GiB is allocated by PyTorch, and 180.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
  
============================================================

In [None]:
# run a test prediction
outputs = pipe(
    text,
    max_new_tokens=1024,
    eos_token_id=terminators,
    no_repeat_ngram_size=3,       
    do_sample=True, 
    top_k=100, 
    top_p=0.9,
    temperature=0.6
)
print(outputs[0][0]['generated_text'])