In [1]:
import datasets
import pandas as pd
import transformers
import torch
import os
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm as tqdm

from utils import EarlyStopping
import re

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
device = torch.device("cuda")

In [2]:
data = datasets.load_dataset(path='multi_woz_v22', cache_dir='/data/.cache/huggingface/datasets')

No config specified, defaulting to: multi_woz_v22/v2.2_active_only
Found cached dataset multi_woz_v22 (/data/.cache/huggingface/datasets/multi_woz_v22/v2.2_active_only/2.2.0/6719c8b21478299411a0c6fdb7137c3ebab2e6425129af831687fb7851c69eb5)


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', cache_dir='/data/.cache/huggingface/transformers', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|pad|>', return_value='pt')
tokenizer.pad_token = tokenizer.eos_token
model = transformers.AutoModelForTokenClassification.from_pretrained('gpt2', cache_dir='/data/.cache/huggingface/transformers')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.4.attn.masked_bias', 'h.10.attn.masked_bias', 'classifier.weight', 'h.0.attn.masked_bias', 'h.11.attn.masked_bias', 'h.2.attn.masked_bias', 'h.9.attn.masked_bias', 'h.7.attn.masked_bias', 'h.3.attn.masked_bias', 'h.1.attn.masked_bias', 'h.6.attn.masked_bias', 'h.8.attn.masked_bias', 'classifier.bias', 'h.5.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class DialogDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer = tokenizer, max_length=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = self.tokenizer.model_max_length if max_length is None else max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        output = {}
        utterances = self.data[index]['turns']['utterance']
        each_utt = [re.sub("[^ +a-zA-Z0-9]+", "", x) for x in utterances]
        each_utt = [x.lower() for x in each_utt]
        label = tokenizer(each_utt)['attention_mask']
        for i in range(len(label)):
            label[i][-1] = 0
        token_label = torch.tensor(sum(label, []), dtype=torch.long)[:self.max_length]
        token_label = (~token_label.bool()).float()
        if len(token_label) < self.max_length: # assign padding token label
            token_label = torch.cat([token_label[:self.max_length], torch.zeros(self.max_length - len(token_label))])
        token_dict = tokenizer(' '.join(each_utt), truncation=True, max_length=self.max_length, padding="max_length", return_tensors='pt')
        output['input_ids'], output['attention_mask'] = token_dict['input_ids'], token_dict['attention_mask']
        output['labels'] = token_label.type(torch.LongTensor)
        return output

In [33]:
###########TODO: shape 해결 필요 , squeeze() 하지 않게.
train_dataset[0]['input_ids'].shape, train_dataset[0]['attention_mask'].shape 

(torch.Size([1, 1024]), torch.Size([1, 1024]))

In [5]:
train_dataset = DialogDataset(data['train'], tokenizer)
valid_dataset = DialogDataset(data['validation'], tokenizer)
test_dataset = DialogDataset(data['test'], tokenizer)

In [6]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [7]:
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
early_stopping = EarlyStopping(patience=3, verbose=True, path=f'saved/best_model.pt')

In [8]:
epochs = 20

In [9]:
train_loss = 0.0
for each_epoch in range(1, epochs):
    model.train()
    for each_batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = each_batch['input_ids'].to(device)
        attention_mask = each_batch['attention_mask'].to(device)
        labels = each_batch['labels'].to(device)
        out = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = out.loss
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    with torch.no_grad():
        model.eval()
        valid_loss = 0.0
        for each_batch in tqdm(valid_loader):
            input_ids = each_batch['input_ids'].to(device)
            attention_mask = each_batch['attention_mask'].to(device)
            labels = each_batch['labels'].to(device)
            out = model(input_ids, attention_mask=attention_mask, labels=labels)
            valid_loss += out.loss.item()
        print(f'Epoch {each_epoch}: Train Loss: {train_loss / len(train_loader):.4f}, Valid Loss: {valid_loss / len(valid_loader):.4f}')

        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print(f"Early stopping at {each_epoch-early_stopping.patience} !")
            break

  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1: Train Loss: 0.1925, Valid Loss: 0.1721

Validation loss decreased (inf --> 10.841613).  Saving model ...


  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 2: Train Loss: 0.3549, Valid Loss: 0.1635

Validation loss decreased (10.841613 --> 10.299261).  Saving model ...


  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 3: Train Loss: 0.5050, Valid Loss: 0.1616

Validation loss decreased (10.299261 --> 10.179914).  Saving model ...


  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 4: Train Loss: 0.6434, Valid Loss: 0.1668
