In [34]:
import os
import json
from pathlib import Path
import copy
import warnings
from collections import defaultdict
from collections import OrderedDict
from collections import Counter
from tqdm import tqdm
import datetime

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from torch.utils.data import DistributedSampler

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoConfig
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import GenerationConfig

from huggingface_hub import snapshot_download
from huggingface_hub import login
from safetensors.torch import load_file


## Data Import

In [726]:
# Data import

# Data parameters

sft_filename = 'llm_processed_tickets_df.json'
cpt_filename = 'sage_docs.json'
context_filename = 'embedded_context_df.json'
summary_filename = 'llm_summarized_tickets_df.json'
data_folder = Path.cwd().joinpath('Data')

context_embedding_cols = ['Query embedding', 'Answer embedding', 'context_sims']
contex_dtypes = {'Query' : str, 'Answer' : str,
                 'Query embedding' : list, 'Answer embedding' : list,
                 'context_ids' : list, 'context_sims' : list}

# Import sage documents

with open(data_folder.joinpath(cpt_filename)) as f:
    sage_docs = json.load(f)
    
sage_docs = \
    [doc['info'] + '\n\n'
         + doc['metadata'].get('Keywords', '') + ('\n\n' if doc['metadata'].get('Keywords', '') else '')
         + doc['metadata'].get('Product', '') + ('\n\n' if doc['metadata'].get('Product', '') else '')
         + doc['title'] + '\n\n' 
         + doc['content'] 
             for doc in sage_docs]

# Import context dataset

context_tickets_df = pd.read_json(data_folder.joinpath(context_filename), 
                                  orient = 'index', typ = 'frame', 
                                  dtype = contex_dtypes, precise_float = True)

for col in context_embedding_cols:
    context_tickets_df[col] = context_tickets_df[col].apply(lambda x: np.array(x, dtype = np.float32))

# Import SFT dataset

processed_tickets_df = pd.read_json(data_folder.joinpath(sft_filename), 
                                    orient = 'index', typ = 'frame', 
                                    dtype = str, precise_float = True)

processed_tickets_df.drop(columns = ['PROBLEM', 'SOLUTION', 'STRUCTUREDSOLUTION'], inplace = True)
processed_tickets_df.rename(columns = {'STRUCTUREDPROBLEM' : 'PROBLEM', 'SUMMARIZEDSOLUTION' : 'SOLUTION'}, inplace = True)

# Import and integrate ticket history summary dataset

summarized_tickets_df = pd.read_json(data_folder.joinpath(summary_filename), 
                                     orient = 'index', typ = 'frame', 
                                     dtype = str, precise_float = True)

summarized_tickets_df.drop(columns = ['PROBLEM', 'SOLUTION', 'STRUCTUREDPROBLEM', 'STRUCTUREDSOLUTION'], inplace = True)
summarized_tickets_df.set_index('TICKETID', drop = True, inplace = True)

context_tickets_df = context_tickets_df.join(summarized_tickets_df)
context_tickets_df['Summarized Answer'] = \
    context_tickets_df['Answer'].str.split('\n\nTicket status history:\n\n').str[0] \
        + '\n\nActivities description:\n\n' \
        + context_tickets_df['SUMMARIZEDSOLUTION']

# Sort both datasets

context_tickets_df.sort_index(inplace = True)
processed_tickets_df.sort_values('TICKETID', ignore_index = True, inplace = True)

# Display the dataset

processed_tickets_df.head()


Unnamed: 0,TICKETID,PROBLEM,SOLUTION
0,t6UJ9A00EVYR,Ticket metadata:\n\n Ticket ID: t6UJ9A00EVYR\...,[Urgency: 1]\n\n**Issue Summary** \nThe user ...
1,t6UJ9A00EVZA,Ticket metadata:\n\n Ticket ID: t6UJ9A00EVZA\...,[Urgency: 1]\n\n**Issue Summary** \nA daily c...
2,t6UJ9A00EW08,Ticket metadata:\n\n Ticket ID: t6UJ9A00EW08\...,[Urgency: 3]\n\n**Issue Summary** \nA user is...
3,t6UJ9A00EW0A,Ticket metadata:\n\n Ticket ID: t6UJ9A00EW0A\...,[Urgency: 3]\n\n**Issue Summary** \nThe custo...
4,t6UJ9A00EW0K,Ticket metadata:\n\n Ticket ID: t6UJ9A00EW0K\...,[Urgency: 3]\n\n**Issue Summary** \nA user is...


## Model Import

In [743]:
# Define model configuration and import parameters

PEFT = True
FROM_PRETRAINED = False
IS_DUMMY = True
DUMMY_PARAMETERS = {'hidden_size' : 2, 'intermediate_size' : 4, 'head_dim' : 8}
LOAD_8BIT = False
LOAD_FINETUNED = True
PARALLELIZE = False
USE_CACHE = False

USE_GPU = True
DEVICE = torch.device('mps' if torch.mps.is_available() and USE_GPU else ('cuda' if torch.cuda.is_available() and USE_GPU else 'cpu'))

