In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!/opt/bin/nvidia-smi

Thu Apr 30 15:12:57 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [3]:
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchtext
from torchtext import data, datasets

import numpy as np


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [0]:
import os 
os.chdir('/content/drive/My Drive/EISM')

## 数据

In [0]:
TEXT = data.Field(sequential=True, lower=True, batch_first=True, fix_length=50)
LABEL = data.Field(sequential=False)

train, valid, test = datasets.SNLI.splits(TEXT, LABEL)

TEXT.build_vocab(train, valid, test, vectors='glove.6B.100d')
LABEL.build_vocab(train)

In [0]:
batch_size = 128
train_iter, valid_iter, test_iter = data.BucketIterator.splits(
    datasets=(train, valid, test),
    batch_sizes=(batch_size, batch_size, batch_size),
    shuffle=True)


## 相关配置参数 

In [7]:
class Config:
    def __init__(self):
        self.batch_size = batch_size
        # embedding
        print(TEXT.vocab.vectors.size())
        self.vocab_size = TEXT.vocab.vectors.size()[0]
        self.embedding_dim = TEXT.vocab.vectors.size()[1]
        # lstm
        self.hidden_dim = 200
        self.num_layers = 1
        # fc
        self.linear_size = 200
        self.dropout = 0.3
        self.output_dim = len(LABEL.vocab)

        # train
        self.learning_rate = 1e-3
        self.epochs = 5

        # model
        self.model_path = '.model.pkl'

args = Config()

torch.Size([57324, 100])


In [0]:
class ESIM(nn.Module):
    def __init__(self, args):
        super(ESIM, self).__init__()
        self.args = args
        self.embedding = nn.Embedding(self.args.vocab_size, self.args.embedding_dim).from_pretrained(TEXT.vocab.vectors)
        self.lstm1 = nn.LSTM(
            self.args.embedding_dim,
            self.args.hidden_dim, 
            num_layers=self.args.num_layers, 
            batch_first=True, 
            bidirectional=True)
        self.lstm2 = nn.LSTM(
            self.args.hidden_dim*8, 
            self.args.hidden_dim, 
            num_layers=self.args.num_layers, 
            batch_first=True, 
            bidirectional=True)
        self.fc = nn.Sequential(
            nn.BatchNorm1d(self.args.hidden_dim*8),
            nn.Linear(self.args.hidden_dim*8, self.args.linear_size),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(self.args.linear_size),
            nn.Dropout(self.args.dropout),
            nn.Linear(self.args.linear_size, self.args.linear_size),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(self.args.linear_size),
            nn.Dropout(self.args.dropout),
            nn.Linear(self.args.linear_size, self.args.output_dim),
            nn.Softmax(dim=-1)
            )
        
    def submul(self, x1, x2):
        sub = x1 - x2
        mul = x1 * x2
        return torch.cat([sub, mul], dim=-1)
    
    def apply_multiple(self, x):
        # input [batch_size, sequence_length, 2*hidden_dim]
        p1 = F.avg_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
        p2 = F.max_pool1d(x.transpose(1, 2), x.size(1)).squeeze(-1)
        # output [batch_size, 4*hidden_dim]
        return torch.cat([p1, p2], dim=1)
    
    def soft_attention_align(self, x1, x2, mask1, mask2):
        # x1 x2 [batch_size, sequence_length, 2*hidden_dim]

        # attention [batch_size, sequence_length, sequence_length]
        attention = torch.matmul(x1, x2.transpose(1, 2))
       
        # 放置softmax时出现异常值
        mask1 = mask1.float().masked_fill_(mask1, float('-inf'))
        mask2 = mask2.float().masked_fill_(mask2, float('-inf'))

        # weight [batch_size, sequence_length, sequence_length]
        # x_align [batch_size, sequence_length, 2*hidden_dim]
        weight1 = F.softmax(attention + mask2.unsqueeze(1), dim=-1)
        x1_align = torch.matmul(weight1, x2)
        weight2 = F.softmax(attention.transpose(1,2) + mask1.unsqueeze(1), dim=-1)
        x2_align = torch.matmul(weight2, x1)

        return x1_align, x2_align
    
    def forward(self, sequence1, sequence2):
        # input sequence [batch_size, sequence_length]

        # x [batch_size, sequence_length, embedding_dim]
        x1, x2 = self.embedding(sequence1), self.embedding(sequence2)

        mask1, mask2 = sequence1.eq(0), sequence2.eq(0)
        # print(mask1)
        # print(mask2)
        # out1 out2 [batch_size, sequence_length, 2*hidden_dim]
        out1, _ = self.lstm1(x1)
        out2, _ = self.lstm1(x2)

        # x_align [batch_size, sequence_length, 2*hidden_dim]
        x1_align, x2_align = self.soft_attention_align(out1, out2, mask1, mask2)
    
        # x1 x2 [batch_size, sequence_length, 8*hidden_dim]
        x1 = torch.cat([out1, x1_align, self.submul(out1, x1_align)], dim=-1)
        x2 = torch.cat([out2, x2_align, self.submul(out2, x2_align)], dim=-1)

        # out1 out2 [batch_size, sequence_length, hidden_dim]
        out1, _ = self.lstm2(x1)
        out2, _ = self.lstm2(x2)
        
        # x1 x2 [batch_size, 4*hidden_dim]
        x1 = self.apply_multiple(out1)
        x2 = self.apply_multiple(out2)

        # out [batch_szie, num_classes]
        out = self.fc(torch.cat([x1, x2], dim=-1))
        return out



