# Small GPT

### 1. Load Model

In [1]:
from transformers import OpenAIGPTTokenizer, OpenAIGPTLMHeadModel

# Load pre-trained tokenizer and model
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')

  from .autonotebook import tqdm as notebook_tqdm
ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.


In [2]:
print(model)

OpenAIGPTLMHeadModel(
  (transformer): OpenAIGPTModel(
    (tokens_embed): Embedding(40478, 768)
    (positions_embed): Embedding(512, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (attn): Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (lm_head): Linear(in_features=768, out_features=40478, bias=False)
)


In [3]:
# Calculate the number of parameters
total_params = sum(p.numel() for p in model.parameters())

# Print the number of parameters
print(f"Total number of parameters: {total_params}")

Total number of parameters: 116534784


### 2. Load Fine-tuning Dataset

In [4]:
from datasets import load_dataset

# Load a summarization dataset (CNN/DailyMail)
dataset = load_dataset('cnn_dailymail', '3.0.0')

In [5]:
# Set the `eos_token` as the `pad_token`
# tokenizer.pad_token = tokenizer.eos_token  # or use `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

# Tokenization of dataset
def tokenize_data(example):
    inputs = tokenizer(
        example['article'],
        padding='max_length',  # Use padding here
        max_length=512,
        truncation=True,
    )
    labels = tokenizer(
        example['highlights'],
        padding='max_length',  # Use padding here
        max_length=512,
        truncation=True,
    )
    # print(f"Input Length: {len(inputs['input_ids'])}, Label Length: {len(labels['input_ids'])}")
    inputs['labels'] = labels['input_ids']
    return inputs


In [6]:
# Tokenize dataset
train_data = dataset['train'].select(range(1000)).map(tokenize_data, batched=True)
val_data = dataset['validation'].select(range(200)).map(tokenize_data, batched=True)

In [7]:
len(train_data[0]['labels'])

512

In [8]:
len(val_data[0]['input_ids'])

512

### 3. Fine-tune the Model

In [9]:
from transformers import Trainer, TrainingArguments

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',            # Directory to save the model
    num_train_epochs=3,                # Number of training epochs
    per_device_train_batch_size=4,     # Batch size for training
    per_device_eval_batch_size=4,      # Batch size for evaluation
    warmup_steps=50,                  # Warmup steps
    weight_decay=0.01,                 # Weight decay
    logging_dir='./logs',              # Directory for logs
    logging_steps=10,
    evaluation_strategy="epoch"        # Evaluate after every epoch
)

# Trainer for fine-tuning
trainer = Trainer(
    model=model,                       # Pre-trained model
    args=training_args,                # Training arguments
    train_dataset=train_data,          # Training dataset
    eval_dataset=val_data              # Evaluation dataset
)

# Fine-tune the model
trainer.train()

  0%|          | 0/750 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [40]:
# Save the fine-tuned model
model.save_pretrained('./fine-tuned-gpt1-summarization')
tokenizer.save_pretrained('./fine-tuned-gpt1-summarization')

('./fine-tuned-gpt1-summarization/tokenizer_config.json',
 './fine-tuned-gpt1-summarization/special_tokens_map.json',
 './fine-tuned-gpt1-summarization/vocab.json',
 './fine-tuned-gpt1-summarization/merges.txt',
 './fine-tuned-gpt1-summarization/added_tokens.json')

### 4. Test: Summarization

In [10]:
# Load the fine-tuned model
fine_tuned_model = OpenAIGPTLMHeadModel.from_pretrained('./fine-tuned-gpt1-summarization')
fine_tuned_tokenizer = OpenAIGPTTokenizer.from_pretrained('./fine-tuned-gpt1-summarization')

# Generate summary
def generate_summary(article):
    inputs = fine_tuned_tokenizer.encode(article, return_tensors='pt', max_length=128, truncation=True)
    # print(f"Length of input_ids after encoding: {inputs.shape}")
    # print(f"Encoded input_ids: {inputs}")
    outputs = fine_tuned_model.generate(inputs, max_new_tokens=128, num_beams=5, early_stopping=True)
    summary = fine_tuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

# Test the summarization
sample_article = dataset['test'][0]['article']
summary = generate_summary(sample_article)
print("Generated Summary:")
print(summary)

ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.


Length of input_ids after encoding: torch.Size([1, 128])
Generated Summary:
( cnn ) the palestinian authority officially became the 123rd member of the international criminal court on wednesday , a step that gives the court jurisdiction over alleged crimes in palestinian territories . the formal accession was marked with a ceremony at the hague , in the netherlands , where the court is based . the palestinians signed the icc ' s founding rome statute in january , when they also accepted its jurisdiction over alleged crimes committed " in the occupied palestinian territory , including east jerusalem , since june 13 , 2014 . " later that month , the icc opened a preliminary examination into the situation in palestinian


In [18]:
# Check the model configuration for input length limits
print(f"Model max position embeddings: {fine_tuned_model.config.n_positions}")

Model max position embeddings: 512