MODEL_DICT = \
    {'name' : 'Llama-3-8B-Instruct',
     'repo_id' : 'meta-llama/Meta-Llama-3-8B-Instruct',
     'required_files' : ['config.json', 'generation_config.json', 'model.safetensors',
                         'model-00001-of-00004.safetensors', 'model-00002-of-00004.safetensors',
                         'model-00003-of-00004.safetensors', 'model-00004-of-00004.safetensors',
                         'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'model.safetensors.index.json'],
     'model_path' : ['LLaMa', '3.1-8B-Instruct']}

ROPE_SCALING = None
MAX_CONTEXT_LENGTH = 16000
ROPE_THETA = 1000000
MAX_GENERATION_LENGTH = 10000

ADAPTER_DIM = 128
LORA_RANK = 16
LORA_ALPHA = 16

LORA_PARAMS = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
ORIGINAL_TRAINABLE_PARAMS = ['input_layernorm', 'post_attention_layernorm', 'norm', 'lm_head']
TRAINABLE_ADAPTER = True
TRAINABLE_LORA = True

# Data options

VALIDATION = True
DATASET_TYPE = 'SFT'
NUM_RETRIEVED = 5
INCLUDES_GOLD_PROB = 0.7
RELEVANT_PROB_DIST = [0.5, 0.2, 0.1, 0.1, 0.1]
IRRELEVANT_PROB_DIST = [0.5, 0.15, 0.15, 0.1, 0.05, 0.05]

# Model import

LLM_PATH = Path.cwd().joinpath(*MODEL_DICT['model_path'])
LLM_PATH.mkdir(parents = True, exist_ok = True)
FINETUNED_MODEL_PATH = LLM_PATH.joinpath('CPT', f"Checkpoint_{MODEL_DICT['name']}")

# Learning rate hyperparameters

ATTN_BASE_LR = 5e-5 
FFN_BASE_LR = 2e-5
NORM_BASE_LR = 5e-5 * 0.1 
HEAD_BASE_LR = 5e-5 

FFN_GATE_LR = FFN_BASE_LR * 2
FFN_BIAS_LR = FFN_BASE_LR * 1

ATTN_LORA_A_LR = ATTN_BASE_LR
ATTN_LORA_B_LR = ATTN_BASE_LR * 0.5
ATTN_GATE_LR = ATTN_BASE_LR * 2
ATTN_BIAS_LR = ATTN_BASE_LR * 1

# Weight decay hyperparameters

WEIGHT_DECAY = 0.01
BIAS_DECAY = 0
NORM_DECAY = 0
LR_LAYER_DECAY = 0.95

# Batch size parameters

GLOBAL_BATCH_SIZE = 32
BATCH_SIZE = 4

# Scheduling and regularization hyperparameters

EPOCHS = 20
WARMUP = 0.05

LAMBDA_MAGNITUDE_DRIFT = 1e-4
MAGNITUDE_SHIFT_WARMUP = 0.05

# Loop hyper-parameters

VALIDATION_LOGGING_FACTOR = 1
CHECKPOINT_FREQUENCY_STEPS = 100

TOLERANCE = 0
PATIENCE = 2

GRADIENT_CHECKPOINTING = False

# Optional download of config files and parameters

if not LLM_PATH.joinpath(MODEL_DICT['model_path'][0]).exists():
    snapshot_download(repo_id = MODEL_DICT['repo_id'], allow_patterns = MODEL_DICT['required_files'], 
                          local_dir = LLM_PATH, use_auth_token = True)


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

# Fine-tuning Preparation

In [287]:
# Define LoRA adapter class

class LoRALinear(nn.Linear):
    
    def __init__(self, original_module: nn.Linear, rank: int = 16, alpha: float = 16):

        # Transfer original attributes
        
        super().__init__(original_module.in_features, original_module.out_features, bias = original_module.bias is not None,
                             device = original_module.weight.device, dtype = original_module.weight.dtype)

        self.weight = copy.deepcopy(original_module.weight)
        if original_module.bias is not None:
            self.bias = copy.deepcopy(original_module.bias)

        self.weight.requires_grad = original_module.weight.requires_grad
        if original_module.bias is not None:
            self.bias.requires_grad = original_module.bias.requires_grad

        self.original_weight_norm = self.weight.norm(p = 'fro')

        # Add LoRA custom attributes
        
        self.lora_rank = rank
        self.lora_alpha = alpha
        self.lora_scale = alpha / rank
        self.lora_weight_a = nn.Parameter(torch.zeros(self.out_features, self.lora_rank,
            device = self.weight.device, dtype = self.weight.dtype), requires_grad = True)
        self.lora_weight_b = nn.Parameter(torch.normal(mean = 0, std = 0.003, size = (self.lora_rank, self.in_features),
            device = self.weight.device, dtype = self.weight.dtype), requires_grad = True)
        self.lora_gate = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.in_features, 1, bias = True, device = self.weight.device, dtype = self.weight.dtype)),
            ('act_fn', nn.Sigmoid())
        ]))

        # Initialize gate 

        for param in self.lora_gate.parameters():
            nn.init.constant_(param, 0.0)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        
        return nn.functional.linear(input, self.weight + self.lora_gate(input.mean(dim = (0, 1))) \
                                    * self.lora_scale * torch.matmul(self.lora_weight_a, self.lora_weight_b), self.bias)

    def magnitude_drift_penalty(self) -> torch.Tensor:
        
        updated_weight = self.weight + torch.matmul(self.lora_weight_a, self.lora_weight_b)
        
        return (updated_weight.norm(p = 'fro') - self.original_weight_norm) ** 2

    @staticmethod
    def patch_attention(model, rank, alpha, params):
        
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear) and any(key in name.split('.')[-2:] for key in params):
                parent = dict(model.named_modules())[name.rsplit('.', 1)[0]]
                attr = name.split('.')[-1]
                setattr(parent, attr, LoRALinear(module, rank = rank, alpha = alpha))
        

In [288]:
# Define Adapter Class

class AdapterLlamaMLP(nn.Module):
    
    def __init__(self, original_module, adapter_dim):
        
        # Original MLP attributes
        
        super().__init__()
        self.config = original_module.config
        self.hidden_size = original_module.hidden_size
        self.intermediate_size = original_module.intermediate_size
        
        self.gate_proj = copy.deepcopy(original_module.gate_proj)
        self.up_proj = copy.deepcopy(original_module.up_proj)
        self.down_proj = copy.deepcopy(original_module.down_proj)
        self.act_fn = original_module.act_fn

        # Added adapter attributes

        self.adapter_dim = adapter_dim
        self.adapter = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.up_proj.in_features, self.adapter_dim, bias = True, 
                              device = self.up_proj.weight.device, dtype = self.up_proj.weight.dtype)),
            ('act_fn', nn.SiLU()),
            ('dropout', nn.Dropout(p = 0.1)),
            ('fc2', nn.Linear(self.adapter_dim, self.down_proj.out_features, bias = True,
                              device = self.down_proj.weight.device, dtype = self.up_proj.weight.dtype))
        ]))
        self.adapter_gate = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.hidden_size, 1, bias = True, device = self.up_proj.weight.device, dtype = self.up_proj.weight.dtype)),
            ('act_fn', nn.Sigmoid())
        ]))

        # Initialize new adapter weights

        for param in self.adapter_gate.parameters():
            nn.init.constant_(param, 0.0)

        for name, param in self.adapter.named_parameters():
            if 'fc2' in name.split('.'):
                nn.init.normal_(param, mean = 0.0, std = 0.003)
            else:
                nn.init.constant_(param, 0.0)

    def forward(self, x):
        
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        adapt_proj = self.adapter(x) * self.adapter_gate(x)
        
        return down_proj + adapt_proj

    # Special function for initiating adapter injection

    @staticmethod
    def patch_mlp(model, adapter_dim):
        
        for i in range(len(model.model.layers)):
            model.model.layers[i].mlp = AdapterLlamaMLP(
                model.model.layers[i].mlp, adapter_dim = adapter_dim)
  

In [744]:
# Define a single class for proper model loading

