In [2]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
import re
from tqdm import tqdm
import torch
from torch.optim import AdamW
import matplotlib.pyplot as plt
from torch import nn

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
device='cuda'
# モデルの準備
model_before = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
model_after = AutoModelForCausalLM.from_pretrained("./model/1Btuned_model")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

data_size = 766
size = int(data_size/4)

validation_dataset=ds["validation"].shuffle(seed=42)

def reshape(dataset):
    dataset=dataset["text"]
    dataset = [item for item in dataset if item != '' and len(item) >= 50 and '@' not in item]
    dataset = [re.sub(r'[^a-zA-Z0-9 ?]', '', item) for item in dataset]
    dataset = [re.sub(r'\s+', ' ', item) for item in dataset]
    print(len(dataset))
    return dataset[:data_size]

def max_length(dataset):
    max_len=0
    for i in dataset:
        max_len = len(i) if len(i) > max_len else max_len
    print(max_len)
    return max_len

def calc_length(arr):
    num=0
    for data in arr:
        num += 1
        if data == tokenizer.pad_token_id:
            break
    return num-1

dataset=reshape(validation_dataset)
max_length(dataset)

def batch(input):
    batch_train=[]
    for i in range(size):
        batch_input=[input[4*i+0], input[4*i+1], input[4*i+2], input[4*i+3]]
        batch_train.append(batch_input)

    return batch_train

def accuracy(top_preds, labels, ignore_index):
    data_num=0
    acc_num=0
    for i in range(labels.size(0)):
        if labels[i]!=ignore_index:
            data_num += 1
            if torch.any(top_preds[i]==labels[i]):
                print("labels: ", labels[i].item(), "top_preds: ", top_preds[i])
                acc_num += 1   
    return acc_num/data_num



# 入力とラベルを設定
data = []
for text in tqdm(dataset, desc="Tokenizing dataset"):
    tokenized = tokenizer(text, padding="max_length", max_length=256, truncation=True, return_tensors="pt")
    input_ids = tokenized['input_ids'].squeeze().tolist()
    attention_mask = tokenized['attention_mask'].squeeze().tolist()
    labels = input_ids[1:] + [tokenizer.pad_token_id]
    data.append({"input_ids": input_ids, "labels": labels, "attention_mask":attention_mask})


input_ids = [item["input_ids"] for item in data]
labels = [item["labels"] for item in data]
attention_mask = [item["attention_mask"] for item in data]

input_ids = batch(input_ids)
labels = batch(labels)
attention_mask = batch(attention_mask)

input_ids_tensor = torch.tensor(input_ids, dtype=torch.long)
labels_tensor = torch.tensor(labels, dtype=torch.long)
attention_mask_tensor = torch.tensor(attention_mask, dtype=torch.long)


# 仮定: ボキャブラリサイズと頻出語のトークンIDを定義
vocab_size = model_after.config.vocab_size

input_ids_tensor=input_ids_tensor.to(device)
labels_tensor=labels_tensor.to(device)
attention_mask_tensor = attention_mask_tensor.to(device)
model_before.to(device)
model_after.to(device)

model_before.eval()
model_after.eval()

i=1
rank=5