In [9]:
model = ESIM(args)
model

ESIM(
  (embedding): Embedding(57324, 100)
  (lstm1): LSTM(100, 200, batch_first=True, bidirectional=True)
  (lstm2): LSTM(1600, 200, batch_first=True, bidirectional=True)
  (fc): Sequential(
    (0): BatchNorm1d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Linear(in_features=1600, out_features=200, bias=True)
    (2): ELU(alpha=1.0, inplace=True)
    (3): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Dropout(p=0.3, inplace=False)
    (5): Linear(in_features=200, out_features=200, bias=True)
    (6): ELU(alpha=1.0, inplace=True)
    (7): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=200, out_features=4, bias=True)
    (10): Softmax(dim=-1)
  )
)

## 训练

In [0]:
def train(model, data_iter, loss_fn, optimizer):
    model.train()
    total_loss = 0
    total_accuracy = 0
    total_train_num = len(data_iter.dataset)
    for i, batch in enumerate(data_iter):
        x1 = batch.premise.to(device)
        x2 = batch.hypothesis.to(device)
        label = batch.label.to(device)

        y_pred = model(x1, x2)

        optimizer.zero_grad()
        loss = loss_fn(y_pred, label)
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        total_loss += batch_loss
        batch_accuracy = (torch.argmax(y_pred, dim=1)==label).sum().item()
        total_accuracy += batch_accuracy

        if i%200==0:
            print('Batch_{}, Train Loss:{}, Accuracy:{}'.format(i, batch_loss/len(label), batch_accuracy/len(label)))

    return total_loss/total_train_num, total_accuracy/total_train_num
        

def valid(model, data_iter, loss_fn):
    model.eval()
    total_loss = 0
    total_accuracy = 0
    total_valid_num = len(data_iter.dataset)
    for i, batch in enumerate(data_iter):
        x1 = batch.premise.to(device)
        x2 = batch.hypothesis.to(device)
        label = batch.label.to(device)

        y_pred = model(x1, x2)

        loss = loss_fn(y_pred, label)
        
        batch_loss = loss.item()
        total_loss += batch_loss
        batch_accuracy = (torch.argmax(y_pred, dim=1)==label).sum().item()
        total_accuracy += batch_accuracy

    return total_loss/total_valid_num, total_accuracy/total_valid_num





In [12]:
model = ESIM(args).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
loss_fn = nn.CrossEntropyLoss()

for e in range(args.epochs):
    print('---------------Epoch_{}---------------'.format(e))
    train_loss, train_accuracy = train(model, train_iter, loss_fn, optimizer)
    valid_loss, valid_accuracy = valid(model, valid_iter, loss_fn)
    print('>>>Epoch_{}\n>>>Train Loss:{}, Accuracy:{}'.format(e, train_loss, train_accuracy))
    print('>>>Valid Loss:{}, Accuracy:{}\n'.format(valid_loss, valid_accuracy))


---------------Epoch_0---------------
Batch_0, Train Loss:0.0109553849324584, Accuracy:0.2421875
Batch_200, Train Loss:0.009341755881905556, Accuracy:0.515625
Batch_400, Train Loss:0.00895842257887125, Accuracy:0.578125
Batch_600, Train Loss:0.008931586518883705, Accuracy:0.6171875
Batch_800, Train Loss:0.00866751279681921, Accuracy:0.6328125
Batch_1000, Train Loss:0.008087384514510632, Accuracy:0.703125
Batch_1200, Train Loss:0.008174305781722069, Accuracy:0.6796875
Batch_1400, Train Loss:0.00828133150935173, Accuracy:0.671875
Batch_1600, Train Loss:0.008321967907249928, Accuracy:0.671875
Batch_1800, Train Loss:0.008685889653861523, Accuracy:0.625
Batch_2000, Train Loss:0.008673866279423237, Accuracy:0.6171875
Batch_2200, Train Loss:0.008319648914039135, Accuracy:0.6796875
Batch_2400, Train Loss:0.008317834697663784, Accuracy:0.6640625
Batch_2600, Train Loss:0.007890625856816769, Accuracy:0.734375
Batch_2800, Train Loss:0.008319636806845665, Accuracy:0.6796875
Batch_3000, Train Loss:0

In [0]:
torch.save(model.state_dict(), args.model_path)

In [0]:
model = ESIM(args)
model.load_state_dict(torch.load(args.model_path))