In [1]:
import json
import pickle
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

from dataset import SeqClsDataset
from utils import Vocab, Acc_counter
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from model import SeqClassifier


import torch
from tqdm import trange
import torch.optim as optim

import torch.nn.functional as F

cache_dir = Path('./cache/slot/')
data_dir = Path('./data/slot')

device = torch.device('cuda')

with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

slot_idx_path = cache_dir / "tag2idx.json"
slot2idx: Dict[str, int] = json.loads(slot_idx_path.read_text())

data_paths = {split: data_dir / f"{split}.json" for split in ['train','eval']}
data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
data=[y for x in data.keys() for y in data[x] ]

embeddings = torch.load(cache_dir / "embeddings.pt")

datasets = SeqClsDataset(data, vocab, slot2idx, 36)
dataloader=DataLoader(datasets,batch_size=2,shuffle=False,collate_fn=datasets.collate_fn)

test_file = data_dir / "test.json"
test_datasets = SeqClsDataset(json.loads(test_file.read_text()), vocab, slot2idx, 36)

model = SeqClassifier(embeddings=embeddings,hidden_size=128,num_layers=2,dropout=0.1,bidirectional=True,num_class=len(slot2idx))
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))
criterion = torch.nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
model.train()

SeqClassifier(
  (embed): Embedding(4117, 300)
  (rnn): GRU(300, 128, num_layers=2, dropout=0.1, bidirectional=True)
  (fc): Linear(in_features=256, out_features=9, bias=True)
)

In [None]:
for y, x, l in dataloader:
    out=model(x,l,device)

In [2]:
l=[ len(x['tokens']) for x in datasets.data ]

In [3]:
sum(l) / len(l)

7.781538088306648

In [4]:
max(l)

35

In [5]:
import numpy as np
l=np.array( [ len(x['tokens']) for x in datasets.data ] )

In [9]:
l.mean()+l.std()*2

16.663303644051194

4.4408827778722735

In [14]:
import torchvision.models as models
from torchsummary import summary
m=models.resnet18().cuda()
summary(m,(3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
         MaxPool2d-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]          36,864
       BatchNorm2d-6           [-1, 64, 64, 64]             128
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,864
       BatchNorm2d-9           [-1, 64, 64, 64]             128
             ReLU-10           [-1, 64, 64, 64]               0
       BasicBlock-11           [-1, 64, 64, 64]               0
           Conv2d-12           [-1, 64, 64, 64]          36,864
      BatchNorm2d-13           [-1, 64, 64, 64]             128
             ReLU-14           [-1, 64,

In [15]:
m=models.resnet34().cuda()
summary(m,(3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
         MaxPool2d-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]          36,864
       BatchNorm2d-6           [-1, 64, 64, 64]             128
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,864
       BatchNorm2d-9           [-1, 64, 64, 64]             128
             ReLU-10           [-1, 64, 64, 64]               0
       BasicBlock-11           [-1, 64, 64, 64]               0
           Conv2d-12           [-1, 64, 64, 64]          36,864
      BatchNorm2d-13           [-1, 64, 64, 64]             128
             ReLU-14           [-1, 64,