In [1]:
!pip install pytorch_pretrained_bert
!pip install pyrouge
# !pip install tensorboardX
# !pip install multiprocess

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 27.6MB/s eta 0:00:01[K     |█████▎                          | 20kB 2.0MB/s eta 0:00:01[K     |████████                        | 30kB 2.7MB/s eta 0:00:01[K     |██████████▋                     | 40kB 2.9MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 2.5MB/s eta 0:00:01[K     |███████████████▉                | 61kB 2.7MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 3.0MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 3.3MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 3.5MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 3.3MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 3.3MB/s eta 0:00:01[K     |██████████████████████

In [2]:
import pytorch_pretrained_bert
import math

import torch
import torch.nn as nn
import os
#from tensorboardX import SummaryWriter

import time
#import distributed

import gc
import glob
#import hashlib
#import itertools
import json
import re
import subprocess
import time
from os.path import join as pjoin

#from multiprocess import Pool
from pytorch_pretrained_bert import BertTokenizer
import numpy as np
import random 


In [3]:
!mkdir bert_data
!mkdir results
!mkdir tmp


# Encoder

In [4]:
'''
# verified
torch.seed()
pe = PositionalEncoding(4,.8,2)
a=torch.tensor([[1,0,0,1],[1,1,0,0]])
pe(a),pe.getPositonalEncoding(a)
torch.seed()
pe = PositionalEncoding2(.8,4,2)
a=torch.tensor([[1,0,0,1],[1,1,0,0]])
pe(a),pe.get_emb(a)
'''
class PositionalEncoding(nn.Module):
    def __init__(self,embeddingDim,dropout,maxWordLen=5000):
        positionalEncoding = torch.zeros(maxWordLen,embeddingDim)
        # keep dim 1 array a singleton
        position = torch.arange(0,maxWordLen).unsqueeze(1)
        # positional encoding is defined to be 
        # PE(pos,2i)=sin(pos/1e4^(2i/embedding_dim))
        # PE(pos,2i+1)=cos(pos/1e4^(2i/embedding_dim))
        exponentTerm = torch.arange(0,embeddingDim,2, dtype=torch.float)\
                *(-math.log(1e4*1.0)/embeddingDim)
        divisionTerm = torch.exp(exponentTerm)
        # all even indices
        positionalEncoding[:,0::2] = torch.sin(position.float()*divisionTerm)
        # all odd indices
        positionalEncoding[:,1::2] = torch.cos(position.float()*divisionTerm)
        # keep dim 0 array size 1 --> a single array
        positionalEncoding = positionalEncoding.unsqueeze(0)
        super(PositionalEncoding, self).__init__()
        self.register_buffer('positionalEncoding', positionalEncoding)
        #dropout is probability of dropout
        self.dropout = nn.Dropout(p=dropout)
        self.embeddingDim = embeddingDim
        
    def forward(self,embedding):
        # optional: add step -- not sure what it does
        # for dropout?
        embedding = embedding * math.sqrt(self.embeddingDim)
        embedding = embedding + self.positionalEncoding[:, :embedding.size(1)]
        embedding = self.dropout(embedding)
        return embedding
    def getPositonalEncoding(self,embedding):
        return self.positionalEncoding[:, :embedding.size(1)]

In [5]:
'''
# correctness confirmed
torch.random.manual_seed(42)
pff=PositionwiseFeedForward2(4,2,.1)
a=torch.tensor([[1.0,0,0,1],[1.0,1,0,0]])
pff(a)
torch.random.manual_seed(42)
pff=PositionwiseFeedForward(4,2,.1)
a=torch.tensor([[1.0,0,0,1],[1.0,1,0,0]])
pff(a)
'''
class PositionwiseFeedForward(nn.Module):
    '''
    A two-layer Feed-Forward-Network with residual layer norm.
    '''
    
    def __init__(self, modelDim, feedforwardDim, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(modelDim, feedforwardDim)
        self.linear2 = nn.Linear(feedforwardDim, modelDim)
        self.layerNorm = nn.LayerNorm(modelDim, eps=1e-6)
        # activation function
        self.gelu = lambda x: \
                0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        layerNorm = self.layerNorm(x)
        linear1 = self.linear1(layerNorm)
        gelu = self.gelu(linear1)
        hidden = self.dropout1(gelu)
        #hidden = self.dropout1(self.gelu(self.linear1(self.layerNorm(x))))
        output = self.dropout2(self.linear2(hidden))
        return output + x

In [6]:
def exampleOfMultiHeadedAttention():
    torch.random.manual_seed(42)
    m=MultiHeadedAttention(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'selfKeys': torch.tensor([[[[ 1.7056, -0.2705]],\
    [[ 0.8711,  1.2850]]]]), 'selfValues': torch.tensor([[[[ 0.0890, -0.4489]],\
    [[-0.4343,  0.5302]]]])}
    #m(a,b,c,layerCache={"memoryKeys":None},types="context")
    m(a,b,c,layerCache=d,types="self")
    torch.random.manual_seed(42)
    m=MultiHeadedAttention2(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'self_keys': torch.tensor([[[[ 1.7056, -0.2705]],\
    [[ 0.8711,  1.2850]]]]), 'self_values': torch.tensor([[[[ 0.0890, -0.4489]],\
    [[-0.4343,  0.5302]]]])}
    m(a,b,c,layer_cache=d,type="self")


    
    torch.random.manual_seed(42)
    m=MultiHeadedAttention(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'memoryKeys': torch.tensor([[[[ 1.7056, -0.2705]],\
    [[ 0.8711,  1.2850]]]]), 'memoryValues': torch.tensor([[[[ 0.0890, -0.4489]],\
    [[-0.4343,  0.5302]]]])}
    #m(a,b,c,layerCache={"memoryKeys":None},types="context")
    m(a,b,c,layerCache=d,types="context")
    torch.random.manual_seed(42)
    m=MultiHeadedAttention2(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'memory_keys': torch.tensor([[[[ 1.7056, -0.2705]],\
    [[ 0.8711,  1.2850]]]]), 'memory_values': torch.tensor([[[[ 0.0890, -0.4489]],\
    [[-0.4343,  0.5302]]]])}
    m(a,b,c,layer_cache=d,type="context")
    
    
    torch.random.manual_seed(42)
    m=MultiHeadedAttention2(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'memory_keys': None, 'memory_values': None}
    m(a,b,c,layer_cache=d,type="context")
    torch.random.manual_seed(42)
    m=MultiHeadedAttention(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'memoryKeys': None, 'memoryValues': None}
    m(a,b,c,layerCache=d,types="context")
    
    
    torch.random.manual_seed(42)
    m=MultiHeadedAttention2(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'self_keys': None, 'self_values': None}
    m(a,b,c,layer_cache=d,type="self")
    torch.random.manual_seed(42)
    m=MultiHeadedAttention(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'selfKeys': None, 'selfValues': None}
    m(a,b,c,layerCache=d,types="self")
    
    torch.random.manual_seed(42)
    m=MultiHeadedAttention2(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])
    d={'self_keys': None, 'self_values': None}
    torch.random.manual_seed(42)
    m=MultiHeadedAttention(2,4)
    a=torch.tensor([[1.0,2,3,1]])
    b=torch.tensor([[1.0,0.5,0.3,1]])
    c=torch.tensor([[1.0,0.5,0.7,1]])

In [7]:
# example above verified 5 cases

class MultiHeadedAttention(nn.Module):
    def __init__(self,numHeads, modelDim, dropout=0.1, isFinalLinear=True):
        assert modelDim%numHeads == 0
        self.dimPerHead =modelDim//numHeads 
        self.modelDim = modelDim
        
        super(MultiHeadedAttention,self).__init__()
        self.numHeads = numHeads
        self.linearKeys = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.linearValues = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.linearQuery = nn.Linear(modelDim,numHeads*self.dimPerHead)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.isFinalLinear = isFinalLinear
        if self.isFinalLinear:
            self.finalLayer = nn.Linear(modelDim,modelDim)
    def forward(self,key,value,query,mask = None,
               layerCache = None, types = None):

        batchSize = key.size(0) # size(0) is batch size for all three vectors
        dimPerHead = self.dimPerHead
        numHeads = self.numHeads

        
        shape = lambda x: x.view(batchSize,-1,numHeads,dimPerHead).transpose(1,2)
        # why contiguous: https://discuss.pytorch.org/t/when-and-why-do-we-use-contiguous/47588
        # we might not need it
        # apparently in the old version, transpose only changes the view of data, but not data itself
        # so to force it to change the data, use contiguous 
        unshape = lambda x: x.transpose(1,2).contiguous().view(batchSize,-1,numHeads*dimPerHead)
        
        # get key, value, and query
        if layerCache is None:
     
            key = self.linearKeys(key)
            value = self.linearValues(value)
            query = self.linearQuery(query)
            key = shape(key)
            value = shape(value)
            
        else:
  
            # Note: this is different from the original code. the original code has 
            # if statement that is already tested, and else statements that will 
            # never get use
            
            # concatenate to variable key" and "value" to their respective caches. 
            if types == "self":
        
                # all query??
                key = self.linearKeys(query)
                value = self.linearValues(query)
                query = self.linearQuery(query)
                key = shape(key)
                value = shape(value)
                device = key.device
                
                itemPairsToUpdate = [[key,"selfKeys"],[value,"selfValues"]]
                
                
                for i in range(2):
                    variable, variableName = itemPairsToUpdate[i]
                    if layerCache[variableName] is not None:
                        itemPairsToUpdate[i][0] = torch.cat((layerCache[variableName].to(device),variable), dim=2)
                    layerCache[variableName] = itemPairsToUpdate[i][0]
                key, value = itemPairsToUpdate[0][0],itemPairsToUpdate[1][0]
                
            elif types == "context":
               
                # if no cache, create the cache, 
                # else copy the cache to the variables 
                query = self.linearQuery(query)
                if layerCache["memoryKeys"] is None:
                    # checked!
                    key = self.linearKeys(key)
                    value = self.linearValues(value)
                    key = shape(key)
                    value = shape(value)

                else:
                    key, value = layerCache["memoryKeys"], layerCache["memoryValues"]
                layerCache["memoryKeys"] = key
                layerCache["memoryValues"] = value

        query = shape(query)
        
        '''
        # possibly for debugging purpose
        keyLength = key.size(2)
        queryLength = query.size(2)
        '''
        
        # compute and scale the scores
        
        # why sqrt?
    
        query = query / math.sqrt(dimPerHead)
    
        scores = torch.matmul(query, key.transpose(2, 3))
        
        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(scores)
            scores = scores.masked_fill(mask,-1e18) # negative infinity 
            
        # apply attention dropout and compute context vectors 
        attention = self.softmax(scores)
        attentionDropout = self.dropout(attention)
        
        if self.isFinalLinear:
            context = unshape(torch.matmul(attentionDropout,value))
            output = self.finalLayer(context)
            return output
        else:
            context = torch.matmul(attentionDropout,value)
            return context


In [8]:
class MultiHeadedAttention2(nn.Module):
    """
    Multi-Head Attention module from
    "Attention is All You Need"
    :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
    Similar to standard `dot` attention but uses
    multiple attention distributions simulataneously
    to select relevant items.
    .. mermaid::
       graph BT
          A[key]
          B[value]
          C[query]
          O[output]
          subgraph Attn
            D[Attn 1]
            E[Attn 2]
            F[Attn N]
          end
          A --> D
          C --> D
          A --> E
          C --> E
          A --> F
          C --> F
          D --> O
          E --> O
          F --> O
          B --> O
    Also includes several additional tricks.
    Args:
       head_count (int): number of parallel heads
       model_dim (int): the dimension of keys/values/queries,
           must be divisible by head_count
       dropout (float): dropout parameter
    """

    def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
        assert model_dim % head_count == 0
        self.dim_per_head = model_dim // head_count
        self.model_dim = model_dim

        super(MultiHeadedAttention2, self).__init__()
        self.head_count = head_count

        self.linear_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        self.linear_query = nn.Linear(model_dim,
                                      head_count * self.dim_per_head)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.use_final_linear = use_final_linear
        if (self.use_final_linear):
            self.final_linear = nn.Linear(model_dim, model_dim)

    def forward(self, key, value, query, mask=None,
                layer_cache=None, type=None, predefined_graph_1=None):
        """
        Compute the context vector and the attention vectors.
        Args:
           key (`FloatTensor`): set of `key_len`
                key vectors `[batch, key_len, dim]`
           value (`FloatTensor`): set of `key_len`
                value vectors `[batch, key_len, dim]`
           query (`FloatTensor`): set of `query_len`
                 query vectors  `[batch, query_len, dim]`
           mask: binary mask indicating which keys have
                 non-zero attention `[batch, query_len, key_len]`
        Returns:
           (`FloatTensor`, `FloatTensor`) :
           * output context vectors `[batch, query_len, dim]`
           * one of the attention vectors `[batch, query_len, key_len]`
        """

        # CHECKS
        # batch, k_len, d = key.size()
        # batch_, k_len_, d_ = value.size()
        # aeq(batch, batch_)
        # aeq(k_len, k_len_)
        # aeq(d, d_)
        # batch_, q_len, d_ = query.size()
        # aeq(batch, batch_)
        # aeq(d, d_)
        # aeq(self.model_dim % 8, 0)
        # if mask is not None:
        #    batch_, q_len_, k_len_ = mask.size()
        #    aeq(batch_, batch)
        #    aeq(k_len_, k_len)
        #    aeq(q_len_ == q_len)
        # END CHECKS

        batch_size = key.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count
        key_len = key.size(1)
        query_len = query.size(1)

        def shape(x):
            """  projection """
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            """  compute context """
            return x.transpose(1, 2).contiguous() \
                .view(batch_size, -1, head_count * dim_per_head)

        # 1) Project key, value, and query.
        if layer_cache is not None:
            if type == "self":
                query, key, value = self.linear_query(query), \
                                    self.linear_keys(query), \
                                    self.linear_values(query)

                key = shape(key)
                value = shape(value)

                if layer_cache is not None:
                    device = key.device
                    if layer_cache["self_keys"] is not None:
                        key = torch.cat(
                            (layer_cache["self_keys"].to(device), key),
                            dim=2)
                    if layer_cache["self_values"] is not None:
                        value = torch.cat(
                            (layer_cache["self_values"].to(device), value),
                            dim=2)
                    layer_cache["self_keys"] = key
                    layer_cache["self_values"] = value
            elif type == "context":
                query = self.linear_query(query)
                if layer_cache is not None:
                    if layer_cache["memory_keys"] is None:
                        key, value = self.linear_keys(key), \
                                     self.linear_values(value)
                        key = shape(key)
                        value = shape(value)
                    else:
                        key, value = layer_cache["memory_keys"], \
                                     layer_cache["memory_values"]
                    layer_cache["memory_keys"] = key
                    layer_cache["memory_values"] = value
                else:
                    key, value = self.linear_keys(key), \
                                 self.linear_values(value)
                    key = shape(key)
                    value = shape(value)
        else:
            key = self.linear_keys(key)
            value = self.linear_values(value)
            query = self.linear_query(query)
            key = shape(key)
            value = shape(value)

        query = shape(query)

        key_len = key.size(2)
        query_len = query.size(2)

        # 2) Calculate and scale scores.
        query = query / math.sqrt(dim_per_head)
        scores = torch.matmul(query, key.transpose(2, 3))

        if mask is not None:
            mask = mask.unsqueeze(1).expand_as(scores)
            
            scores = scores.masked_fill(mask, -1e18)

        # 3) Apply attention dropout and compute context vectors.

        attn = self.softmax(scores)

        if (not predefined_graph_1 is None):
            attn_masked = attn[:, -1] * predefined_graph_1
            attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9)

            attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1)

        drop_attn = self.dropout(attn)
        if (self.use_final_linear):
            context = unshape(torch.matmul(drop_attn, value))
            output = self.final_linear(context)
            return output
        else:
            context = torch.matmul(drop_attn, value)
            return context

In [9]:
class PositionwiseFeedForward2(nn.Module):
    """ A two-layer Feed-Forward-Network with residual layer norm.
    Args:
        d_model (int): the size of input for the first-layer of the FFN.
        d_ff (int): the hidden layer size of the second-layer
            of the FNN.
        dropout (float): dropout probability in :math:`[0, 1)`.
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward2, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.actv = gelu
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)

    def forward(self, x):
        inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
        output = self.dropout_2(self.w_2(inter))
        return output + x
def gelu(x):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [10]:
'''
torch.random.manual_seed(42)
t = TransformerEncoderLayer2(4,2,2,.1)
a=torch.tensor([[1.0,2,3,1],[1.0,2,3,1]])
b=torch.tensor([[1.0,0.5,0.3,1],[1.0,0.5,0.3,1]])
c=torch.tensor([[1.0,0.5,0.7,1]])
d=torch.tensor([[True]])
t(0,a,b,d)
torch.random.manual_seed(42)
t = TransformerEncoderLayer(2,4,2,.1)
a=torch.tensor([[1.0,2,3,1],[1.0,2,3,1]])
b=torch.tensor([[1.0,0.5,0.3,1],[1.0,0.5,0.3,1]])
c=torch.tensor([[1.0,0.5,0.7,1]])
d=torch.tensor([[True]])
t(0,a,b,d)
'''
class TransformerEncoderLayer(nn.Module):
    def __init__(self, numHeads, modelDim, feedForwardDim, dropout):
        super(TransformerEncoderLayer, self).__init__()
        self.selfAttention = MultiHeadedAttention(numHeads,modelDim,dropout=dropout)
        self.feedForward = PositionwiseFeedForward(modelDim,feedForwardDim,dropout)
        # output = (input - Expectation[input])/sqrt(variance(input)+eps)*gamma+beta
        # gamma and beta learnable 
        # normalizing over the last dim
        self.layerNorm = nn.LayerNorm(modelDim,eps=1e-6)
        #dropout is probability of dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,iteration,query,inputs,mask):
        # if it is not the first iteration, then normalize the layer 
        inputsNorm = self.layerNorm(inputs) if iteration != 0 else inputs 
        # keep dim 1 a singleton
        mask=mask.unsqueeze(1)
        # why?
        contextEncoding = self.selfAttention(inputsNorm,inputsNorm,inputsNorm,mask=mask)
        # why do we add input back?
        contextEncodingWDropout = self.dropout(contextEncoding) + inputs 
        return self.feedForward(contextEncodingWDropout)

In [11]:
class PositionalEncoding2(nn.Module):

    def __init__(self, dropout, dim, max_len=5000):
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                              -(math.log(10000.0) / dim)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0)
        super(PositionalEncoding2, self).__init__()
        self.register_buffer('pe', pe)
        self.dropout = nn.Dropout(p=dropout)
        self.dim = dim
    def forward(self, emb, step=None):
        emb = emb * math.sqrt(self.dim)
        if (step):
            emb = emb + self.pe[:, step][:, None, :]

        else:
            emb = emb + self.pe[:, :emb.size(1)]
        emb = self.dropout(emb)
        return emb

    def get_emb(self, emb):
        return self.pe[:, :emb.size(1)]


In [None]:
torch.random.manual_seed(42)
t = InterSentencesTransformerEncoderLayer(4,2,2,.1,1)
a=torch.tensor([[[1.0,2,3,1],[1.0,2,3,1]]])
d=torch.tensor([[True]])
t(a,d)

tensor([[[-1.3448,  0.3613,  1.3798, -0.3963],
         [-1.2437, -0.6751,  1.2393,  0.6795]]],
       grad_fn=<NativeLayerNormBackward>) tensor([[[ 0.4179,  0.2501],
         [ 0.4903, -0.0544]]], grad_fn=<AddBackward0>) tensor([[[ 0.2767,  0.1498],
         [ 0.3374, -0.0260]]], grad_fn=<MulBackward0>) tensor([[[ 0.3074,  0.1664],
         [ 0.3749, -0.0289]]], grad_fn=<MulBackward0>)


tensor([[0.4031, 0.3847]], grad_fn=<MulBackward0>)

In [12]:
#inter-sentence transformer focuses on learning relationship between sentences to produce a document level summary
# the input to this transformer is output of Bert, which can be viewed as contextual encoding. 
# modelDim is the output of the Bert's hidden size, or Bert's transformer's model dimension size 
# (model dimension size is the word used in the self hidden is all you need)
class InterSentencesTransformerEncoderLayer(nn.Module):
    def __init__(self, modelDim, feedforwardDim, numHeads, dropout,numInterSentencesLayers=0):
        super(InterSentencesTransformerEncoderLayer, self).__init__()
        self.modelDim = modelDim
        self.numInterSentencesLayers =  numInterSentencesLayers
        self.positionalEmbedding = PositionalEncoding(modelDim,dropout)
        self.interSentencesTransformers = nn.ModuleList(
            [TransformerEncoderLayer(numHeads, modelDim, feedforwardDim, dropout)
             for _ in range(numInterSentencesLayers)])
        self.dropout = nn.Dropout(dropout)
        self.layerNorm = nn.LayerNorm(modelDim, eps=1e-6)
        self.linearLayer = nn.Linear(modelDim, 1, bias=True)
        self.sigmoid = nn.Sigmoid()
    # top_vector in the original code: top vectors? topic vectors?
    # topicVectors is the output of Bert, some sort of contextual embedding
    # we will use contextualEncoding instead of top_vector
    def forward(self,contextualEncoding,mask):
        batchSize, nSentences = contextualEncoding.size(0), contextualEncoding.size(1)
        positionalEmbedding = self.positionalEmbedding.positionalEncoding[:,:nSentences]
        # mask takes [:,:,None] to account for batches?
        # x will be the contextualEncoding undergoing transformer operations
        x = contextualEncoding * mask[:,:,None].float() + positionalEmbedding
        for i in range(self.numInterSentencesLayers):
            x = self.interSentencesTransformers[i](i,x,x,~mask)
        x = self.layerNorm(x)
        sentencesScores = self.sigmoid(self.linearLayer(x))
        sentencesScores = sentencesScores.squeeze(-1)*mask.float()
        return sentencesScores

In [None]:
t=InterSentencesTransformerEncoderLayer(2,4,2,.1,1)
print(t)

InterSentencesTransformerEncoderLayer(
  (positionalEmbedding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (interSentencesTransformers): ModuleList(
    (0): TransformerEncoderLayer(
      (selfAttention): MultiHeadedAttention(
        (linearKeys): Linear(in_features=2, out_features=2, bias=True)
        (linearValues): Linear(in_features=2, out_features=2, bias=True)
        (linearQuery): Linear(in_features=2, out_features=2, bias=True)
        (softmax): Softmax(dim=-1)
        (dropout): Dropout(p=0.1, inplace=False)
        (finalLayer): Linear(in_features=2, out_features=2, bias=True)
      )
      (feedForward): PositionwiseFeedForward(
        (linear1): Linear(in_features=2, out_features=4, bias=True)
        (linear2): Linear(in_features=4, out_features=2, bias=True)
        (layerNorm): LayerNorm((2,), eps=1e-06, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
  

# Training

In [13]:
def build_optim(args, model, checkpoint):
    """ Build optimizer """
    saved_optimizer_state_dict = None
    optim = None

    if args["train_from"] != '':
        print("We made a checkpoint")
        optim = checkpoint['optim']
        saved_optimizer_state_dict = optim.optimizer.state_dict()
    else:
        print("we created an optimizer")
        optim = Optimizer(
            args["optim"], args["lr"], args["max_grad_norm"],
            beta1=args["beta1"], beta2=args["beta2"],
            decay_method=args["decay_method"],
            warmup_steps=args["warmup_steps"])

    optim.set_parameters(list(model.named_parameters()))

    if args["train_from"] != '':
        optim.optimizer.load_state_dict(saved_optimizer_state_dict)
        if args["visible_gpus"] != '-1':
            for state in optim.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()

        if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
            raise RuntimeError(
                "Error: loaded Adam optimizer from existing model" +
                " but optimizer state is empty")

    return optim

In [14]:
from pytorch_pretrained_bert import BertModel, BertConfig

class Bert(nn.Module):
    def __init__(self, temp_dir, load_pretrained_bert, bert_config):
        super(Bert, self).__init__()
        if(load_pretrained_bert):
            self.model = BertModel.from_pretrained('bert-base-uncased', cache_dir=temp_dir)
        else:
            self.model = BertModel(bert_config)

    def forward(self, x, segs, mask):
        encoded_layers, _ = self.model(x, segs, attention_mask =mask)
        contextualEncoding = encoded_layers[-1]
        return contextualEncoding
class Summarizer(nn.Module):
    def __init__(self, args, device, load_pretrained_bert = False, bert_config = None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args["temp_dir"], load_pretrained_bert, bert_config)
        self.encoder = InterSentencesTransformerEncoderLayer(self.bert.model.config.hidden_size, args["ff_size"], args["heads"],
                                                   args["dropout"], args["inter_layers"])
        if args["param_init"] != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args["param_init"], args["param_init"])
        if args["param_init_glorot"]:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
    def load_cp(self, pt):
        self.load_state_dict(pt['model'], strict=True)

    def forward(self, x, segs, clss, mask, mask_cls, sentence_range=None):

        contextualEncoding = self.bert(x, segs, mask)
        sents_vec = contextualEncoding[torch.arange(contextualEncoding.size(0)).unsqueeze(1), clss]
        sents_vec = sents_vec * mask_cls[:, :, None].float()
        sent_scores = self.encoder(sents_vec, mask_cls).squeeze(-1)
        return sent_scores, mask_cls

# Data Builder

In [15]:
def cal_rouge(evaluated_ngrams, reference_ngrams):
    reference_count = len(reference_ngrams)
    evaluated_count = len(evaluated_ngrams)

    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
    overlapping_count = len(overlapping_ngrams)

    if evaluated_count == 0:
        precision = 0.0
    else:
        precision = overlapping_count / evaluated_count

    if reference_count == 0:
        recall = 0.0
    else:
        recall = overlapping_count / reference_count

    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
    return {"f": f1_score, "p": precision, "r": recall}

def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    abstract = sum(abstract_sent_list, [])
    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])

    selected = []
    for s in range(summary_size):
        cur_max_rouge = max_rouge
        cur_id = -1
        for i in range(len(sents)):
            if (i in selected):
                continue
            c = selected + [i]
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))
            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
            rouge_score = rouge_1 + rouge_2
            if rouge_score > cur_max_rouge:
                cur_max_rouge = rouge_score
                cur_id = i
        if (cur_id == -1):
            return selected
        selected.append(cur_id)
        max_rouge = cur_max_rouge

    return sorted(selected)

def combination_selection(doc_sent_list, abstract_sent_list, summary_size):
    def _rouge_clean(s):
        return re.sub(r'[^a-zA-Z0-9 ]', '', s)

    max_rouge = 0.0
    max_idx = (0, 0)
    abstract = sum(abstract_sent_list, [])
    abstract = _rouge_clean(' '.join(abstract)).split()
    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
    reference_1grams = _get_word_ngrams(1, [abstract])
    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
    reference_2grams = _get_word_ngrams(2, [abstract])

    impossible_sents = []
    for s in range(summary_size + 1):
        combinations = itertools.combinations([i for i in range(len(sents)) if i not in impossible_sents], s + 1)
        for c in combinations:
            candidates_1 = [evaluated_1grams[idx] for idx in c]
            candidates_1 = set.union(*map(set, candidates_1))
            candidates_2 = [evaluated_2grams[idx] for idx in c]
            candidates_2 = set.union(*map(set, candidates_2))
            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']

            rouge_score = rouge_1 + rouge_2
            if (s == 0 and rouge_score == 0):
                impossible_sents.append(c[0])
            if rouge_score > max_rouge:
                max_idx = c
                max_rouge = rouge_score
    return sorted(list(max_idx))

# Data Loader / BATCH


In [16]:

class Batch(object):
    def _pad(self, data, pad_id, width=-1):
        if (width == -1):
            width = max(len(d) for d in data)
        rtn_data = [d + [pad_id] * (width - len(d)) for d in data]
        return rtn_data

    def __init__(self, data=None, device=None,  is_test=False):
        """Create a Batch from a list of examples."""
        if data is not None and data != []:
            self.bad_batch = False
            self.batch_size = len(data)
            pre_src = [x[0] for x in data]
            pre_labels = [x[1] for x in data]
            pre_segs = [x[2] for x in data]
            pre_clss = [x[3] for x in data]

            src = torch.tensor(self._pad(pre_src, 0))

            labels = torch.tensor(self._pad(pre_labels, 0))
            segs = torch.tensor(self._pad(pre_segs, 0))
            mask = ~(src == 0)

            clss = torch.tensor(self._pad(pre_clss, -1))
            mask_cls = ~(clss == -1)
            clss[clss == -1] = 0

            setattr(self, 'clss', clss.to(device))
            setattr(self, 'mask_cls', mask_cls.to(device))
            setattr(self, 'src', src.to(device))
            setattr(self, 'labels', labels.to(device))
            setattr(self, 'segs', segs.to(device))
            setattr(self, 'mask', mask.to(device))

            if (is_test):
                src_str = [x[-2] for x in data]
                setattr(self, 'src_str', src_str)
                tgt_str = [x[-1] for x in data]
                setattr(self, 'tgt_str', tgt_str)
        else:
          self.bad_batch = True

    def __len__(self):
        return self.batch_size


def batch(data, batch_size):
    """Yield elements from data in chunks of batch_size."""
    minibatch, size_so_far = [], 0
    for ex in data:
        minibatch.append(ex)
        size_so_far = simple_batch_size_fn(ex, len(minibatch))
        if size_so_far == batch_size:
            yield minibatch
            minibatch, size_so_far = [], 0
        elif size_so_far > batch_size:
            yield minibatch[:-1]
            minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1)
    if minibatch:
        yield minibatch


def load_dataset(args, corpus_type, shuffle, bert_data_path="bert_data/", pre="cnndm"):
    """
    Dataset generator. Don't do extra stuff here, like printing,
    because they will be postponed to the first loading time.
    Args:
        corpus_type: 'train' or 'valid'
    Returns:
        A list of dataset, the dataset(s) are lazily loaded.
    """
    assert corpus_type in ["train", "valid", "test"]

    def _lazy_dataset_loader(pt_file, corpus_type):
        dataset = torch.load(pt_file)
        print('Loading %s dataset from %s, number of examples: %d' %
                    (corpus_type, pt_file, len(dataset)))
        return dataset

    # Sort the glob output by file name (by increasing indexes).
    allFiles = os.listdir(bert_data_path)
    pts = [bert_data_path+file for file in allFiles if file[-3:] == '.pt']
    if pts:
        if (shuffle):
            random.shuffle(pts)

        for pt in pts:
            print(pt)
            yield _lazy_dataset_loader(pt, corpus_type)
    else:
        # Only one inputters.*Dataset, simple!
        pt = bert_data_path + pre + '.' + corpus_type + '.pt'
        yield _lazy_dataset_loader(pt, corpus_type)


def simple_batch_size_fn(new, count):
    src, labels = new[0], new[1]
    global max_n_sents, max_n_tokens, max_size
    if count == 1:
        max_size = 0
        max_n_sents=0
        max_n_tokens=0
    max_n_sents = max(max_n_sents, len(src))
    max_size = max(max_size, max_n_sents)
    src_elements = count * max_size
    return src_elements

# Data Loader Class + Iterators

In [17]:
class Dataloader(object):
    def __init__(self, args, datasets,  batch_size,
                 device, shuffle, is_test):
        self.args = args
        self.datasets = datasets
        self.batch_size = batch_size
        self.device = device
        self.shuffle = shuffle
        self.is_test = is_test
        self.cur_iter = self._next_dataset_iterator(datasets)

        assert self.cur_iter is not None

    def __iter__(self):
        dataset_iter = (d for d in self.datasets)
        while self.cur_iter is not None:
            for batch in self.cur_iter:
                yield batch
            self.cur_iter = self._next_dataset_iterator(dataset_iter)


    def _next_dataset_iterator(self, dataset_iter):
        try:
            # Drop the current dataset for decreasing memory
            if hasattr(self, "cur_dataset"):
                self.cur_dataset = None
                gc.collect()
                del self.cur_dataset
                gc.collect()

            self.cur_dataset = next(dataset_iter)
        except StopIteration:
            return None

        return DataIterator(args = self.args,
            dataset=self.cur_dataset,  batch_size=self.batch_size,
            device=self.device, shuffle=self.shuffle, is_test=self.is_test)


class DataIterator(object):
    def __init__(self, args, dataset,  batch_size,  device=None, is_test=False,
                 shuffle=True):
        self.args = args
        self.batch_size, self.is_test, self.dataset = batch_size, is_test, dataset
        self.iterations = 0
        self.device = device
        self.shuffle = shuffle

        self.sort_key = lambda x: len(x[1])

        self._iterations_this_epoch = 0

    def data(self):
        if self.shuffle:
            random.shuffle(self.dataset)
        xs = self.dataset
        return xs


    def preprocess(self, ex, is_test):
        src = ex['src']
        if('labels' in ex):
            labels = ex['labels']
        else:
            labels = ex['src_sent_labels']

        segs = ex['segs']
        if(not self.args["use_interval"]):
            segs=[0]*len(segs)
        clss = ex['clss']
        src_txt = ex['src_txt']
        tgt_txt = ex['tgt_txt']

        if(is_test):
            return src,labels,segs, clss, src_txt, tgt_txt
        else:
            return src,labels,segs, clss

    def batch_buffer(self, data, batch_size):
        minibatch, size_so_far = [], 0
        for ex in data:
            if(len(ex['src'])==0):
                continue
            ex = self.preprocess(ex, self.is_test)
            if(ex is None):
                continue
            minibatch.append(ex)
            size_so_far = simple_batch_size_fn(ex, len(minibatch))
            if size_so_far == batch_size:
                yield minibatch
                minibatch, size_so_far = [], 0
            elif size_so_far > batch_size:
                yield minibatch[:-1]
                minibatch, size_so_far = minibatch[-1:], simple_batch_size_fn(ex, 1)
        if minibatch:
            yield minibatch

    def create_batches(self):
        """ Create batches """
        data = self.data()
        for buffer in self.batch_buffer(data, self.batch_size * 50):

            p_batch = sorted(buffer, key=lambda x: len(x[3]))
            p_batch = batch(p_batch, self.batch_size)

            p_batch = list(p_batch)
            if (self.shuffle):
                random.shuffle(p_batch)
            for b in p_batch:
                yield b

    def __iter__(self):
        while True:
            self.batches = self.create_batches()
            for idx, minibatch in enumerate(self.batches):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                batch = Batch(minibatch, self.device, self.is_test)

                yield batch
            return

# Optimizer

In [18]:
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

""" Optimizers class """
def use_gpu(opt):
    """
    Creates a boolean if gpu used
    """
    return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \
           (hasattr(opt, 'gpu') and opt.gpu > -1)

def build_optim2(model, opt, checkpoint):
    """ Build optimizer """
    saved_optimizer_state_dict = None

    if opt.train_from:
        optim = checkpoint['optim']
        # We need to save a copy of optim.optimizer.state_dict() for setting
        # the, optimizer state later on in Stage 2 in this method, since
        # the method optim.set_parameters(model.parameters()) will overwrite
        # optim.optimizer, and with ith the values stored in
        # optim.optimizer.state_dict()
        saved_optimizer_state_dict = optim.optimizer.state_dict()
    else:
        optim = Optimizer(
            opt.optim, opt.learning_rate, opt.max_grad_norm,
            lr_decay=opt.learning_rate_decay,
            start_decay_steps=opt.start_decay_steps,
            decay_steps=opt.decay_steps,
            beta1=opt.adam_beta1,
            beta2=opt.adam_beta2,
            adagrad_accum=opt.adagrad_accumulator_init,
            decay_method=opt.decay_method,
            warmup_steps=opt.warmup_steps)

    # Stage 1:
    # Essentially optim.set_parameters (re-)creates and optimizer using
    # model.paramters() as parameters that will be stored in the
    # optim.optimizer.param_groups field of the torch optimizer class.
    # Importantly, this method does not yet load the optimizer state, as
    # essentially it builds a new optimizer with empty optimizer state and
    # parameters from the model.
    optim.set_parameters(model.named_parameters())

    if opt.train_from:
        # Stage 2: In this stage, which is only performed when loading an
        # optimizer from a checkpoint, we load the saved_optimizer_state_dict
        # into the re-created optimizer, to set the optim.optimizer.state
        # field, which was previously empty. For this, we use the optimizer
        # state saved in the "saved_optimizer_state_dict" variable for
        # this purpose.
        # See also: https://github.com/pytorch/pytorch/issues/2830
        optim.optimizer.load_state_dict(saved_optimizer_state_dict)
        # Convert back the state values to cuda type if applicable
        if use_gpu(opt):
            for state in optim.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()

        # We want to make sure that indeed we have a non-empty optimizer state
        # when we loaded an existing model. This should be at least the case
        # for Adam, which saves "exp_avg" and "exp_avg_sq" state
        # (Exponential moving average of gradient and squared gradient values)
        if (optim.method == 'adam') and (len(optim.optimizer.state) < 1):
            raise RuntimeError(
                "Error: loaded Adam optimizer from existing model" +
                " but optimizer state is empty")

    return optim


class MultipleOptimizer(object):
    """ Implement multiple optimizers needed for sparse adam """

    def __init__(self, op):
        """ ? """
        self.optimizers = op

    def zero_grad(self):
        """ ? """
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        """ ? """
        for op in self.optimizers:
            op.step()

    @property
    def state(self):
        """ ? """
        return {k: v for op in self.optimizers for k, v in op.state.items()}

    def state_dict(self):
        """ ? """
        return [op.state_dict() for op in self.optimizers]

    def load_state_dict(self, state_dicts):
        """ ? """
        assert len(state_dicts) == len(self.optimizers)
        for i in range(len(state_dicts)):
            self.optimizers[i].load_state_dict(state_dicts[i])


class Optimizer(object):
    """
    Controller class for optimization. Mostly a thin
    wrapper for `optim`, but also useful for implementing
    rate scheduling beyond what is currently available.
    Also implements necessary methods for training RNNs such
    as grad manipulations.
    Args:
      method (:obj:`str`): one of [sgd, adagrad, adadelta, adam]
      lr (float): learning rate
      lr_decay (float, optional): learning rate decay multiplier
      start_decay_steps (int, optional): step to start learning rate decay
      beta1, beta2 (float, optional): parameters for adam
      adagrad_accum (float, optional): initialization parameter for adagrad
      decay_method (str, option): custom decay options
      warmup_steps (int, option): parameter for `noam` decay
    We use the default parameters for Adam that are suggested by
    the original paper https://arxiv.org/pdf/1412.6980.pdf
    These values are also used by other established implementations,
    e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
    https://keras.io/optimizers/
    Recently there are slightly different values used in the paper
    "Attention is all you need"
    https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
    was used there however, beta2=0.999 is still arguably the more
    established value, so we use that here as well
    """

    def __init__(self, method, learning_rate, max_grad_norm,
                 lr_decay=1, start_decay_steps=None, decay_steps=None,
                 beta1=0.9, beta2=0.999,
                 adagrad_accum=0.0,
                 decay_method=None,
                 warmup_steps=4000
                 ):
        self.last_ppl = None
        self.learning_rate = learning_rate
        self.original_lr = learning_rate
        self.max_grad_norm = max_grad_norm
        self.method = method
        self.lr_decay = lr_decay
        self.start_decay_steps = start_decay_steps
        self.decay_steps = decay_steps
        self.start_decay = False
        self._step = 0
        self.betas = [beta1, beta2]
        self.adagrad_accum = adagrad_accum
        self.decay_method = decay_method
        self.warmup_steps = warmup_steps

    def set_parameters(self, params):
        """ ? """
        self.params = []
        self.sparse_params = []
        for k, p in params:
            if p.requires_grad:
                if self.method != 'sparseadam' or "embed" not in k:
                    self.params.append(p)
                else:
                    self.sparse_params.append(p)
        if self.method == 'sgd':
            self.optimizer = optim.SGD(self.params, lr=self.learning_rate)
        elif self.method == 'adagrad':
            self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate)
            for group in self.optimizer.param_groups:
                for p in group['params']:
                    self.optimizer.state[p]['sum'] = self.optimizer\
                        .state[p]['sum'].fill_(self.adagrad_accum)
        elif self.method == 'adadelta':
            self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate)
        elif self.method == 'adam':
            self.optimizer = optim.Adam(self.params, lr=self.learning_rate,
                                        betas=self.betas, eps=1e-9)
        elif self.method == 'sparseadam':
            self.optimizer = MultipleOptimizer(
                [optim.Adam(self.params, lr=self.learning_rate,
                            betas=self.betas, eps=1e-8),
                 optim.SparseAdam(self.sparse_params, lr=self.learning_rate,
                                  betas=self.betas, eps=1e-8)])
        else:
            raise RuntimeError("Invalid optim method: " + self.method)

    def _set_rate(self, learning_rate):
        self.learning_rate = learning_rate
        if self.method != 'sparseadam':
            self.optimizer.param_groups[0]['lr'] = self.learning_rate
        else:
            for op in self.optimizer.optimizers:
                op.param_groups[0]['lr'] = self.learning_rate

    def step(self):
        """Update the model parameters based on current gradients.
        Optionally, will employ gradient modification or update learning
        rate.
        """
        self._step += 1

        # Decay method used in tensor2tensor.
        if self.decay_method == "noam":
            self._set_rate(
                self.original_lr *

                 min(self._step ** (-0.5),
                     self._step * self.warmup_steps**(-1.5)))

            # self._set_rate(self.original_lr *self.model_size ** (-0.5) *min(1.0, self._step / self.warmup_steps)*max(self._step, self.warmup_steps)**(-0.5))
        # Decay based on start_decay_steps every decay_steps
        else:
            if ((self.start_decay_steps is not None) and (
                     self._step >= self.start_decay_steps)):
                self.start_decay = True
            if self.start_decay:
                if ((self._step - self.start_decay_steps)
                   % self.decay_steps == 0):
                    self.learning_rate = self.learning_rate * self.lr_decay

        if self.method != 'sparseadam':
            self.optimizer.param_groups[0]['lr'] = self.learning_rate

        if self.max_grad_norm:
            clip_grad_norm_(self.params, self.max_grad_norm)
        self.optimizer.step()


# Eval

In [19]:
class Rouge155(object):
    """
    This is a wrapper for the ROUGE 1.5.5 summary evaluation package.
    This class is designed to simplify the evaluation process by:
        1) Converting summaries into a format ROUGE understands.
        2) Generating the ROUGE configuration file automatically based
            on filename patterns.
    This class can be used within Python like this:
    rouge = Rouge155()
    rouge.system_dir = 'test/systems'
    rouge.model_dir = 'test/models'
    # The system filename pattern should contain one group that
    # matches the document ID.
    rouge.system_filename_pattern = 'SL.P.10.R.11.SL062003-(\d+).html'
    # The model filename pattern has '#ID#' as a placeholder for the
    # document ID. If there are multiple model summaries, pyrouge
    # will use the provided regex to automatically match them with
    # the corresponding system summary. Here, [A-Z] matches
    # multiple model summaries for a given #ID#.
    rouge.model_filename_pattern = 'SL.P.10.R.[A-Z].SL062003-#ID#.html'
    rouge_output = rouge.evaluate()
    print(rouge_output)
    output_dict = rouge.output_to_dict(rouge_ouput)
    print(output_dict)
    ->    {'rouge_1_f_score': 0.95652,
         'rouge_1_f_score_cb': 0.95652,
         'rouge_1_f_score_ce': 0.95652,
         'rouge_1_precision': 0.95652,
        [...]
    To evaluate multiple systems:
        rouge = Rouge155()
        rouge.system_dir = '/PATH/TO/systems'
        rouge.model_dir = 'PATH/TO/models'
        for system_id in ['id1', 'id2', 'id3']:
            rouge.system_filename_pattern = \
                'SL.P/.10.R.{}.SL062003-(\d+).html'.format(system_id)
            rouge.model_filename_pattern = \
                'SL.P.10.R.[A-Z].SL062003-#ID#.html'
            rouge_output = rouge.evaluate(system_id)
            print(rouge_output)
    """

    def __init__(self, rouge_dir=None, rouge_args=None, temp_dir = None):
        """
        Create a Rouge155 object.
            rouge_dir:  Directory containing Rouge-1.5.5.pl
            rouge_args: Arguments to pass through to ROUGE if you
                        don't want to use the default pyrouge
                        arguments.
        """
        self.temp_dir=temp_dir
        self.log = log.get_global_console_logger()
        self.__set_dir_properties()
        self._config_file = None
        self._settings_file = self.__get_config_path()
        self.__set_rouge_dir(rouge_dir)
        self.args = self.__clean_rouge_args(rouge_args)
        self._system_filename_pattern = None
        self._model_filename_pattern = None

    def save_home_dir(self):
        config = ConfigParser()
        section = 'pyrouge settings'
        config.add_section(section)
        config.set(section, 'home_dir', self._home_dir)
        with open(self._settings_file, 'w') as f:
            config.write(f)
        self.log.info("Set ROUGE home directory to {}.".format(self._home_dir))

    @property
    def settings_file(self):
        """
        Path of the setttings file, which stores the ROUGE home dir.
        """
        return self._settings_file

    @property
    def bin_path(self):
        """
        The full path of the ROUGE binary (although it's technically
        a script), i.e. rouge_home_dir/ROUGE-1.5.5.pl
        """
        if self._bin_path is None:
            raise Exception(
                "ROUGE path not set. Please set the ROUGE home directory "
                "and ensure that ROUGE-1.5.5.pl exists in it.")
        return self._bin_path

    @property
    def system_filename_pattern(self):
        """
        The regular expression pattern for matching system summary
        filenames. The regex string.
        E.g. "SL.P.10.R.11.SL062003-(\d+).html" will match the system
        filenames in the SPL2003/system folder of the ROUGE SPL example
        in the "sample-test" folder.
        Currently, there is no support for multiple systems.
        """
        return self._system_filename_pattern

    @system_filename_pattern.setter
    def system_filename_pattern(self, pattern):
        self._system_filename_pattern = pattern

    @property
    def model_filename_pattern(self):
        """
        The regular expression pattern for matching model summary
        filenames. The pattern needs to contain the string "#ID#",
        which is a placeholder for the document ID.
        E.g. "SL.P.10.R.[A-Z].SL062003-#ID#.html" will match the model
        filenames in the SPL2003/system folder of the ROUGE SPL
        example in the "sample-test" folder.
        "#ID#" is a placeholder for the document ID which has been
        matched by the "(\d+)" part of the system filename pattern.
        The different model summaries for a given document ID are
        matched by the "[A-Z]" part.
        """
        return self._model_filename_pattern

    @model_filename_pattern.setter
    def model_filename_pattern(self, pattern):
        self._model_filename_pattern = pattern

    @property
    def config_file(self):
        return self._config_file

    @config_file.setter
    def config_file(self, path):
        config_dir, _ = os.path.split(path)
        verify_dir(config_dir, "configuration file")
        self._config_file = path

    def split_sentences(self):
        """
        ROUGE requires texts split into sentences. In case the texts
        are not already split, this method can be used.
        """
        from pyrouge.utils.sentence_splitter import PunktSentenceSplitter
        self.log.info("Splitting sentences.")
        ss = PunktSentenceSplitter()
        sent_split_to_string = lambda s: "\n".join(ss.split(s))
        process_func = partial(
            DirectoryProcessor.process, function=sent_split_to_string)
        self.__process_summaries(process_func)

    @staticmethod
    def convert_summaries_to_rouge_format(input_dir, output_dir):
        """
        Convert all files in input_dir into a format ROUGE understands
        and saves the files to output_dir. The input files are assumed
        to be plain text with one sentence per line.
            input_dir:  Path of directory containing the input files.
            output_dir: Path of directory in which the converted files
                        will be saved.
        """
        DirectoryProcessor.process(
            input_dir, output_dir, Rouge155.convert_text_to_rouge_format)

    @staticmethod
    def convert_text_to_rouge_format(text, title="dummy title"):
        """
        Convert a text to a format ROUGE understands. The text is
        assumed to contain one sentence per line.
            text:   The text to convert, containg one sentence per line.
            title:  Optional title for the text. The title will appear
                    in the converted file, but doesn't seem to have
                    any other relevance.
        Returns: The converted text as string.
        """
        # sentences = text.split("\n")
        sentences = text.split("<q>")
        sent_elems = [
            "<a name=\"{i}\">[{i}]</a> <a href=\"#{i}\" id={i}>"
            "{text}</a>".format(i=i, text=sent)
            for i, sent in enumerate(sentences, start=1)]
        html = """<html>
<head>
<title>{title}</title>
</head>
<body bgcolor="white">
{elems}
</body>
</html>""".format(title=title, elems="\n".join(sent_elems))

        return html

    @staticmethod
    def write_config_static(system_dir, system_filename_pattern,
                            model_dir, model_filename_pattern,
                            config_file_path, system_id=None):
        """
        Write the ROUGE configuration file, which is basically a list
        of system summary files and their corresponding model summary
        files.
        pyrouge uses regular expressions to automatically find the
        matching model summary files for a given system summary file
        (cf. docstrings for system_filename_pattern and
        model_filename_pattern).
            system_dir:                 Path of directory containing
                                        system summaries.
            system_filename_pattern:    Regex string for matching
                                        system summary filenames.
            model_dir:                  Path of directory containing
                                        model summaries.
            model_filename_pattern:     Regex string for matching model
                                        summary filenames.
            config_file_path:           Path of the configuration file.
            system_id:                  Optional system ID string which
                                        will appear in the ROUGE output.
        """
        system_filenames = [f for f in os.listdir(system_dir)]
        system_models_tuples = []

        system_filename_pattern = re.compile(system_filename_pattern)
        for system_filename in sorted(system_filenames):
            match = system_filename_pattern.match(system_filename)
            if match:
                id = match.groups(0)[0]
                model_filenames = [model_filename_pattern.replace('#ID#',id)]
                # model_filenames = Rouge155.__get_model_filenames_for_id(
                #     id, model_dir, model_filename_pattern)
                system_models_tuples.append(
                    (system_filename, sorted(model_filenames)))
        if not system_models_tuples:
            raise Exception(
                "Did not find any files matching the pattern {} "
                "in the system summaries directory {}.".format(
                    system_filename_pattern.pattern, system_dir))

        with codecs.open(config_file_path, 'w', encoding='utf-8') as f:
            f.write('<ROUGE-EVAL version="1.55">')
            for task_id, (system_filename, model_filenames) in enumerate(
                    system_models_tuples, start=1):

                eval_string = Rouge155.__get_eval_string(
                    task_id, system_id,
                    system_dir, system_filename,
                    model_dir, model_filenames)
                f.write(eval_string)
            f.write("</ROUGE-EVAL>")

    def write_config(self, config_file_path=None, system_id=None):
        """
        Write the ROUGE configuration file, which is basically a list
        of system summary files and their matching model summary files.
        This is a non-static version of write_config_file_static().
            config_file_path:   Path of the configuration file.
            system_id:          Optional system ID string which will
                                appear in the ROUGE output.
        """
        if not system_id:
            system_id = 1
        if (not config_file_path) or (not self._config_dir):
            self._config_dir = mkdtemp(dir=self.temp_dir)
            config_filename = "rouge_conf.xml"
        else:
            config_dir, config_filename = os.path.split(config_file_path)
            verify_dir(config_dir, "configuration file")
        self._config_file = os.path.join(self._config_dir, config_filename)
        Rouge155.write_config_static(
            self._system_dir, self._system_filename_pattern,
            self._model_dir, self._model_filename_pattern,
            self._config_file, system_id)
        self.log.info(
            "Written ROUGE configuration to {}".format(self._config_file))

    def evaluate(self, system_id=1, rouge_args=None):
        """
        Run ROUGE to evaluate the system summaries in system_dir against
        the model summaries in model_dir. The summaries are assumed to
        be in the one-sentence-per-line HTML format ROUGE understands.
            system_id:  Optional system ID which will be printed in
                        ROUGE's output.
        Returns: Rouge output as string.
        """
        self.write_config(system_id=system_id)
        options = self.__get_options(rouge_args)
        command = [self._bin_path] + options
        self.log.info(
            "Running ROUGE with command {}".format(" ".join(command)))
        rouge_output = check_output(command).decode("UTF-8")
        return rouge_output

    def convert_and_evaluate(self, system_id=1,
                             split_sentences=False, rouge_args=None):
        """
        Convert plain text summaries to ROUGE format and run ROUGE to
        evaluate the system summaries in system_dir against the model
        summaries in model_dir. Optionally split texts into sentences
        in case they aren't already.
        This is just a convenience method combining
        convert_summaries_to_rouge_format() and evaluate().
            split_sentences:    Optional argument specifying if
                                sentences should be split.
            system_id:          Optional system ID which will be printed
                                in ROUGE's output.
        Returns: ROUGE output as string.
        """
        if split_sentences:
            self.split_sentences()
        self.__write_summaries()
        rouge_output = self.evaluate(system_id, rouge_args)
        return rouge_output

    def output_to_dict(self, output):
        """
        Convert the ROUGE output into python dictionary for further
        processing.
        """
        #0 ROUGE-1 Average_R: 0.02632 (95%-conf.int. 0.02632 - 0.02632)
        pattern = re.compile(
            r"(\d+) (ROUGE-\S+) (Average_\w): (\d.\d+) "
            r"\(95%-conf.int. (\d.\d+) - (\d.\d+)\)")
        results = {}
        for line in output.split("\n"):
            match = pattern.match(line)
            if match:
                sys_id, rouge_type, measure, result, conf_begin, conf_end = \
                    match.groups()
                measure = {
                    'Average_R': 'recall',
                    'Average_P': 'precision',
                    'Average_F': 'f_score'
                    }[measure]
                rouge_type = rouge_type.lower().replace("-", '_')
                key = "{}_{}".format(rouge_type, measure)
                results[key] = float(result)
                results["{}_cb".format(key)] = float(conf_begin)
                results["{}_ce".format(key)] = float(conf_end)
        return results

    ###################################################################
    # Private methods

    def __set_rouge_dir(self, home_dir=None):
        """
        Verfify presence of ROUGE-1.5.5.pl and data folder, and set
        those paths.
        """
        if not home_dir:
            self._home_dir = self.__get_rouge_home_dir_from_settings()
        else:
            self._home_dir = home_dir
            self.save_home_dir()
        self._bin_path = os.path.join(self._home_dir, 'ROUGE-1.5.5.pl')
        self.data_dir = os.path.join(self._home_dir, 'data')
        if not os.path.exists(self._bin_path):
            raise Exception(
                "ROUGE binary not found at {}. Please set the "
                "correct path by running pyrouge_set_rouge_path "
                "/path/to/rouge/home.".format(self._bin_path))

    def __get_rouge_home_dir_from_settings(self):
        config = ConfigParser()
        with open(self._settings_file) as f:
            if hasattr(config, "read_file"):
                config.read_file(f)
            else:
                # use deprecated python 2.x method
                config.readfp(f)
        rouge_home_dir = config.get('pyrouge settings', 'home_dir')
        return rouge_home_dir

    @staticmethod
    def __get_eval_string(
            task_id, system_id,
            system_dir, system_filename,
            model_dir, model_filenames):
        """
        ROUGE can evaluate several system summaries for a given text
        against several model summaries, i.e. there is an m-to-n
        relation between system and model summaries. The system
        summaries are listed in the <PEERS> tag and the model summaries
        in the <MODELS> tag. pyrouge currently only supports one system
        summary per text, i.e. it assumes a 1-to-n relation between
        system and model summaries.
        """
        peer_elems = "<P ID=\"{id}\">{name}</P>".format(
            id=system_id, name=system_filename)

        model_elems = ["<M ID=\"{id}\">{name}</M>".format(
            id=chr(65 + i), name=name)
            for i, name in enumerate(model_filenames)]

        model_elems = "\n\t\t\t".join(model_elems)
        eval_string = """
    <EVAL ID="{task_id}">
        <MODEL-ROOT>{model_root}</MODEL-ROOT>
        <PEER-ROOT>{peer_root}</PEER-ROOT>
        <INPUT-FORMAT TYPE="SEE">
        </INPUT-FORMAT>
        <PEERS>
            {peer_elems}
        </PEERS>
        <MODELS>
            {model_elems}
        </MODELS>
    </EVAL>
""".format(
            task_id=task_id,
            model_root=model_dir, model_elems=model_elems,
            peer_root=system_dir, peer_elems=peer_elems)
        return eval_string

    def __process_summaries(self, process_func):
        """
        Helper method that applies process_func to the files in the
        system and model folders and saves the resulting files to new
        system and model folders.
        """
        temp_dir = mkdtemp(dir=self.temp_dir)
        new_system_dir = os.path.join(temp_dir, "system")
        os.mkdir(new_system_dir)
        new_model_dir = os.path.join(temp_dir, "model")
        os.mkdir(new_model_dir)
        self.log.info(
            "Processing summaries. Saving system files to {} and "
            "model files to {}.".format(new_system_dir, new_model_dir))
        process_func(self._system_dir, new_system_dir)
        process_func(self._model_dir, new_model_dir)
        self._system_dir = new_system_dir
        self._model_dir = new_model_dir

    def __write_summaries(self):
        self.log.info("Writing summaries.")
        self.__process_summaries(self.convert_summaries_to_rouge_format)

    @staticmethod
    def __get_model_filenames_for_id(id, model_dir, model_filenames_pattern):
        pattern = re.compile(model_filenames_pattern.replace('#ID#', id))
        model_filenames = [
            f for f in os.listdir(model_dir) if pattern.match(f)]
        if not model_filenames:
            raise Exception(
                "Could not find any model summaries for the system"
                " summary with ID {}. Specified model filename pattern was: "
                "{}".format(id, model_filenames_pattern))
        return model_filenames

    def __get_options(self, rouge_args=None):
        """
        Get supplied command line arguments for ROUGE or use default
        ones.
        """
        if self.args:
            options = self.args.split()
        elif rouge_args:
            options = rouge_args.split()
        else:
            options = [
                '-e', self._data_dir,
                '-c', 95,
                # '-2',
                # '-1',
                # '-U',
                '-m',
                # '-v',
                '-r', 1000,
                '-n', 2,
                # '-w', 1.2,
                '-a',
                ]
            options = list(map(str, options))




        options = self.__add_config_option(options)
        return options

    def __create_dir_property(self, dir_name, docstring):
        """
        Generate getter and setter for a directory property.
        """
        property_name = "{}_dir".format(dir_name)
        private_name = "_" + property_name
        setattr(self, private_name, None)

        def fget(self):
            return getattr(self, private_name)

        def fset(self, path):
            verify_dir(path, dir_name)
            setattr(self, private_name, path)

        p = property(fget=fget, fset=fset, doc=docstring)
        setattr(self.__class__, property_name, p)

    def __set_dir_properties(self):
        """
        Automatically generate the properties for directories.
        """
        directories = [
            ("home", "The ROUGE home directory."),
            ("data", "The path of the ROUGE 'data' directory."),
            ("system", "Path of the directory containing system summaries."),
            ("model", "Path of the directory containing model summaries."),
            ]
        for (dirname, docstring) in directories:
            self.__create_dir_property(dirname, docstring)

    def __clean_rouge_args(self, rouge_args):
        """
        Remove enclosing quotation marks, if any.
        """
        if not rouge_args:
            return
        quot_mark_pattern = re.compile('"(.+)"')
        match = quot_mark_pattern.match(rouge_args)
        if match:
            cleaned_args = match.group(1)
            return cleaned_args
        else:
            return rouge_args

    def __add_config_option(self, options):
        return options + [self._config_file]

    def __get_config_path(self):
        if platform.system() == "Windows":
            parent_dir = os.getenv("APPDATA")
            config_dir_name = "pyrouge"
        elif os.name == "posix":
            parent_dir = os.path.expanduser("~")
            config_dir_name = ".pyrouge"
        else:
            parent_dir = os.path.dirname(__file__)
            config_dir_name = ""
        config_dir = os.path.join(parent_dir, config_dir_name)
        if not os.path.exists(config_dir):
            os.makedirs(config_dir)
        return os.path.join(config_dir, 'settings.ini')

#Training

In [20]:
def IterFunc(corpus_type='train'):
        return Dataloader(args, load_dataset(args, corpus_type, shuffle=True), args["batch_size"], device,
                                                 shuffle=True, is_test=False)

In [21]:
loss = torch.nn.BCELoss(reduction='none')

# def train(model, optim, trainSteps, trainIterFunc, validIterFunc, savePath, seed=42):
def train(args, model, savePath, seed=42):
    '''
    optim : optimizer
    savePath: where the model of the last will be saved.
    '''
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def trainIterFunc():
        tmp = load_dataset(args, corpus_type='train', shuffle=True)
        return Dataloader(args, tmp, args["batch_size"], device,
                                                 shuffle=True, is_test=False)
    def validIterFunc():
        tmp = load_dataset(args, corpus_type='valid', shuffle=True)
        return Dataloader(args, tmp, args["batch_size"], device,
                                                 shuffle=True, is_test=False)
        
    torch.manual_seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    

    
    model = Summarizer(args, device, load_pretrained_bert=True)
    n_gpu = torch.cuda.current_device()
    
    
    if args["train_from"] != '':
        print('Loading checkpoint from %s' % args["train_from"])
        checkpoint = torch.load(args["train_from"],
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = build_optim(args, model, checkpoint)
    else:
        optim = build_optim(args, model, None)
    
    
    for step in range(optim._step + 1, args["train_steps"]):
        print("step", step)
        trainIter = trainIterFunc()
        # optim.zero_grad()
        trainLoss, trainBatchSize = 0, 0
        validLoss, validBatchSize = 0, 0
        timeStart = time.time()
        
        # training 
        model.train()
        for i, batch in enumerate(trainIter):
            if batch.bad_batch:
              print("This is a bad batch")
              continue
            print("i (batch step)", i)
            labels = batch.labels
            segs = batch.segs
            cls = batch.clss
            mask = batch.mask
            mask_cls = batch.mask_cls
            src = batch.src


            sentenceScores, mask = model(src, segs, cls, mask, mask_cls)
            lossVec = loss(sentenceScores, labels.float())
            lossMag = (lossVec*mask.float()).sum()
            (lossMag/lossMag.numel()).backward()
            optim.step()
            lossMag.detach()
            
            trainLoss += lossMag
            trainBatchSize += batch.batch_size
        
        timeEnded_T = time.time()
        
        # validating 
        model.eval()
        print("starting to validate")
        with torch.no_grad():
            validIter = validIterFunc() 
            for batch in validIter:
                if batch.bad_batch:
                  print("This is a bad batch")
                  continue
                labels = batch.labels
                segs = batch.segs
                cls = batch.clss
                mask = batch.mask
                mask_cls = batch.mask_cls
                src = batch.src
                
                sentenceScores, mask = model(src, segs, cls, mask, mask_cls)
                lossVec = loss(sentenceScores, labels.float())
                lossMag = (lossVec*mask.float()).sum()
                validLoss += lossMag
                validBatchSize += batch.batch_size
        
        timeEnded_V = time.time()
        trainTime = timeStart   - timeEnded_T
        validTime = timeEnded_T - timeEnded_V
        trainLoss /= trainBatchSize
        validLoss /= validBatchSize
        print("Step %s; lr: %7.7f; training loss: %4.2f;" +
             " trained %6.0f sec; valid loss: %4.2f; validated %6.0f sec", step, 
              step, optim.learning_rate, trainLoss, trainTime, validLoss, 
             validTime)
    # saving the model 
    modelStateDict  = model.state_dict()
    # don't have args, check if it's okay when loaded
    checkpoint = {
        'model' : modelStateDict,
        'optim' : optim
    }
    
    checkpointPath = os.path.join(savePath, 'model_step_last.pt')
    print("Saving checkpoint at", checkpointPath)
    torch.save(checkpoint, checkpointPath)
    return checkpoint, checkpointPath
    
            

In [22]:
torch.cuda.empty_cache()
import torch.cuda as cutorch

print(torch.cuda.memory_stats())

OrderedDict([('active.all.allocated', 0), ('active.all.current', 0), ('active.all.freed', 0), ('active.all.peak', 0), ('active.large_pool.allocated', 0), ('active.large_pool.current', 0), ('active.large_pool.freed', 0), ('active.large_pool.peak', 0), ('active.small_pool.allocated', 0), ('active.small_pool.current', 0), ('active.small_pool.freed', 0), ('active.small_pool.peak', 0), ('active_bytes.all.allocated', 0), ('active_bytes.all.current', 0), ('active_bytes.all.freed', 0), ('active_bytes.all.peak', 0), ('active_bytes.large_pool.allocated', 0), ('active_bytes.large_pool.current', 0), ('active_bytes.large_pool.freed', 0), ('active_bytes.large_pool.peak', 0), ('active_bytes.small_pool.allocated', 0), ('active_bytes.small_pool.current', 0), ('active_bytes.small_pool.freed', 0), ('active_bytes.small_pool.peak', 0), ('allocated_bytes.all.allocated', 0), ('allocated_bytes.all.current', 0), ('allocated_bytes.all.freed', 0), ('allocated_bytes.all.peak', 0), ('allocated_bytes.large_pool.all

In [24]:
# python train.py -mode train -encoder classifier -dropout 0.1 
# -bert_data_path ../bert_data/cnndm -model_path ../models/bert_classifier -lr 2e-3
# -visible_gpus 0,1,2  -gpu_ranks 0,1,2 -world_size 3 -report_every 50 -save_checkpoint_steps 1000 
# -batch_size 3000 -decay_method noam -train_steps 50000 -accum_count 2 -log_file ../logs/bert_classifier -use_interval true -warmup_steps 10000

# torch.random.manual_seed(42)
# t = TransformerInterEncoder2(4,2,2,.1,1)
# a=torch.tensor([[[1.0,2,3,1],[1.0,2,3,1]]])
# d=torch.tensor([[True]])
# t(a,d)


argDict = {
    "temp_dir": "tmp",
    "ff_size": 2,
    "heads": 1,
    "dropout": 0.1,
    "inter_layers": 1,
    "param_init": 0,
    "param_init_glorot": False,
    "train_from": "",
    "train_steps":1, 
    "use_interval": True, 
    "batch_size": 2500,
    "lr":1,
    "beta1":.9,
    "beta2":.999,
    "max_grad_norm": 0,
    "visible_gpus": -1, 
    "warmup_steps": 10,
    "decay_method":'',
    "optim": 'adam'
}
device = "cpu" 
savePath = "results/"
model = Summarizer(argDict, device, load_pretrained_bert=True)
blah = train(argDict, model,savePath)
# optim = build_optim(argDict, model, None)
# config = BertConfig.from_json_file("testBert.json")
# summ = Summarizer(argDict, device, bert_config=config)
# model = Summarizer(argDict, device, load_pretrained_bert=True)


# train(model, optim, trainSteps, trainIterFunc, validIterFunc, savePath)


we created an optimizer
Saving checkpoint at results/model_step_last.pt
