基于GPT-2预训练模型的prompt learning：通过人工定义prompt template与verbalizer，进行句子情感分类

In [1]:
pip install transformers

Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple
Collecting transformers
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/5b/0b/e45d26ccd28568013523e04f325432ea88a442b4e3020b757cf4361f0120/transformers-4.30.2-py3-none-any.whl (7.2 MB)
[K     |████████████████████████████████| 7.2 MB 100.0 MB/s eta 0:00:01
Collecting regex!=2019.12.17
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/9a/05/18911646681dfab0ffb76b4b958356c0a3d211bb08e9a2f33f1e9487977d/regex-2024.4.16-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (761 kB)
[K     |████████████████████████████████| 761 kB 19.9 MB/s eta 0:00:01
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/4d/40/ab3c3c705e0a8cbbe760c49302b407190201d96fe7dfeea37ccafa004da3/tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[K     |████████████████████████████████| 7.8 MB 112.5 MB/s e

In [None]:
import os

import numpy
import torch
from sklearn.model_selection import train_test_split
from pathlib import Path
import codecs
import math
import random
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

from tqdm import tqdm
from easydict import EasyDict as edict

: 

In [3]:

cfg = edict({
    'name': 'movie review',
    'pre_trained': True,
    'num_classes': 2,
    'batch_size': 15,
    'epoch_size': 3,
    'weight_decay': 3e-5,
    'data_path': "./data/prompt tuning/data/",
    'checkpoint_path': 'soft-prompt.pth',
    'device_name':"cuda" if torch.cuda.is_available() else "cpu",
    'gpt2_model':'./gpt2',
    'prompt_len':10,
    'max_len' : 100,
    'classes':[['positive'],['negative']],
    'split': 0.8,
    'device_target': 'Ascend',
    'device_id': 0,
    'keep_checkpoint_max': 1,
    'word_len': 768,
    'vec_length': 40,
})

In [4]:

## load model ##

tokenizer = GPT2TokenizerFast.from_pretrained(
    cfg.gpt2_model, add_prefix_space=True
)
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(cfg.gpt2_model)
# 冻结所有参数
for param in model.parameters():
    param.requires_grad = False
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [10]:
# 数据预览
with open(cfg.data_path + "rt-polarity.neg", 'r', encoding='utf-8') as f:
        print("Negative reivews:")
        for i in range(5):
            print("[{0}]:{1}".format(i,f.readline()))
with open(cfg.data_path + "rt-polarity.pos", 'r', encoding='utf-8') as f:
        print("Positive reivews:")
        for i in range(5):
            print("[{0}]:{1}".format(i,f.readline()))

Negative reivews:
[0]:simplistic , silly and tedious . 

[1]:it's so laddish and juvenile , only teenage boys could possibly find it funny . 

[2]:exploitative and largely devoid of the depth or sophistication that would make watching such a graphic treatment of the crimes bearable . 

[3]:[garbus] discards the potential for pathological study , exhuming instead , the skewed melodrama of the circumstantial situation . 

[4]:a visually flashy but narratively opaque and emotionally vapid exercise in style and mystification . 

Positive reivews:
[0]:the rock is destined to be the 21st century's new " conan " and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal . 

[1]:the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson's expanded vision of j . r . r . tolkien's middle-earth . 

[2]:effective but too-tepid biopic

In [5]:
class CustomDataset(Dataset):
    #data: list[dict[str, torch.Tensor]]

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
class MovieDataset:
    '''
    影评数据集
    '''
    def __init__(self, root_dir, maxlen, split):
        '''
        input:
            root_dir: 影评数据目录
            maxlen: 设置句子最大长度
            split: 设置数据集中训练/评估的比例
        '''
        self.path = root_dir
        self.files = []

        self.doConvert = False
        
        mypath = Path(self.path)
        if not mypath.exists() or not mypath.is_dir():
            print("please check the root_dir!")
            raise ValueError

        # 在数据目录中找到文件
        for root,_,filename in os.walk(self.path):
            for each in filename:
                self.files.append(os.path.join(root,each))
            break

        # 确认是否为两个文件.neg与.pos
        if len(self.files) != 2:
            print("There are {} files in the root_dir".format(len(self.files)))
            raise ValueError

        # 读取数据
        self.word_num = 0
        self.maxlen = 0
        self.minlen = float("inf")
        self.maxlen = float("-inf")
        self.Pos = []
        self.Neg = []
        self.sentences = []
        self.isShuffle = True

        for filename in self.files:
            f = codecs.open(filename, 'r')
            ff = f.read()
            file_object = codecs.open(filename, 'w', 'utf-8')
            file_object.write(ff)
            self.read_data(filename)

        self.Pos = self.process_data(self.Pos, cfg.classes[0][0])
        self.Neg = self.process_data(self.Neg, cfg.classes[1][0])
        
        #self.text2vec(maxlen=maxlen)
        self.split_dataset(split=split)

    def read_data(self, filePath):
        with open(filePath,'r') as f:
            for sentence in f.readlines():
                sentence = sentence.replace('\n','')\
                    .replace('"','')\
                    .replace('\'','')\
                    .replace('.','')\
                    .replace(',','')\
                    .replace('[','')\
                    .replace(']','')\
                    .replace('(','')\
                    .replace(')','')\
                    .replace(':','')\
                    .replace('--','')\
                    .replace('-',' ')\
                    .replace('\\','')\
                    .replace('0','')\
                    .replace('1','')\
                    .replace('2','')\
                    .replace('3','')\
                    .replace('4','')\
                    .replace('5','')\
                    .replace('6','')\
                    .replace('7','')\
                    .replace('8','')\
                    .replace('9','')\
                    .replace('`','')\
                    .replace('=','')\
                    .replace('$','')\
                    .replace('/','')\
                    .replace('*','')\
                    .replace(';','')\
                    .replace('<b>','')\
                    .replace('%','')
                if sentence:
                    self.word_num += len(sentence.split(' '))
                    self.maxlen = max(self.maxlen, len(sentence.split(' ')))
                    self.minlen = min(self.minlen, len(sentence.split(' ')))
                    if 'pos' in filePath:
                        self.Pos.append([sentence, self.feelMap['pos']])
                    else:
                        self.Neg.append([sentence, self.feelMap['neg']])

    def process_data(self, data_set, tag):
        ret = []
        for line in data_set:
            res = tokenizer(
                line.strip('\n'),
                return_tensors="pt",
                text_target=tag,
                padding='max_length',
                max_length=cfg.max_len + cfg.prompt_len,
                add_special_tokens=True,
            )
            res['text'] = line
            res['input_ids'] = res['input_ids'].squeeze(0)
            res['labels'] = res['labels'].squeeze(0)
            res['attention_mask'] = res['attention_mask'].squeeze(0)
            res['answer'] = tag
            res['len'] = res['attention_mask'].sum()
            res['attention_mask'][res['len']:res['len'] + cfg.prompt_len] = 1
            ret.append(res)
        return ret

    def split_dataset(self, split):
        '''
        分割为训练集与测试集

        '''

        trunk_pos_size = math.ceil((1-split)*len(self.Pos))
        trunk_neg_size = math.ceil((1-split)*len(self.Neg))
        trunk_num = int(1/(1-split))
        pos_temp=list()
        neg_temp=list()
        for index in range(trunk_num):
            pos_temp.append(self.Pos[index*trunk_pos_size:(index+1)*trunk_pos_size])
            neg_temp.append(self.Neg[index*trunk_neg_size:(index+1)*trunk_neg_size])
        self.test = pos_temp.pop(2)+neg_temp.pop(2)
        self.train = [i for item in pos_temp+neg_temp for i in item]

        random.shuffle(self.train)
        # random.shuffle(self.test)

    def get_dict_len(self):
        '''
        获得数据集中文字组成的词典长度
        '''
        if self.doConvert:
            return len(self.Vocab)
        else:
            print("Haven't finished Text2Vec")
            return -1
        
    def train_dataset(self):
        return CustomDataset(self.train)
    
    def test_dataset(self):
        return CustomDataset(self.test) 



In [7]:
instance = MovieDataset(cfg.data_path, maxlen=cfg.max_len, split = cfg.split)
train_dataset = instance.train_dataset()
test_dataset = instance.test_dataset()

In [14]:
def create_soft_prompt(length):
    prompt = \
        tokenizer('bool Attitude equeals ', max_length=length, padding='max_length',
                  return_tensors='np')[
            'input_ids']
    prompt = numpy.array([prompt, ])
    initial_sp = model.transformer.wte(torch.from_numpy(prompt).to(cfg.device_name))
    # initial_sp = torch.rand([1, length, model.config.n_embd])
    sp = torch.nn.Parameter(initial_sp[0], requires_grad=True)
    return sp

In [15]:
soft_prompt = create_soft_prompt(cfg.prompt_len)
# 生成 soft prompt embeddings
prompt_embeddings = (
        soft_prompt
        .to(cfg.device_name)
    )

In [16]:

# 只优化 soft prompt 的参数
optimizer = AdamW([soft_prompt], lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()


In [17]:
class GPT2WithHardPromptTuning(nn.Module):
    def __init__(self, gpt2_model, tokenizer):
        super().__init__()
        self.gpt2_model = gpt2_model
        self.tokenizer = tokenizer
    def forward(self, batch):
        input_ids = batch["input_ids"].to(cfg.device_name)
        target_ids = batch["labels"].to(cfg.device_name)

        # sentence_embeddings = model.transformer.wte(input_ids)
        sentence_embeddings = model.transformer.wte(input_ids).to(cfg.device_name)

        # 生成 soft prompt embeddings
        prompt_embeddings = (
            soft_prompt
            .to(cfg.device_name)
        )
        
        for i in range(input_ids.shape[0]):
            l = batch["len"][i]
            sentence_embeddings[i, l:l + cfg.prompt_len] = prompt_embeddings

        # 执行前向传递
        output = model(
            inputs_embeds=sentence_embeddings,  # labels=target_ids
            attention_mask=batch["attention_mask"].to(cfg.device_name)
        )

        # 选取最后一个提示词对应的生成词
        AnswerPlace = (batch["len"] + cfg.prompt_len - 1).to(cfg.device_name)

        probabilities = torch.nn.functional.softmax(output.logits[:, :, :], dim=-1)

        answer_pb = probabilities[torch.arange(probabilities.shape[0]), AnswerPlace]
        predicted_tokens = [tokenizer.decode(s).strip() for
                            s
                            in torch.argmax(answer_pb, dim=-1)]

        batch['result'] = predicted_tokens

        # 计算损失
        answer_logits = output.logits[torch.arange(probabilities.shape[0]), AnswerPlace, :]
        loss = loss_fn(answer_logits, target_ids[:, 0])

        return answer_logits, loss

In [18]:
# 初始化带prompt的GPT-2模型
model_p = GPT2WithHardPromptTuning(model, tokenizer).to(cfg.device_name)

def train():
    data_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, )
    for epoch in range(cfg.epoch_size):
        for batch in tqdm(data_loader, desc=f"Epoch {epoch+1}/{cfg.epoch_size}"):
            answer_logits, loss = model_p(batch)

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()
            print(f"Epoch: {epoch}, Loss: {loss.item()}")
        torch.save(soft_prompt, cfg.checkpoint_path)

    if os.path.exists(cfg.checkpoint_path):
        raise RuntimeError('checkpoint already exist!')
    torch.save(soft_prompt, cfg.checkpoint_path)

In [19]:

train()

Epoch 1/3:   0%|          | 1/533 [00:18<2:40:31, 18.10s/it]

Epoch: 0, Loss: 20.74989128112793


Epoch 1/3:   0%|          | 2/533 [00:33<2:26:06, 16.51s/it]

Epoch: 0, Loss: 16.657255172729492


Epoch 1/3:   1%|          | 3/533 [00:47<2:13:43, 15.14s/it]

Epoch: 0, Loss: 11.952482223510742


Epoch 1/3:   1%|          | 4/533 [01:01<2:10:12, 14.77s/it]

Epoch: 0, Loss: 7.702881336212158


Epoch 1/3:   1%|          | 5/533 [01:14<2:05:17, 14.24s/it]

Epoch: 0, Loss: 3.544212818145752


Epoch 1/3:   1%|          | 6/533 [01:28<2:02:51, 13.99s/it]

Epoch: 0, Loss: 1.1515787839889526


Epoch 1/3:   1%|▏         | 7/533 [01:42<2:04:38, 14.22s/it]

Epoch: 0, Loss: 0.8505994081497192


Epoch 1/3:   2%|▏         | 8/533 [01:56<2:04:05, 14.18s/it]

Epoch: 0, Loss: 0.7688920497894287


Epoch 1/3:   2%|▏         | 9/533 [02:11<2:04:27, 14.25s/it]

Epoch: 0, Loss: 0.6671916842460632


Epoch 1/3:   2%|▏         | 10/533 [02:25<2:04:20, 14.26s/it]

Epoch: 0, Loss: 0.739412248134613


Epoch 1/3:   2%|▏         | 11/533 [02:39<2:03:24, 14.18s/it]

Epoch: 0, Loss: 0.7344328761100769


Epoch 1/3:   2%|▏         | 12/533 [02:54<2:05:02, 14.40s/it]

Epoch: 0, Loss: 0.7005836963653564


Epoch 1/3:   2%|▏         | 13/533 [03:08<2:04:01, 14.31s/it]

Epoch: 0, Loss: 0.7248353362083435


Epoch 1/3:   3%|▎         | 14/533 [03:22<2:03:14, 14.25s/it]

Epoch: 0, Loss: 0.7113498449325562


Epoch 1/3:   3%|▎         | 15/533 [03:36<2:03:23, 14.29s/it]

Epoch: 0, Loss: 0.7470254302024841


Epoch 1/3:   3%|▎         | 16/533 [03:50<2:00:18, 13.96s/it]

Epoch: 0, Loss: 0.6204156875610352


Epoch 1/3:   3%|▎         | 17/533 [04:03<1:59:39, 13.91s/it]

Epoch: 0, Loss: 0.7335054874420166


Epoch 1/3:   3%|▎         | 18/533 [04:17<1:59:39, 13.94s/it]

Epoch: 0, Loss: 0.9699222445487976


Epoch 1/3:   4%|▎         | 19/533 [04:31<1:57:15, 13.69s/it]

Epoch: 0, Loss: 1.103454351425171


Epoch 1/3:   4%|▍         | 20/533 [04:44<1:57:19, 13.72s/it]

Epoch: 0, Loss: 0.7959731817245483


Epoch 1/3:   4%|▍         | 21/533 [04:58<1:57:18, 13.75s/it]

Epoch: 0, Loss: 0.7147509455680847


Epoch 1/3:   4%|▍         | 22/533 [05:13<1:58:44, 13.94s/it]

Epoch: 0, Loss: 0.7467682957649231


Epoch 1/3:   4%|▍         | 23/533 [05:27<1:58:38, 13.96s/it]

Epoch: 0, Loss: 0.8469144701957703


Epoch 1/3:   5%|▍         | 24/533 [05:39<1:53:56, 13.43s/it]

Epoch: 0, Loss: 0.7006831765174866


Epoch 1/3:   5%|▍         | 25/533 [05:53<1:55:39, 13.66s/it]

Epoch: 0, Loss: 0.8261553049087524


Epoch 1/3:   5%|▍         | 26/533 [06:07<1:56:33, 13.79s/it]

Epoch: 0, Loss: 0.8403409719467163


Epoch 1/3:   5%|▌         | 27/533 [06:20<1:54:49, 13.61s/it]

Epoch: 0, Loss: 0.8800287842750549


Epoch 1/3:   5%|▌         | 28/533 [06:33<1:52:48, 13.40s/it]

Epoch: 0, Loss: 0.5444158911705017


Epoch 1/3:   5%|▌         | 29/533 [06:46<1:51:02, 13.22s/it]

Epoch: 0, Loss: 0.6494976878166199


Epoch 1/3:   6%|▌         | 30/533 [07:00<1:52:32, 13.43s/it]

Epoch: 0, Loss: 0.6705062985420227


Epoch 1/3:   6%|▌         | 31/533 [07:15<1:56:00, 13.87s/it]

Epoch: 0, Loss: 0.6384264826774597


Epoch 1/3:   6%|▌         | 32/533 [07:29<1:57:22, 14.06s/it]

Epoch: 0, Loss: 0.6418234705924988


Epoch 1/3:   6%|▌         | 33/533 [07:43<1:56:15, 13.95s/it]

Epoch: 0, Loss: 0.6980370283126831


Epoch 1/3:   6%|▋         | 34/533 [07:57<1:56:23, 13.99s/it]

Epoch: 0, Loss: 0.67220139503479


Epoch 1/3:   7%|▋         | 35/533 [08:12<1:57:40, 14.18s/it]

Epoch: 0, Loss: 0.6315585374832153


Epoch 1/3:   7%|▋         | 36/533 [08:25<1:54:59, 13.88s/it]

Epoch: 0, Loss: 0.7056344747543335


Epoch 1/3:   7%|▋         | 37/533 [08:38<1:52:20, 13.59s/it]

Epoch: 0, Loss: 0.6813927888870239


Epoch 1/3:   7%|▋         | 38/533 [08:52<1:53:07, 13.71s/it]

Epoch: 0, Loss: 0.6261942386627197


Epoch 1/3:   7%|▋         | 39/533 [09:06<1:53:06, 13.74s/it]

Epoch: 0, Loss: 0.6647101044654846


Epoch 1/3:   8%|▊         | 40/533 [09:19<1:51:48, 13.61s/it]

Epoch: 0, Loss: 0.6374157667160034


Epoch 1/3:   8%|▊         | 41/533 [09:32<1:50:05, 13.43s/it]

Epoch: 0, Loss: 0.6196144819259644


Epoch 1/3:   8%|▊         | 42/533 [09:45<1:50:16, 13.48s/it]

Epoch: 0, Loss: 0.6314032077789307


Epoch 1/3:   8%|▊         | 43/533 [10:00<1:52:48, 13.81s/it]

Epoch: 0, Loss: 0.6194866895675659


Epoch 1/3:   8%|▊         | 44/533 [10:14<1:53:02, 13.87s/it]

Epoch: 0, Loss: 0.6833812594413757


Epoch 1/3:   8%|▊         | 45/533 [10:28<1:53:51, 14.00s/it]

Epoch: 0, Loss: 0.5746106505393982


Epoch 1/3:   9%|▊         | 46/533 [10:43<1:55:34, 14.24s/it]

Epoch: 0, Loss: 0.6094716191291809


Epoch 1/3:   9%|▉         | 47/533 [10:58<1:56:56, 14.44s/it]

Epoch: 0, Loss: 0.6358164548873901


Epoch 1/3:   9%|▉         | 48/533 [11:12<1:56:22, 14.40s/it]

Epoch: 0, Loss: 0.6883097290992737


Epoch 1/3:   9%|▉         | 49/533 [11:26<1:54:55, 14.25s/it]

Epoch: 0, Loss: 0.7223533391952515


Epoch 1/3:   9%|▉         | 50/533 [11:40<1:53:22, 14.08s/it]

Epoch: 0, Loss: 0.5189816951751709


Epoch 1/3:  10%|▉         | 51/533 [11:54<1:52:26, 14.00s/it]

Epoch: 0, Loss: 0.6285799145698547


Epoch 1/3:  10%|▉         | 52/533 [12:08<1:51:59, 13.97s/it]

Epoch: 0, Loss: 0.6059518456459045


Epoch 1/3:  10%|▉         | 53/533 [12:22<1:53:30, 14.19s/it]

Epoch: 0, Loss: 0.830412745475769


Epoch 1/3:  10%|█         | 54/533 [12:36<1:51:23, 13.95s/it]

Epoch: 0, Loss: 0.7469950318336487


Epoch 1/3:  10%|█         | 55/533 [12:50<1:50:32, 13.88s/it]

Epoch: 0, Loss: 0.7072045803070068


Epoch 1/3:  11%|█         | 56/533 [13:04<1:51:19, 14.00s/it]

Epoch: 0, Loss: 0.6413148641586304


Epoch 1/3:  11%|█         | 57/533 [13:17<1:48:56, 13.73s/it]

Epoch: 0, Loss: 0.6452659368515015


Epoch 1/3:  11%|█         | 58/533 [13:31<1:50:46, 13.99s/it]

Epoch: 0, Loss: 0.6175947785377502


Epoch 1/3:  11%|█         | 59/533 [13:46<1:50:48, 14.03s/it]

Epoch: 0, Loss: 0.7208284735679626


Epoch 1/3:  11%|█▏        | 60/533 [14:01<1:53:05, 14.35s/it]

Epoch: 0, Loss: 0.6048122048377991


Epoch 1/3:  11%|█▏        | 61/533 [14:15<1:53:42, 14.45s/it]

Epoch: 0, Loss: 0.6516791582107544


Epoch 1/3:  12%|█▏        | 62/533 [14:30<1:54:31, 14.59s/it]

Epoch: 0, Loss: 0.7482892274856567


Epoch 1/3:  12%|█▏        | 63/533 [14:45<1:54:59, 14.68s/it]

Epoch: 0, Loss: 0.7443437576293945


Epoch 1/3:  12%|█▏        | 64/533 [15:00<1:53:51, 14.57s/it]

Epoch: 0, Loss: 0.6232579946517944


Epoch 1/3:  12%|█▏        | 65/533 [15:13<1:50:11, 14.13s/it]

Epoch: 0, Loss: 0.6227914094924927


Epoch 1/3:  12%|█▏        | 66/533 [15:25<1:46:22, 13.67s/it]

Epoch: 0, Loss: 0.6435189843177795


Epoch 1/3:  13%|█▎        | 67/533 [15:39<1:45:18, 13.56s/it]

Epoch: 0, Loss: 0.6929176449775696


Epoch 1/3:  13%|█▎        | 68/533 [15:53<1:47:15, 13.84s/it]

Epoch: 0, Loss: 0.5827760696411133


Epoch 1/3:  13%|█▎        | 69/533 [16:07<1:47:37, 13.92s/it]

Epoch: 0, Loss: 0.6789473295211792


Epoch 1/3:  13%|█▎        | 70/533 [16:22<1:48:31, 14.06s/it]

Epoch: 0, Loss: 0.64945387840271


Epoch 1/3:  13%|█▎        | 71/533 [16:37<1:50:27, 14.34s/it]

Epoch: 0, Loss: 0.6324805021286011


Epoch 1/3:  14%|█▎        | 72/533 [16:50<1:48:43, 14.15s/it]

Epoch: 0, Loss: 0.5561641454696655


Epoch 1/3:  14%|█▎        | 73/533 [17:04<1:46:59, 13.96s/it]

Epoch: 0, Loss: 0.6224555373191833


Epoch 1/3:  14%|█▍        | 74/533 [17:18<1:48:13, 14.15s/it]

Epoch: 0, Loss: 0.6670452356338501


Epoch 1/3:  14%|█▍        | 75/533 [17:33<1:49:01, 14.28s/it]

Epoch: 0, Loss: 0.5659826397895813


Epoch 1/3:  14%|█▍        | 76/533 [17:46<1:45:37, 13.87s/it]

Epoch: 0, Loss: 0.5760151743888855


Epoch 1/3:  14%|█▍        | 77/533 [18:00<1:45:56, 13.94s/it]

Epoch: 0, Loss: 0.6555280089378357


Epoch 1/3:  15%|█▍        | 78/533 [18:13<1:43:06, 13.60s/it]

Epoch: 0, Loss: 0.6300207376480103


Epoch 1/3:  15%|█▍        | 79/533 [18:27<1:44:15, 13.78s/it]

Epoch: 0, Loss: 0.412286639213562


Epoch 1/3:  15%|█▌        | 80/533 [18:40<1:42:42, 13.60s/it]

Epoch: 0, Loss: 0.6108880639076233


Epoch 1/3:  15%|█▌        | 81/533 [18:54<1:43:10, 13.69s/it]

Epoch: 0, Loss: 0.6768736243247986


Epoch 1/3:  15%|█▌        | 82/533 [19:10<1:47:00, 14.24s/it]

Epoch: 0, Loss: 0.5519882440567017


Epoch 1/3:  16%|█▌        | 83/533 [19:24<1:47:34, 14.34s/it]

Epoch: 0, Loss: 0.49805179238319397


Epoch 1/3:  16%|█▌        | 84/533 [19:37<1:44:19, 13.94s/it]

Epoch: 0, Loss: 0.6124197840690613


Epoch 1/3:  16%|█▌        | 85/533 [19:49<1:39:30, 13.33s/it]

Epoch: 0, Loss: 0.6355181932449341


Epoch 1/3:  16%|█▌        | 86/533 [20:04<1:42:21, 13.74s/it]

Epoch: 0, Loss: 0.6306524872779846


Epoch 1/3:  16%|█▋        | 87/533 [20:17<1:41:49, 13.70s/it]

Epoch: 0, Loss: 0.6658543348312378


Epoch 1/3:  17%|█▋        | 88/533 [20:31<1:42:28, 13.82s/it]

Epoch: 0, Loss: 0.6592490077018738


Epoch 1/3:  17%|█▋        | 89/533 [20:44<1:40:13, 13.54s/it]

Epoch: 0, Loss: 0.5678051710128784


Epoch 1/3:  17%|█▋        | 90/533 [20:58<1:39:40, 13.50s/it]

Epoch: 0, Loss: 0.49177494645118713


Epoch 1/3:  17%|█▋        | 91/533 [21:12<1:41:40, 13.80s/it]

Epoch: 0, Loss: 0.5648319125175476


Epoch 1/3:  17%|█▋        | 92/533 [21:26<1:42:31, 13.95s/it]

Epoch: 0, Loss: 0.7318971753120422


Epoch 1/3:  17%|█▋        | 93/533 [21:40<1:42:11, 13.94s/it]

Epoch: 0, Loss: 0.5629085898399353


Epoch 1/3:  18%|█▊        | 94/533 [21:54<1:41:13, 13.84s/it]

Epoch: 0, Loss: 0.5670391321182251


Epoch 1/3:  18%|█▊        | 95/533 [22:09<1:42:40, 14.06s/it]

Epoch: 0, Loss: 0.6676802039146423


Epoch 1/3:  18%|█▊        | 96/533 [22:23<1:42:30, 14.07s/it]

Epoch: 0, Loss: 0.5593950152397156


Epoch 1/3:  18%|█▊        | 97/533 [22:37<1:42:20, 14.08s/it]

Epoch: 0, Loss: 0.7344688773155212


Epoch 1/3:  18%|█▊        | 98/533 [22:51<1:43:13, 14.24s/it]

Epoch: 0, Loss: 0.45706814527511597


Epoch 1/3:  19%|█▊        | 99/533 [23:07<1:45:03, 14.53s/it]

Epoch: 0, Loss: 0.542537271976471


Epoch 1/3:  19%|█▉        | 100/533 [23:20<1:41:18, 14.04s/it]

Epoch: 0, Loss: 0.5046890377998352


Epoch 1/3:  19%|█▉        | 101/533 [23:34<1:42:18, 14.21s/it]

Epoch: 0, Loss: 0.4005897641181946


Epoch 1/3:  19%|█▉        | 102/533 [23:49<1:43:19, 14.38s/it]

Epoch: 0, Loss: 0.38264888525009155


Epoch 1/3:  19%|█▉        | 103/533 [24:04<1:43:59, 14.51s/it]

Epoch: 0, Loss: 0.5531531572341919


Epoch 1/3:  20%|█▉        | 104/533 [24:18<1:44:09, 14.57s/it]

Epoch: 0, Loss: 0.3895440101623535


Epoch 1/3:  20%|█▉        | 105/533 [24:32<1:42:53, 14.43s/it]

Epoch: 0, Loss: 0.5780152678489685


Epoch 1/3:  20%|█▉        | 106/533 [24:47<1:42:36, 14.42s/it]

Epoch: 0, Loss: 0.5200769901275635


Epoch 1/3:  20%|██        | 107/533 [25:02<1:43:49, 14.62s/it]

Epoch: 0, Loss: 0.7561447620391846


Epoch 1/3:  20%|██        | 108/533 [25:16<1:42:02, 14.41s/it]

Epoch: 0, Loss: 0.5368738174438477


Epoch 1/3:  20%|██        | 109/533 [25:31<1:42:25, 14.49s/it]

Epoch: 0, Loss: 0.6167789697647095


Epoch 1/3:  21%|██        | 110/533 [25:46<1:44:05, 14.77s/it]

Epoch: 0, Loss: 0.4739689230918884


Epoch 1/3:  21%|██        | 111/533 [26:01<1:44:33, 14.87s/it]

Epoch: 0, Loss: 0.6110771298408508


Epoch 1/3:  21%|██        | 112/533 [26:16<1:44:10, 14.85s/it]

Epoch: 0, Loss: 0.45002180337905884


Epoch 1/3:  21%|██        | 113/533 [26:30<1:43:11, 14.74s/it]

Epoch: 0, Loss: 0.5081257224082947


Epoch 1/3:  21%|██▏       | 114/533 [26:44<1:40:33, 14.40s/it]

Epoch: 0, Loss: 0.5350406169891357


Epoch 1/3:  22%|██▏       | 115/533 [26:58<1:40:06, 14.37s/it]

Epoch: 0, Loss: 0.8842815160751343


Epoch 1/3:  22%|██▏       | 116/533 [27:13<1:39:43, 14.35s/it]

Epoch: 0, Loss: 0.6873986124992371


Epoch 1/3:  22%|██▏       | 117/533 [27:27<1:40:00, 14.42s/it]

Epoch: 0, Loss: 0.6155678033828735


Epoch 1/3:  22%|██▏       | 118/533 [27:41<1:39:06, 14.33s/it]

Epoch: 0, Loss: 0.4947533905506134


Epoch 1/3:  22%|██▏       | 119/533 [27:55<1:38:22, 14.26s/it]

Epoch: 0, Loss: 0.3449285328388214


Epoch 1/3:  23%|██▎       | 120/533 [28:10<1:38:02, 14.24s/it]

Epoch: 0, Loss: 0.9938010573387146


Epoch 1/3:  23%|██▎       | 121/533 [28:24<1:38:19, 14.32s/it]

Epoch: 0, Loss: 0.8130161166191101


Epoch 1/3:  23%|██▎       | 122/533 [28:39<1:38:51, 14.43s/it]

Epoch: 0, Loss: 0.5617355704307556


Epoch 1/3:  23%|██▎       | 123/533 [28:53<1:38:45, 14.45s/it]

Epoch: 0, Loss: 0.5844780206680298


Epoch 1/3:  23%|██▎       | 124/533 [29:08<1:39:50, 14.65s/it]

Epoch: 0, Loss: 0.5138282179832458


Epoch 1/3:  23%|██▎       | 125/533 [29:24<1:41:45, 14.96s/it]

Epoch: 0, Loss: 0.5400484204292297


Epoch 1/3:  24%|██▎       | 126/533 [29:39<1:41:46, 15.00s/it]

Epoch: 0, Loss: 0.6303163766860962


Epoch 1/3:  24%|██▍       | 127/533 [29:55<1:42:32, 15.15s/it]

Epoch: 0, Loss: 0.57998126745224


Epoch 1/3:  24%|██▍       | 128/533 [30:09<1:41:09, 14.99s/it]

Epoch: 0, Loss: 0.5379920601844788


Epoch 1/3:  24%|██▍       | 129/533 [30:23<1:39:06, 14.72s/it]

Epoch: 0, Loss: 0.6142863631248474


Epoch 1/3:  24%|██▍       | 130/533 [30:39<1:40:26, 14.95s/it]

Epoch: 0, Loss: 0.6433258056640625


Epoch 1/3:  25%|██▍       | 131/533 [30:54<1:40:05, 14.94s/it]

Epoch: 0, Loss: 0.5524150729179382


Epoch 1/3:  25%|██▍       | 132/533 [31:08<1:37:45, 14.63s/it]

Epoch: 0, Loss: 0.4959432780742645


Epoch 1/3:  25%|██▍       | 133/533 [31:21<1:35:39, 14.35s/it]

Epoch: 0, Loss: 0.5496160984039307


Epoch 1/3:  25%|██▌       | 134/533 [31:36<1:36:30, 14.51s/it]

Epoch: 0, Loss: 0.569138765335083


Epoch 1/3:  25%|██▌       | 135/533 [31:51<1:36:26, 14.54s/it]

Epoch: 0, Loss: 0.6699613332748413


Epoch 1/3:  26%|██▌       | 136/533 [32:05<1:35:31, 14.44s/it]

Epoch: 0, Loss: 0.3858725130558014


Epoch 1/3:  26%|██▌       | 137/533 [32:19<1:33:26, 14.16s/it]

Epoch: 0, Loss: 0.4793683886528015


Epoch 1/3:  26%|██▌       | 138/533 [32:34<1:34:40, 14.38s/it]

Epoch: 0, Loss: 0.5495122671127319


Epoch 1/3:  26%|██▌       | 139/533 [32:48<1:34:39, 14.42s/it]

Epoch: 0, Loss: 0.5260503888130188


Epoch 1/3:  26%|██▋       | 140/533 [33:02<1:33:13, 14.23s/it]

Epoch: 0, Loss: 0.46835049986839294


Epoch 1/3:  26%|██▋       | 141/533 [33:16<1:33:18, 14.28s/it]

Epoch: 0, Loss: 0.3780846893787384


Epoch 1/3:  26%|██▋       | 141/533 [33:32<1:33:13, 14.27s/it]


KeyboardInterrupt: 

In [None]:
def test():
    global soft_prompt
    soft_prompt = torch.load(cfg.checkpoint_path)
    print('model loaded')
    cfg.batch_size = 1
    data_loader = DataLoader(CustomDataset(test_dataset), batch_size=cfg.batch_size, )
    total = 0
    correct = 0
    for batch in data_loader:
        answer_logits, loss = model_p(batch)
        for i in range(cfg.batch_size):
            print(f'text: {batch["text"][i].strip()}')
            print(f'result/answer: {batch["result"][i].strip()}/{batch["answer"][i].strip()}')
            print()
            total += 1
            if batch["result"][i] == batch["answer"][i]:
                correct += 1

        print(f'correct: {correct}/{total} = {correct / total}')


In [None]:

test()