class LLMmanager:
    
    def __init__(self, model_path, device, peft = True, load_8bit = False, finetuned_model_path = None,
                     from_pretrained = False, is_dummy = False, dummy_parameters = {}, use_cache = None,
                         rope_scaling = None, rope_theta = None, max_context_length = None,
                             adapter_dim = None, lora_rank = None, lora_alpha = None, lora_params = None,
                                 original_trainable_params = None, trainable_adapter = None, trainable_lora = None):

        # Loading parameters 
        
        self.model_path = model_path
        self.finetuned_model_path = finetuned_model_path
        self.device = torch.device(device)
        self.load_8bit = load_8bit
        self.from_pretrained = from_pretrained
        self.peft = peft
        self.use_cache = use_cache

        # Dummy model control
        
        self.is_dummy = is_dummy
        self.dummy_parameters = dummy_parameters
        self.strict_loading = not self.is_dummy

        # Context length parameters

        self.rope_scaling = rope_scaling
        self.rope_theta = rope_theta
        self.max_context_length = max_context_length

        # Model files
        
        self.tokenizer = None
        self.model = None
        self.gen_config = None
        self.checkpoint_report_template = None

        # Internal configs

        self.config = None
        self.finetuning_config = {
            'adapter_dim' : 128 if adapter_dim is None else adapter_dim, 
            'lora_rank' : 16 if lora_rank is None else lora_rank,
            'lora_alpha' : 16 if lora_alpha is None else lora_alpha, 
            'lora_params' : ['q_proj', 'k_proj', 'v_proj', 'o_proj'] if lora_params is None else lora_params,
            'original_trainable_params' :  [] if original_trainable_params is None else original_trainable_params,
            'trainable_adapter' : True if trainable_adapter is None else trainable_adapter,
            'trainable_lora' : True if trainable_lora is None else trainable_lora
        }

        # Checkpoint parameters
        
        self.checkpoint_report_filename = 'finetuning_report.txt'
        self.finetuning_config_filename = 'finetuning_config.json'

    # Main loading function 
    
    def load_model(self, load_finetuned: bool = False, parallelize: bool = False) \
                        -> tuple[PreTrainedTokenizer, PreTrainedModel, GenerationConfig]:

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only = True)
    
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
    
        self.gen_config = GenerationConfig.from_pretrained(self.model_path)
    
        if type(self.max_context_length) == int:
            setattr(self.gen_config, 'max_length', self.max_context_length)

        if 'checkpoint_report_template.txt' in os.listdir(self.model_path):
            with open(self.model_path.joinpath('checkpoint_report_template.txt'), 'r') as file:
                self.checkpoint_report_template = file.read()
    
        if self.from_pretrained and not self.is_dummy:
        
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                device_map = 'auto' if self.device == torch.device('cuda') else None,
                torch_dtype = torch.bfloat16 if self.device == torch.device('cuda') else torch.float32,
                load_in_8bit = self.load_8bit,
                local_files_only = True,
                use_safetensors = True,
                rope_scaling = self.rope_scaling
            ).to(self.device)

        else:

            # Pulling out the config
    
            self.config = AutoConfig.from_pretrained(self.model_path, local_files_only = True, use_safetensors = True)
            
            # Optional config editing for a dummy model
            
            if self.is_dummy:
                for name, value in self.dummy_parameters.items():
                    setattr(self.config, name, value)
    
            if type(self.rope_scaling) == dict:
                setattr(self.config, 'rope_scaling', self.rope_scaling)
    
            if type(self.max_context_length) == int:
                setattr(self.config, 'max_position_embeddings', self.max_context_length)
    
            if self.rope_theta is not None:
                setattr(self.config, 'rope_theta', self.rope_theta)

            if self.use_cache is not None:
                setattr(self.config, 'use_cache', self.use_cache)
            
            # Instantiate the model
            
            self.model = AutoModelForCausalLM.from_config(self.config).to(self.device)
            
            # Load parameters
    
            if (not load_finetuned) or (self.finetuned_model_path is None):
                self.load_safetensors()

        # Fine-tuning preparation

        if self.peft:
            
            if load_finetuned and self.finetuned_model_path is not None:
                with open(self.finetuned_model_path.joinpath('finetuning_config.json'), 'r') as file:
                    self.finetuning_config.update(json.load(file))
                    
            self.finetune_prepare(load_weights = load_finetuned)

        if parallelize and torch.cuda.device_count() > 1 and self.device == torch.device('cuda'):
            model = model.to(self.device)
            model = nn.DataParallel(model)

        return self.tokenizer, self.model, self.gen_config 
                
    # Safetensor loading function 
    
    def load_safetensors(self, model_path = None, model = None):

        if model_path is None:
            model_path = self.model_path

        if model is None:
            model = self.model

        with open(model_path.joinpath('model.safetensors.index.json'), 'r') as file:
            index = json.load(file)
        
        safetensors_params_map = defaultdict(list)
        for param_name, safetensor_name in index['weight_map'].items():
            safetensors_params_map[safetensor_name].append(param_name)

        full_state_dict = {}
        
        for safetensor_name, param_names in safetensors_params_map.items():
            safetensor_path = model_path.joinpath(safetensor_name)
            shard_dict = load_file(safetensor_path)
            model_dict = model.state_dict()
        
            shard_dict = {
                key : value for key, value in shard_dict.items()
                if key in model_dict and model_dict[key].shape == value.shape
            }

            full_state_dict.update(shard_dict)
        
        with warnings.catch_warnings(action = 'ignore'):
            model.load_state_dict(full_state_dict, strict = self.strict_loading)

    # Function for fine-tuning preparation
    
    def finetune_prepare(self, model = None, load_weights = False):

        if model is None:
            model = self.model
    
        # Freeze original parameters
    
        for param in model.parameters():
            param.requires_grad = False
    
        # Inject LoRA modules
    
        LoRALinear.patch_attention(model, self.finetuning_config['lora_rank'], self.finetuning_config['lora_alpha'],
                                       self.finetuning_config['lora_params'])
    
        # Inject Adapters
    
        AdapterLlamaMLP.patch_mlp(model, self.finetuning_config['adapter_dim'])
    
        if load_weights and self.finetuned_model_path is not None:
            self.load_safetensors(model_path = self.finetuned_model_path)
    
        # Unfreeze selected modules

        trainable_params = []
        trainable_params.extend(self.finetuning_config['original_trainable_params'])
        if self.finetuning_config['trainable_adapter']:
            trainable_params.extend(self.finetuning_config.get('adapter_param_names',
                [name for name, param in model.named_parameters() if 'adapter' in name]))
        if self.finetuning_config['trainable_lora']:
            trainable_params.extend(self.finetuning_config.get('lora_param_names',
                [name for name, param in model.named_parameters() if 'lora' in name]))
    
        for name, param in model.named_parameters():
            if any([pname in name for pname in trainable_params]):
                param.requires_grad = True
            else:
                param.requires_grad = False

    # Utility for saving checkpoints

    def save_model_checkpoint(self, output_dir, log_param_list: list = None,
                                  model: PreTrainedModel = None, tokenizer: PreTrainedTokenizer = None):

        if model is None:
            model = self.model
            
        if tokenizer is None:
            tokenizer = self.tokenizer
        
        # Collect model parameters
        
        adapter_dim = None
        lora_rank = None
        lora_alpha = None
        adapter_param_names = []
        lora_param_names = []
        
        is_adapter = False
        is_lora = False
        
        for mname, module in model.named_modules():
                
            if isinstance(module, AdapterLlamaMLP):
        
                if not is_adapter:
        
                    adapter_dim = module.adapter_dim
                    is_adapter = True
                
                for pname, param in module.named_parameters():
                    if any([(keyword in pname.split('.')) for keyword in ['adapter', 'adapter_gate']]):
                        adapter_param_names.append('.'.join([mname, pname]))
        
            if isinstance(module, LoRALinear):
        
                if not is_lora:
        
                    lora_rank = module.lora_rank
                    lora_alpha = module.lora_alpha
                    is_lora = True
                
                for pname, param in module.named_parameters():
                    if any([(keyword in pname.split('.')) for keyword in ['lora_weight_a', 'lora_weight_b', 'lora_gate']]):
                        lora_param_names.append('.'.join([mname, pname]))
    
        lora_params = list(set([name.split('.')[4] for name in lora_param_names]))
    
        finetuning_config = {
            'adapter_dim' : adapter_dim, 
            'lora_rank' : lora_rank,
            'lora_alpha' : lora_alpha, 
            'lora_params' : lora_params,
            'adapter_param_names' : adapter_param_names, 
            'lora_param_names' : lora_param_names
        }
            
        # Save files
        
        os.makedirs(output_dir, exist_ok = True)
        model.save_pretrained(output_dir, safe_serialization = True) 
        tokenizer.save_pretrained(output_dir)

        with open(output_dir.joinpath(self.finetuning_config_filename), 'w') as f:
            json.dump(finetuning_config, f, indent = 2)

        if log_param_list is not None and self.checkpoint_report_template is not None:
            txt_log = self.checkpoint_report_template.format(*log_param_list)
            with open(output_dir.joinpath(self.checkpoint_report_filename), 'w') as f:
                f.write(txt_log)


