In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import torch

In [None]:
# Check if GPU is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:

# Load dataset
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

dataset


In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Tokenize the dataset
def tokenize_function(examples):
    inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128)
    inputs['labels'] = inputs['input_ids'].copy()
    return inputs

tokenised_datasets = dataset.map(tokenize_function, batched=True)

tokenised_datasets

In [None]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='model/',
    evaluation_strategy='epoch',
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='model/logs',
    # no_cuda=True,
    use_mps_device=True,  #VERY IMPORTANT PARAM
)

model = AutoModelForCausalLM.from_pretrained('distilgpt2')
model.to(device)
# base_model.to("mps")

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenised_datasets['train'],
    eval_dataset=tokenised_datasets['validation']
)

# Train the model
# trainer.train(resume_from_checkpoint='model/checkpoints/checkpoint')
trainer.train()



In [None]:
# save the model and tokenizer explicitly
model_output_dir = 'model/trained_model'

model.save_pretrained(model_output_dir)
tokenizer.save_pretrained(model_output_dir)