## ESIM 모델을 pytorch 버전으로 간단하게 구현해보기

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/Projs/ESIM/test

/content/drive/MyDrive/Projs/ESIM/test


### Import Modules in Need

In [4]:
import os
import collections

import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from data_iterator import TextIterator

In [5]:
#dir setting
data_path = '/content/drive/MyDrive/Data/ESIM_data/'
train_file = os.path.join(data_path, 'ubuntu_data_concat/train.txt')
valid_file = os.path.join(data_path, 'ubuntu_data_concat/valid.txt')
test_file = os.path.join(data_path, 'ubuntu_data_concat/test.txt')
vocab_file = os.path.join(data_path, 'ubuntu_data_concat/vocab.txt')
output_dir = '/rst'
embedding_file = os.path.join(data_path, 'embedding_w2v_d300.txt')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#args setting
hidden_size = 300
dim_word = 300
patience=1
vocab_size = 100000
batch_size = 16
epochs = 50
num_labels = 2
learning_rate = 2e-4

max_buffer_size = 128 * 20
maxlen_1 = 400
maxlen_2 = 150

### Helper Functions

In [6]:
def prepare_data(transformed_samples, maxlen_1=maxlen_1, maxlen_2=maxlen_2):
    """ padding the data with minibatch
    Args:
        instance: [list, list, list] for [labels, seqs_x, seqs_y]
    Return:
        c: int64 numpy.array of shape [seq_length_x, batch_size].
        c_mask: float32 numpy.array of shape [seq_length_x, batch_size].
        r: int64 numpy.array of shape [seq_length_y, batch_size].
        r_mask: float32 numpy.array of shape [seq_length_y, batch_size].
        l: int64 numpy.array of shape [batch_size, ].
    """
    seqs_c = []
    seqs_r = []
    labels = []

    for sample in transformed_samples:
        labels.append(sample[0])
        seqs_c.append(sample[1])
        seqs_r.append(sample[2])

    lengths_c = [len(s) for s in seqs_c]
    lengths_r = [len(s) for s in seqs_r]

    new_seqs_c = []
    new_seqs_r = []
    new_lengths_c = []
    new_lengths_r = []
    new_labels = []

    for l_c, s_c, l_r, s_r, l in zip(lengths_c, seqs_c, lengths_r, seqs_r, labels):
        if l_c > maxlen_1:
            new_seqs_c.append(s_c[-maxlen_1:])
            new_lengths_c.append(maxlen_1)
        else:
            new_seqs_c.append(s_c)
            new_lengths_c.append(l_c)
        if l_r > maxlen_2:
            new_seqs_r.append(s_r[:maxlen_2])
            new_lengths_r.append(maxlen_2)
        else:
            new_seqs_r.append(s_r)
            new_lengths_r.append(l_r)

        new_labels.append(l)

    lengths_c = new_lengths_c
    seqs_c = new_seqs_c
    lengths_r = new_lengths_r
    seqs_r = new_seqs_r
    labels = new_labels

    if len(lengths_c) < 1 or len(lengths_r) < 1:
        return None

    n_samples = len(seqs_c)
    maxlen_1 = np.max(lengths_c)
    maxlen_2 = np.max(lengths_r)

    c = np.zeros((maxlen_1, n_samples)).astype("float32")
    r = np.zeros((maxlen_2, n_samples)).astype("float32")
    c_mask = np.zeros((maxlen_1, n_samples)).astype("float32")
    r_mask = np.zeros((maxlen_2, n_samples)).astype("float32")
    l = np.zeros((n_samples,)).astype("float32")

    for idx, (s_c, s_r, ll) in enumerate(zip(seqs_c, seqs_r, labels)):
        c[:lengths_c[idx], idx] = s_c
        c_mask[:lengths_c[idx], idx] = 1.
        r[:lengths_r[idx], idx] = s_r
        r_mask[:lengths_r[idx], idx] = 1.
        l[idx] = ll

    return (c, c_mask, r, r_mask, l)

In [7]:
def load_word_embedding(token_to_idx):
    embedding_np = 0.02 * np.random.randn(vocab_size, dim_word).astype("float32")

    if embedding_file:
        with open(embedding_file, "r") as f:
            for line in f:
                tokens = line.strip().split(" ")
                token = tokens[0]
                vector = list(map(float, tokens[1:]))
                if token in token_to_idx and token_to_idx[token] < vocab_size:
                    embedding_np[token_to_idx[token], :] = vector

    embedding = torch.Tensor(embedding_np)
    return embedding

In [8]:
def load_vocab(vocab_file):
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as f:
        while True:
            token = f.readline()
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