In [745]:
# Import and  prepare the LLM for fine-tuning

llm_utility = \
    LLMmanager(LLM_PATH, DEVICE, finetuned_model_path = FINETUNED_MODEL_PATH,
               peft = PEFT, from_pretrained = FROM_PRETRAINED, load_8bit = LOAD_8BIT,
               is_dummy = IS_DUMMY, dummy_parameters = DUMMY_PARAMETERS,
               rope_scaling = ROPE_SCALING, rope_theta = ROPE_THETA, max_context_length = MAX_CONTEXT_LENGTH,
               adapter_dim = ADAPTER_DIM, lora_rank = LORA_RANK, lora_alpha = LORA_ALPHA, lora_params = LORA_PARAMS,
               original_trainable_params = ORIGINAL_TRAINABLE_PARAMS,
               trainable_adapter = TRAINABLE_ADAPTER, trainable_lora = TRAINABLE_LORA,
               use_cache = USE_CACHE)

tokenizer, llm_model, gen_config = llm_utility.load_model(load_finetuned = LOAD_FINETUNED, parallelize = PARALLELIZE)

llm_model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): LoRALinear(
            in_features=2, out_features=256, bias=False
            (lora_gate): Sequential(
              (fc1): Linear(in_features=2, out_features=1, bias=True)
              (act_fn): Sigmoid()
            )
          )
          (k_proj): LoRALinear(
            in_features=2, out_features=64, bias=False
            (lora_gate): Sequential(
              (fc1): Linear(in_features=2, out_features=1, bias=True)
              (act_fn): Sigmoid()
            )
          )
          (v_proj): LoRALinear(
            in_features=2, out_features=64, bias=False
            (lora_gate): Sequential(
              (fc1): Linear(in_features=2, out_features=1, bias=True)
              (act_fn): Sigmoid()
            )
          )
          (o_proj): LoRALinear(
            in_feature

## Data Loader

In [858]:
# Import a tokenizer

with open(LLM_PATH.joinpath('system_prompt.txt')) as sp, \
    open(LLM_PATH.joinpath('generation_template.txt')) as gt:
        
    system_prompt = sp.read()
    generation_template = gt.read()

tokenizer.chat_template = generation_template

# Declare a special dataset class for CPT (Continued Pre-training) and SFT datasets

class FineTuneDataset(Dataset):
    
    def __init__(self, tokenizer, dataset, dataset_type, target_col = None, system_prompt = None, query_col = None, id_col = None,
                     context_dataset = None, query_con_col = None, answer_con_col = None, query_match_col = None, sim_match_col = None,
                         num_retrieved = None, includes_gold_prob = None, relevant_prob_dist = None,
                             irrelevant_prob_dist = None, train = True, max_context_length = 16000, max_generation_length = 1000):

        if dataset_type == 'SFT' and (query_col is None or context_dataset is None 
                                      or query_match_col is None or sim_match_col is None
                                      or query_con_col is None or answer_con_col is None):
            raise ValueError('SFT dataset must have valid query, context and embeddings')

        self.tokenizer = tokenizer
        self.train = train
        self.dataset_type = dataset_type
        self.max_context_length = max_context_length
        self.max_generation_length = max_generation_length

        # Context evaluation phrases for sampling

        self.relevant_comments = [
            'The retrieved cases are relevant and provide actionable insights applicable to this query.',
            'Relevant prior cases were found and directly informed the suggested actions.',
            'The context from retrieved tickets was useful and guided the resolution strategy.',
            'Prior cases align closely with the current issue and offer practical guidance.',
            'Useful precedents were identified, supporting the recommended troubleshooting steps.'
        ]

        self.semi_relevant_comments = [
            'The retrieved cases are partially relevant. Some suggestions may apply but require adaptation.',
            'Prior tickets contain useful hints but do not fully address the current situation.',
            'Context provides limited insights; additional investigation may be necessary.',
            'The retrieved cases offer partial guidance, but not all information is directly applicable.',
            'Some elements from prior cases can inform the solution, but verification is needed.'
        ]

        self.non_relevant_comments = [
            'The retrieved cases do not appear to be relevant to this query.',
            'No prior cases provide actionable insights for the current issue.',
            'Context from retrieved tickets is not applicable; alternative strategies should be considered.',
            'The prior cases are irrelevant. Guidance must be derived from general troubleshooting principles.',
            'Retrieved tickets do not contribute to resolving this query.'
        ]

        if self.dataset_type == 'CPT':
            self.samples = dataset

        if self.dataset_type == 'SFT':
            
            dataset = dataset.merge(context_dataset[[query_match_col, sim_match_col]],
                                        how = 'left', left_on = id_col, right_index = True)
            
            self.system_prompt = system_prompt
            self.system_prompt_length = self.tokenizer(self.system_prompt, add_special_tokens = False,
                                                           return_tensors = 'pt')['attention_mask'].size()[1]
            
            self.context_dataset = context_dataset
            self.context_len = len(context_dataset)
            self.min_context_len = 300
            self.query_con_col = query_con_col
            self.answer_con_col = answer_con_col
            self.num_retrieved = num_retrieved if num_retrieved is not None else 5
            self.includes_gold_prob = includes_gold_prob if includes_gold_prob is not None else 0.7
            self.relevant_prob_dist = relevant_prob_dist if relevant_prob_dist is not None else [1 / self.num_retrieved] * self.num_retrieved
            self.irrelevant_prob_dist = irrelevant_prob_dist if irrelevant_prob_dist is not None else \
                                            [1 / (self.num_retrieved + 1)] * (self.num_retrieved + 1)
            self.samples = list((query, context, target) for query, context, target in \
                                    zip(dataset[query_col].tolist(),
                                        zip(dataset[query_match_col].tolist(), dataset[sim_match_col].tolist()),
                                            dataset[target_col].tolist()))

    def __len__(self):
        
        return len(self.samples)

    def __getitem__(self, idx):

        # Get data item depending on CPT or SFT training

        if self.dataset_type == 'CPT':
            text = self.tokenizer.bos_token + self.samples[idx] + self.tokenizer.eos_token
            whitespace_sep_text = text.split()
            instruct_text = whitespace_sep_text[:int(len(whitespace_sep_text) / 2)]
            target_text = whitespace_sep_text[int(len(whitespace_sep_text) / 2):]

        if self.dataset_type == 'SFT':

            # Retrieve context documents and adjust context length based on global context length

            retrieved_context, comment = self.retrieve_context(idx)
            retrieved_context = [(rank, doc) for rank, doc in sorted(retrieved_context, key = lambda x: x[0], reverse = True)]
            answer = comment + '\n\n' + self.samples[idx][2].rstrip('<|eot_id|>')
            
            retrieved_context_tokenized = [
                (rank, self.tokenizer(doc, add_special_tokens = False, return_tensors = 'pt')['input_ids'].squeeze(0))
                    for rank, doc in retrieved_context
            ]

            retrieved_context_lengths = torch.tensor([ids.size(0) for rank, ids in retrieved_context_tokenized])

            query_length = self.tokenizer(self.samples[idx][0], add_special_tokens = False,
                                              return_tensors = 'pt')['attention_mask'].size()[1]
            answer_length = self.tokenizer(answer, add_special_tokens = False,
                                              return_tensors = 'pt')['attention_mask'].size()[1]

            excess_length = (self.max_context_length 
                                 - self.system_prompt_length 
                                 - query_length
                                 - answer_length
                                 - retrieved_context_lengths.sum().item())
            
            excess_length = -excess_length if excess_length < 0 else 0

            if excess_length > 0:
                
                allowed_context_lengths = torch.clamp(retrieved_context_lengths - 
                                                        torch.ceil((retrieved_context_lengths / retrieved_context_lengths.sum()) 
                                                                       * excess_length).to(torch.long), min = 400)
                
            else:
                
                allowed_context_lengths = retrieved_context_lengths

            # Apply chat templates and distingush between instructions and generation parts

            retrieved_documents = [
                {'content' : self.tokenizer.decode(ids[:allowed_context_lengths[i].item()], skip_special_tokens = True)}
                     for i, (rank, ids) in enumerate(retrieved_context_tokenized)
            ]
            
            text = self.tokenizer.apply_chat_template(
                conversation = [
                    {'role' : 'system', 'content' : self.system_prompt},
                    {'role' : 'user', 'content' : self.samples[idx][0]},
                    {'role' : 'assistant', 'content' : answer}
                ],
                documents = retrieved_documents,
                add_generation_prompt = False,
                tokenize = False
            ).strip()

            instruct_text = self.tokenizer.apply_chat_template(
                conversation = [
                    {'role' : 'system', 'content' : self.system_prompt},
                    {'role' : 'user', 'content' : self.samples[idx][0]},
                    {'role' : 'assistant', 'content' : ''}
                ],
                documents = retrieved_documents,
                add_generation_prompt = False,
                tokenize = False
            ).strip()

            target_text = text[len(instruct_text)-10:]

        # Tokenize text/template
            
        encoding = self.tokenizer(text, return_tensors = 'pt', add_special_tokens = False)
        input_ids = encoding['input_ids'].squeeze(0)

        # Get position to start loss calculation

        if self.dataset_type == 'CPT':

            labels = input_ids.clone()

        if self.dataset_type == 'SFT':
        
            mask_start = self.tokenizer(instruct_text, return_tensors = 'pt', add_special_tokens = False)['input_ids'].squeeze(0).size()[0]

            # Mask an input up to expected generated text
            
            labels = input_ids.clone()
            labels[:mask_start-1] = -100

        return input_ids, labels, instruct_text, target_text, self.train
        
    def retrieve_context(self, idx):

        includes_gold = np.random.rand() <= self.includes_gold_prob

        if includes_gold:

            random_num = self.num_retrieved - np.random.choice(list(range(self.num_retrieved)), p = self.relevant_prob_dist)
            retrieval_idx = [self.samples[idx][1][0][i] if i < random_num
                                 else self.context_dataset.index[np.random.randint(0, self.context_len)]
                                     for i in range(self.num_retrieved)]
            match_sim = [round(self.samples[idx][1][1][i], 3) if i < random_num 
                             else round(np.random.rand() * 0.05, 3)
                                 for i in range(self.num_retrieved)]
            
        else:

            random_num = self.num_retrieved - np.random.choice(list(range(self.num_retrieved + 1)), p = self.irrelevant_prob_dist)
            retrieval_idx = [self.samples[idx][1][0][1:][i] if i < random_num
                                 else self.context_dataset.index[np.random.randint(0, self.context_len)]
                                     for i in range(self.num_retrieved)]
            match_sim = [round(self.samples[idx][1][1][1:][i], 3) if i < random_num
                             else round(np.random.rand() * 0.05, 3)
                                 for i in range(self.num_retrieved)]
            
        content = list(zip(match_sim,
                           list(map(lambda i: self.context_dataset.loc[i][self.query_con_col] + '\n\n' \
                                        + self.context_dataset.loc[i][self.answer_con_col],
                                    retrieval_idx))))
        
        if includes_gold:
            comment = np.random.choice(self.relevant_comments)
        elif np.array(match_sim).max() >= 0.6:
            comment = np.random.choice(self.semi_relevant_comments)
        else:
            comment = np.random.choice(self.non_relevant_comments)
            
        return content, comment

    def batch_collate(self, batch):
        
        input_ids, labels, instruct_text, target_text, is_train = zip(*batch)
        input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first = True, padding_value = self.tokenizer.pad_token_id)
        labels = nn.utils.rnn.pad_sequence(labels, batch_first = True, padding_value = -100)
        attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long)
        
        return input_ids, attention_mask, labels, instruct_text, target_text, is_train


