In [1]:
import torch
import os
import math
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("GPT2")
pretrained_model = GPT2LMHeadModel.from_pretrained('GPT2', pad_token_id=tokenizer.eos_token_id).to("cuda:7")

In [4]:
pretrained_model.device

device(type='cuda', index=7)

In [5]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [6]:
datasets = load_dataset('wikitext', 'wikitext-103-raw-v1')

Found cached dataset wikitext (/home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 102.34it/s]


In [7]:
datasets['train']

Dataset({
    features: ['text'],
    num_rows: 1801350
})

In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [9]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0373ac50873b497f_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-96928dfcd8be926a_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-bf8f68fe7370cf91_*_of_00004.arrow


In [10]:
tokenized_datasets["train"]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1801350
})

In [11]:
block_size = 526

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [13]:
lm_datasets = tokenized_datasets.map(
    group_texts, 
    batched=True,
    num_proc=4
)

Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-09d3da81c9b157a4_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-bb5c35209c9889fd_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ae030aaeea65564e_*_of_00004.arrow


In [22]:
lm_datasets["train"].shard(num_shards=8, index=0)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 27912
})

In [3]:
from ensemble import Ensemble
num_ensemble = 8
model_paths = []
model_name = "GPT2"
subset_data = "wikitext-103-raw-v1"
for i in range(num_ensemble):
    model_paths.append(os.path.join("models", f"lora-{model_name}-{i}-finetuned-{subset_data}"))

In [4]:
Ensemble(model_paths, pretrained_model)

TypeError: 'PeftModelForCausalLM' object is not iterable