In [101]:
def local_inference(c, c_mask, r, r_mask):
    """Local inference collected over sequences
    Args:
        c :      float32 Tensor of shape [seq_length1, batch_size, dim]
        c_mask : float32 Tensor of shape [seq_length1, batch_size]
        r :      float32 Tensor of shape [seq_length2, batch_size, dim]
        r_mask : float32 Tensor of shape [seq_length2, batch_size]
    
    Return:
        c_dual: float32 Tensor of shape [seq_length1, batch_size, dim]
        r_dual: float32 Tensor of shape [seq_length2, batch_size, dim]
    """

    # c :      [batch_size, seq_length1, dim]
    # c_mask : [batch_size, seq_length1]
    # r :      [batch_size, seq_length2, dim]
    # r_mask : [batch_size, seq_length2]
    c = torch.permute(c, [1, 0, 2]).contiguous()
    c_mask = torch.permute(c_mask, [1, 0]).contiguous()
    r = torch.permute(r, [1, 0, 2]).contiguous()
    r_mask = torch.permute(r_mask, [1, 0]).contiguous()

    # attention_weight: [batch_size, seq_length1, seq_length2]
    attention_weight = torch.matmul(c, torch.permute(r, [0, 2, 1]).contiguous())

    # calculate normalized attention weight x1 and x2
    # attention_weight_2: [batch_size, seq_length1, seq_length2]
    attention_weight_2 = torch.exp(attention_weight - torch.max(attention_weight, 2, keepdim=True).values)
    attention_weight_2 = attention_weight_2 * torch.unsqueeze(r_mask, 1)

    # alpha: [batch_size, seq_length1, seq_length2]
    alpha = attention_weight_2 / (torch.sum(attention_weight_2, -1, keepdim=True) + 1e-8)

    # c_dual: [batch_size, seq_length1, dim]
    c_dual = torch.sum(torch.unsqueeze(r, 1) * torch.unsqueeze(alpha, -1), 2)

    # x1_dual: [seq_length1, batch_size, dim]
    c_dual = torch.permute(c_dual, [1, 0, 2])

    # attention_weight_1: [batch_size, seq_length2, seq_length1]
    attention_weight_1 = attention_weight - torch.max(attention_weight, 1, keepdim=True)
    attention_weight_1 = torch.exp(torch.permute(attention_weight_1, [0, 2, 1]))
    attention_weight_1 = attention_weight_1 * torch.unsqueeze(c_mask, 1)


    # beta: [batch_size, seq_length2, seq_length1]
    beta = attention_weight_1 / (torch.sum(attention_weight_1, -1, keepdim=True) + 1e-8)
    
    # r_dual: [batch_size, seq_length2, dim]
    r_dual = torch.sum(torch.unsqueeze(c, 1) * torch.unsqueeze(beta, -1), 2)
    
    # r_dual: [seq_length2, batch_size, dim]
    r_dual = torch.permute(r_dual, [1, 0, 2])

    return (c_dual, r_dual)

### Eval Functions

### Model Class

In [16]:
class ESIM(torch.nn.Module):
    def __init__(self, hidden_size, dropout, embedding):
        super(ESIM, self).__init__()
        
        self.hidden_size = hidden_size
        self.emb = nn.Embedding.from_pretrained(embedding)

        self.bilstm = nn.LSTM(input_size=dim_word, hidden_size=hidden_size * 2, bidirectional=True)
        
        #for dimension reduction
        self.fc = nn.Linear(in_features=in_features, out_features=out_features)

        self.mlp = nn.Linear() #one-hidden-layer, tanh activation, softmax
        
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout()
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, c, r, c_mask, r_mask):
#embedding part
        emb1 = self.emb(c)
        emb2 = self.emb(r)

        emb1 = self.dropout(emb1, 1)
        emb2 = self.dropout(emb1, 1)

        emb1 = emb1 * torch.unsqueeze(c_mask, -1)
        emb2 = emb2 * torch.unsqueeze(r_mask, -1)

#encoding part        
        enc1, _ = self.bilstm(emb1)
        enc2, _ = self.bilstm(emb1)

        enc1 = enc1 * torch.unsqueeze(c_mask, -1)
        enc2 = enc2 * torch.unsqueeze(r_mask, -1)

#Matching Part
        #Local Inference
        dual1, dual2 = local_inference(enc1, c_mask, enc2, r_mask)
        c_match = torch.cat([enc1, dual1, enc1 * dual1, enc1 - dual1], dim=2)
        r_match = torch.cat([enc2, dual2, enc2 * dual2, enc2 - dual2], dim=2)

        #Dimension Reduction with FC and dropout
        c_match_mapping = self.fc(c_match)
        r_match_mapping = self.fc(r_match)

        c_match_mapping = self.dropout(c_match_mapping, dropout=dropout)
        r_match_mapping = self.dropout(r_match_mapping, dropout=dropout)

        #Matching Compositioon
        c_cmp = self.bilstm(c_match_mapping, 1, args.hidden_size)
        r_cmp = self.bilstm(r_match_mapping, 1, args.hidden_size)

        
        #logit calc
        logit_c_sum = torch.sum(c_cmp * torch.unsqueeze(c_mask, -1), 0) / \
                        torch.unsqueeze(torch.sum(c_mask, 0), 1)
        logit_c_max = torch.max(c_cmp * torch.unsqueeze(c_mask, -1), 0)

        logit_r_sum = torch.sum(r_cmp * torch.unsqueeze(r_mask, -1), 0) / \
                        torch.unsqueeze(torch.sum(r_mask, 0), 1)
        logit_r_max = torch.max(r_cmp * torch.unsqueeze(r_mask, -1), 0)

        logit = torch.cat([logit_c_sum, logit_c_max, logit_r_sum, logit_r_max], 1)

        #Get Binary prediction (final out)
        logit = self.dropout(logit)
        logit = self.fc(logit)
        logit = self.dropout(logit)
        out = self.softmax(out)
        
        return out

