In [1]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
import re
from tqdm import tqdm
import torch
from torch.optim import AdamW

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
device='cuda'
# モデルの準備
teacher_model = AutoModelForCausalLM.from_pretrained("../distillLLAMA2")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset=ds["train"].shuffle(seed=42).select(range(120000))
validation_dataset=ds["validation"].shuffle(seed=42).select(range(3500))
train_dataset = train_dataset["text"]
train_dataset = [item for item in train_dataset if item != '' and len(item) >= 50 and '@' not in item]
validation_dataset=validation_dataset["text"]
validation_dataset = [item for item in validation_dataset if item != '' and len(item) >= 50 and '@' not in item]



train_dataset = [re.sub(r'[^a-zA-Z ]', '', item) for item in train_dataset]
train_dataset = [re.sub(r'\s+', ' ', item) for item in train_dataset]
validation_dataset = [re.sub(r'[^a-zA-Z ]', '', item) for item in validation_dataset]
validation_dataset = [re.sub(r'\s+', ' ', item) for item in validation_dataset]

In [2]:
# 入力とラベルを設定
train_data = []
for text in tqdm(train_dataset, desc="Tokenizing dataset"):
    tokenized = tokenizer(text, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
    input_ids = tokenized['input_ids'].squeeze().tolist()
    attention_mask = tokenized['attention_mask'].squeeze().tolist()
    labels = input_ids[1:] + [tokenizer.pad_token_id]
    labels[-1]=-100
    train_data.append({"input_ids": input_ids, "labels": labels, "attention_mask":attention_mask})

Tokenizing dataset: 100%|██████████| 23979/23979 [00:12<00:00, 1892.72it/s]


In [7]:
input_ids = [item["input_ids"] for item in train_data]

In [28]:
# 仮定: ボキャブラリサイズと頻出語のトークンIDを定義
vocab_size = teacher_model.config.vocab_size
weights = torch.ones(vocab_size)  # 全単語の重みを 1 に初期化

In [29]:
for i in input_ids:
    for j in i:
        weights[j]+=1

In [34]:
weights=1/weights

In [36]:
torch.set_printoptions(threshold=torch.inf)
weights
torch.set_printoptions(threshold=1000)