In [1]:
import os
import gc
import sys
import pickle
import random
import warnings
import pickle
import lightgbm as lgb
import numpy as np
import pandas as pd


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

from callback.lr_scheduler import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from callback.progressbar import ProgressBar
from callback.adversarial import FGM

from tools.common import seed_everything
from tools.common import init_logger, logger

from tqdm import tqdm

from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import WEIGHTS_NAME, BertConfig, get_linear_schedule_with_warmup, AdamW, BertTokenizer

from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from pytorch_tabnet.tab_model import TabNetClassifier

from models.nezha.modeling_nezha import NeZhaForSequenceClassification, NeZhaModel
from models.nezha.configuration_nezha import NeZhaConfig

from pylab import mpl

mpl.rcParams['font.sans-serif'] = ['SimHei']

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


warnings.filterwarnings('ignore')

In [2]:
MODEL_CLASSES = {
    ## bert ernie bert_wwm bert_wwwm_ext
    'bert-base-chinese': (BertConfig, AutoModel, BertTokenizer),
    'nezha-cn-base': (NeZhaConfig, NeZhaModel, BertTokenizer),
}

In [3]:
class CFGs:
    def __init__(self):
        super(CFGs, self).__init__()
        
        self.data_dir = './data/'
        self.out_dir = './output'

        self.epochs=100
        self.folds = 5

        self.task = 'whole' # whole detail
        #self.task = 'detail' 
        self.train_file = f'{self.task}.pkl'
        self.test_file = f'test_{self.task}.pkl'
        self.model_name = 'nezha-cn-base'
        self.tokenizer_path = './prev_trained_model/nezha-cn-base'
        self.model_path = './prev_trained_model/nezha-cn-base'
        
        self.scheduler='cosine'
        self.seed = 42
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.batch_size = 128 #16,32
        self.dropout = 0.2

        self.sample_pos_rate = 0.1
        self.sample_neg_rate = 1

        self.text_dim = 768
        self.img_dim = 2048

        self.transformer_lr = 2e-5
        self.clf_lr = 1e-4

        self.weight_decay = 0.01
        self.eps=1e-6
        self.betas=(0.9, 0.999)
        self.num_warmup_steps=0

        self.max_norm = 1000
        self.num_cycles=0.5
        self.patience = 5

        self.log_name = './output'

        self.overwrite_output_dir = True
        
CFG = CFGs()

In [4]:
 config_class, model_class, tokenizer_class = MODEL_CLASSES[CFG.model_name]
    
config = config_class.from_pretrained(CFG.model_path)
tokenizer = tokenizer_class.from_pretrained(CFG.tokenizer_path)
bert_model = model_class.from_pretrained(CFG.model_path, config=config)
CFG.tokenizer = tokenizer

del tokenizer
gc.collect()

Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


431

In [5]:
class MyDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.text = df['text'].values
        self.feature = df['feature'].values
        self.label = None
        if 'label' in df.columns:
            self.label = df['label'].values
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        if self.label is not None:
            return self.text[index],  torch.tensor(self.feature[index]), torch.tensor(self.label[index]).long()
        else:
            return self.text[index],  torch.tensor(self.feature[index])

