In [1]:
import torch
import pickle
from torch.utils.data import DataLoader
from model.SnliDataSet import SnliDataSet

worddict_dir="data\\worddict.txt"
data_train_id_dir="data\\train_data_id.pkl"
data_dev_id_dir="data\\dev_data_id.pkl"
embedding_matrix_dir="data\\embedding_matrix.pkl"

model_train_dir="saved_model\\train_model_"

#超参数
batch_size=1
use_gpu=True
patience=5

device=torch.device("cuda:0" if use_gpu else "cpu")

In [2]:
hidden_size=50
dropout=0.5
num_classes=3
lr=0.0004
epochs=2
max_grad_norm=10.0

In [3]:
# 加载数据
with open(data_train_id_dir,'rb') as f:
    train_data=SnliDataSet(pickle.load(f),max_premises_len=None,max_hypothesis_len=None)
train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True)

with open(data_dev_id_dir,'rb') as f:
    dev_data=SnliDataSet(pickle.load(f),max_premises_len=None,max_hypothesis_len=None)
dev_loader=DataLoader(dev_data,batch_size=batch_size,shuffle=False)

#加载embedding
with open(embedding_matrix_dir,'rb') as f:
    embeddings=torch.tensor(pickle.load(f),dtype=torch.float).to(device)

In [5]:
from model.esim import ESIM
model = ESIM(embeddings.shape[0],
             embeddings.shape[1],
             hidden_size,
             embeddings=embeddings,
             dropout=dropout,
             num_classes=num_classes,
             device=device).to(device)

In [6]:
#准备训练
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode="max",factor=0.5,patience=0)


In [95]:
def getCorrectNum(probs, targets):
    _, out_classes = probs.max(dim=1)
    correct = (out_classes == targets).sum()
    return correct.item()

