In [1]:
from datasets import load_dataset
ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")

In [2]:
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
device='cuda'
# モデルの準備
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

In [3]:
train_dataset=ds["train"].shuffle(seed=42).select(range(2000))
validation_dataset=ds["validation"].shuffle(seed=42).select(range(300))
train_dataset = train_dataset["text"]
train_dataset = [item for item in train_dataset if item != '']
validation_dataset=validation_dataset["text"]
validation_dataset = [item for item in validation_dataset if item != '']


In [4]:
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# 入力とラベルを設定
train_data = []
for text in tqdm(train_dataset, desc="Tokenizing dataset"):
    tokenized = tokenizer(text, padding="max_length", max_length=64, truncation=True, return_tensors="pt")
    input_ids = tokenized['input_ids'].squeeze().tolist()
    # 次の単語のインデックスをラベルとして追加
    labels = input_ids[1:] + [tokenizer.pad_token_id]  # 最初の単語を除いて次の単語をラベルにする
    train_data.append({"input_ids": input_ids, "labels": labels})

train_data


Tokenizing dataset: 100%|██████████| 1271/1271 [00:00<00:00, 2479.90it/s]


[{'input_ids': [128000,
   284,
   284,
   284,
   12877,
   6460,
   284,
   284,
   284,
   720,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009],
  'labels': [284,
   284,
   284,
   12877,
   6460,
   284,
   284,
   284,
   720,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
   128009,
 

In [5]:
import torch
input_ids = [item["input_ids"] for item in train_data]
labels = [item["labels"] for item in train_data]

In [6]:
input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
labels_tensor = torch.tensor(labels, dtype=torch.long)

In [7]:
class MyDataset(Dataset):
    def __init__(self, input_ids, labels):
        self.input_ids = input_ids
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.labels[idx]

# Datasetインスタンスの作成
dataset = MyDataset(input_ids_tensor, labels_tensor)

In [8]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [9]:
from torch.optim import AdamW

# オプティマイザの定義
optimizer = AdamW(teacher_model.parameters(), lr=5e-5)


In [10]:
print(dataloader)

<torch.utils.data.dataloader.DataLoader object at 0x7fbea1ee0be0>


In [None]:
epochs = 3  # エポック数
device='cuda'
teacher_model.to(device)

for epoch in range(epochs):
    for batch in dataloader:
        
        input_ids, attention_mask = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        # 勾配の初期化
        optimizer.zero_grad()
        
        # モデルの出力と損失計算
        outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # 逆伝播とパラメータの更新
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed. Loss: {loss.item()}")


KeyboardInterrupt: 

In [14]:
input_ids[3]

[128000,
 362,
 10007,
 75662,
 2543,
 65546,
 574,
 1176,
 19144,
 555,
 97847,
 304,
 220,
 3753,
 22,
 1174,
 279,
 1890,
 892,
 439,
 2380,
 1023,
 2543,
 79,
 12732,
 430,
 1053,
 13967,
 3719,
 279,
 71141,
 1174,
 63226,
 300,
 675,
 1174,
 323,
 3842,
 38988,
 65320,
 2543,
 79,
 12732,
 662,
 82493,
 3284,
 8406,
 826,
 1051,
 57098,
 1603,
 97847,
 23183,
 389,
 279,
 220,
 4161,
 339,
 6825,
 46979,
 439,
 279,
 2816,
 315]

In [None]:
outputs = teacher_model(input_ids=input_ids[0], attention_mask=attention_mask, labels=labels)