In [1]:
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

from dataset import SeqClsDataset
from utils import Vocab, Acc_counter
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from model import SeqClassifier

import torch
from tqdm import trange
import torch.optim as optim

import torch.nn.functional as F

cache_dir = Path('./cache/slot/')
data_dir = Path('./data/slot')

device = torch.device('cuda')

with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

slot_idx_path = cache_dir / "tag2idx.json"
slot2idx: Dict[str, int] = json.loads(slot_idx_path.read_text())

data_paths = {split: data_dir / f"{split}.json" for split in ['train','eval']}
data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
data=[y for x in data.keys() for y in data[x] ]

embeddings = torch.load(cache_dir / "embeddings.pt")

datasets = SeqClsDataset(data, vocab, slot2idx, 36)

dataloader=DataLoader(datasets,batch_size=256,shuffle=False,collate_fn=datasets.collate_fn)

model = SeqClassifier(embeddings=embeddings,hidden_size=128,num_layers=2,dropout=0.1,bidirectional=True,num_class=len(slot2idx))
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
model.train()

SeqClassifier(
  (embed): Embedding(4117, 300)
  (rnn): GRU(300, 128, num_layers=2, dropout=0.1, bidirectional=True)
  (fc): Linear(in_features=256, out_features=9, bias=True)
)

In [None]:
d=iter(dataloader)

In [None]:
y, x, length = next(d)
_y = model(x,length,device)

In [2]:
f=torch.nn.LogSoftmax(dim=1)
for epoch in range(10):
    
    token_acc = Acc_counter()
    sent_acc = Acc_counter()
    for y, x, length in dataloader:
        
        _y = model(x,length,device)
        loss = model.loss_and_acc(_y,y,criterion,device,sent_acc,token_acc)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(epoch,f"{sent_acc.out()}/{token_acc.out()}")

0 40.2717%/80.7844%
1 55.8224%/90.4740%
2 69.2261%/94.2495%
3 75.2305%/95.7008%
4 79.0272%/96.5082%
5 81.7200%/97.0429%
6 83.7215%/97.4389%
7 85.8079%/97.7989%
8 86.9602%/98.0250%
9 88.4279%/98.2479%
