In [1]:
##data
##feature
##Model build
##Loss
##init and optimzator
##train 
##eval

# prepare

In [2]:
import torch
import numpy as np
import random
import datetime
import os
from tensorboardX import SummaryWriter
from tqdm import tqdm
import time

In [3]:
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

In [4]:
setup_seed(20)

In [5]:
torch.__version__

'1.10.1+cu102'

In [6]:
dt = datetime.datetime.now()
timestamp = "%s-%s-%s-%s-%s"%(dt.year,dt.month,dt.day,dt.hour,dt.minute)

In [7]:
project_desc ='test_nn_match_model'

In [8]:
file_desc = project_desc+'/'+timestamp

In [9]:
summary_writer = SummaryWriter('./tensorboard_log/'+file_desc,comment='test')

# load data

In [10]:
from datasets import load_from_disk
from torch.utils.data import DataLoader

In [11]:
cache_train = load_from_disk('../cache_data/train/')
cache_test = load_from_disk('../cache_data/test/')

In [12]:
cache_train = cache_train.with_format("torch").remove_columns('sample_id')

In [13]:
cache_test = cache_test.with_format('torch').remove_columns('sample_id')

In [14]:
batch_size =2048

In [15]:
train_loader = DataLoader(dataset = cache_train,batch_size=batch_size,shuffle=True,drop_last = True,num_workers=10)
test_loader = DataLoader(dataset = cache_test,batch_size=batch_size,shuffle=True,drop_last = True,num_workers=10)

In [16]:
batch_step = int(cache_train.num_rows/batch_size)
batch_step

659

# feat_process

In [17]:
feature_name_src = ['sample_key', 'name_len1', 'name_len2',
       'name_val_editdist', 'name_lcs_clr', 'name_lcs_crr', 'name_clr',
       'name_crr', 'adcode_match', 'geo_diffx ', 'geo_diffy', 'tel_match',
       'brand_match', 'cate1', 'cate2', 'cate3', 'cate_index', 'dist',
       'label']

# model part

In [18]:
class match_model(torch.nn.Module):
    def __init__(self,num_of_feat = 10,embed_dim = 764,ad_num = 5,brand_num = 6,tel_num = 6,cate_num = 800,geo_diff_size = 100):
        super(match_model, self).__init__()
        ad_embedding = torch.nn.Embedding(num_embeddings=ad_num,embedding_dim=64)
        brand_embedding = torch.nn.Embedding(num_embeddings=brand_num,embedding_dim=64)
        self.ln1 = torch.nn.Sequential(torch.nn.Linear(num_of_feat,out_features=20),torch.nn.ReLU())
        self.ln2 = torch.nn.Sequential(torch.nn.Linear(20,out_features=20),torch.nn.ELU(),torch.nn.Dropout(p=0.5))
        self.ln3 = torch.nn.Sequential(torch.nn.Linear(20,1),torch.nn.Sigmoid())
        self.post_init()
        
        pass
    def forward(self,feat):
        #assert(17 == feat.shape[1])
        x = self.ln1(feat)
        x = self.ln2(x)
        out = self.ln3(x)
        #out = torch.nn.functional.sigmoid(out)
        return out
        pass

    def post_init(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Linear):
                #print('before',m.weight)
                torch.nn.init.xavier_normal(m.weight)
                #print('after',m.weight)
                if m.bias is not None:
                    m.bias.data.zero_()

In [19]:
#torch.nn.functional.log_softmax?

In [20]:
mm = match_model()



In [21]:
mm(cache_train['norm_feat'][0:2])

tensor([[1.0696e-16],
        [1.0000e+00]], grad_fn=<SigmoidBackward0>)

# loss and optimizer

In [22]:
loss_fn = torch.nn.BCELoss()

