In [21]:
from data.multiwoz_loader import create_utterances_dataset, BucketBatchSampler, tokenizer, collator
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import GPT2LMHeadModel
from transformers import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm
import re
import json

from model import fit

In [5]:
raw_datasets = load_dataset("multi_woz_v22", ignore_verifications=True)

utterances_dataset_train = create_utterances_dataset(raw_datasets)
utterances_dataset_valid = create_utterances_dataset(raw_datasets,split = "validation")

BATCH_SIZE = 10

bucket_batch_sampler_train = BucketBatchSampler(BATCH_SIZE, utterances_dataset_train)
bucket_batch_sampler_valid = BucketBatchSampler(BATCH_SIZE, utterances_dataset_valid)

train_dataloader = DataLoader(utterances_dataset_train, batch_sampler = bucket_batch_sampler_train, shuffle=False, drop_last=False, collate_fn=collator)
eval_dataloader = DataLoader(utterances_dataset_valid, batch_sampler = bucket_batch_sampler_valid, shuffle=False, drop_last=False, collate_fn=collator)

  0%|          | 0/56776 [00:00<?, ?ex/s]

  0%|          | 0/56776 [00:00<?, ?ex/s]

  0%|          | 0/7374 [00:00<?, ?ex/s]

  0%|          | 0/7374 [00:00<?, ?ex/s]

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)


Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50261, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


In [7]:
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

fit(num_epochs, model, optimizer, train_dataloader, eval_dataloader, optimizer, lr_scheduler, num_training_steps)

  0%|          | 0/283880 [00:00<?, ?it/s]

Validation loss 1.2420
Validation loss 1.0715
Validation loss 1.5776
Validation loss 1.1670
Validation loss 1.0882


In [8]:
torch.save(model, "5-epoch-run-0-corrected")

Decoding

In [10]:
utterances_dataset_test = create_utterances_dataset(raw_datasets,split = "test")

  0%|          | 0/7372 [00:00<?, ?ex/s]

  0%|          | 0/7372 [00:00<?, ?ex/s]

In [11]:
bucket_batch_sampler_test = BucketBatchSampler(BATCH_SIZE, utterances_dataset_test)

In [12]:
test_dataloader = DataLoader(utterances_dataset_test, batch_sampler = bucket_batch_sampler_test, shuffle=False, drop_last=False, collate_fn=collator)

In [13]:
from belief_parser import BeliefParser

bp = BeliefParser(['hotel :', 'restaurant :', 'attraction :', 'train :', 'taxi :', 'police :', 'hospital :'])

In [17]:
from database.database import MultiWOZDatabase
db = MultiWOZDatabase()

In [None]:
for batch in tqdm(test_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    batch_attention_mask = batch['context_mask'] + batch['belief_mask'] + batch['database_mask'] + batch['utterance_mask']
    batch_ids = batch['labels']
    dialog_state_targets = batch_ids * batch['belief_mask']
    utterance_targets = batch_ids * batch['utterance_mask']

    utterance_targets[batch['utterance_mask'] == 0] = -100
    dialog_state_targets[batch['belief_mask'] == 0] = -100
    for i in range(batch_ids.shape[0]):
        input = {'input_ids':batch['labels'][i,:][batch['context_mask'][i,:]]}
        input['input_ids'] = input['input_ids'].unsqueeze(0)
        output_ids = model.generate(**input, eos_token_id = tokenizer.encode(['}}'])[0], max_length = 200)
        belief_state = tokenizer.batch_decode(output_ids[:,input['input_ids'].shape[1]:])
        parsed_belief_state = bp(belief_state[0])
        db_results = {}
        for domain,constraints in parsed_belief_state.items():
            db_results[domain] = len(db.query(domain, constraints = constraints))

        db_result_encoded = ''.join(('<|database|> ', re.sub('[\[\]{}"]', '', json.dumps(db_results)), '<|endoftext|> '))

        new_prompt = torch.cat((output_ids.squeeze(), tokenizer(db_result_encoded, return_tensors='pt')['input_ids'].squeeze().to(device)))
        output_ids = model.generate(new_prompt.unsqueeze(0), max_length = 400)

        system_utterance = tokenizer.batch_decode(output_ids[:,len(new_prompt):])

        with open('multiwoz_outputs.txt', 'a') as f:
            f.write(f'{belief_state[0]}\t{system_utterance[0]}\n')


In [None]:
for batch in tqdm(test_dataloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    batch_attention_mask = batch['context_mask'] + batch['belief_mask'] + batch['database_mask'] + batch['utterance_mask']
    batch_ids = batch['labels']
    dialog_state_targets = batch_ids * batch['belief_mask']
    utterance_targets = batch_ids * batch['utterance_mask']

    utterance_targets[batch['utterance_mask'] == 0] = -100
    dialog_state_targets[batch['belief_mask'] == 0] = -100
    for i in range(batch_ids.shape[0]):
        input = {'input_ids':batch['labels'][i,:][batch['context_mask'][i,:]]}
        input['input_ids'] = input['input_ids'].unsqueeze(0)
        output_ids = model.generate(**input, eos_token_id = tokenizer.encode(['}}'])[0], max_length = 200)
        belief_state = tokenizer.batch_decode(output_ids[:,input['input_ids'].shape[1]:])
        parsed_belief_state = bp(belief_state[0])
        db_results = {}
        for domain,constraints in parsed_belief_state.items():
            db_results[domain] = len(db.query(domain, constraints = constraints))

        db_result_encoded = ''.join(('<|database|> ', re.sub('[\[\]{}"]', '', json.dumps(db_results)), '<|endoftext|> '))

        new_prompt = torch.cat((output_ids.squeeze(), tokenizer(db_result_encoded, return_tensors='pt')['input_ids'].squeeze().to(device)))
        output_ids = model.generate(new_prompt.unsqueeze(0), max_length = 400)

        system_utterance = tokenizer.batch_decode(output_ids[:,len(new_prompt):])



        with open('multiwoz_outputs.txt', 'a') as f:
            f.write(f'{belief_state[0]}\t{system_utterance[0]}\n')