In [859]:
# Instantiate a dataset

if VALIDATION:

    validation_idx = np.random.choice(list(processed_tickets_df['TICKETID']), int(len(processed_tickets_df.index) * 0.1), replace = False)
    validation_mask = processed_tickets_df['TICKETID'].isin(validation_idx)

    processed_tickets_df[~validation_mask].to_json(data_folder.joinpath('llm_processed_tickets_df_train.json'),
                                                       orient = 'index', double_precision = 15, index = True)
    processed_tickets_df[validation_mask].to_json(data_folder.joinpath('llm_processed_tickets_df_val.json'),
                                                       orient = 'index', double_precision = 15, index = True)

    dataset_train = FineTuneDataset(tokenizer, processed_tickets_df[~validation_mask], DATASET_TYPE,
                                    target_col = 'SOLUTION', system_prompt = system_prompt, query_col = 'PROBLEM',
                                    context_dataset = context_tickets_df, query_con_col = 'Query', id_col = 'TICKETID',
                                    answer_con_col = 'Summarized Answer', query_match_col = 'context_ids', sim_match_col = 'context_sims',
                                    num_retrieved = NUM_RETRIEVED, includes_gold_prob = INCLUDES_GOLD_PROB,
                                    relevant_prob_dist = RELEVANT_PROB_DIST, irrelevant_prob_dist = IRRELEVANT_PROB_DIST,
                                    max_context_length = MAX_CONTEXT_LENGTH, max_generation_length = MAX_GENERATION_LENGTH, train = True)

    dataset_val = FineTuneDataset(tokenizer, processed_tickets_df[validation_mask], DATASET_TYPE,
                                  target_col = 'SOLUTION', system_prompt = system_prompt, query_col = 'PROBLEM',
                                  context_dataset = context_tickets_df, query_con_col = 'Query', id_col = 'TICKETID',
                                  answer_con_col = 'Summarized Answer', query_match_col = 'context_ids', sim_match_col = 'context_sims',
                                  num_retrieved = NUM_RETRIEVED, includes_gold_prob = INCLUDES_GOLD_PROB,
                                  relevant_prob_dist = RELEVANT_PROB_DIST, irrelevant_prob_dist = IRRELEVANT_PROB_DIST,
                                  max_context_length = MAX_CONTEXT_LENGTH, max_generation_length = MAX_GENERATION_LENGTH, train = False)
    
