In [212]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from tqdm import tqdm
import torch
from torch.nn import functional as F
from torch.optim import AdamW
import matplotlib.pyplot as plt
from torch import nn

device='cuda'
ds = load_dataset("rajpurkar/squad")

In [213]:
model=AutoModelForCausalLM.from_pretrained('../model/normal_model')
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token=tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

In [214]:
data_size = 100
data_size_v = 100
size = int(data_size/4)
size_v = int(data_size_v/4)
train_dataset=ds["train"].shuffle(seed=42).select(range(100))
validation_dataset=ds["validation"].shuffle(seed=42)

In [215]:
def reshape(dataset):
    reshape_dataset = [0] * len(dataset)
    for i in range(len(dataset)):
        reshape_dataset[i]="C: "+dataset[i]["context"]+" Q: "+dataset[i]["question"]+" A: "+dataset[i]["answers"]["text"][0]
    reshape_dataset = [item for item in reshape_dataset if item != '' and len(item) >= 50 and '@' not in item]
    reshape_dataset = [re.sub(r'[^a-zA-Z0-9 .:?]', '', item) for item in reshape_dataset]
    reshape_dataset = [re.sub(r'\s+', ' ', item) for item in reshape_dataset]
    return reshape_dataset[:data_size]

In [216]:
def make_data(data):
    dataset=reshape(data)
    data = []
    for text in tqdm(dataset, desc="Tokenizing dataset"):
        cq_len=len(tokenizer(text[:text.find("A:")])['input_ids'])
        tokenized = tokenizer(text, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
        input_ids = tokenized['input_ids'].squeeze().tolist()
        attention_mask = tokenized['attention_mask'].squeeze().tolist()
        labels = input_ids[1:] + [tokenizer.pad_token_id]
        for i in range(min(cq_len-2, 512)):
            labels[i]=128001
        data.append({"input_ids": input_ids, "labels": labels, "attention_mask":attention_mask})
    
    return data

In [217]:
data = make_data(train_dataset)
data_v = make_data(validation_dataset)

Tokenizing dataset:   0%|          | 0/100 [00:00<?, ?it/s]

Tokenizing dataset: 100%|██████████| 100/100 [00:00<00:00, 1346.89it/s]
Tokenizing dataset: 100%|██████████| 100/100 [00:00<00:00, 1338.46it/s]


In [218]:
labelsdata=[]
for i in range(data_size):
    data_l = [x for x in data[i]['labels'] if x != 128001]
    labelsdata.append(data_l)


In [223]:
for date_la in labelsdata:
    print(date_la) 
    print(tokenizer.decode(date_la))

[362, 25, 220, 5833]
 A: 84
[362, 25, 6603]
 A: books
[362, 25, 279, 11145]
 A: the executive
[362, 25, 1556, 7910, 299]
 A: Anjiro
[362, 25, 30853]
 A: loops
[362, 25, 220, 17, 13, 17, 7239]
 A: 2.2 billion
[362, 25, 28058, 24245, 315, 279, 549, 815, 13, 99452, 22967]
 A: Military Governor of the U.S. Occupation Zone
[362, 25, 279, 14198, 3026]
 A: the brown men
[362, 25, 5070, 1436, 387, 17550, 311, 279, 10977, 520, 12474, 5326]
 A: resources could be targeted to the communities at greatest risk
[362, 25, 26828, 61495]
 A: honey ants
[362, 25, 578, 356, 3746, 7977]
 A: The Cossacks
[362, 25, 26742, 5346, 285]
 A: verdigris
[362, 25, 220, 4468, 15, 82]
 A: 1970s
[362, 25, 8219, 55551, 329, 647]
 A: Sun Jiadong
[362, 25, 4783, 11060]
 A: House Master
[362, 25, 11888]
 A: nine
[362, 25, 384, 14946, 324, 598, 323, 47715, 1371, 51835]
 A: echiurans and sipunculan
[362, 25, 44193]
 A: Religion
[362, 25, 1370, 5893, 32893, 24569]
 A: paralyzes muscles
[362, 25, 8305, 24520]
 A: tituli
[362,

In [221]:
datav=[0,0,0,0,0]
datav2 = datav[:0]

In [222]:
datav2

[]