In [6]:
class FuseLayer(nn.Module):
    def __init__(self, text_dim, img_dim, dropout):
        super().__init__()
        self.bn = nn.BatchNorm1d(768*2)
        self.fc1 = nn.Sequential(
            nn.Linear(img_dim, text_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

        )
        self.fc2 = nn.Sequential(
            nn.Linear(2 * text_dim, text_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

        )
    
    def forward(self, text, img):
        img = self.fc1(img)
        
        concat = torch.cat((img, text),dim=1)
        concat = self.bn(concat)
        fuse = self.fc2(concat)
        return fuse
    
class Model(nn.Module):
    def __init__(self, CFG):
        super().__init__()
        dropout = CFG.dropout
        self.transformer = model_class.from_pretrained(CFG.model_path, config=config)
        self.dropout = nn.Dropout(dropout)
        self.fuse = FuseLayer(CFG.text_dim, CFG.img_dim, dropout)
        self.clf = nn.Linear(CFG.text_dim, 2)
        self.clf1 = nn.Sequential(
                    nn.Linear(CFG.text_dim, 256),
                    nn.Linear(256, 64),
                    nn.Linear(64, 13))
        
    def forward(self, text, img):
        text = self.transformer(**text)[1]
        text = self.dropout(text)
        fuse = self.fuse(text, img)
        out = self.clf1(fuse)
        return out

In [7]:
device = CFG.device
task_name = '_pretrained_6000_2_1.5_1.5_5_shuffle_0.3_fold5_修正数据集'
CFG.out_dir = os.path.join(CFG.out_dir, f'{CFG.model_name}' + task_name)

In [8]:
data_dir = './data/preliminary_testA.txt'
import json
import itertools
import re
def load_attr_dict(file):
    # 读取属性字典
    with open(file, 'r') as f:
        attr_dict = {}
        for attr, attrval_list in json.load(f).items():
            attrval_list = list(map(lambda x: x.split('='), attrval_list))
            attr_dict[attr] = list(itertools.chain.from_iterable(attrval_list))
    return attr_dict

attr_dict_file = "./data/attr_to_attrvals.json"
attr_dict = load_attr_dict(attr_dict_file)

def extract_key_attr(title, attr, attr_dict):
    # 在title中匹配属性值
    if attr == '图文':
        return '图文', '符合'
    attr_dict1 = attr_dict
    attrvals = "|".join(attr_dict1[attr])
    ret = re.findall(attrvals, title)
    if ret:
        return attr, ret[0]
    else:
        return 'N',''


def extract_all_key_attr(text):
    key_attr = {}
    for attr in class_name:
        #print(text, attr)
        ret_attr, class_label = extract_key_attr(text, attr, attr_dict)
        if ret_attr != 'N':
            key_attr[ret_attr] = class_label
    if not key_attr:
        return '无'
    return key_attr #['衣长':'中长款']

img_name = []
img_features = []
texts =[]
querys = []
class_name = ['图文', '版型', '裤型', '袖长', '裙长', '领型', '裤门襟', '鞋帮高度', '穿着方式', '衣长', '闭合方式', '裤长', '类别']

with open(data_dir, 'r') as f:
    for data in tqdm(f):
        data = json.loads(data)
        img_features.append(np.array(data['feature']).astype(np.float32))
        img_name.append(data['img_name'])
        texts.append(data['title'])
        querys.append(data['query'])

df = pd.DataFrame(img_name)
df['feature'] = img_features
df['text'] = texts
df['querys'] = querys
df.columns = ['img_name', 'feature', 'text', 'querys']

4000it [00:01, 2519.91it/s]


In [9]:
df.head()

Unnamed: 0,img_name,feature,text,querys
0,test000000,"[-1.5868584, -0.008788199, 0.058588244, -0.416...",七分袖开衫男装2020年冬季牛仔蓝针织衫,"[图文, 穿着方式, 袖长]"
1,test000001,"[0.24715243, 0.026822748, 0.02925557, -0.32535...",2021年秋冬低帮登山鞋松紧带防水棕色,"[图文, 鞋帮高度]"
2,test000002,"[-0.59509367, -0.007138365, 0.013907805, -0.63...",系带女士休闲鞋低帮2021年冬季,"[图文, 闭合方式]"
3,test000003,"[-0.6622405, 0.048829023, 0.09136587, -0.23281...",高领2021年秋季童装长袖蓝粉色儿童T恤,"[图文, 袖长, 领型]"
4,test000004,"[-0.8325352, -0.0028627717, 0.056332473, -1.18...",连帽牛油果绿女装2021年秋季开衫长袖女士卫衣,"[图文, 穿着方式, 袖长, 领型]"


In [10]:
@torch.no_grad()
def evaluate(df, CFG, fold):
    pred = np.empty((0,13))
    dataset = MyDataset(df)
    loader = DataLoader(dataset, batch_size=CFG.batch_size, pin_memory=True)
    model = Model(CFG).to(device)
    model.eval()
    model.load_state_dict(torch.load(f'{CFG.out_dir}/model_{CFG.task}_fold{fold}.pth', map_location='cpu'), strict=False)
    for text, feature in tqdm(loader):
        text = CFG.tokenizer(text, return_tensors='pt', add_special_tokens=True, padding=True)
        for k, v in text.items():
            text[k] = v.cuda()
        img = feature.cuda()
        outputs = model(text, img)
        #print(outputs.shape, outputs)
        pred= np.concatenate((pred, outputs.cpu().numpy()))
        #print(pred.shape)
        #pred = np.stack([pred, outputs.cpu().numpy()])
    return pred

In [11]:
pred = []
CFG.folds = 5

for fold in range(CFG.folds):
    pred.append(evaluate(df, CFG, fold))

pred = np.mean(pred, axis=0)

# pred = []
# CFG.folds = 1

# for fold in [2]:
#     pred.append(evaluate(df, CFG, fold))

# pred = np.mean(pred, axis=0)

Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 32/32 [00:14<00:00,  2.15it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 32/32 [00:03<00:00, 10.09it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 32/32 [00:03<00:00, 10.11it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model

In [12]:
torch.sigmoid(torch.from_numpy(pred)).cpu().numpy()

array([[5.41623139e-13, 3.95441178e-09, 1.10741809e-14, ...,
        3.63742127e-16, 3.61294971e-12, 9.46279919e-18],
       [3.51845436e-02, 9.83736950e-06, 1.10167489e-04, ...,
        3.48486702e-02, 7.43422413e-06, 4.89141166e-06],
       [9.55888802e-01, 4.47966654e-08, 6.18413133e-07, ...,
        9.99939239e-01, 7.60262865e-09, 2.31381898e-07],
       ...,
       [1.26584477e-15, 1.48179089e-06, 4.66360959e-13, ...,
        5.25257846e-14, 2.05747207e-12, 9.70711811e-19],
       [1.15642000e-14, 1.07466582e-09, 3.41610951e-13, ...,
        3.33628788e-15, 8.00228638e-11, 1.83273456e-17],
       [9.99987229e-01, 2.55586909e-10, 2.25325870e-07, ...,
        1.33181259e-07, 1.31956611e-05, 9.99998571e-01]])

In [13]:
df['pred'] = list(torch.sigmoid(torch.from_numpy(pred)).cpu().numpy())

In [14]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred
0,test000000,"[-1.5868584, -0.008788199, 0.058588244, -0.416...",七分袖开衫男装2020年冬季牛仔蓝针织衫,"[图文, 穿着方式, 袖长]","[5.416231385244068e-13, 3.95441177849989e-09, ..."
1,test000001,"[0.24715243, 0.026822748, 0.02925557, -0.32535...",2021年秋冬低帮登山鞋松紧带防水棕色,"[图文, 鞋帮高度]","[0.03518454364884706, 9.837369502446041e-06, 0..."
2,test000002,"[-0.59509367, -0.007138365, 0.013907805, -0.63...",系带女士休闲鞋低帮2021年冬季,"[图文, 闭合方式]","[0.9558888018281618, 4.4796665437059203e-08, 6..."
3,test000003,"[-0.6622405, 0.048829023, 0.09136587, -0.23281...",高领2021年秋季童装长袖蓝粉色儿童T恤,"[图文, 袖长, 领型]","[0.9956064700869293, 5.273069499614336e-07, 1...."
4,test000004,"[-0.8325352, -0.0028627717, 0.056332473, -1.18...",连帽牛油果绿女装2021年秋季开衫长袖女士卫衣,"[图文, 穿着方式, 袖长, 领型]","[0.004565992422371046, 7.185137647272583e-07, ..."


In [15]:
class_name=['图文','版型', '裤型', '袖长', '裙长', '领型', '裤门襟', '鞋帮高度', '穿着方式', '衣长', '闭合方式', '裤长', '类别']
class_dict={'图文': ['符合','不符合'], 
            '版型': ['修身型', '宽松型', '标准型'], 
            '裤型': ['微喇裤', '小脚裤', '哈伦裤', '直筒裤', '阔腿裤', '铅笔裤', 'O型裤', '灯笼裤', '锥形裤', '喇叭裤', '工装裤', '背带裤', '紧身裤'],
            '袖长': ['长袖', '短袖', '七分袖', '五分袖', '无袖', '九分袖'], 
            '裙长': ['中长裙', '短裙', '超短裙', '中裙', '长裙'], 
            '领型': ['半高领', '高领', '翻领', 'POLO领', '立领', '连帽', '娃娃领', 'V领', '圆领', '西装领', '荷叶领', '围巾领', '棒球领', '方领', '可脱卸帽', '衬衫领', 'U型领', '堆堆领', '一字领', '亨利领', '斜领', '双层领'], 
            '裤门襟': ['系带', '松紧', '拉链'], 
            '鞋帮高度': ['低帮', '高帮', '中帮'], 
            '穿着方式': ['套头', '开衫'], 
            '衣长': ['常规款', '中长款', '长款', '短款', '超短款', '超长款'], 
            '闭合方式': ['系带', '套脚', '一脚蹬', '松紧带', '魔术贴', '搭扣', '套筒', '拉链'], 
            '裤长': ['九分裤', '长裤', '五分裤', '七分裤', '短裤'], 
            '类别': ['单肩包', '斜挎包', '双肩包', '手提包']
            }

In [16]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred
0,test000000,"[-1.5868584, -0.008788199, 0.058588244, -0.416...",七分袖开衫男装2020年冬季牛仔蓝针织衫,"[图文, 穿着方式, 袖长]","[5.416231385244068e-13, 3.95441177849989e-09, ..."
1,test000001,"[0.24715243, 0.026822748, 0.02925557, -0.32535...",2021年秋冬低帮登山鞋松紧带防水棕色,"[图文, 鞋帮高度]","[0.03518454364884706, 9.837369502446041e-06, 0..."
2,test000002,"[-0.59509367, -0.007138365, 0.013907805, -0.63...",系带女士休闲鞋低帮2021年冬季,"[图文, 闭合方式]","[0.9558888018281618, 4.4796665437059203e-08, 6..."
3,test000003,"[-0.6622405, 0.048829023, 0.09136587, -0.23281...",高领2021年秋季童装长袖蓝粉色儿童T恤,"[图文, 袖长, 领型]","[0.9956064700869293, 5.273069499614336e-07, 1...."
4,test000004,"[-0.8325352, -0.0028627717, 0.056332473, -1.18...",连帽牛油果绿女装2021年秋季开衫长袖女士卫衣,"[图文, 穿着方式, 袖长, 领型]","[0.004565992422371046, 7.185137647272583e-07, ..."


In [17]:
def function(x):
    query = x['querys']
    pre = x['pred']
    tmp={}
    for que in query:
#         if que != '图文':
#             tmp[que]=2
#             continue
        inx=class_name.index(que)
        if pre[inx] > Threshold:
            #print(pre[inx])
            tmp[que]=1
        else:
            tmp[que]=0
    return tmp

In [18]:
Threshold = 0.5
df['match'] = df.apply(function, axis=1)

In [19]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred,match
0,test000000,"[-1.5868584, -0.008788199, 0.058588244, -0.416...",七分袖开衫男装2020年冬季牛仔蓝针织衫,"[图文, 穿着方式, 袖长]","[5.416231385244068e-13, 3.95441177849989e-09, ...","{'图文': 0, '穿着方式': 0, '袖长': 0}"
1,test000001,"[0.24715243, 0.026822748, 0.02925557, -0.32535...",2021年秋冬低帮登山鞋松紧带防水棕色,"[图文, 鞋帮高度]","[0.03518454364884706, 9.837369502446041e-06, 0...","{'图文': 0, '鞋帮高度': 1}"
2,test000002,"[-0.59509367, -0.007138365, 0.013907805, -0.63...",系带女士休闲鞋低帮2021年冬季,"[图文, 闭合方式]","[0.9558888018281618, 4.4796665437059203e-08, 6...","{'图文': 1, '闭合方式': 1}"
3,test000003,"[-0.6622405, 0.048829023, 0.09136587, -0.23281...",高领2021年秋季童装长袖蓝粉色儿童T恤,"[图文, 袖长, 领型]","[0.9956064700869293, 5.273069499614336e-07, 1....","{'图文': 1, '袖长': 1, '领型': 1}"
4,test000004,"[-0.8325352, -0.0028627717, 0.056332473, -1.18...",连帽牛油果绿女装2021年秋季开衫长袖女士卫衣,"[图文, 穿着方式, 袖长, 领型]","[0.004565992422371046, 7.185137647272583e-07, ...","{'图文': 0, '穿着方式': 0, '袖长': 1, '领型': 1}"


In [20]:
submit=[]
submit_sample={"img_name":"test000255","match":{"图文":0,"领型":1,"袖长":1,"穿着方式":0}}
for i, row in df.iterrows():
    submit_sample['img_name']=row['img_name']
    submit_sample['match']=row['match']
    #print(submit_sample)
    submit.append(json.dumps(submit_sample, ensure_ascii=False)+'\n')

In [22]:
with open(os.path.join(CFG.out_dir, 'w_1.txt'), 'w') as f:
    f.writelines(submit)