else:

    dataset_train = FineTuneDataset(tokenizer, processed_tickets_df, DATASET_TYPE,
                                    target_col = 'SOLUTION', system_prompt = system_prompt, query_col = 'PROBLEM',
                                    context_dataset = context_tickets_df, id_col = 'TICKETID', query_con_col = 'Query',
                                    answer_con_col = 'Summarized Answer', query_match_col = 'context_ids', sim_match_col = 'context_sims',
                                    num_retrieved = NUM_RETRIEVED, includes_gold_prob = INCLUDES_GOLD_PROB,
                                    relevant_prob_dist = RELEVANT_PROB_DIST, irrelevant_prob_dist = IRRELEVANT_PROB_DIST,
                                    max_context_length = MAX_CONTEXT_LENGTH, max_generation_length = MAX_GENERATION_LENGTH, train = True)
    

## Fine-tuning

In [52]:
# Construction of parameter-specific learning rates and weight decays

def create_parameter_groups(model):

    parameter_optim_list = []
    num_layers = len(model.model.layers)
    
    for mname, module in model.named_modules():
    
        if len(mname.split('.')) >= 3 and 'layers' in mname.split('.'):
    
            layer_num = int(mname.split('.')[2]) + 1
            
            if isinstance(module, AdapterLlamaMLP):
                for pname, param in module.named_parameters():
                    if 'weight' in pname.split('.') and 'adapter' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY,
                                                         'lr': FFN_BASE_LR * (LR_LAYER_DECAY ** (num_layers - layer_num))})
                    elif 'weight' in pname.split('.') and 'adapter_gate' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY, 'lr': FFN_GATE_LR})
                    elif 'bias' in pname.split('.') and 'adapter' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': BIAS_DECAY,
                                                         'lr': FFN_BIAS_LR * (LR_LAYER_DECAY ** (num_layers - layer_num))})
                    elif 'bias' in pname.split('.') and 'adapter_gate' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': BIAS_DECAY, 'lr': FFN_BIAS_LR})
        
            if isinstance(module, LoRALinear):
                for pname, param in module.named_parameters():
                    if 'lora_weight_a' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY,
                                                         'lr': ATTN_LORA_A_LR * (LR_LAYER_DECAY ** (num_layers - layer_num))})
                    elif 'lora_weight_b' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY,
                                                         'lr': ATTN_LORA_B_LR * (LR_LAYER_DECAY ** (num_layers - layer_num))})
                    elif 'weight' in pname.split('.') and 'lora_gate' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY, 'lr': ATTN_GATE_LR})
                    elif 'bias' in pname.split('.') and 'lora_gate' in pname.split('.'):
                        parameter_optim_list.append({'params' : param, 'weight_decay': BIAS_DECAY, 'lr': ATTN_BIAS_LR})
                        
            if isinstance(module, model.model.norm.__class__):
                for pname, param in module.named_parameters():
                    parameter_optim_list.append({'params' : param, 'weight_decay': NORM_DECAY, 
                                                 'lr': NORM_BASE_LR * (LR_LAYER_DECAY ** (num_layers - layer_num))})
    
        if isinstance(module, model.model.norm.__class__) and 'layers' not in mname.split('.'):
            for pname, param in module.named_parameters():
                parameter_optim_list.append({'params' : param, 'weight_decay': NORM_DECAY, 'lr': NORM_BASE_LR})
                    
        if 'lm_head' in mname.split('.') and 'layers' not in mname.split('.'):
            for pname, param in module.named_parameters():
                parameter_optim_list.append({'params' : param, 'weight_decay': WEIGHT_DECAY, 'lr': HEAD_BASE_LR})
                
    return parameter_optim_list

# Learning rate scheduler instantiation

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles = 0.5, min_lambda_lr = 0.1):

    def lr_lambda(current_step):
        
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        
        return max(min_lambda_lr , 0.5 * (1.0 + torch.cos(torch.tensor(num_cycles * torch.pi * 2.0 * progress))).item())

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Define loss function

