In [1]:
# https://zhuanlan.zhihu.com/p/47580077
import torch
import pickle
from torch.utils.data import DataLoader
from torch import nn
import time 

worddict_dir = "data/worddict.txt"
data_train_id_dir = "data/train_data_id.pkl"
data_dev_id_dir = "data/dev_data_id.pkl"
embedding_matrix_dir = "data/embedding_matrix.pkl"
model_train_dir = "saved_model/train_model_"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
from torch.utils.data import Dataset

class SnliDataSet(Dataset):
    def __init__(self, data, max_premise_len=None, max_hypothesis_len=None):
        #序列长度
        self.num_sequence = len(data["premise_id"])
        
        #创建tensor矩阵的尺寸
        self.premise_len = [len(seq) for seq in data["premise_id"]]
        self.max_premise_len = max_premise_len
        if self.max_premise_len is None:
            self.max_premise_len = max(self.premise_len)
        
        self.hypothesis_len = [len(seq) for seq in data["hypothesis_id"]]
        self.max_hypothesis_len = max_hypothesis_len
        if max_hypothesis_len is None:
            self.max_hypothesis_len = max(self.hypothesis_len)

        #转成tensor，封装到data里
        self.data = {
            "premise": torch.zeros((self.num_sequence, self.max_premise_len), dtype=torch.long),
            "hypothesis": torch.zeros((self.num_sequence, self.max_hypothesis_len), dtype=torch.long),
            "labels": torch.tensor(data["labels_id"])
        }
        
        for i, premise in enumerate(data["premise_id"]):
            l = len(data["premise_id"][i])
            self.data["premise"][i][:l] = torch.tensor(data["premise_id"][i][:l])
            l2 = len(data["hypothesis_id"][i])
            self.data["hypothesis"][i][:l2] = torch.tensor(data["hypothesis_id"][i][:l2])
        
    def __len__(self):
        return self.num_sequence
        
    def __getitem__(self, index):
        return { 
            "premise": self.data["premise"][index],
            "premise_len": min(self.premise_len[index], self.max_premise_len),
            "hypothesis": self.data["hypothesis"][index],
            "hypothesis_len": min(self.hypothesis_len[index], self.max_hypothesis_len),
            "labels": self.data["labels"][index]
        }

In [57]:
#超参数
batch_size = 512
patience = 5
hidden_size = 50
dropout = 0.5
num_classes = 3
lr = 0.0004
epochs = 1
max_grad_norm = 10.0

In [58]:
# 加载数据
with open(data_train_id_dir, 'rb') as f:
    train_data = SnliDataSet(pickle.load(f), max_premise_len=None, max_hypothesis_len=None)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

with open(data_dev_id_dir,'rb') as f:
    dev_data = SnliDataSet(pickle.load(f), max_premise_len=None, max_hypothesis_len=None)
dev_loader = DataLoader(dev_data, batch_size=batch_size, shuffle=False)

#加载embedding
with open(embedding_matrix_dir,'rb') as f:
    embeddings = torch.tensor(pickle.load(f), dtype=torch.float32).to(device) # 一定要是torch.float32

In [52]:
embeddings.shape, embeddings.dtype

(torch.Size([33268, 50]), torch.float32)

In [59]:
from model.esim import ESIM

model = ESIM(embeddings.shape[0], # 词汇表中单词个数
             embeddings.shape[1], # 词向量维度
             hidden_size,
             embeddings=embeddings,
             dropout=dropout,
             num_classes=num_classes, # 输出为几类
             device=device).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=0)

In [60]:
def getCorrectNum(probs, targets):
    _, out_classes = probs.max(dim=1) # 值，下标
    correct = (out_classes == targets).sum()
    return correct.item()