In [23]:
##optimize
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{"params": [p for n, p in mm.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
{
    "params": [p for n, p in mm.named_parameters() if any(nd in n for nd in no_decay)],
    "weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW( optimizer_grouped_parameters, 
                               lr=5e-2,betas=(0.9, 0.999),eps=1e-08)


In [24]:
lr_adjust = torch.optim.lr_scheduler.LinearLR(optimizer,total_iters=3,verbose=False)

# train

In [25]:
from sklearn.metrics import f1_score,precision_score,accuracy_score,recall_score

In [26]:
epochs = 3
global_step = 0
best_test_metric = 0
un_up_cnt = 0
threshold = 0.5
continue_train = True

In [27]:
for i in range(epochs):
    if not continue_train:
        break
    beg = time.time()
    print("------%dth epoch------"%(i+1))
   
    train_pred = list()
    train_target = list()
    train_loss = list()
    train_f1 = list()
    train_precision = list()
    train_recall = list()
    for step,mini_train in tqdm(enumerate(train_loader),desc='train-Processing',total=batch_step):
        if not continue_train:
            break
        
        mm.train()
        
        ##get feat and label
        mini_train_feat = mini_train['norm_feat']
        mini_train_label = mini_train['label']
        
        ##forward
        mini_train_pred = mm(mini_train_feat).reshape([-1])
        
        ##backward
        mini_train_loss = loss_fn(mini_train_pred,mini_train_label.float())
        optimizer.zero_grad()
        mini_train_loss.backward()
        ###record by tensorboard
        if(global_step%10==0): 
            summary_writer.add_histogram('ln1.weight.grad',mm.ln1[0].weight.grad,global_step)
            summary_writer.add_histogram('ln2.weight.grad',mm.ln2[0].weight.grad,global_step)
            summary_writer.add_histogram('ln3.weight.grad',mm.ln2[0].weight.grad,global_step)
        ##clip gradient
        torch.nn.utils.clip_grad_norm_(mm.parameters(),1.)
        optimizer.step()
        
        ##eval mini_train
        mini_train_f1 = f1_score(y_true=mini_train_label.detach().numpy(),y_pred=mini_train_pred.detach().numpy()>threshold)
        mini_train_precision = precision_score(y_true=mini_train_label.detach().numpy(),y_pred=mini_train_pred.detach().numpy()>threshold,zero_division=0)
        mini_train_recall = recall_score(y_true=mini_train_label.detach().numpy(),y_pred=mini_train_pred.detach().numpy()>threshold,zero_division=0)
        
        ##collect train data
        train_loss.append(mini_train_loss)
        train_pred.extend(mini_train_pred.detach().numpy())
        train_target.extend(mini_train_label.detach().numpy())
        train_f1.append(mini_train_f1)
        train_precision.append(mini_train_precision)
        train_recall.append(mini_train_recall)
            
        ##record by tensorboard
        if(global_step%10==0): 
            summary_writer.add_scalar('train_batch_loss',mini_train_loss,global_step)
            summary_writer.add_scalar('train_batch_f1',mini_train_f1,global_step)
            summary_writer.add_scalar('train_batch_precision',mini_train_precision,global_step)
            summary_writer.add_scalar('train_batch_recall',mini_train_recall,global_step)
            summary_writer.add_histogram('ln1.weight',mm.ln1[0].weight,global_step)
            summary_writer.add_histogram('ln2.weight',mm.ln2[0].weight,global_step)
            summary_writer.add_histogram('ln3.weight',mm.ln3[0].weight,global_step)
    
        ##tensorboard 变量监视
        if global_step % 100 == 0:
            mm.eval()
            test_pred = list()
            test_target = list()
            test_loss = list()
            test_f1 = list()
            for mini_test in tqdm(test_loader,desc='test-Processing',total=int(cache_test.num_rows/batch_size)):
                ##get mini_test feat and label
                mini_test_feat = mini_test['norm_feat']
                mini_test_label = mini_test['label']
                
                ##forward
                mini_test_pred = mm(mini_test_feat).reshape(-1)
                mini_test_loss = loss_fn(mini_test_pred,mini_test_label.float())
                
                test_loss.append(mini_test_loss)
                test_pred.extend(mini_test_pred.detach().numpy())
                test_target.extend(mini_test_label.detach().numpy())
                
                #break
            #print(test_pred[:100])
            test_pred = np.array(test_pred)
            test_f1 = f1_score(y_true=test_target,y_pred=test_pred>threshold)
            test_precision = precision_score(y_true=test_target,y_pred=test_pred>threshold,zero_division=0)
            test_recall = recall_score(y_true=test_target,y_pred=test_pred>threshold,zero_division=0)
            
            mean_test_loss = sum(test_loss)/len(test_loss)
            summary_writer.add_scalar('test_loss',mean_test_loss,global_step)
            summary_writer.add_scalar('test_f1',test_f1 ,global_step)
            summary_writer.add_scalar('test_precision',test_precision ,global_step)
            summary_writer.add_scalar('test_recall',test_recall ,global_step)
            print('test-loss:%f test-f1:%f test-precison:%f test-recall:%f'%(mean_test_loss,test_f1,test_precision,test_recall))
            if test_f1 >= best_test_metric:
                best_test_metric = test_f1
                un_up_cnt = 0
                torch.save(mm.state_dict(),'./nn_model/'+file_desc+'_'+str(global_step))
                print('/****best_metrics found :\tval: %f  global_step:%d  ****/'%(best_test_metric,global_step))
            else:
                un_up_cnt+=1
                if (un_up_cnt >= 10):
                    continue_train = False
                         
        global_step+=1
        if global_step %100 == 0:
            ##动态学习率调整
            lr_adjust.step()
    
    ##eval train
    mean_train_loss = sum(train_loss)/len(train_loss)
    train_pred = np.array(train_pred)
    train_f1 = f1_score(y_true=train_target,y_pred=train_pred>threshold)
    train_precision = precision_score(y_true=train_target,y_pred=train_pred>threshold)
    train_recall = recall_score(y_true=train_target,y_pred=train_pred>threshold)
    print('train-loss:%f train-f1:%f train-precison:%f trainrecall:%f'%(mean_train_loss,train_f1,train_precision,train_recall))
    summary_writer.add_scalar('train_loss',mini_train_loss,global_step)
    summary_writer.add_scalar('train_f1',train_f1,global_step)
    summary_writer.add_scalar('train_precision',train_precision,global_step)
    summary_writer.add_scalar('train_recall',train_recall,global_step)
    #break
     
    ##record time cost
    end = time.time()
    print("耗时: {:.2f}秒".format(end - beg))
        

------1th epoch------


train-Processing:   0%|          | 0/659 [00:00<?, ?it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing:  25%|██▌       | 1/4 [00:01<00:03,  1.30s/it][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.72it/s][A
train-Processing:   1%|          | 8/659 [00:02<02:47,  3.89it/s]

test-loss:49.839493 test-f1:0.607861 test-precison:0.440905 test-recall:0.978319
/****best_metrics found :	val: 0.607861  global_step:0  ****/


train-Processing:  15%|█▍        | 97/659 [00:11<00:51, 10.85it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.90it/s][A
train-Processing:  16%|█▌        | 105/659 [00:13<01:17,  7.12it/s]

test-loss:0.623753 test-f1:0.086246 test-precison:0.718310 test-recall:0.045877


train-Processing:  30%|███       | 198/659 [00:23<00:55,  8.26it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing:  25%|██▌       | 1/4 [00:01<00:04,  1.33s/it][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.53it/s][A
train-Processing:  32%|███▏      | 208/659 [00:24<01:02,  7.24it/s]

test-loss:0.639449 test-f1:0.084252 test-precison:0.801075 test-recall:0.044464


train-Processing:  46%|████▌     | 300/659 [00:33<00:31, 11.43it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.37it/s][A
train-Processing:  47%|████▋     | 308/659 [00:35<00:51,  6.81it/s]

test-loss:0.636266 test-f1:0.088536 test-precison:0.776119 test-recall:0.046946


train-Processing:  61%|██████    | 400/659 [00:44<00:25, 10.26it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing:  25%|██▌       | 1/4 [00:01<00:04,  1.35s/it][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.61it/s][A
train-Processing:  61%|██████    | 403/659 [00:46<00:54,  4.66it/s]

test-loss:0.640243 test-f1:0.058874 test-precison:0.822581 test-recall:0.030530


train-Processing:  75%|███████▌  | 495/659 [00:55<00:20,  8.04it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.44it/s][A
train-Processing:  77%|███████▋  | 507/659 [00:57<00:19,  7.67it/s]

test-loss:0.642189 test-f1:0.072427 test-precison:0.814103 test-recall:0.037899


train-Processing:  91%|█████████ | 600/659 [01:06<00:05, 11.24it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing:  25%|██▌       | 1/4 [00:01<00:03,  1.31s/it][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.38it/s][A
train-Processing:  92%|█████████▏| 606/659 [01:08<00:09,  5.82it/s]

test-loss:0.637261 test-f1:0.077497 test-precison:0.849057 test-recall:0.040602


train-Processing: 100%|██████████| 659/659 [01:12<00:00,  9.07it/s]


train-loss:2.249038 train-f1:0.357527 train-precison:0.474747 trainrecall:0.286730
耗时: 78.87秒
------2th epoch------


train-Processing:   6%|▌         | 41/659 [00:05<01:11,  8.61it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing:  25%|██▌       | 1/4 [00:01<00:04,  1.53s/it][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.30it/s][A
train-Processing:   8%|▊         | 50/659 [00:08<01:37,  6.25it/s]

test-loss:0.635795 test-f1:0.074605 test-precison:0.872483 test-recall:0.038969


train-Processing:  21%|██        | 140/659 [00:18<00:55,  9.27it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s][A
train-Processing:  23%|██▎       | 150/659 [00:20<01:15,  6.77it/s]

test-loss:0.635965 test-f1:0.077626 test-precison:0.844720 test-recall:0.040682


train-Processing:  36%|███▌      | 238/659 [00:31<00:59,  7.09it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.36it/s][A
train-Processing:  38%|███▊      | 248/659 [00:33<01:02,  6.58it/s]

test-loss:0.635317 test-f1:0.080849 test-precison:0.892405 test-recall:0.042342


train-Processing:  51%|█████▏    | 339/659 [00:44<00:34,  9.36it/s]
test-Processing:   0%|          | 0/4 [00:00<?, ?it/s][A
test-Processing: 100%|██████████| 4/4 [00:01<00:00,  2.09it/s][A
train-Processing:  52%|█████▏    | 342/659 [00:46<00:42,  7.41it/s]


test-loss:0.639590 test-f1:0.086416 test-precison:0.788660 test-recall:0.045713
train-loss:0.649898 train-f1:0.140151 train-precison:0.570006 trainrecall:0.079898
耗时: 49.32秒


# test_eval

In [28]:
mm.eval()
test_pred = list()
test_target = list()
test_loss = list()
test_f1 = list()
for mini_test in tqdm(test_loader,desc='test-Processing',total=int(cache_test.num_rows/batch_size)):
    ##get mini_test feat and label
    mini_test_feat = mini_test['norm_feat']
    mini_test_label = mini_test['label']

    ##forward
    mini_test_pred = mm(mini_test_feat).reshape(-1)
    mini_test_loss = loss_fn(mini_test_pred,mini_test_label.float())

    test_loss.append(mini_test_loss)
    test_pred.extend(mini_test_pred.detach().numpy())
    test_target.extend(mini_test_label.detach().numpy())

    #break
#print(test_pred[:100])


test-Processing: 100%|██████████| 4/4 [00:01<00:00,  3.28it/s]


In [29]:
test_pred = np.array(test_pred)
test_f1 = f1_score(y_true=test_target,y_pred=test_pred>threshold)
test_precision = precision_score(y_true=test_target,y_pred=test_pred>threshold)
test_recall = recall_score(y_true=test_target,y_pred=test_pred>threshold)
mean_test_loss = sum(test_loss)/len(test_loss)
print('test-f1:%f test-precison:%f test-recall:%f'%(test_f1,test_precision,test_recall))

test-f1:0.086343 test-precison:0.772727 test-recall:0.045726


In [30]:
test_pred[:10]

array([0.4281078 , 0.4281078 , 0.82776827, 0.4281078 , 0.03100571,
       0.4281078 , 0.4281078 , 0.9791509 , 0.4281078 , 0.4281078 ],
      dtype=float32)

In [31]:
test_target[:10]

[0, 1, 1, 0, 0, 1, 0, 1, 0, 1]

# save &load

In [32]:
mm.state_dict()['ln3.0.weight']

tensor([[ 0.0054,  0.0213,  0.0629, -0.0242,  0.0022, -0.0013, -0.0030, -0.0241,
          0.0103,  0.0008,  0.0088, -0.0070, -0.0268, -0.0459, -0.0306, -0.0305,
         -0.0030, -0.0116,  0.0810, -0.0016]])

In [33]:
torch.save(mm.state_dict(),'./nn_model/'+file_desc )

In [34]:
test_mm = match_model()
load_test= torch.load('./nn_model/'+file_desc )
test_mm.load_state_dict(load_test)



<All keys matched successfully>

In [35]:
test_mm.state_dict()['ln3.0.weight']

tensor([[ 0.0054,  0.0213,  0.0629, -0.0242,  0.0022, -0.0013, -0.0030, -0.0241,
          0.0103,  0.0008,  0.0088, -0.0070, -0.0268, -0.0459, -0.0306, -0.0305,
         -0.0030, -0.0116,  0.0810, -0.0016]])

In [36]:
load_test

OrderedDict([('ln1.0.weight',
              tensor([[ 2.6446e-01,  8.7113e-01,  1.2907e+00, -9.4492e-01, -1.3192e+00,
                        1.4850e+00, -1.6454e+00,  1.4339e+00, -4.7500e+00,  1.7237e+00],
                      [-3.9284e-01, -5.2317e-01, -9.4101e-02, -4.1987e-01, -8.8168e-01,
                       -7.1515e-01, -2.9171e-01, -3.0335e-01, -1.4873e-02, -3.1604e-01],
                      [-2.3079e-01, -1.4834e-01,  1.7564e-01, -3.2870e-01, -2.7424e-01,
                       -2.5367e-01,  9.4351e-02, -4.0524e-02, -5.5089e-01, -1.1732e-01],
                      [ 4.7117e-02, -7.1074e-02,  3.2055e-01, -2.8899e-02, -1.9431e-01,
                       -4.1873e-01, -4.8288e-01, -3.4121e-01, -4.0624e-01, -1.6691e-01],
                      [-5.7434e-01, -4.9653e-01, -8.3600e-01, -8.0416e-01, -1.0798e+00,
                       -4.4157e-01,  4.3638e-01, -3.2111e-01, -6.4324e-03, -8.1838e-01],
                      [-6.5204e-01, -7.7991e-01, -1.2579e+00, -9.4577e-01, -4.3919e-0