EarlyStopping counter: 1 out of 3


  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 5: Train Loss: 0.7695, Valid Loss: 0.1629
EarlyStopping counter: 2 out of 3


  0%|          | 0/528 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

Epoch 6: Train Loss: 0.8832, Valid Loss: 0.1773
EarlyStopping counter: 3 out of 3


NameError: name 'epoch' is not defined

In [26]:
out['logits'].shape

torch.Size([8, 1, 1024, 2])

In [107]:
def test_model(model, test_loader):
	model.load_state_dict(torch.load(f'saved/best_model.pt'))
	model.eval()
	logit_test = []
	pred_test = []
	real_test = []
	with torch.no_grad():
		for each_batch in tqdm(test_loader):
			input_ids = each_batch['input_ids'].to(device)
			attention_mask = each_batch['attention_mask'].to(device)
			out = model(input_ids, attention_mask=attention_mask)
			logit_test.append(out.logits.cpu())
			pred_test.append(torch.argmax(out.logits, dim=-1).cpu())
			real_test.append(each_batch['labels'].cpu())
	return torch.cat(logit_test), torch.cat(pred_test), torch.cat(real_test)

In [38]:
def get_pred(pred):
	prediction = []
	for each_pred in pred:
		prediction.append(each_pred[0])
	return prediction

In [108]:
test_logit, test_pred, test_real = test_model(model, test_loader)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [109]:
test_logit.shape, test_pred.shape, test_real.shape

(torch.Size([1000, 1, 1024, 2]),
 torch.Size([1000, 1, 1024]),
 torch.Size([1000, 1024]))

In [110]:
# Recall w.r.t. class==1 (끝 지점)
test_recall = ((test_pred.squeeze()==1) & (test_real==1)).sum() / (test_real==1).sum()
test_recall

tensor(0.3055)

In [111]:
# Precision w.r.t. class==1 (끝 지점)
test_precision = ((test_pred.squeeze()==1) & (test_real==1)).sum() / (test_pred.squeeze()==1).sum()
test_precision

tensor(0.4907)

In [112]:
# f1 w.r.t. class==1 (끝 지점)
test_f1_turn = 2 * test_precision * test_recall / (test_precision + test_recall)
test_f1_turn

tensor(0.3765)

In [113]:
def extract_sample(sample, sample_logits):
    end_count = 0
    for each_sample_id, each_sample_logit in zip(sample['input_ids'].squeeze(), sample_logits.squeeze()):
        print(tokenizer.decode(each_sample_id), end= " ")
        if each_sample_logit.argmax() ==1:   # threshold 없이 argmax가 1인 지점
            print()
            print(f'        %Prob% : {torch.sigmoid(each_sample_logit[1]):.4f}')
            print(f'        %Logit% : {each_sample_logit[1].item()}')
        if each_sample_id==50256:
            end_count+=1
        else:
            end_count=0
        if end_count>5:
            break

epoch 3 (=early stopping)

In [114]:
extract_sample(test_dataset[0], test_logit[0])

i  need  train  reservations  from  nor wich  to  cam bridge 
        %Prob% : 0.4344
        %Logit% : -0.26409125328063965
 i  have  133  trains  matching  your  request  is  there  a  specific  day  and  time  you  would  like  to  travel 
        %Prob% : 0.5647
        %Logit% : 0.26006728410720825
 id  like  to  leave  on  m onday  and  arrive  by  1800  there  are  12  trains  for  the  day  and  time  you  request  would  you  like  to  book  it  now 
        %Prob% : 0.5528
        %Logit% : 0.21185562014579773
 before  booking  i  would  also  like  to  know  the  travel  time  price  and  departure  time  please 
        %Prob% : 0.6617
        %Logit% : 0.6708519458770752
 there  are  12  trains  meeting  your  needs  with  the  first  leaving  at  05 16  and  the  last  one  leaving  at  16 16  do  you  want  to  book  one  of  these  no  hold  off  on  booking  for  now  can  you  help  me  find  an  attraction  called  c in eworld  cinema 
        %Prob% : 0.5407
       

In [98]:
data['test'][0]['turns']['utterance']

['I need train reservations from norwich to cambridge',
 'I have 133 trains matching your request. Is there a specific day and time you would like to travel?',
 "I'd like to leave on Monday and arrive by 18:00.",
 'There are 12 trains for the day and time you request. Would you like to book it now?',
 'Before booking, I would also like to know the travel time, price, and departure time please.',
 'There are 12 trains meeting your needs with the first leaving at 05:16 and the last one leaving at 16:16. Do you want to book one of these?',
 'No hold off on booking for now. Can you help me find an attraction called cineworld cinema?',
 'Yes it is a cinema located in the south part of town what information would you like on it?',
 'Yes, that was all I needed. Thank you very much!',
 'Thank you for using our system.']

