In [1]:
import sys
sys.path.insert(0, '../')

import transformers
import torch.nn as nn
from transformers import AlbertModel, AlbertConfig, get_linear_schedule_with_warmup
from transformers.modeling_bert import ACT2FN
import torch
from optimization import Lamb
import argparse
import os
import easydict
from torch.utils.data import DataLoader, ConcatDataset
import pyxis.torch as pxt
from torch.nn import CrossEntropyLoss
from consonant.model.tokenization import NGRAMTokenizer


In [2]:
tokenizer = NGRAMTokenizer(3)


In [3]:
class AlbertConsonantHead(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.LayerNorm = nn.LayerNorm(config.embedding_size)
        self.bias = nn.Parameter(torch.zeros(config.output_vocab_size))
        self.dense = nn.Linear(config.hidden_size, config.embedding_size)
        self.decoder = nn.Linear(config.embedding_size, config.output_vocab_size)
        self.activation = ACT2FN[config.hidden_act]

        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states)

        prediction_scores = hidden_states

        return prediction_scores

class Consonant(nn.Module):
    def __init__(self, config):
        super(Consonant, self).__init__()
        self.config = config
        self.albert = AlbertModel(config)
        self.predictions = AlbertConsonantHead(config) 

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, answer_label=None):
        outputs = self.albert(input_ids, attention_mask, token_type_ids)
        sequence_output, pooled_output = outputs[:2]
        prediction_scores = self.predictions(sequence_output)
        
        outputs = (prediction_scores, ) + outputs[2:]  
        #print(prediction_scores.shape, answer_label.shape)
        #print(prediction_scores.view(-1, self.config.output_vocab_size).shape, answer_label.view(-1).shape)

        if answer_label is not None :
            loss_fct = CrossEntropyLoss()
            consonant_loss = loss_fct(prediction_scores.view(-1, self.config.output_vocab_size), answer_label.view(-1))
            #consonant_loss = loss_fct(prediction_scores, answer_label)
            #print(consonant_loss.shape, consonant_loss.mean())
            total_loss = consonant_loss
            outputs = (total_loss,) + outputs

        return outputs  

In [4]:
albert_base_configuration = AlbertConfig(
    hidden_size=256,
    embedding_size=64,
    num_attention_heads=4,
    intermediate_size=1024,
    vocab_size = 17579,
    max_position_embeddings= 100,
    output_vocab_size = 589,
    type_vocab_size = 1,
)

model = Consonant(albert_base_configuration)

In [5]:
state_dic = '../output/baseline_01/ckpt-0012000.bin'
model.load_state_dict(torch.load(state_dic)['model_state_dict'])

<All keys matched successfully>

In [6]:
model = model.cuda()

In [7]:
def val_dataloader(args):
        
    # We should filter out only directory name excluding all the *.tar.gz files
    data_dir = os.path.join(args.pretrain_dataset_dir, 'val') 
    subset_list = [subset_dir for subset_dir in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, subset_dir))]
    train_dataset = ConcatDataset([pxt.TorchDataset(os.path.join(data_dir, subset_dir)) for subset_dir in subset_list])

    # Very small dataset for debugging
    # toy_dataset = Subset(train_dataset, range(0, 100)) # -> If you want to make 100sample toy dataset. 

    data_loader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        shuffle=True
    )

    return data_loader

In [8]:
args = easydict.EasyDict({
    "pretrain_dataset_dir": '../dataset/processed/ratings_3_100',
    "train_batch_size": 128,
    "num_workers": 0,
})


In [9]:
valloader = val_dataloader(args)

In [10]:
for batch in valloader:
    input_ids = batch['head_ids'].type(torch.LongTensor).cuda()
    answer_label = batch['midtail_ids'].type(torch.LongTensor).cuda()  
    attention_mask = batch['attention_masks'].type(torch.LongTensor).cuda()  
    
    #print(input_ids.shape, attention_mask.shape,  answer_label.shape)
    output = model(input_ids, attention_mask=attention_mask, token_type_ids=None, answer_label=answer_label)

    break

In [11]:
predict_label = output[1].argmax(dim=2)

In [12]:
print('===============')
for i in range(answer_label.shape[0]):
    predict_label[i][answer_label[i]==0]=0
    answer_string = tokenizer.decode_sent(input_ids[i].detach().cpu().numpy(), answer_label[i].detach().cpu().numpy())
    predict_string = tokenizer.decode_sent(input_ids[i].detach().cpu().numpy(), predict_label[i].detach().cpu().numpy())
    #print('===============')
    print('answer string\t: '+ answer_string)
    print('predict string\t:' + predict_string)
    print('===============')
    

answer string	: 재밌다 지루한 부분 없이 !
predict string	:재밌다 지루한 부분 없음 !
answer string	: 이런류의 프로 식상합니다
predict string	:이런류의 프로 신수합니다
answer string	: 금성무 사랑함
predict string	:김수미 사랑해
answer string	: 정말 재미없다.
predict string	:정말 재미없다.
answer string	: 강간당했는데 경찰에 신고도안하고 먼가 현실성은 떨어진다
predict string	:김기도했는데 괜찮은 시간대영화고 뭔가 현실성이 떨어진다
answer string	: 원작을 보세요 ~점준 비읍시옷들아
predict string	:예작은 보세요 ~점준 받이수이되요
answer string	: 최악의 애니메이션.. 지루하고 재미없고 스토리마저 진부하다
predict string	:최악의 애니메이션.. 지루하고 재미없고 스토리마중 진부하다
answer string	: 이번시즌이 최고바비야 밥이나 먹으러가자 털업
predict string	:이본사절의 최고부분을 보이는 말이러가지 탑임
answer string	: 크리스찬베일.점
predict string	:크리스츠베음.점
answer string	: 이소룡 배우가 흉륭하다
predict string	:이소를 발우기 훌륭하다
answer string	: 이런게 똥같은 영화지
predict string	:이렇게 똑같은 영화지
answer string	: 최고다 감동적 역시 우리학교 설립자인것 같다
predict string	:최고다 그다지 역상 우리하고 사리지않것 같다
answer string	: 달레이라마의 고달프고도 위대한일대기,상식으로라도보시오
predict string	:독립이랜만에 과대평가된 유덕해이되고,소상으로라다보세요
answer string	: 첫 만 보고 계속보다가는 후회할거다..
predict string	:참 몇 보고 가속보다그는 후회하겠다..
answer string	: 어이없게