In [1]:
import sys
sys.path.insert(0, '../')

import transformers
import torch.nn as nn
from transformers import AlbertModel, AlbertConfig
import torch
import argparse
import os
import easydict
from torch.utils.data import DataLoader, ConcatDataset
import pyxis.torch as pxt

from consonant.model.tokenization import NGRAMTokenizer
from consonant.model.modeling import Consonant


In [2]:
def load_tokenizer_model(ckpt):
    state = torch.load(ckpt)
    tokenizer = NGRAMTokenizer(state['ngram'])

    config = AlbertConfig(**state['config_dict'])
    model = Consonant(config)
    model.load_state_dict(state['model_state_dict'])
    return tokenizer, model

ckpt = '../output/comment_baseline_b390_half/ckpt-0000100.bin'
tokenizer, model = load_tokenizer_model(ckpt)


In [3]:
def val_dataloader(args):
        
    # We should filter out only directory name excluding all the *.tar.gz files
    data_dir = os.path.join(args.pretrain_dataset_dir, 'val') 
    subset_list = [subset_dir for subset_dir in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, subset_dir))]
    train_dataset = ConcatDataset([pxt.TorchDataset(os.path.join(data_dir, subset_dir)) for subset_dir in subset_list])

    # Very small dataset for debugging
    # toy_dataset = Subset(train_dataset, range(0, 100)) # -> If you want to make 100sample toy dataset. 

    data_loader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        shuffle=False
    )

    return data_loader

In [4]:
args = easydict.EasyDict({
    "pretrain_dataset_dir": '../dataset/processed/ratings_3_100',
    "train_batch_size": 1,
    "num_workers": 0,
})


In [5]:
valloader = val_dataloader(args)
model.cuda()

Consonant(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(17579, 128, padding_idx=0)
      (position_embeddings): Embedding(100, 128)
      (token_type_embeddings): Embedding(1, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=512, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=512, out_features=512, bias=True)
                (key): Linear(in_features=512, out_features=512, bias=True)
                (value): Linear(in_features=512, out_features=512, bias=True)
          

In [6]:
for batch in valloader:
    input_ids = batch['head_ids'].type(torch.LongTensor).cuda()
    answer_label = batch['midtail_ids'].type(torch.LongTensor).cuda()  
    attention_mask = batch['attention_masks'].type(torch.LongTensor).cuda()  
    
    #print(input_ids.shape, attention_mask.shape,  answer_label.shape)
    output = model(input_ids, attention_mask=attention_mask, token_type_ids=None, answer_label=answer_label)

    break

In [7]:
predict_label = output[1].argmax(dim=2)

In [8]:
print('===============')
for i in range(answer_label.shape[0]):
    predict_label[i][answer_label[i]==0]=0
    answer_string = tokenizer.decode_sent(input_ids[i].detach().cpu().numpy(), answer_label[i].detach().cpu().numpy())
    predict_string = tokenizer.decode_sent(input_ids[i].detach().cpu().numpy(), predict_label[i].detach().cpu().numpy())
    #print('===============')
    print('answer string\t:'+ answer_string)
    print('predict string\t:' + predict_string)
    print('===============')
    

answer string	:우아....새벽에보니까 무섭네 미스터리와에로스와 스릴러를 겸비한 영화!
predict string	:이이....시비이비니까 미시나 미사타리이이리시이 시리리리 기비히 이하!