def train(model, data_loader, optimizer, criterion, max_gradient_norm):
    model.train()
    device = model.device
    time_epoch_start = time.time()
    running_loss = 0 
    correct_cnt = 0
    batch_cnt = 0

    for index, batch in enumerate(data_loader):
        time_batch_start = time.time()
        premises = batch["premise"].to(device)
        premises_len = batch["premise_len"].to(device)
        hypothesis = batch["hypothesis"].to(device)
        hypothesis_len = batch["hypothesis_len"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        logits,probs = model(premises, premises_len, hypothesis, hypothesis_len)
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
        optimizer.step()
        
        running_loss += loss.item()
        correct_cnt += getCorrectNum(probs, labels)
        batch_cnt += 1
        print("Training  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt, len(data_loader), time.time() - time_batch_start, running_loss / (index + 1)))
        
    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time, epoch_loss, epoch_accuracy

In [61]:
def validate(model, data_loader, criterion):
    model.eval()
    device = model.device
    time_epoch_start = time.time()
    running_loss = 0 
    correct_cnt = 0
    batch_cnt = 0
    for index, batch in enumerate(data_loader):
        time_batch_start = time.time()
        premises = batch["premise"].to(device)
        premises_len = batch["premise_len"].to(device)
        hypothesis = batch["hypothesis"].to(device)
        hypothesis_len = batch["hypothesis_len"].to(device)
        labels = batch["labels"].to(device)

        logits, probs = model(premises, premises_len, hypothesis, hypothesis_len)
        loss = criterion(logits, labels)
        running_loss += loss.item()
        correct_cnt += getCorrectNum(probs, labels)
        
        batch_cnt += 1
        print("Testing  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt, len(data_loader), time.time() - time_batch_start, running_loss / (index + 1)))

    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time, epoch_loss, epoch_accuracy

In [62]:
#训练过程中的参数
best_score = 0.0
train_losses = []
valid_losses = []
patience_cnt = 0

for epoch in range(epochs):
    print("-"*50, "Training epoch %d" % (epoch), "-"*50)
    epoch_time, epoch_loss, epoch_accuracy = train(model, train_loader, optimizer, criterion,max_grad_norm)
    train_losses.append(epoch_loss)
    print("Training time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%".format(epoch_time, epoch_loss, (epoch_accuracy*100)))
    
    print("-"*50, "Validating epoch %d"%(epoch), "-"*50)
    epoch_time_dev, epoch_loss_dev, epoch_accuracy_dev = validate(model, dev_loader, criterion)
    valid_losses.append(epoch_loss_dev)
    print("Validating time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n".format(epoch_time_dev, epoch_loss_dev, (epoch_accuracy_dev*100)))
    
    #更新学习率
    scheduler.step(epoch_accuracy)
    
    #early stoping
    if epoch_accuracy_dev < best_score:
        patience_cnt += 1
    else:
        best_score = epoch_accuracy_dev
        patience_cnt = 0
    if patience_cnt >= patience:
        print("-"*50, "Early stopping", "-"*50)
        break
        
    #每个epoch都保存模型
    torch.save({"epoch": epoch,
                "model": model.state_dict(),
                "best_score": best_score,
                "train_losses": train_losses,
                "valid_losses": valid_losses},
               model_train_dir + str(epoch))

-------------------------------------------------- Training epoch 0 --------------------------------------------------
Training  ------>   Batch count: 1/1073,  batch time: 0.3425s,  batch average loss: 1.1486
Training  ------>   Batch count: 2/1073,  batch time: 0.3304s,  batch average loss: 1.1418
Training  ------>   Batch count: 3/1073,  batch time: 0.3307s,  batch average loss: 1.1345
Training  ------>   Batch count: 4/1073,  batch time: 0.3248s,  batch average loss: 1.1344
Training  ------>   Batch count: 5/1073,  batch time: 0.3895s,  batch average loss: 1.1350
Training  ------>   Batch count: 6/1073,  batch time: 0.3151s,  batch average loss: 1.1361
Training  ------>   Batch count: 7/1073,  batch time: 0.3178s,  batch average loss: 1.1372
Training  ------>   Batch count: 8/1073,  batch time: 0.3164s,  batch average loss: 1.1367
Training  ------>   Batch count: 9/1073,  batch time: 0.3658s,  batch average loss: 1.1360
Training  ------>   Batch count: 10/1073,  batch time: 0.3938s

Training  ------>   Batch count: 176/1073,  batch time: 0.3093s,  batch average loss: 1.1017
Training  ------>   Batch count: 177/1073,  batch time: 0.2878s,  batch average loss: 1.1016
Training  ------>   Batch count: 178/1073,  batch time: 0.3558s,  batch average loss: 1.1016
Training  ------>   Batch count: 179/1073,  batch time: 0.3428s,  batch average loss: 1.1015
Training  ------>   Batch count: 180/1073,  batch time: 0.3388s,  batch average loss: 1.1015
Training  ------>   Batch count: 181/1073,  batch time: 0.2892s,  batch average loss: 1.1014
Training  ------>   Batch count: 182/1073,  batch time: 0.4042s,  batch average loss: 1.1013
Training  ------>   Batch count: 183/1073,  batch time: 0.3241s,  batch average loss: 1.1013
Training  ------>   Batch count: 184/1073,  batch time: 0.3288s,  batch average loss: 1.1012
Training  ------>   Batch count: 185/1073,  batch time: 0.3092s,  batch average loss: 1.1012
Training  ------>   Batch count: 186/1073,  batch time: 0.3608s,  batc

Training  ------>   Batch count: 352/1073,  batch time: 0.3248s,  batch average loss: 1.0873
Training  ------>   Batch count: 353/1073,  batch time: 0.3268s,  batch average loss: 1.0872
Training  ------>   Batch count: 354/1073,  batch time: 0.3602s,  batch average loss: 1.0871
Training  ------>   Batch count: 355/1073,  batch time: 0.3258s,  batch average loss: 1.0870
Training  ------>   Batch count: 356/1073,  batch time: 0.3608s,  batch average loss: 1.0869
Training  ------>   Batch count: 357/1073,  batch time: 0.3038s,  batch average loss: 1.0868
Training  ------>   Batch count: 358/1073,  batch time: 0.3358s,  batch average loss: 1.0868
Training  ------>   Batch count: 359/1073,  batch time: 0.3088s,  batch average loss: 1.0866
Training  ------>   Batch count: 360/1073,  batch time: 0.2771s,  batch average loss: 1.0866
Training  ------>   Batch count: 361/1073,  batch time: 0.3808s,  batch average loss: 1.0865
Training  ------>   Batch count: 362/1073,  batch time: 0.3378s,  batc

Training  ------>   Batch count: 528/1073,  batch time: 0.2969s,  batch average loss: 1.0732
Training  ------>   Batch count: 529/1073,  batch time: 0.3125s,  batch average loss: 1.0731
Training  ------>   Batch count: 530/1073,  batch time: 0.3750s,  batch average loss: 1.0731
Training  ------>   Batch count: 531/1073,  batch time: 0.3437s,  batch average loss: 1.0730
Training  ------>   Batch count: 532/1073,  batch time: 0.3393s,  batch average loss: 1.0729
Training  ------>   Batch count: 533/1073,  batch time: 0.3594s,  batch average loss: 1.0729
Training  ------>   Batch count: 534/1073,  batch time: 0.3125s,  batch average loss: 1.0728
Training  ------>   Batch count: 535/1073,  batch time: 0.3125s,  batch average loss: 1.0728
Training  ------>   Batch count: 536/1073,  batch time: 0.2968s,  batch average loss: 1.0727
Training  ------>   Batch count: 537/1073,  batch time: 0.4375s,  batch average loss: 1.0726
Training  ------>   Batch count: 538/1073,  batch time: 0.3549s,  batc

Training  ------>   Batch count: 704/1073,  batch time: 0.2969s,  batch average loss: 1.0594
Training  ------>   Batch count: 705/1073,  batch time: 0.3281s,  batch average loss: 1.0594
Training  ------>   Batch count: 706/1073,  batch time: 0.3437s,  batch average loss: 1.0593
Training  ------>   Batch count: 707/1073,  batch time: 0.3437s,  batch average loss: 1.0593
Training  ------>   Batch count: 708/1073,  batch time: 0.3782s,  batch average loss: 1.0592
Training  ------>   Batch count: 709/1073,  batch time: 0.3035s,  batch average loss: 1.0591
Training  ------>   Batch count: 710/1073,  batch time: 0.2812s,  batch average loss: 1.0590
Training  ------>   Batch count: 711/1073,  batch time: 0.3125s,  batch average loss: 1.0589
Training  ------>   Batch count: 712/1073,  batch time: 0.2812s,  batch average loss: 1.0589
Training  ------>   Batch count: 713/1073,  batch time: 0.3437s,  batch average loss: 1.0588
Training  ------>   Batch count: 714/1073,  batch time: 0.3750s,  batc

Training  ------>   Batch count: 880/1073,  batch time: 0.4076s,  batch average loss: 1.0459
Training  ------>   Batch count: 881/1073,  batch time: 0.3437s,  batch average loss: 1.0459
Training  ------>   Batch count: 882/1073,  batch time: 0.2969s,  batch average loss: 1.0458
Training  ------>   Batch count: 883/1073,  batch time: 0.2824s,  batch average loss: 1.0458
Training  ------>   Batch count: 884/1073,  batch time: 0.3281s,  batch average loss: 1.0457
Training  ------>   Batch count: 885/1073,  batch time: 0.3593s,  batch average loss: 1.0456
Training  ------>   Batch count: 886/1073,  batch time: 0.3549s,  batch average loss: 1.0456
Training  ------>   Batch count: 887/1073,  batch time: 0.3281s,  batch average loss: 1.0455
Training  ------>   Batch count: 888/1073,  batch time: 0.2969s,  batch average loss: 1.0454
Training  ------>   Batch count: 889/1073,  batch time: 0.2969s,  batch average loss: 1.0453
Training  ------>   Batch count: 890/1073,  batch time: 0.3437s,  batc

Training  ------>   Batch count: 1055/1073,  batch time: 0.3246s,  batch average loss: 1.0333
Training  ------>   Batch count: 1056/1073,  batch time: 0.3750s,  batch average loss: 1.0333
Training  ------>   Batch count: 1057/1073,  batch time: 0.2969s,  batch average loss: 1.0332
Training  ------>   Batch count: 1058/1073,  batch time: 0.3125s,  batch average loss: 1.0332
Training  ------>   Batch count: 1059/1073,  batch time: 0.3281s,  batch average loss: 1.0331
Training  ------>   Batch count: 1060/1073,  batch time: 0.2969s,  batch average loss: 1.0330
Training  ------>   Batch count: 1061/1073,  batch time: 0.3080s,  batch average loss: 1.0329
Training  ------>   Batch count: 1062/1073,  batch time: 0.3125s,  batch average loss: 1.0329
Training  ------>   Batch count: 1063/1073,  batch time: 0.4326s,  batch average loss: 1.0328
Training  ------>   Batch count: 1064/1073,  batch time: 0.3448s,  batch average loss: 1.0327
Training  ------>   Batch count: 1065/1073,  batch time: 0.2

KeyboardInterrupt: 

In [12]:
class RNNDropout(nn.Dropout): # 这个继承自带了self.training
    """
    Dropout layer for the inputs of RNNs.

    Apply the same dropout mask to all the elements of the same sequence in
    a batch of sequences of size (batch, sequences_length, embedding_dim).
    """
    def forward(self, sequences_batch):
        """
        Apply dropout to the input batch of sequences.

        Args:
            sequences_batch: A batch of sequences of vectors that will serve
                as input to an RNN.
                Tensor of size (batch, sequences_length, embedding_dim).

        Returns:
            A new tensor on which dropout has been applied.
        """
        ones = sequences_batch.data.new_ones(sequences_batch.shape[0], 1, sequences_batch.shape[-1])
        print("ones", ones)
        dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) # 最好用nn.Dropout
        print("drop", dropout_mask)
        return dropout_mask * sequences_batch

In [13]:
rnn = RNNDropout()

In [17]:
a = torch.randn(1,2,5)

In [18]:
rnn(a)

ones tensor([[[1., 1., 1., 1., 1.]]])
drop tensor([[[0., 0., 0., 2., 0.]]])


tensor([[[-0.0000, -0.0000, -0.0000,  2.3526, -0.0000],
         [-0.0000, -0.0000, -0.0000, -0.6429, -0.0000]]])

In [24]:
a = torch.tensor([[1,3],[4,2],[5,6]])

In [25]:
a.sort(0)

torch.return_types.sort(
values=tensor([[1, 2],
        [4, 3],
        [5, 6]]),
indices=tensor([[0, 1],
        [1, 0],
        [2, 2]]))