def llm_finetuning_loss_fn(logits, labels, lora_modules, lambda_magnitude_drift,
                               vocab_size, ignore_index = -100, label_smoothing = 0):
    
    cross_entropy = nn.CrossEntropyLoss(ignore_index = -100, reduction = 'mean', label_smoothing = label_smoothing) \
                        (logits.view(-1, vocab_size), labels.view(-1))
    magnitude_regularization = torch.stack([module.magnitude_drift_penalty() for module in lora_modules]).mean()
    
    return cross_entropy + lambda_magnitude_drift * magnitude_regularization

# Memory allocation tracking

def check_memory(device, size, label = '', reporting_threshold_1 = 0.2, reporting_threshold_2 = 0.7):

    available = torch.cuda.get_device_properties(device.index).total_memory / 1024**3
    allocated = torch.cuda.memory_allocated(device) / 1024 ** 3
    reserved = torch.cuda.memory_reserved(device) / 1024 ** 3
    max_alloc = torch.cuda.max_memory_allocated(device) / 1024 ** 3

    if (reserved / available) >= reporting_threshold_1:
        print(f"[{label}][rank{device}] allocated={allocated:.2f} GB, reserved={reserved:.2f} GB, peak={max_alloc:.2f} GB, tensor={size}")

    if (reserved / available) >= reporting_threshold_2:

        snapshot = torch.cuda.memory_snapshot()
        with open('snapshot.json', 'w') as sn, open('memory_summary.txt', 'w') as msu, open('memory_stats.txt', 'w') as mst:
            json.dump(snapshot, sn, indent = 2)
            msu.write(torch.cuda.memory_summary(device = device, abbreviated = False))
            json.dump(torch.cuda.memory_stats(device), mst, indent = 2)

        torch.cuda.empty_cache()
        gc.collect()


In [54]:
# Fine-Tuning Loop

