In [1]:
import transformers
import numpy as np
import torch

In [2]:
from transformers import  AutoTokenizer,AutoModel,BertModel

In [3]:
token = AutoTokenizer.from_pretrained('./ft_lm/add_token/')

In [4]:
len(token)

21503

In [5]:
model = BertModel.from_pretrained('./ft_lm/add_token/')

Some weights of BertModel were not initialized from the model checkpoint at ./ft_lm/add_token/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
#model: bert+pool

In [7]:
#model

In [8]:
seq = '北京市朝阳区望京首开广场'

In [9]:
token.encode?

In [10]:
token_seq = token(seq,max_length=20,add_special_tokens=True,padding='max_length',truncation=True,return_tensors='pt')

In [11]:
token_seq

{'input_ids': tensor([[ 101, 1266,  776, 2356, 3308, 7345, 1277, 3307,  776, 7674, 2458, 2408,
         1767,  102,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}

In [12]:
token.decode(token(seq)['input_ids'])

'[CLS] 北 京 市 朝 阳 区 望 京 首 开 广 场 [SEP]'

In [13]:
output = model(**token_seq)

In [14]:
output.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [15]:
output['last_hidden_state'][0].shape

torch.Size([20, 768])

In [16]:
output['pooler_output'][0].shape

torch.Size([768])

## 2 memnet

In [31]:
class MemNet(torch.nn.Module):
    def __init__(self,num_geo = 10000,num_en = 500,dim_en=768*2):
        super(MemNet, self).__init__()
        ##mem_key and mem_value
        self.num_geo = num_geo
        self.num_en = num_en
        self.dim_en = dim_en
        self.mem_keys = torch.nn.Parameter(torch.Tensor(num_geo,num_en*dim_en))
        self.mem_vals = torch.nn.Parameter(torch.Tensor(num_geo,num_en*dim_en))
        ##weight init
        self.mem_keys.data.normal_(0,std=0.01)
        self.mem_vals.data.normal_(0,std=0.01)
        
        self.sfm = torch.nn.Softmax(dim=-1)
    
    def forward(self,index_geo,query,topK =2,tem_coe = 0.05,debug= False):
        ##[num_en,dim_en]
        mem_key = self.mem_keys[index_geo].view(self.num_en,-1)
        if debug:
            print('mem_key',mem_key.shape,mem_key)
        ##[num_en,dim_en]
        mem_val = self.mem_vals[index_geo].view(self.num_en,-1)
        if debug:
            print('mem_val',mem_val.shape,mem_val)
        ##[batch_size,dim_en]
        query = query.view(-1,self.dim_en)
        if debug:
            print('query',query.shape)
        ##compute softmax
        ##[batch_size,dim_en] [num_en,dim_en]
        ##print(torch.tensordot(query,mem_key,dims=([1],[1])))
        att_weight = torch.tensordot(query,mem_key,dims=([1],[1]))
        if debug:
            print('before top k att_weight:',att_weight)
        att_weight = self.get_topk_mask(att_weight,query,topK)
        att_weight /= tem_coe
        if debug:
            print('after topk att_weight',att_weight)
        att_weight = self.sfm(att_weight)
        if debug:
            print('after topk att_weight',att_weight)
        #print(att_weight.shape)

        att_value = torch.matmul(att_weight,mem_val)
        
        return att_value
    
    def get_topk_mask(self,att_weight,query,topK):
        topv = torch.min(torch.topk(att_weight,topK,sorted=False).values,dim=-1).values.view(query.shape[0],-1)
        att_weight = att_weight* ((att_weight - topv)>=0)- 10**8* ((att_weight - topv)<0)
        #print('pos',att_weight* ((att_weight - topv)>=0))
        
        return att_weight

In [32]:
#memnet = MemNet(2,3,4)

In [33]:
#h1 = memnet(0,query,2)

In [39]:
class net(torch.nn.Module):
    def __init__(self,num_geo = 2,num_en = 3,dim_en=4):
        super().__init__()
        self.memnet = MemNet(num_geo,num_en,dim_en)
        self.pred = torch.nn.Linear(in_features=dim_en,out_features=2)
        
        
    def forward(self,x,tem_coe = 0.05,debug=False):
        h1 = self.memnet(0,x,2,tem_coe,debug)
        out = self.pred(h1)
        return out,self.memnet.mem_keys,self.memnet.mem_vals
        

In [40]:
query = torch.Tensor([[0.1,0.2,0.3,0.4],[2.5,0.1,0.1,0.2]])
query.shape
label = torch.Tensor([0,1]).long()

In [41]:
model = net()

In [42]:
loss_func = torch.nn.CrossEntropyLoss()
optimize = torch.optim.SGD(model.parameters(),lr=1)

In [53]:
##train
for i in range(1):
    print('mem_keys',model.memnet.mem_keys[0])
    print('mem_vals',model.memnet.mem_vals[0])
    pred,memkey,memval  = model(query,debug=True,tem_coe = 0.005)
    loss_tmp = loss_func(pred,label)
    #print(torch.autograd.grad(loss_tmp,memkey,retain_graph=True))
    #print(torch.autograd.grad(loss_tmp,memval,retain_graph=True))
    optimize.zero_grad()
    #loss_tmp.backward()
    #optimize.step()
    #print(pred)
    print(loss_tmp)

mem_keys tensor([ 0.0574, -0.0093, -0.0049, -0.0051, -0.0050,  0.0152, -0.0014, -0.0197,
        -0.0448, -0.0167,  0.0026,  0.0086], grad_fn=<SelectBackward>)
mem_vals tensor([ 0.0005, -0.0108,  0.0239,  0.0463, -0.0069, -0.0102, -0.0152, -0.0015,
        -0.0273, -0.0172,  0.0145,  0.0346], grad_fn=<SelectBackward>)
mem_key torch.Size([3, 4]) tensor([[ 0.0574, -0.0093, -0.0049, -0.0051],
        [-0.0050,  0.0152, -0.0014, -0.0197],
        [-0.0448, -0.0167,  0.0026,  0.0086]], grad_fn=<ViewBackward>)
mem_val torch.Size([3, 4]) tensor([[ 0.0005, -0.0108,  0.0239,  0.0463],
        [-0.0069, -0.0102, -0.0152, -0.0015],
        [-0.0273, -0.0172,  0.0145,  0.0346]], grad_fn=<ViewBackward>)
query torch.Size([2, 4])
before top k att_weight: tensor([[ 0.0004, -0.0058, -0.0036],
        [ 0.1411, -0.0150, -0.1117]], grad_fn=<ViewBackward>)
after topk att_weight tensor([[ 8.0622e-02, -2.0000e+10, -7.2441e-01],
        [ 2.8213e+01, -3.0057e+00, -2.0000e+10]], grad_fn=<DivBackward0>)
after 

In [None]:
model.memnet.mem_keys[0]

In [None]:
model.memnet.mem_vals[0]

In [None]:
loss_tmp

In [None]:
##ouput grad

In [None]:
model

In [None]:
pred,a = model(query)
loss_tmp = loss_func(pred,label)
torch.autograd.grad(loss_tmp,a)

In [None]:
torch.autograd.grad?

In [None]:
a

In [None]:
0.0142-0.0096

In [28]:
model.to?