In [1]:
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

from dataset import SeqClsDataset
from utils import Vocab, Acc_counter
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from model import SeqClassifier

import torch
from tqdm import trange
import torch.optim as optim

import torch.nn.functional as F

cache_dir = Path('./cache/slot/')
data_dir = Path('./data/slot')

device = torch.device('cuda')

with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

slot_idx_path = cache_dir / "tag2idx.json"
slot2idx: Dict[str, int] = json.loads(slot_idx_path.read_text())

data_paths = {split: data_dir / f"{split}.json" for split in ['train','eval']}
data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
data=[y for x in data.keys() for y in data[x] ]

embeddings = torch.load(cache_dir / "embeddings.pt")

datasets = SeqClsDataset(data, vocab, slot2idx, 36)
dataloader=DataLoader(datasets,batch_size=2,shuffle=False,collate_fn=datasets.collate_fn)

test_file = data_dir / "test.json"
test_datasets = SeqClsDataset(json.loads(test_file.read_text()), vocab, slot2idx, 36)

model = SeqClassifier(embeddings=embeddings,hidden_size=128,num_layers=2,dropout=0.1,bidirectional=True,num_class=len(slot2idx))
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
model.train()

SeqClassifier(
  (embed): Embedding(4117, 300)
  (rnn): GRU(300, 128, num_layers=2, dropout=0.1, bidirectional=True)
  (fc): Linear(in_features=256, out_features=9, bias=True)
)

In [2]:
for y, x, l in dataloader:
    out=model(x,l,device)

In [3]:
y

[[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 8, 0, 5], [5, 5, 5, 5]]

In [5]:
out.shape

torch.Size([2, 36, 9])

In [6]:
[1,2,3,4,5][:3]

[1, 2, 3]

In [7]:
datasets.data

[{'tokens': ['i', 'have', 'three', 'people', 'for', 'august', 'seventh'],
  'tags': ['O', 'O', 'B-people', 'I-people', 'O', 'B-date', 'O'],
  'id': 'train-0'},
 {'tokens': ['do', 'you', 'have', 'highchairs', 'for', 'my', '4', 'kids'],
  'tags': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
  'id': 'train-1'},
 {'tokens': ['i', 'want', 'the', 'west', 'central', 'neighborhood'],
  'tags': ['O', 'O', 'O', 'O', 'O', 'O'],
  'id': 'train-2'},
 {'tokens': ['a',
   'table',
   'for',
   '2',
   'adults',
   'and',
   '4',
   'children',
   'please'],
  'tags': ['O',
   'O',
   'O',
   'B-people',
   'I-people',
   'I-people',
   'I-people',
   'I-people',
   'O'],
  'id': 'train-3'},
 {'tokens': ["i'd",
   'like',
   'information',
   'for',
   'the',
   'royal',
   'george',
   'strand',
   'restaurant',
   'please'],
  'tags': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
  'id': 'train-4'},
 {'tokens': ['date', 'the', '1st', 'of', 'june'],
  'tags': ['O', 'O', 'B-date', 'I-date', 'I-date