In [20]:
import torch
import os
import math
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

In [6]:
data_dir = "data"
train_file = os.path.join(data_dir, "wikitext-103-train-corpus.pt")
train = torch.load(train_file)
train.shape

torch.Size([116635588])

In [3]:
tokenizer = GPT2Tokenizer.from_pretrained("GPT2")
pretrained_model = GPT2LMHeadModel.from_pretrained('GPT2', pad_token_id=tokenizer.eos_token_id).to("cuda:7")

In [4]:
pretrained_model.device

device(type='cuda', index=7)

In [5]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [6]:
datasets = load_dataset('wikitext', 'wikitext-103-raw-v1')

Found cached dataset wikitext (/home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 102.34it/s]


In [7]:
datasets['train']

Dataset({
    features: ['text'],
    num_rows: 1801350
})

In [8]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

In [9]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-0373ac50873b497f_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-96928dfcd8be926a_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-bf8f68fe7370cf91_*_of_00004.arrow


In [10]:
tokenized_datasets["train"]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1801350
})

In [11]:
block_size = 526

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [13]:
lm_datasets = tokenized_datasets.map(
    group_texts, 
    batched=True,
    num_proc=4
)

Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-09d3da81c9b157a4_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-bb5c35209c9889fd_*_of_00004.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-ae030aaeea65564e_*_of_00004.arrow


In [22]:
lm_datasets["train"].shard(num_shards=8, index=0)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 27912
})

In [66]:
num_ensemble = 8
epochs = 10
device = "cuda:7"
train_share = lm_datasets['train'].num_rows // num_ensemble
val_share = lm_datasets['validation'].num_rows // num_ensemble

def accuracy(preds, labels):
    return (preds == labels).mean()

for i in range(8):
    lm_shards = {}
    lm_shards['train'] = lm_datasets['train'].shard(num_shards=num_ensemble, index=i)
    lm_shards['validation'] = lm_datasets['train'].shard(num_shards=num_ensemble, index=i)

    
    train_loader = DataLoader(
    lm_datasets['train'],
    batch_size=4
    )

    validation_loader = DataLoader(
    lm_datasets['train'],
    batch_size=4
    )

    lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM"
    )
    lora_model = get_peft_model(pretrained_model, lora_config).to(device)
    print(lora_model.device)

    model_name = f"lora-gpt2-{i}-finetuned-wikitext2"

    for epoch in range(epochs):
        train_loss = []
        train_acc = []
        val_loss = []
        val_acc = []
        for batch in tqdm(train_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = lora_model(batch)
            loss, logits = outputs[:2]
            preds = np.argmax(logits.)

cuda:7


RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/peft/peft_model.py", line 678, in forward
    return self.base_model(
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1080, in forward
    transformer_outputs = self.transformer(
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 846, in forward
    inputs_embeds = self.wte(input_ids)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/home/james/priv_pred_ensemble/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper_CUDA__index_select)