input_ids=input_ids_tensor[i] 
labels=labels_tensor[i]
attention_mask=attention_mask_tensor[i]
with torch.no_grad():
    outputs_before = model_before(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    outputs_after = model_after(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    logits_before = outputs_before.logits
    logits_after = outputs_after.logits

top_preds_before = torch.topk(logits_before, k=rank, dim=-1).indices
top_preds_after = torch.topk(logits_after, k=rank, dim=-1).indices
acc_before = accuracy(top_preds_before.view(-1, rank), labels.view(-1), tokenizer.pad_token_id)
acc_after = accuracy(top_preds_after.view(-1, rank), labels.view(-1), tokenizer.pad_token_id)

print("acc", acc_before)
print("acc", acc_after)




767
2084


Tokenizing dataset: 100%|██████████| 766/766 [00:00<00:00, 3205.12it/s]


labels:  264 top_preds:  tensor([ 264, 1202,  872,  279,  459], device='cuda:0')
labels:  311 top_preds:  tensor([1990,  520,  311,  389,  369], device='cuda:0')
labels:  709 top_preds:  tensor([ 709, 5352,  279, 2574,  455], device='cuda:0')
labels:  279 top_preds:  tensor([ 369,  279,  872, 1603,  323], device='cuda:0')
labels:  13734 top_preds:  tensor([13734,  7359, 10877, 25946, 35851], device='cuda:0')
labels:  439 top_preds:  tensor([ 439,  369,  311,  304, 1193], device='cuda:0')
labels:  264 top_preds:  tensor([ 279,  330,  264, 1054,  350], device='cuda:0')
labels:  6453 top_preds:  tensor([ 330, 1054,  364, 6453,  350], device='cuda:0')
labels:  2489 top_preds:  tensor([2489, 5634, 1925, 6481, 9248], device='cuda:0')
labels:  311 top_preds:  tensor([311, 279,  13,  11, 627], device='cuda:0')
labels:  279 top_preds:  tensor([  279,   872,   350, 29680, 18396], device='cuda:0')
labels:  1501 top_preds:  tensor([ 1925,  1501,   350, 30950,  1567], device='cuda:0')
labels:  1515

In [13]:
print(acc_before)
print(acc_after)

[264, 311, 709, 279, 13734, 439, 264, 6453, 2489, 311, 279, 1501, 15154, 578, 18079, 4409, 13257, 60565, 91609, 323, 11517, 328, 9068, 2403, 279, 2128, 315, 80792, 83, 323, 6168, 483, 18079, 4409, 13257, 60565, 2834, 279, 220, 315, 47997, 74310, 86361, 323, 25518, 285, 64, 10016, 311, 388, 9240, 4987, 13030, 311, 19874, 279, 315, 7690, 323, 64, 24577, 304, 813, 4632, 279, 40839, 5348, 323, 25518, 285, 1555, 279, 6237, 52694, 430, 279, 4632, 374, 3196, 389, 264, 1972, 10102, 1162, 27313, 1667, 8767, 374, 311, 4048, 430, 279, 3446, 374, 1694, 1534, 369, 279, 315, 2336, 220, 220, 459, 3237, 3195, 1990, 11046, 323, 11188, 304, 220, 4468, 16, 264, 502, 14497, 927, 279, 97496, 11188, 304, 220, 3753, 18, 287, 1831, 315, 279, 4846, 315, 2326, 220, 17, 578, 4376, 574, 4174, 291, 1614, 27834, 3156, 433, 574, 311, 2585, 315, 220, 279, 279, 263, 20467, 19441, 1890, 1060, 279, 2326, 220, 17, 311, 304, 34881, 580, 18054, 433, 505, 279, 220, 2075, 84674, 578, 1051, 1903, 311, 2326, 220, 17, 30158, 11

In [7]:
only_in_arr1 = list(set(acc_before) -set(acc_after))
only_in_arr2 = list(set(acc_after) -set(acc_before))

In [8]:
print(only_in_arr1)
print(only_in_arr2)

[24577, 1162, 7695, 18, 6550, 11543, 9240, 4376, 80792, 2075, 1694, 18079, 17439, 14497, 10016, 3235, 3237, 813, 27313, 11953, 433, 8625, 6453, 1972, 97713, 64, 34881, 580, 9671, 5960, 3400, 459, 4174, 1614, 4048, 2131, 3156, 97496, 8667, 5596, 2397, 1890, 8681, 68078, 369, 755, 20467, 10102, 4987, 3196]
[15140, 8967, 13307, 653, 9454, 9070, 3314, 1139, 274, 7476, 23, 568, 409, 22555, 5054, 763]


In [10]:
tokenizer.decode(only_in_arr1)

' actress case Il3105 routes Rest half Hoy75 being Motor captured bridge Scott along express his investigated carried it remains dark real Bombaya Ignace Mad San Cal annumber state learn55 untilistique signed parts65 same Great allegiance fordef Bridge murder South based'

In [11]:
tokenizer.decode(only_in_arr2)

' river meant agents un Frank extension State into sNA8 he de Bureau political In'

In [5]:
tokenizer.decode(top_preds[0][0])

' The In After On '

In [6]:
top_preds

tensor([[[  578,   763,  4740,  1952,   220],
         [  220,   315,   358, 14853,  8105],
         [  315,   374,   220,   323,  5727],
         ...,
         [  220,   323,   374,   287,   279],
         [  220,   323,   374,   287,   315],
         [  220,   323,   287,   374,   315]],

        [[  578,   763,  4740,  1952,   220],
         [  279,   220,   264,   420,   832],
         [  605,   717,   845,   508,   806],
         ...,
         [11226, 10007, 10411,  9909,  6460],
         [ 1051,   279, 16654,  1070,  1047],
         [  311,   704,   872,   279,   323]],

        [[  578,   763,  4740,  1952,   220],
         [ 8458,  3842, 10455, 11291,  7043],
         [  220,   473,  1630,   362,   816],
         ...,
         [  220,   287,   278,   291,   357],
         [  220,   278,   287,   291,   357],
         [  278,   220,   287,   291,   357]],

        [[  578,   763,  4740,  1952,   220],
         [  274,  2834,   574,  1051,  1080],
         [ 4846,  2010, 15688,  

In [7]:
top_preds.view(-1, 5)

tensor([[  578,   763,  4740,  1952,   220],
        [  220,   315,   358, 14853,  8105],
        [  315,   374,   220,   323,  5727],
        ...,
        [  323,   574,   279,   274,   220],
        [  323,   279,   574,   274,   220],
        [  323,   279,   274,   220,   574]], device='cuda:0')

In [8]:
labels.view(-1).size()

torch.Size([1024])