In [6]:
import torch
from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP, FullStateDictConfig, StateDictType
from torch.distributed.fsdp.api import CPUOffload, ShardingStrategy
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
import lightning as L
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from transformers import HfArgumentParser
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer

from genc.trainer.load_model import load_model
from genc.args import DataArguments, ModelArguments, TrainingArguments, ValidationArgument
from genc.trainer.trainer_utils import (
    choose_logger,
    get_default_supported_precision,
    get_wrapping_policy,
)

config_file = "output/checkpoints/7b-esft_msmarco/config.yaml"
checkpoint_path = "output/checkpoints/7b-esft_msmarco/checkpoints_1/step_50.ckpt"

parser = HfArgumentParser((DataArguments, ModelArguments, TrainingArguments, ValidationArgument))
data_args, model_args, training_args, validation_args = parser.parse_yaml_file(yaml_file=config_file)

model, tokenizer = load_model(
        model_weights_name_or_path=model_args.model_name_or_path,
        use_bidirectional=model_args.use_bidirectional,
        normalized=model_args.normalized,
        pooling_method=model_args.pooling_method,
        loss_gen_type=model_args.loss_gen_type,
        temperature=model_args.temperature,
        quantization=model_args.quantization,
        use_lora=model_args.use_lora,
        train_adapter_name=model_args.train_adapter_name,
        lora_weights_name_or_path=model_args.lora_weights_name_or_path,
        lora_target_modules=["all"],
        lora_r=model_args.lora_r,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        inference=False,
        low_memory=training_args.low_memory,
        torch_dtype=torch.bfloat16,
        compute_dtype=torch.bfloat16,
        precision=training_args.precision,
        rank=0,
        local_rank=0,
        gradient_checkpointing=training_args.gradient_checkpointing,
        attn_implementation=model_args.attn_implementation,
    )

checkpoint = torch.load('output/checkpoints/7b-esft_msmarco/checkpoints_1/step_50.ckpt', map_location='cpu')
model.load_state_dict(checkpoint['model'])

model = model.merge_and_unload()

model.save_pretrained('checkpoint/7b-esft_msmarco-50')
tokenizer.save_pretrained('checkpoint/7b-esft_msmarco-50')




Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.76s/it]
You provided 'all' as target modules, we will use all the model to which LoRA can be applied.


Rank 0: Model created: 0.000 GiB
trainable params: 37,748,736 || all params: 7,279,480,832 || trainable%: 0.518563574397499


('checkpoint/7b-esft_msmarco-50/tokenizer_config.json',
 'checkpoint/7b-esft_msmarco-50/special_tokens_map.json',
 'checkpoint/7b-esft_msmarco-50/tokenizer.model',
 'checkpoint/7b-esft_msmarco-50/added_tokens.json',
 'checkpoint/7b-esft_msmarco-50/tokenizer.json')