In [33]:
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

import torch
from tqdm import trange

from dataset import SeqClsDataset
from utils import Vocab

from torch.utils.data import DataLoader, SubsetRandomSampler
from model import SeqClassifier
import torch.optim as optim
import torch.backends.cudnn as cudnn

import csv

TRAIN = "train"
DEV = "eval"
SPLITS = [TRAIN, DEV]

cache_dir = Path("./cache/intent/")
data_dir = Path("./data/intent/")
test_file = Path("./data/intent/test.json")

ckpt_path = Path()

device = torch.device('cuda')

In [8]:
with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

In [10]:
intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
data = json.loads(test_file.read_text())
dataset = SeqClsDataset(data, vocab, intent2idx, 32)

In [12]:
embeddings = torch.load(cache_dir / "embeddings.pt")

model = SeqClassifier(
    embeddings,
    1024,
    3,
    0.01,
    True,
    dataset.num_classes,
)

In [14]:
ckpt_path=Path("./ckpt/intent/intent_best_model.pth")
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt)

<All keys matched successfully>

In [15]:
model.to(device)
model.eval()

SeqClassifier(
  (embed): Embedding(6491, 300)
  (rnn): GRU(300, 1024, num_layers=3, dropout=0.01, bidirectional=True)
  (fc): Linear(in_features=1024, out_features=150, bias=True)
)

In [19]:
x=[]
for data in dataset.data:
    text = data['text'].split()
    x.append(text)
    
x = dataset.vocab.encode_batch(batch_tokens=x,to_len=dataset.max_len)
x = torch.tensor(x,dtype=torch.int64).to(device)

In [30]:
f = torch.nn.LogSoftmax(dim=1)
p_label =f(model(x))
p_label=torch.argmax(p_label, dim=1)

In [49]:
p_label.to(device = torch.device('cpu'))
with open('test.csv','w',newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['id','intent'])
    n=0
    for label in p_label:
        writer.writerow([f"test-{n}",f"{dataset.idx2label(label.item())}"])
        n=n+1