# Lesson 5. Model training

In [ ]:
import warnings
warnings.filterwarnings('ignore')

# 1. Load the model to be trained

In [ ]:
import torch
from transformers import AutoModelForCausalLM

pretrained_model = AutoModelForCausalLM.from_pretrained(
    "./models/upstage/TinySolar-308m-4k-init",
    device_map="cpu", 
    torch_dtype=torch.bfloat16,
    use_cache=False,
)
pretrained_model

# 2. Load dataset

In [ ]:
import datasets
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, args, split="train"):
        """Initializes the custom dataset object."""
        self.args = args
        self.dataset = datasets.load_dataset(
            "parquet",
            data_files=args.dataset_name,
            split=split
        )

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves a single data sample from the dataset 
        at the specified index
        """
        input_ids = torch.LongTensor(self.dataset[idx]["input_ids"])
        labels = torch.LongTensor(self.dataset[idx]["input_ids"])

        return {"input_ids": input_ids, "labels": labels}

## 3. Configure Training Arguments

In [ ]:
from dataclasses import dataclass, field
import transformers

@dataclass
class CustomArguments(transformers.TrainingArguments):
    dataset_name: str = field(                           
        default="./parquet/packaged_pretrain_dataset.parquet")
    num_proc: int = field(default=1)                     
    max_seq_length: int = field(default=32)              

    seed: int = field(default=0)                         
    optim: str = field(default="adamw_torch")            
    max_steps: int = field(default=30)                   
    per_device_train_batch_size: int = field(default=2)  

    learning_rate: float = field(default=5e-5)           
    weight_decay: float = field(default=0)               
    warmup_steps: int = field(default=10)                
    lr_scheduler_type: str = field(default="linear")     
    gradient_checkpointing: bool = field(default=True)   
    dataloader_num_workers: int = field(default=2)       
    bf16: bool = field(default=True)                     
    gradient_accumulation_steps: int = field(default=1)  
    
    logging_steps: int = field(default=3)                
    report_to: str = field(default="none") 

In [ ]:
parser = transformers.HfArgumentParser(CustomArguments)
args, = parser.parse_args_into_dataclasses(
    args=["--output_dir", "output"]
)
train_dataset = CustomDataset(args=args)
print("Input shape: ", train_dataset[0]['input_ids'].shape)

## 4. Run the trainer and monitor the loss

In [ ]:
from transformers import TrainerCallback

class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            self.logs.append(logs)

    def __init__(self):
        self.logs = []

loss_logging_callback = LossLoggingCallback()

# Checking the performance of an intermediate checkpoint

In [ ]:
from transformers import AutoTokenizer

model_name_or_path = "./models/upstage/TinySolar-248m-4k"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [ ]:
from transformers import TextStreamer, AutoModelForCausalLM
import torch

model_name_or_path = "./models/output/checkpoint-10000"
model2 = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,    
)


In [ ]:
prompt = "I am an engineer. I love"

inputs = tokenizer(prompt, return_tensors="pt").to(model2.device)

streamer = TextStreamer(
    tokenizer, 
    skip_prompt=True, 
    skip_special_tokens=True
)

outputs = model2.generate(
    **inputs, 
    streamer=streamer, 
    use_cache=True, 
    max_new_tokens=64,     
    do_sample=True,
    temperature=1.0,
)