### Unit Test

In [11]:
token_to_idx = load_vocab(vocab_file)

In [12]:
embedding = load_word_embedding(token_to_idx)

In [13]:
train = TextIterator(train_file, token_to_idx,
                     batch_size=batch_size,
                     vocab_size=vocab_size,
                     shuffle=True)
valid = TextIterator(valid_file, token_to_idx,
                     batch_size=batch_size,
                     vocab_size=vocab_size,
                     shuffle=False)
test = TextIterator(test_file, token_to_idx,
                    batch_size=batch_size,
                    vocab_size=vocab_size,
                    shuffle=False)
# Text iterator of training set for evaluation
train_eval = TextIterator(train_file, token_to_idx,
                          vocab_size=vocab_size, batch_size=batch_size, shuffle=False)

In [18]:
test_emb = nn.Embedding.from_pretrained(embedding)

In [27]:
test_c, test_c_mask, test_r, test_r_mask, test_target = prepare_data(train.next())

In [29]:
print(test_c.shape)
print(test_c_mask.shape)
print(test_r.shape)
print(test_r_mask.shape)
print(test_target.shape)

(96, 16)
(96, 16)
(41, 16)
(41, 16)
(16,)


In [47]:
test_c_emb = test_emb(torch.tensor(test_c, dtype=torch.int))
test_r_emb = test_emb(torch.tensor(test_r, dtype=torch.int))

In [48]:
print(test_c_emb.shape)
print(test_r_emb.shape)

torch.Size([96, 16, 300])
torch.Size([41, 16, 300])


In [52]:
test_dropout = nn.Dropout(p=1.0)
test_c_emb_drop = test_dropout(test_c_emb)
test_r_emb_drop = test_dropout(test_r_emb)

In [54]:
test_c_emb_last = test_c_emb_drop * torch.unsqueeze(torch.tensor(test_c_mask, dtype=torch.float32), -1)
test_r_emb_last = test_r_emb_drop * torch.unsqueeze(torch.tensor(test_r_mask, dtype=torch.float32), -1)

In [55]:
print(test_c_emb_last.shape)
print(test_r_emb_last.shape)

torch.Size([96, 16, 300])
torch.Size([41, 16, 300])


In [62]:
print(tuple(test_c_emb_last.shape))

(96, 16, 300)


In [76]:
test_bilstm = nn.LSTM(dim_word, hidden_size=hidden_size * 2, bidirectional=True)

In [84]:
test_c_enc, _ = test_bilstm(test_c_emb_last)
test_r_enc, _ = test_bilstm(test_r_emb_last)

In [83]:
test_c_enc.shape

torch.Size([96, 16, 1200])

In [85]:
test_c_enc = test_c_enc * torch.unsqueeze(torch.tensor(test_c_mask, dtype=torch.float32), -1)
test_r_enc = test_r_enc * torch.unsqueeze(torch.tensor(test_r_mask, dtype=torch.float32), -1)

In [86]:
print(test_c_enc.shape)
print(test_r_enc.shape)

torch.Size([96, 16, 1200])
torch.Size([41, 16, 1200])


In [93]:
c = torch.permute(test_c_enc, [1, 0, 2]).contiguous()
c_mask = torch.permute(torch.tensor(test_c_mask), [1, 0]).contiguous()
r = torch.permute(test_r_enc, [1, 0, 2]).contiguous()
r_mask = torch.permute(torch.tensor(test_r_mask), [1, 0]).contiguous()

# attention_weight: [batch_size, seq_length1, seq_length2]
attention_weight = torch.matmul(c, torch.permute(r, [0, 2, 1]).contiguous())

In [94]:
attention_weight.shape

torch.Size([16, 96, 41])

In [100]:
torch.max(attention_weight, , keepdim=True).values.shape

torch.Size([1, 96, 41])

In [90]:
print(type(test_c_enc))
print(type(test_c_mask))
print(type(test_r_enc))
print(type(test_r_mask))

<class 'torch.Tensor'>
<class 'numpy.ndarray'>
<class 'torch.Tensor'>
<class 'numpy.ndarray'>


In [None]:
test_dual1, test_dual2 = local_inference(test_c_enc, torch.tensor(test_c_mask), test_r_enc, torch.tensor(test_r_mask))

In [43]:
torch.cuda.get_device_name(0)

'Tesla P100-PCIE-16GB'

### Training

In [None]:
#model, loss_func, optimizer 선언
model = ESIM().to(devive)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(epochs):
    while True:
        if train.next() is not None:
            for batch_idx, (c, c_mask, r, r_mask, l) in enumerate(train_loader):
                
                pred = model(c, c_mask, r, r_mask)
                loss = criterion(pred, l)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        else:
            break
