In [3]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
import re
from tqdm import tqdm
import torch
from torch.nn import functional as F
from torch.optim import AdamW
import matplotlib.pyplot as plt
from torch import nn


def reshape(dataset):
    reshape_dataset = [0] * len(dataset)
    for i in range(len(dataset)):
        reshape_dataset[i]="C: "+dataset[i]["context"]+" Q: "+dataset[i]["question"]+" A: "+dataset[i]["answers"]["text"][0]
    reshape_dataset = [item for item in reshape_dataset if item != '' and len(item) >= 50 and '@' not in item]
    reshape_dataset = [re.sub(r'[^a-zA-Z0-9 .:?]', '', item) for item in reshape_dataset]
    reshape_dataset = [re.sub(r'\s+', ' ', item) for item in reshape_dataset]
    return reshape_dataset[:data_size]

def max_length(dataset):
    max_eval=0
    for i in dataset:
        max_eval = len(i) if len(i) > max_eval else max_eval
    print(max_eval)
    return

def batch(input, size):
    batch_train=[]
    for i in range(size):
        batch_input=[input[4*i+0], input[4*i+1], input[4*i+2], input[4*i+3]]
        batch_train.append(batch_input)

    return batch_train

def make_data(data):
    dataset=reshape(data)
    data = []
    for text in tqdm(dataset, desc="Tokenizing dataset"):
        cq_len=len(tokenizer(text[:text.find("A:")])['input_ids'])
        tokenized = tokenizer(text, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
        input_ids = tokenized['input_ids'].squeeze().tolist()
        attention_mask = tokenized['attention_mask'].squeeze().tolist()
        labels = input_ids[1:] + [tokenizer.pad_token_id]
        for i in range(min(cq_len-2, 512)):
            labels[i]=128001
        data.append({"input_ids": input_ids, "labels": labels, "attention_mask":attention_mask})
    
    return data

def make_tensor(data, type, size):
    tmp = [item[type] for item in data]
    tmp = batch(tmp, size)
    tensor=torch.tensor(tmp, dtype=torch.long)
    return tensor

ds = load_dataset("rajpurkar/squad")
device='cuda'
student_model = AutoModelForCausalLM.from_pretrained("../model/initialized_distill_model2")
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

data_size = 1000
size = int(data_size/4)
train_dataset=ds["train"].shuffle(seed=42).select(range(10000))

data = make_data(train_dataset)

input_ids_tensor = make_tensor(data, "input_ids", size)
labels_tensor = make_tensor(data, "labels", size)
attention_mask_tensor = make_tensor(data, "attention_mask", size)


vocab_size = student_model.config.vocab_size
criterion = torch.nn.CrossEntropyLoss(ignore_index=128001)

criterion.to(device)
input_ids_tensor=input_ids_tensor.to(device)
labels_tensor=labels_tensor.to(device)
attention_mask_tensor=attention_mask_tensor.to(device)

student_model.to(device)
teacher_model.to(device)

epochs = 1
lr=1e-4

student_model.train()
teacher_model.train()

for j in range(epochs):
    optimizer = AdamW(student_model.parameters(), lr=lr)

    i=0
    
    input_ids=input_ids_tensor[i]
    labels=labels_tensor[i]
    attention_mask=attention_mask_tensor[i]
    optimizer.zero_grad()
    student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
    student_logits = student_outputs.logits
    student_prob=F.log_softmax(student_logits, dim=-1)
    student_prob_view = student_prob.view(-1, vocab_size)
    

    with torch.no_grad():
        teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
        teacher_logits = teacher_outputs.logits
        teacher_prob = F.softmax(teacher_logits, dim=-1)
        teacher_prob_view = teacher_prob.view(-1, vocab_size)
    

    sec_student_prob=[]
    sec_teacher_prob=[]
    for i in range(labels.view(-1).size(0)):
        if labels.view(-1)[i] == 128001:
            sec_student_prob.append(torch.zeros_like(student_prob_view[i]))  
            sec_teacher_prob.append(torch.zeros_like(teacher_prob_view[i]))
        else:
            sec_student_prob.append(student_prob_view[i])  
            sec_teacher_prob.append(teacher_prob_view[i])
    sec_student = torch.stack(sec_student_prob, dim=0)
    sec_teacher = torch.stack(sec_teacher_prob, dim=0)

    kldiv_loss=F.kl_div(sec_student, sec_teacher, reduction="none")
    kl_div_answer = kldiv_loss.sum(dim=-1)
    kl_loss=kl_div_answer.sum()

    loss= kl_loss/10
    loss.backward()
    optimizer.step()
        
    print("done: ", j+1, "/", epochs)
    lr/=10

Tokenizing dataset: 100%|██████████| 1000/1000 [00:00<00:00, 1076.93it/s]


done:  1 / 1


In [4]:
print(kl_loss)

tensor(155.9477, device='cuda:0', grad_fn=<SumBackward0>)


In [7]:
labels.view(-1)[0]

tensor(128001, device='cuda:0')

In [15]:
logits_view = logits.view(-1, vocab_size)

for i in range(labels.view(-1).size(0)):
    if labels.view(-1)[i] == 128001:
        logits_view[i].zero_()
    else:
        print(i)

138
139
140
141
635
636
637
1166
1167
1168
1169
1677
1678
1679
1680
1681


In [14]:
logits_view[2045]

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [None]:
kldiv_loss=F.kl_div(student_prob/temperature, teacher_prob/temperature, reduction="none")   

In [5]:
text = "banana"
index = text.find("na")
print(index)  # 出力: 2

2