def train(model, data_loader, optimizer, criterion, max_gradient_norm):
    model.train()
    device=model.device
    
    time_epoch_start= time.time()
    running_loss=0 
    correct_cnt=0
    batch_cnt=0
    
    for index,batch in enumerate(data_loader):
        time_batch_start=time.time()
        #从data_loader中取出数据
        premises=batch["premises"].to(device)
        premises_len=batch["premises_len"].to(device)
        hypothesis=batch["hypothesis"].to(device)
        hypothesis_len=batch["hypothesis_len"].to(device)
        labels=batch["labels"].to(device)
        
        #梯度置0
        optimizer.zero_grad()
        
        #正向传播
        logits,probs=model(premises,premises_len,hypothesis,hypothesis_len)

        #求损失，反向传播，梯度裁剪，更新权重
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
        optimizer.step()
        
        running_loss+=loss.item()
        correct_cnt+=getCorrectNum(probs,labels)
        batch_cnt+=1
        print("Training  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt,len(data_loader),time.time()-time_batch_start, running_loss/(index+1)))
        
    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time,epoch_loss,epoch_accuracy



In [99]:
def validate(model, data_loader, criterion):
    model.eval()
    device=model.device
    
    time_epoch_start= time.time()
    running_loss=0 
    correct_cnt=0
    batch_cnt=0

    for index,batch in enumerate(data_loader):
        time_batch_start=time.time()
        #从data_loader中取出数据
        premises=batch["premises"].to(device)
        premises_len=batch["premises_len"].to(device)
        hypothesis=batch["hypothesis"].to(device)
        hypothesis_len=batch["hypothesis_len"].to(device)
        labels=batch["labels"].to(device)
        
        
        #正向传播
        logits,probs=model(premises,premises_len,hypothesis,hypothesis_len)

        #求损失
        loss = criterion(logits, labels)
        
        running_loss+=loss.item()
        correct_cnt+=getCorrectNum(probs,labels)
        batch_cnt+=1
        print("Testing  ------>   Batch count: {:d}/{:d},  batch time: {:.4f}s,  batch average loss: {:.4f}"
              .format(batch_cnt,len(data_loader),time.time()-time_batch_start, running_loss/(index+1)))
        
    epoch_time = time.time() - time_epoch_start
    epoch_loss = running_loss / len(data_loader)
    epoch_accuracy = correct_cnt / len(data_loader.dataset) 
    return epoch_time,epoch_loss,epoch_accuracy



In [101]:
#训练过程中的参数
best_score=0.0
train_losses=[]
valid_losses=[]
patience_cnt=0

for epoch in range(epochs):
    #训练
    print("-"*50,"Training epoch %d"%(epoch),"-"*50)
    epoch_time,epoch_loss,epoch_accuracy =train(model,train_loader,optimizer,criterion,max_grad_norm)
    train_losses.append(epoch_loss)
    print("Training time: {:.4f}s, loss :{:.4f}, accuracy: {:.4f}%".format(epoch_time, epoch_loss, (epoch_accuracy*100)))
    
    #验证
    print("-"*50,"Validating epoch %d"%(epoch),"-"*50)
    epoch_time_dev, epoch_loss_dev, epoch_accuracy_dev = validate(model,dev_loader,criterion)
    valid_losses.append(epoch_loss_dev)
    print("Validating time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n".format(epoch_time_dev, epoch_loss_dev, (epoch_accuracy_dev*100)))
    
    #更新学习率
    scheduler.step(epoch_accuracy)
    
    #early stoping
    if epoch_accuracy_dev< best_score:
        patience_cnt+=1
    else:
        best_score=epoch_accuracy_dev
        patience_cnt=0
    if patience_cnt>=patience:
            print("-"*50,"Early stopping","-"*50)
            break
        
    #每个epoch都保存模型
    torch.save({"epoch": epoch,
                "model": model.state_dict(),
                "best_score": best_score,
                "train_losses": train_losses,
                "valid_losses": valid_losses},
               model_train_dir+str(epoch)+".dir")

-------------------------------------------------- Training epoch 0 --------------------------------------------------
Training  ------>   Batch count: 1/549367,  batch time: 0.2719s,  batch average loss: 1.3751
Training  ------>   Batch count: 2/549367,  batch time: 0.0379s,  batch average loss: 1.2290
Training  ------>   Batch count: 3/549367,  batch time: 0.0409s,  batch average loss: 1.2561
Training  ------>   Batch count: 4/549367,  batch time: 0.0389s,  batch average loss: 1.1968
Training  ------>   Batch count: 5/549367,  batch time: 0.0481s,  batch average loss: 1.1857
Training  ------>   Batch count: 6/549367,  batch time: 0.0382s,  batch average loss: 1.1873
Training  ------>   Batch count: 7/549367,  batch time: 0.0389s,  batch average loss: 1.1692
Training  ------>   Batch count: 8/549367,  batch time: 0.0366s,  batch average loss: 1.1703
Training  ------>   Batch count: 9/549367,  batch time: 0.0387s,  batch average loss: 1.1602
Training  ------>   Batch count: 10/549367, 

Training  ------>   Batch count: 86/549367,  batch time: 0.0380s,  batch average loss: 1.1175
Training  ------>   Batch count: 87/549367,  batch time: 0.0416s,  batch average loss: 1.1164
Training  ------>   Batch count: 88/549367,  batch time: 0.0379s,  batch average loss: 1.1160
Training  ------>   Batch count: 89/549367,  batch time: 0.0417s,  batch average loss: 1.1155
Training  ------>   Batch count: 90/549367,  batch time: 0.0435s,  batch average loss: 1.1155
Training  ------>   Batch count: 91/549367,  batch time: 0.0521s,  batch average loss: 1.1147
Training  ------>   Batch count: 92/549367,  batch time: 0.0417s,  batch average loss: 1.1151
Training  ------>   Batch count: 93/549367,  batch time: 0.0396s,  batch average loss: 1.1154
Training  ------>   Batch count: 94/549367,  batch time: 0.0387s,  batch average loss: 1.1148
Training  ------>   Batch count: 95/549367,  batch time: 0.0428s,  batch average loss: 1.1147
Training  ------>   Batch count: 96/549367,  batch time: 0.0

Testing  ------>   Batch count: 72/9842,  batch time: 0.0130s,  batch average loss: 1.0959
Testing  ------>   Batch count: 73/9842,  batch time: 0.0156s,  batch average loss: 1.0955
Testing  ------>   Batch count: 74/9842,  batch time: 0.0160s,  batch average loss: 1.0959
Testing  ------>   Batch count: 75/9842,  batch time: 0.0140s,  batch average loss: 1.0958
Testing  ------>   Batch count: 76/9842,  batch time: 0.0146s,  batch average loss: 1.0957
Testing  ------>   Batch count: 77/9842,  batch time: 0.0136s,  batch average loss: 1.0958
Testing  ------>   Batch count: 78/9842,  batch time: 0.0125s,  batch average loss: 1.0965
Testing  ------>   Batch count: 79/9842,  batch time: 0.0136s,  batch average loss: 1.0964
Testing  ------>   Batch count: 80/9842,  batch time: 0.0146s,  batch average loss: 1.0963
Testing  ------>   Batch count: 81/9842,  batch time: 0.0134s,  batch average loss: 1.0970
Testing  ------>   Batch count: 82/9842,  batch time: 0.0145s,  batch average loss: 1.0969

Training  ------>   Batch count: 57/549367,  batch time: 0.0479s,  batch average loss: 1.0853
Training  ------>   Batch count: 58/549367,  batch time: 0.0378s,  batch average loss: 1.0856
Training  ------>   Batch count: 59/549367,  batch time: 0.0415s,  batch average loss: 1.0859
Training  ------>   Batch count: 60/549367,  batch time: 0.0392s,  batch average loss: 1.0871
Training  ------>   Batch count: 61/549367,  batch time: 0.0381s,  batch average loss: 1.0860
Training  ------>   Batch count: 62/549367,  batch time: 0.0378s,  batch average loss: 1.0852
Training  ------>   Batch count: 63/549367,  batch time: 0.0407s,  batch average loss: 1.0847
Training  ------>   Batch count: 64/549367,  batch time: 0.0365s,  batch average loss: 1.0843
Training  ------>   Batch count: 65/549367,  batch time: 0.0417s,  batch average loss: 1.0850
Training  ------>   Batch count: 66/549367,  batch time: 0.0403s,  batch average loss: 1.0829
Training  ------>   Batch count: 67/549367,  batch time: 0.0

Testing  ------>   Batch count: 42/9842,  batch time: 0.0140s,  batch average loss: 1.0991
Testing  ------>   Batch count: 43/9842,  batch time: 0.0165s,  batch average loss: 1.0992
Testing  ------>   Batch count: 44/9842,  batch time: 0.0116s,  batch average loss: 1.0983
Testing  ------>   Batch count: 45/9842,  batch time: 0.0142s,  batch average loss: 1.0985
Testing  ------>   Batch count: 46/9842,  batch time: 0.0132s,  batch average loss: 1.0987
Testing  ------>   Batch count: 47/9842,  batch time: 0.0140s,  batch average loss: 1.0983
Testing  ------>   Batch count: 48/9842,  batch time: 0.0130s,  batch average loss: 1.0979
Testing  ------>   Batch count: 49/9842,  batch time: 0.0118s,  batch average loss: 1.0983
Testing  ------>   Batch count: 50/9842,  batch time: 0.0130s,  batch average loss: 1.0982
Testing  ------>   Batch count: 51/9842,  batch time: 0.0170s,  batch average loss: 1.0975
Testing  ------>   Batch count: 52/9842,  batch time: 0.0128s,  batch average loss: 1.0976