def train(rank, world_size):

    # Initialize process group (if multi-GPU)

    global_rank = int(os.environ['SLURM_PROCID'])

    parallelized = False
    device = DEVICE

    if device == torch.device('cuda'):
        
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
        torch.backends.cuda.enable_math_sdp(False)
    
    if world_size > 1 and device == torch.device('cuda'):

        torch.cuda.set_device(rank)
        device = torch.device(f'cuda:{rank}')
        
        dist.init_process_group('nccl', rank = global_rank, world_size = world_size, device_id = device)
        parallelized = True

    # Initiate dynamically updated parameters

    best_score = np.inf
    patience_epochs = 0
    training_steps = 0
    steps = 0
    
    train_losses = []
    val_losses = []
    train_perplex = []
    val_perplex = []

    gradient_accumulation = \
        int(np.ceil(GLOBAL_BATCH_SIZE / (BATCH_SIZE * world_size)))

    num_training_steps = int(np.ceil((len(dataset_train) / GLOBAL_BATCH_SIZE) * EPOCHS))
    num_warmup_steps = int(np.ceil(num_training_steps * WARMUP))
        
    # Construct the model

    llm_utility = \
        LLMmanager(LLM_PATH, device, finetuned_model_path = FINETUNED_MODEL_PATH,
                   peft = PEFT, from_pretrained = FROM_PRETRAINED, load_8bit = LOAD_8BIT,
                   is_dummy = IS_DUMMY, dummy_parameters = DUMMY_PARAMETERS,
                   rope_scaling = ROPE_SCALING, rope_theta = ROPE_THETA, max_context_length = MAX_CONTEXT_LENGTH,
                   adapter_dim = ADAPTER_DIM, lora_rank = LORA_RANK, lora_alpha = LORA_ALPHA, lora_params = LORA_PARAMS,
                   original_trainable_params = ORIGINAL_TRAINABLE_PARAMS,
                   trainable_adapter = TRAINABLE_ADAPTER, trainable_lora = TRAINABLE_LORA,
                   use_cache = USE_CACHE)

    tokenizer, model, _ = llm_utility.load_model(load_finetuned = LOAD_FINETUNED)

    # Broadcast parameters from a source process
    
    if parallelized:
        
        dist.barrier()
        
        for param in model.parameters():
            dist.broadcast(param.data, src = 0)

        dist.barrier()

    # Set up dataloaders

    if parallelized:
        sampler_train = DistributedSampler(dataset_train, num_replicas = world_size, rank = global_rank, shuffle = True)
    else:
        sampler_train = None

    dataloader_train = DataLoader(dataset_train, batch_size = BATCH_SIZE, sampler = sampler_train, 
                                      shuffle = (sampler_train is None), collate_fn = dataset_train.batch_collate)

    if VALIDATION:

        if parallelized:
            sampler_val = DistributedSampler(dataset_val, num_replicas = world_size, rank = global_rank, shuffle = False)
        else:
            sampler_val = None
            
        dataloader_val = DataLoader(dataset_val, batch_size = BATCH_SIZE, sampler = sampler_val, 
                                        shuffle = (sampler_val is None), collate_fn = dataset_val.batch_collate)

    # Memory optimizations

    if device.type == 'cuda':
        torch.cuda.empty_cache()
        gc.collect()

    if GRADIENT_CHECKPOINTING:
        model.gradient_checkpointing_enable()
        if hasattr(model, 'enable_input_require_grads'):
            model.enable_input_require_grads()

    # Wrap the model into DDP class and define optimizer and schedule

    if parallelized:
        model = DDP(model, device_ids = [rank], output_device = rank, gradient_as_bucket_view = True)
        model_module = model.module
    else:
        model_module = model

    optimizer = optim.AdamW(create_parameter_groups(model_module))

    scheduler = get_cosine_schedule_with_warmup(
                    optimizer, 
                    num_warmup_steps = num_warmup_steps, 
                    num_training_steps = num_training_steps
    )

    lora_modules = [module for name, module in model_module.named_modules() if isinstance(module, LoRALinear)]

    # Training loop

    for epoch in range(EPOCHS):
    
        # Training

        if sampler_train is not None:
            sampler_train.set_epoch(epoch)
    
        epoch_train_loss = []
        model.train()

        accumulated_loss = torch.zeros(1, device = device) 
        
        for input_ids, attention_mask, labels, _, _, _ in tqdm(dataloader_train):
            
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)
    
            logits = model(input_ids = input_ids, attention_mask = attention_mask).logits
            shifted_logits = logits[:, :-1, :].contiguous()
            shifted_labels = labels[:, 1:].contiguous()
    
            magnitude_schedule = min((training_steps + 1) / (MAGNITUDE_SHIFT_WARMUP * num_training_steps), 1)
            
            loss = llm_finetuning_loss_fn(shifted_logits, shifted_labels, lora_modules, magnitude_schedule * LAMBDA_MAGNITUDE_DRIFT,
                                             model_module.model.embed_tokens.num_embeddings, ignore_index = -100, label_smoothing = 0.1) \
                                                 / gradient_accumulation

            # Print memory usage statistics
            
            if device.type == 'cuda':
                check_memory(device, input_ids.size(), label = f'forward step {steps}')

            # Loss accumulation
            
            loss.backward()

            with torch.no_grad():
                accumulated_loss += loss.detach()

            # Optimization steps
    
            steps += 1
            if steps % gradient_accumulation == 0:
                
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none = True)

                if device.type == 'cuda':
                    torch.cuda.reset_peak_memory_stats()

                with torch.no_grad():
                    if parallelized:
                        dist.all_reduce(accumulated_loss, op = dist.ReduceOp.SUM)
                        loss_avg = (accumulated_loss / dist.get_world_size()).item()
                    else:
                        loss_avg = accumulated_loss.item()
                
                epoch_train_loss.append(loss_avg)
                accumulated_loss = torch.zeros(1, device = device) 
                training_steps += 1

            # Intermediate checkpointing
    
            if training_steps % CHECKPOINT_FREQUENCY_STEPS == 0 and global_rank == 0:
    
                checkpoint_source = 'Ongoing intermediate checkpointing'
                log_list = \
                    [str(datetime.datetime.now()), dataset_train.dataset_type, training_steps, epoch + 1,
                     str(train_losses), str(val_losses), str(train_perplex), str(val_perplex), checkpoint_source, str(epoch_train_loss),
                     ATTN_BASE_LR, FFN_BASE_LR, NORM_BASE_LR, HEAD_BASE_LR, WEIGHT_DECAY, BIAS_DECAY, NORM_DECAY,
                     LR_LAYER_DECAY, WARMUP, EPOCHS, GLOBAL_BATCH_SIZE, LAMBDA_MAGNITUDE_DRIFT, MAGNITUDE_SHIFT_WARMUP]
                llm_utility.save_model_checkpoint(FINETUNED_MODEL_PATH, log_param_list = log_list, 
                                                      model = model_module, tokenizer = tokenizer)
    
        train_losses.append(torch.tensor(epoch_train_loss).mean().item())
        train_perplex.append(torch.exp(torch.tensor(epoch_train_loss).mean()).item())
        if global_rank == 0:
            print(f'Epoch {epoch + 1} train loss: {torch.tensor(epoch_train_loss).mean().item():.4f}')
    
        # Validation
    
        if (epoch + 1) % VALIDATION_LOGGING_FACTOR == 0 and VALIDATION:

            if sampler_val is not None:
                sampler_val.set_epoch(epoch)
                        
            epoch_val_loss = []
            model.eval()
    
            with torch.no_grad():
            
                for input_ids, attention_mask, labels, _, _, _ in dataloader_val:
                    
                    input_ids = input_ids.to(device)
                    attention_mask = attention_mask.to(device)
                    labels = labels.to(device)
            
                    logits = model(input_ids = input_ids, attention_mask = attention_mask).logits
                    shifted_logits = logits[:, :-1, :].contiguous()
                    shifted_labels = labels[:, 1:].contiguous()
                    
                    loss = llm_finetuning_loss_fn(shifted_logits, shifted_labels, lora_modules, 0,
                                                    model_module.model.embed_tokens.num_embeddings,
                                                        ignore_index = -100, label_smoothing = 0.1)

                    loss_tensor = loss.detach()
                    if parallelized:
                        dist.all_reduce(loss_tensor, op = dist.ReduceOp.SUM)
                        loss_avg = (loss_tensor / dist.get_world_size()).item()
                    else:
                        loss_avg = loss_tensor.item()
                    epoch_val_loss.append(loss_avg)

                val_losses.append(torch.tensor(epoch_val_loss).mean().item())
                val_perplex.append(torch.exp(torch.tensor(epoch_val_loss).mean()).item())
                if global_rank == 0:
                    print(f'Epoch {epoch + 1} validation loss: {torch.tensor(epoch_val_loss).mean().item():.4f}')
    
        else:
            
            val_losses.append(np.nan)
            val_perplex.append(np.nan)
    
        # Early stopping
    
        if VALIDATION and global_rank == 0:

            # Post-validation checkpoint
    
            if not np.isnan(val_losses[-1]) and val_losses[-1] < best_score * (1 + TOLERANCE):
                best_score = val_losses[-1]
    
                checkpoint_source = 'Post validation improvement checkpointing'
                log_list = \
                    [str(datetime.datetime.now()), dataset_train.dataset_type, training_steps, epoch + 1,
                     str(train_losses), str(val_losses), str(train_perplex), str(val_perplex), checkpoint_source, str(epoch_train_loss),
                     ATTN_BASE_LR, FFN_BASE_LR, NORM_BASE_LR, HEAD_BASE_LR, WEIGHT_DECAY, BIAS_DECAY, NORM_DECAY,
                     LR_LAYER_DECAY, WARMUP, EPOCHS, GLOBAL_BATCH_SIZE, LAMBDA_MAGNITUDE_DRIFT, MAGNITUDE_SHIFT_WARMUP]
                llm_utility.save_model_checkpoint(FINETUNED_MODEL_PATH, log_param_list = log_list, 
                                                      model = model_module, tokenizer = tokenizer)
                patience_epochs = 0
                
            elif not np.isnan(val_losses[-1]) and val_losses[-1] >= best_score * (1 + TOLERANCE):
                patience_epochs += 1
                if patience_epochs >= PATIENCE:
                    print('Early stopping triggered!')
                    break

        # Final checkpoint
        
        elif (epoch + 1) == EPOCHS and global_rank == 0:
    
            checkpoint_source = 'Final training checkpoint'
            log_list = \
                [str(datetime.datetime.now()), dataset_train.dataset_type, training_steps, epoch + 1,
                 str(train_losses), str(val_losses), str(train_perplex), str(val_perplex), checkpoint_source, str(epoch_train_loss),
                 ATTN_BASE_LR, FFN_BASE_LR, NORM_BASE_LR, HEAD_BASE_LR, WEIGHT_DECAY, BIAS_DECAY, NORM_DECAY,
                 LR_LAYER_DECAY, WARMUP, EPOCHS, GLOBAL_BATCH_SIZE, LAMBDA_MAGNITUDE_DRIFT, MAGNITUDE_SHIFT_WARMUP]
            llm_utility.save_model_checkpoint(FINETUNED_MODEL_PATH, log_param_list = log_list, 
                                                  model = model_module, tokenizer = tokenizer)

    # Clean up processes
    
    if parallelized:
        dist.destroy_process_group()
        
def main():
    
    world_size = torch.cuda.device_count()

    if world_size > 1 and DEVICE == torch.device('cuda'):
        # Spawn processes, one per GPU
        mp.spawn(train, args = (world_size, ), nprocs = world_size, join = True)
        
    else:  
        # Single process training
        train(rank = 0, world_size = 1)

if __name__ == "__main__":
    main()