In [118]:
extract_sample(test_dataset[1], test_logit[1])

hello  i  am  looking  for  a  restaurant  in  cam bridge  i  believe  it  is  called  golden  w ok  it  is  located  at  191  hist on  road  chest erton  can  you  book  me  a  table  for  1100  on  fr iday  yes  i  can  table  for  1  actually  for  4  please  okay  your  booking  was  successful  the  reference  number  is  m uf c my ff    the  table  will  be  reserved  for  15  minutes  great 
        %Prob% : 0.5345
        %Logit% : 0.13834477961063385
 can  you  also  get  me  information  or  architecture  in  the  area  sure  there  are  several  churches  and  an  old  schools  attraction  all  in  the  centre  area  do  you  have  a  preference  what  do  you  recommend  old  schools  is  lovely  they  are  on  tr inity  lane  and  free  admission  can  i  get  the  post code  for  that  i  also  need  to  book  a  taxi  to  the  golden  w ok  the  post code  is  c b 21 tt  are  you  looking  for  a  taxi  from  old  schools  to  the  golden  w ok  yes  i  do  id  like  to 

In [119]:
data['test'][1]['turns']['utterance']

['Hello, I am looking for a restaurant in Cambridge. I believe it is called Golden Wok.',
 'It is located at 191 Histon Road Chesterton',
 'Can you book me a table for 11:00 on Friday?',
 'Yes I can! Table for 1?',
 'Actually, for 4, please.',
 'Okay, your booking was successful! The reference number is MUFCMYFF . The table will be reserved for 15 minutes.',
 'Great, can you also get me information or architecture in the area',
 'Sure. There are several churches and an old schools attraction, all in the centre area. Do you have a preference?',
 'What do you recommend?',
 'old schools is lovely, they are on trinity lane and free admission',
 'Can I get the postcode for that? I also need to book a taxi to the Golden Wok.',
 'The postcode is cb21tt. Are you looking for a taxi from Old Schools to the Golden Wok?',
 "Yes I do. I'd like to make sure I arrive at the restaurant by the booked time. Can you check?",
 'What time do you want to leave?',
 'Actually all you have to do is set the tax

epoch 6 (= final training) version.

In [50]:
test_recall = ((test_pred.squeeze()==1) & (test_real==1)).sum() / (test_real==1).sum()
test_recall

tensor(0.3969)

In [51]:
test_precision = ((test_pred.squeeze()==1) & (test_real==1)).sum() / (test_pred.squeeze()==1).sum()
test_precision

tensor(0.1838)

In [53]:
test_f1_turn = 2 * test_precision * test_recall / (test_precision + test_recall)
test_f1_turn

tensor(0.2512)

In [104]:
extract_sample(test_dataset[0], test_logit[0])

i  need  train  reservations  from  nor wich  to  cam bridge  i  have  133  trains  matching  your  request  is  there  a  specific  day  and  time  you  would  like  to  travel 
        %Prob% : 0.7557
        %Logit% : 1.1293197870254517
 id  like  to  leave  on  m onday  and  arrive  by  1800 
        %Prob% : 0.4216
        %Logit% : -0.316301167011261
 there  are  12  trains  for  the  day  and  time  you  request 
        %Prob% : 0.4253
        %Logit% : -0.30112752318382263
 would  you  like  to  book  it  now 
        %Prob% : 0.5885
        %Logit% : 0.3577260971069336
 before  booking  i  would  also  like  to  know  the  travel  time  price 
        %Prob% : 0.5151
        %Logit% : 0.060555506497621536
 and  departure 
        %Prob% : 0.4736
        %Logit% : -0.10589227080345154
 time  please 
        %Prob% : 0.6415
        %Logit% : 0.5819669365882874
 there  are  12  trains  meeting  your  needs  with  the  first  leaving  at  05 16  and  the  last  one  leaving  at  

https://huggingface.co/transformers/v3.0.2/_modules/transformers/modeling_bert.html#BertForTokenClassification

AutoModelForTokenClassification: 
1. sequence_output = AutoModel 의 output(각 토큰의 vector)
2. sequence_output = dropout(config.hidden_dropout_prob)
3. logits = Linear(token_dim -> num_labels)
4. loss = CrossEntropy 인데, attention_mask가 1인 토큰만 계산. logits, labels 모두