In [45]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext import data, datasets
import random

import models.transformer as transformer

In [46]:
from argparse import Namespace

config = {
    'train_ratio': .8,
    'batch_size': 64,
    'num_heads' : 8,
    'hidden_size' : 768,
    'n_enc_block' : 6,
    'n_dec_block' : 6
}

config = Namespace(**config)



SEED = 777
torch.manual_seed(SEED)
random.seed(SEED)

device = torch.device('cuda:{}'.format(torch.cuda.current_device())) if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu') if config.gpu_id < 0 else torch.device('cuda:{}'.format(config.gpu_id))

# Download and prepare dataset

from https://pytorch.org/text/stable/datasets.html#imdb

In [3]:
TEXT = data.Field(batch_first=True, lower=True, use_vocab=True)
LABEL = data.Field(sequential=False, use_vocab=True, unk_token=None)

In [4]:
trainset, testset = datasets.IMDB.splits(TEXT, LABEL)

In [5]:
print('trainset의 구성 요소 출력 : ', trainset.fields)

trainset의 구성 요소 출력 :  {'text': <torchtext.data.field.Field object at 0x7fdb13ead820>, 'label': <torchtext.data.field.Field object at 0x7fdb13ead790>}


In [6]:
print('testset의 구성 요소 출력 : ', testset.fields)

testset의 구성 요소 출력 :  {'text': <torchtext.data.field.Field object at 0x7fdb13ead820>, 'label': <torchtext.data.field.Field object at 0x7fdb13ead790>}


In [7]:
print(vars(trainset[0]))

{'text': ['i', 'happened', 'to', 'see', 'this', 'movie', 'twice', 'or', 'more', 'and', 'found', 'it', 'well', 'made!', 'wwii', 'had', 'freshly', 'ended', 'and', 'the', 'so-called', '"cold', 'war"', 'was', 'about', 'to', 'begin.', 'this', 'movie', 'could,', 'therefore,', 'be', 'defined', 'as', 'one', 'of', 'the', 'best', '"propaganda",', 'patriotic', 'movies', 'preparing', 'americans', 'and,', 'secondly,', 'people', 'from', 'the', 'still', 'to', 'be', 'formed', '"western', 'nato', 'block"', 'of', 'countries', 'to', 'face', 'the', 'next', 'coming', 'menace.', 'the', 'movie', 'celebrates', 'the', 'might', 'of', 'the', 'us,', 'through', 'the', 'centuries,', 'while', 'projecting', 'itself', 'onwards', 'to', 'the', 'then', 'present', 'war,', 'which', 'had', 'just', 'ended.', 'nice', 'and', 'funny', 'is', 'the', 'way', 'of', 'describing', 'the', 'discovering', 'of', 'the', 'american', 'continent', 'by', 'columbus', 'and', 'pretty', 'the', '"espisode"', 'of', 'new', 'amsterdam', 'and', 'the', 

In [8]:
# vocabulary 생성
TEXT.build_vocab(trainset, min_freq=5)
LABEL.build_vocab(trainset)

In [9]:
vocab_size = len(TEXT.vocab)
n_classes = len(LABEL.vocab)
print('단어 집합의 크기 : {}'.format(vocab_size))
print('클래스의 개수 : {}'.format(n_classes))

단어 집합의 크기 : 46159
클래스의 개수 : 2


In [10]:
# token set 확인
print(TEXT.vocab.stoi)



In [11]:
# train, valid set split

trainset, valset = trainset.split(config.train_ratio)

In [None]:
# loader = DataLoader(
#         config.train,                           # Train file name except extention, which is language.
#         config.valid,                           # Validation file name except extension.
#         (config.lang[:2], config.lang[-2:]),    # Source and target language.
#         batch_size=config.batch_size,
#         device=-1,                              # Lazy loading
#         max_length=config.max_length,           # Loger sequence will be excluded.
#         dsl=False,                              # Turn-off Dual-supervised Learning mode.
#     )

# input_size, output_size = len(loader.src.vocab), len(loader.tgt.vocab)

In [12]:
# train, valid, test loader 생성

train_iter, val_iter, test_iter = data.BucketIterator.splits(
        (trainset, valset, testset), batch_size=config.batch_size,
        shuffle=True, repeat=False)

In [13]:
print('훈련 데이터의 미니 배치의 개수 : {}'.format(len(train_iter)))
print('테스트 데이터의 미니 배치의 개수 : {}'.format(len(test_iter)))
print('검증 데이터의 미니 배치의 개수 : {}'.format(len(val_iter)))

훈련 데이터의 미니 배치의 개수 : 313
테스트 데이터의 미니 배치의 개수 : 391
검증 데이터의 미니 배치의 개수 : 79


# Implement a Transformer block as a layer & embedding layer

In [42]:
input_size = vocab_size
hidden_size = config.hidden_size
output_size = n_classes
n_splits = config.num_heads

model = transformer.Transformer(input_size,
                                     hidden_size,
                                     output_size,
                                     n_splits).to(device)

In [43]:
model

Transformer(
  (emb_enc): Embedding(46159, 768)
  (emb_dec): Embedding(2, 768)
  (emb_dropout): Dropout(p=0.1, inplace=False)
  (encoder): MySequential(
    (0): EncoderBlock(
      (attn): MultiHead(
        (Q_linear): Linear(in_features=768, out_features=768, bias=False)
        (K_linear): Linear(in_features=768, out_features=768, bias=False)
        (V_linear): Linear(in_features=768, out_features=768, bias=False)
        (linear): Linear(in_features=768, out_features=768, bias=False)
        (attn): Attention(
          (softmax): Softmax(dim=-1)
        )
      )
      (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn_dropout): Dropout(p=0.1, inplace=False)
      (fc): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): ReLU()
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (fc_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (fc_dropout): Dropout(p=0.1, inplac

In [47]:
optimizer = torch.optim.Adam(model.parameters())

In [64]:
# huggingface -> bert tokenizer 를 받아서 tokenizer하면 (maxlength 등을 옵셔널로 추가 가능) -> 해결 가능 할 것

cnt = 0
for batch in val_iter:
    x = batch.text.to(device)
    y = batch.label.unsqueeze(1).to(device)
    
#     x = torch.nn.utils.rnn.pack_padded_sequence(x, )
    
    print(x.shape)
    
    print(y.shape)
    
    pred = model(batch.text.to(device), batch.label.to(device))
    print(pred)
    print(pred.max(1))
    print(pred.max(1)[1])
    
    cnt += 1
    
    if cnt==3:
        break

torch.Size([64, 45])
torch.Size([64, 1])


RuntimeError: expand(CUDABoolType{[45, 1, 19490]}, size=[45, 19490]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)