In [1]:
import os
import gc
import sys
import pickle
import random
import warnings
import pickle
import lightgbm as lgb
import numpy as np
import pandas as pd


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

from callback.lr_scheduler import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from callback.progressbar import ProgressBar
from callback.adversarial import FGM

from tools.common import seed_everything
from tools.common import init_logger, logger

from tqdm import tqdm

from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import WEIGHTS_NAME, BertConfig, get_linear_schedule_with_warmup, AdamW, BertTokenizer

from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from pytorch_tabnet.tab_model import TabNetClassifier

from models.nezha.modeling_nezha import NeZhaForSequenceClassification, NeZhaModel
from models.nezha.configuration_nezha import NeZhaConfig

from pylab import mpl

mpl.rcParams['font.sans-serif'] = ['SimHei']

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


warnings.filterwarnings('ignore')

In [2]:
MODEL_CLASSES = {
    ## bert ernie bert_wwm bert_wwwm_ext
    'bert-base-chinese': (BertConfig, AutoModel, BertTokenizer),
    'roberta-base':(BertConfig, AutoModel, BertTokenizer),
    'nezha-cn-base': (NeZhaConfig, NeZhaModel, BertTokenizer),
}

In [3]:
class CFGs:
    def __init__(self):
        super(CFGs, self).__init__()
        
        self.data_dir = './data/'
        self.out_dir = './output'

        self.epochs=100
        self.folds = 5

        self.task = 'whole' # whole detail
        #self.train_file = f'{self.task}.pkl'
        self.train_file = f'{self.task}.pkl'
        self.model_name = 'roberta-base'
        self.tokenizer_path = './prev_trained_model/chinese-roberta-wwm-ext'
        self.model_path = './prev_trained_model/chinese-roberta-wwm-ext-pretrained/maskedLM-pretraining/checkpoint-21000'

        self.scheduler='cosine'
        self.seed = 42
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.batch_size = 512 #16,32
        self.dropout = 0.2
        #self.max_len = 40

        self.text_dim = 768
        self.img_dim = 2048

        self.transformer_lr = 2e-5
        self.clf_lr = 1e-4

        self.weight_decay = 0.01
        self.eps=1e-6
        self.betas=(0.9, 0.999)
        self.num_warmup_steps=0

        self.max_norm = 1000
        self.num_cycles=0.5
        self.patience = 5
        
        self.do_fgm = False
        self.do_pgd = False
        self.do_freelb = False
        self.do_ema = True

        self.log_name = './output'

        self.overwrite_output_dir = True
        
CFG = CFGs()

In [4]:
config_class, model_class, tokenizer_class = MODEL_CLASSES[CFG.model_name]
    
config = config_class.from_pretrained(CFG.model_path)
tokenizer = tokenizer_class.from_pretrained(CFG.tokenizer_path)
bert_model = model_class.from_pretrained(CFG.model_path, config=config)
CFG.tokenizer = tokenizer

del tokenizer
gc.collect()

Some weights of the model checkpoint at ./prev_trained_model/chinese-roberta-wwm-ext-pretrained/maskedLM-pretraining/checkpoint-21000 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./prev_trained_model/chinese-ro

425

In [5]:
class MyDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.text = df['text'].values
        self.feature = df['feature'].values
        self.label = None
        if 'label' in df.columns:
            self.label = df['label'].values
    
    def __len__(self):
        return len(self.text)
    
    def __getitem__(self, index):
        if self.label is not None:
            return self.text[index],  torch.tensor(self.feature[index]), torch.tensor(self.label[index]).long()
        else:
            return self.text[index],  torch.tensor(self.feature[index])

In [6]:
class FuseLayer(nn.Module):
    def __init__(self, text_dim, img_dim, dropout):
        super().__init__()
        self.bn = nn.BatchNorm1d(768*2)
        self.fc1 = nn.Sequential(
            nn.Linear(img_dim, text_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

        )
        self.fc2 = nn.Sequential(
            nn.Linear(2 * text_dim, text_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

        )
    
    def forward(self, text, img):
        img = self.fc1(img)
        
        concat = torch.cat((img, text),dim=1)
        concat = self.bn(concat)
        fuse = self.fc2(concat)
        return fuse
    
class Model(nn.Module):
    def __init__(self, CFG):
        super().__init__()
        dropout = CFG.dropout
        self.transformer = model_class.from_pretrained(CFG.model_path, config=config)
        self.dropout = nn.Dropout(dropout)
        self.fuse = FuseLayer(CFG.text_dim, CFG.img_dim, dropout)
        self.clf = nn.Linear(CFG.text_dim, 2)
        self.clf1 = nn.Sequential(
                    nn.Linear(CFG.text_dim, 256),
                    nn.Linear(256, 64),
                    nn.Linear(64, 13))
        
    def forward(self, text, img):
        text = self.transformer(**text)[1]
        text = self.dropout(text)
        fuse = self.fuse(text, img)
        out = self.clf1(fuse)
        return out

In [7]:
device = CFG.device
task_name = '_pretrained_6000_2_1.5_1.5_5_shuffle_0.3_fold5_修正数据集'
CFG.out_dir = os.path.join(CFG.out_dir, f'roberta-base' + task_name)

In [8]:
data_dir = './data/preliminary_testB.txt'
import json
import itertools
import re
def load_attr_dict(file):
    # 读取属性字典
    with open(file, 'r') as f:
        attr_dict = {}
        for attr, attrval_list in json.load(f).items():
            attrval_list = list(map(lambda x: x.split('='), attrval_list))
            attr_dict[attr] = list(itertools.chain.from_iterable(attrval_list))
    return attr_dict

attr_dict_file = "./data/attr_to_attrvals.json"
attr_dict = load_attr_dict(attr_dict_file)

def extract_key_attr(title, attr, attr_dict):
    # 在title中匹配属性值
    if attr == '图文':
        return '图文', '符合'
    attr_dict1 = attr_dict
    attrvals = "|".join(attr_dict1[attr])
    ret = re.findall(attrvals, title)
    if ret:
        return attr, ret[0]
    else:
        return 'N',''


def extract_all_key_attr(text):
    key_attr = {}
    for attr in class_name:
        #print(text, attr)
        ret_attr, class_label = extract_key_attr(text, attr, attr_dict)
        if ret_attr != 'N':
            key_attr[ret_attr] = class_label
    if not key_attr:
        return '无'
    return key_attr #['衣长':'中长款']

img_name = []
img_features = []
texts =[]
querys = []
class_name = ['图文', '版型', '裤型', '袖长', '裙长', '领型', '裤门襟', '鞋帮高度', '穿着方式', '衣长', '闭合方式', '裤长', '类别']

with open(data_dir, 'r') as f:
    for data in tqdm(f):
        data = json.loads(data)
        img_features.append(np.array(data['feature']).astype(np.float32))
        img_name.append(data['img_name'])
        texts.append(data['title'])
        querys.append(data['query'])

df = pd.DataFrame(img_name)
df['feature'] = img_features
df['text'] = texts
df['querys'] = querys
df.columns = ['img_name', 'feature', 'text', 'querys']

10000it [00:07, 1264.19it/s]


In [9]:
df.head()

Unnamed: 0,img_name,feature,text,querys
0,test004000,"[-0.044185117, 0.0056368606, 0.113686286, -1.9...",蓝色衬衫2021年秋季长袖童装,"[图文, 袖长]"
1,test004001,"[0.09333559, -0.008259959, 0.09762427, -0.2455...",短裤2021年冬季女士休闲裤阔腿裤黑色加厚女装,"[图文, 裤型]"
2,test004002,"[0.29037216, -0.0029620992, 0.020722916, -0.88...",紧身裤短裤男装厚度常规2022年春季运动裤,"[图文, 裤型, 裤长]"
3,test004003,"[-0.24379866, -0.008214505, 0.06657183, -0.950...",灰色厚度常规女装2021年春季七分裤女士休闲裤,"[图文, 裤长]"
4,test004004,"[0.2508675, -0.0061487244, -0.037342805, -0.83...",低帮2021年夏季男士休闲鞋一脚蹬灰色,"[图文, 鞋帮高度]"


In [10]:
@torch.no_grad()
def evaluate(df, CFG, fold):
    pred = np.empty((0,13))
    dataset = MyDataset(df)
    loader = DataLoader(dataset, batch_size=CFG.batch_size, pin_memory=True)
    model = Model(CFG).to(device)
    model.eval()
    model.load_state_dict(torch.load(f'{CFG.out_dir}/model_{CFG.task}_fold{fold}.pth'))
    for text, feature in tqdm(loader):
        text = CFG.tokenizer(text, return_tensors='pt', add_special_tokens=True, padding=True)
        for k, v in text.items():
            text[k] = v.cuda()
        img = feature.cuda()
        outputs = model(text, img)
        outputs = torch.sigmoid(outputs)
        #print(outputs.shape, outputs)
        pred= np.concatenate((pred, outputs.cpu().numpy()))
        #print(pred.shape)
        #pred = np.stack([pred, outputs.cpu().numpy()])
    return pred

In [11]:
pred = []
CFG.folds = 5

for fold in range(CFG.folds):
    pred.append(evaluate(df, CFG, fold))



# pred = []
# CFG.folds = 1

# for fold in [2]:
#     pred.append(evaluate(df, CFG, fold))

# pred = np.mean(pred, axis=0)

Some weights of the model checkpoint at ./prev_trained_model/chinese-roberta-wwm-ext-pretrained/maskedLM-pretraining/checkpoint-21000 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ./prev_trained_model/chinese-ro

In [12]:
pred = np.mean(pred, axis=0)

In [13]:
pred

array([[9.95482361e-01, 7.59311122e-10, 2.56077689e-09, ...,
        1.92529722e-10, 2.28087361e-07, 1.48916808e-08],
       [6.55258693e-05, 3.28966524e-08, 9.80447233e-01, ...,
        7.32720515e-07, 2.43781182e-05, 1.03547826e-09],
       [2.01985047e-09, 1.86724972e-07, 5.37951998e-03, ...,
        4.71472686e-08, 1.75374046e-07, 6.55442434e-13],
       ...,
       [9.89723551e-01, 5.19829739e-09, 3.40055176e-09, ...,
        8.32683550e-10, 2.74463315e-06, 1.46830635e-08],
       [9.50243222e-13, 1.83735760e-12, 1.49755734e-09, ...,
        1.64256064e-08, 2.82474913e-09, 8.31762159e-18],
       [9.99960852e-01, 7.60867554e-08, 3.52279700e-07, ...,
        1.52549977e-05, 2.07406439e-06, 9.99997687e-01]])

In [14]:
CFG.model_name = 'nezha-cn-base'
CFG.tokenizer_path = './prev_trained_model/nezha-cn-base'
CFG.model_path = './prev_trained_model/nezha-cn-base'

In [15]:
config_class, model_class, tokenizer_class = MODEL_CLASSES[CFG.model_name]
    
config = config_class.from_pretrained(CFG.model_path)
tokenizer = tokenizer_class.from_pretrained(CFG.tokenizer_path)
bert_model = model_class.from_pretrained(CFG.model_path, config=config)
CFG.tokenizer = tokenizer

del tokenizer
gc.collect()

Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


431

In [16]:
device = CFG.device
task_name = '_pretrained_6000_2_1.5_1.5_5_shuffle_0.3_fold5_修正数据集'
CFG.out_dir = os.path.join('./output', f'nezha-cn-base' + task_name)

In [17]:
@torch.no_grad()
def evaluate(df, CFG, fold):
    pred = np.empty((0,13))
    dataset = MyDataset(df)
    loader = DataLoader(dataset, batch_size=CFG.batch_size, pin_memory=True)
    model = Model(CFG).to(device)
    model.eval()
    model.load_state_dict(torch.load(f'{CFG.out_dir}/model_{CFG.task}_fold{fold}.pth', map_location='cpu'), strict=False)
    for text, feature in tqdm(loader):
        text = CFG.tokenizer(text, return_tensors='pt', add_special_tokens=True, padding=True)
        for k, v in text.items():
            text[k] = v.cuda()
        img = feature.cuda()
        outputs = model(text, img)
        outputs = torch.sigmoid(outputs)
        #print(outputs.shape, outputs)
        pred= np.concatenate((pred, outputs.cpu().numpy()))
        #print(pred.shape)
        #pred = np.stack([pred, outputs.cpu().numpy()])
    return pred

In [18]:
pred_nezha = []
CFG.folds = 5

for fold in range(CFG.folds):
    pred_nezha.append(evaluate(df, CFG, fold))

Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 20/20 [00:16<00:00,  1.18it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 20/20 [00:08<00:00,  2.30it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model/nezha-cn-base and are newly initialized: []
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 20/20 [00:11<00:00,  1.69it/s]
Some weights of NeZhaModel were not initialized from the model checkpoint at ./prev_trained_model

In [19]:
pred_nezha = np.mean(pred_nezha, axis=0)

In [20]:
pred_cc = pd.read_csv('./data/pred.csv')
pred_cc = pred_cc[['图文', '版型', '裤型', '袖长', '裙长', '领型', '裤门襟', '鞋帮高度', '穿着方式', '衣长', '闭合方式', '裤长', '类别']]
pred_cc = pred_cc.values

In [21]:
pred_last = np.mean([pred, pred_nezha, pred_cc],  axis=0)

In [22]:
# df['pred'] = list(torch.sigmoid(torch.from_numpy(pred)).cpu().numpy())

In [23]:
df['pred'] = list(pred_last)

In [24]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred
0,test004000,"[-0.044185117, 0.0056368606, 0.113686286, -1.9...",蓝色衬衫2021年秋季长袖童装,"[图文, 袖长]","[0.9935616453488668, 0.0002019983049626054, 1...."
1,test004001,"[0.09333559, -0.008259959, 0.09762427, -0.2455...",短裤2021年冬季女士休闲裤阔腿裤黑色加厚女装,"[图文, 裤型]","[0.0012159751139506618, 0.00017367854506139088..."
2,test004002,"[0.29037216, -0.0029620992, 0.020722916, -0.88...",紧身裤短裤男装厚度常规2022年春季运动裤,"[图文, 裤型, 裤长]","[3.8624471712694975e-05, 8.689490778788192e-05..."
3,test004003,"[-0.24379866, -0.008214505, 0.06657183, -0.950...",灰色厚度常规女装2021年春季七分裤女士休闲裤,"[图文, 裤长]","[2.5088159061977444e-05, 1.8810140521654298e-0..."
4,test004004,"[0.2508675, -0.0061487244, -0.037342805, -0.83...",低帮2021年夏季男士休闲鞋一脚蹬灰色,"[图文, 鞋帮高度]","[0.003223230822989491, 8.103681704613931e-05, ..."


In [25]:
class_name=['图文','版型', '裤型', '袖长', '裙长', '领型', '裤门襟', '鞋帮高度', '穿着方式', '衣长', '闭合方式', '裤长', '类别']
class_dict={'图文': ['符合','不符合'], 
            '版型': ['修身型', '宽松型', '标准型'], 
            '裤型': ['微喇裤', '小脚裤', '哈伦裤', '直筒裤', '阔腿裤', '铅笔裤', 'O型裤', '灯笼裤', '锥形裤', '喇叭裤', '工装裤', '背带裤', '紧身裤'],
            '袖长': ['长袖', '短袖', '七分袖', '五分袖', '无袖', '九分袖'], 
            '裙长': ['中长裙', '短裙', '超短裙', '中裙', '长裙'], 
            '领型': ['半高领', '高领', '翻领', 'POLO领', '立领', '连帽', '娃娃领', 'V领', '圆领', '西装领', '荷叶领', '围巾领', '棒球领', '方领', '可脱卸帽', '衬衫领', 'U型领', '堆堆领', '一字领', '亨利领', '斜领', '双层领'], 
            '裤门襟': ['系带', '松紧', '拉链'], 
            '鞋帮高度': ['低帮', '高帮', '中帮'], 
            '穿着方式': ['套头', '开衫'], 
            '衣长': ['常规款', '中长款', '长款', '短款', '超短款', '超长款'], 
            '闭合方式': ['系带', '套脚', '一脚蹬', '松紧带', '魔术贴', '搭扣', '套筒', '拉链'], 
            '裤长': ['九分裤', '长裤', '五分裤', '七分裤', '短裤'], 
            '类别': ['单肩包', '斜挎包', '双肩包', '手提包']
            }

In [26]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred
0,test004000,"[-0.044185117, 0.0056368606, 0.113686286, -1.9...",蓝色衬衫2021年秋季长袖童装,"[图文, 袖长]","[0.9935616453488668, 0.0002019983049626054, 1...."
1,test004001,"[0.09333559, -0.008259959, 0.09762427, -0.2455...",短裤2021年冬季女士休闲裤阔腿裤黑色加厚女装,"[图文, 裤型]","[0.0012159751139506618, 0.00017367854506139088..."
2,test004002,"[0.29037216, -0.0029620992, 0.020722916, -0.88...",紧身裤短裤男装厚度常规2022年春季运动裤,"[图文, 裤型, 裤长]","[3.8624471712694975e-05, 8.689490778788192e-05..."
3,test004003,"[-0.24379866, -0.008214505, 0.06657183, -0.950...",灰色厚度常规女装2021年春季七分裤女士休闲裤,"[图文, 裤长]","[2.5088159061977444e-05, 1.8810140521654298e-0..."
4,test004004,"[0.2508675, -0.0061487244, -0.037342805, -0.83...",低帮2021年夏季男士休闲鞋一脚蹬灰色,"[图文, 鞋帮高度]","[0.003223230822989491, 8.103681704613931e-05, ..."


In [27]:
df['pred_0'] = df['pred'].apply(lambda x:x[0])

In [28]:
def function(x):
    query = x['querys']
    pre = x['pred']
    tmp={}
    for que in query:
#         if que != '图文':
#             tmp[que]=2
#             continue
        inx=class_name.index(que)
        if pre[inx] > Threshold:
            #print(pre[inx])
            tmp[que]=1
        else:
            tmp[que]=0
    return tmp

In [29]:
# Threshold = sorted(list(df['pred_0'].values))[2326]
# print(Threshold)
Threshold = 0.5
df['match'] = df.apply(function, axis=1)

In [30]:
def itm_same_as_tuwen(x):
    ret = {}
    for key, value in x['match'].items():
        if key == '图文':
            ret[key] = value
        elif ret['图文'] and ret['图文'] == 1:
            ret[key] = 1
        else:
            ret[key] = value
    return ret

In [31]:
def itm_same_as_key_attr(x):
    ret = {}
    for key, value in x['match'].items():
        if value == 0:
            ret['图文'] = 0
        ret[key] = value
    return ret

In [32]:
# # df['itm'] = df['match'].apply(lambda x:x['图文'])
# df['match'] = df.apply(itm_same_as_key_attr, axis=1)

In [33]:
df.head()

Unnamed: 0,img_name,feature,text,querys,pred,pred_0,match
0,test004000,"[-0.044185117, 0.0056368606, 0.113686286, -1.9...",蓝色衬衫2021年秋季长袖童装,"[图文, 袖长]","[0.9935616453488668, 0.0002019983049626054, 1....",0.993562,"{'图文': 1, '袖长': 1}"
1,test004001,"[0.09333559, -0.008259959, 0.09762427, -0.2455...",短裤2021年冬季女士休闲裤阔腿裤黑色加厚女装,"[图文, 裤型]","[0.0012159751139506618, 0.00017367854506139088...",0.001216,"{'图文': 0, '裤型': 1}"
2,test004002,"[0.29037216, -0.0029620992, 0.020722916, -0.88...",紧身裤短裤男装厚度常规2022年春季运动裤,"[图文, 裤型, 裤长]","[3.8624471712694975e-05, 8.689490778788192e-05...",3.9e-05,"{'图文': 0, '裤型': 0, '裤长': 0}"
3,test004003,"[-0.24379866, -0.008214505, 0.06657183, -0.950...",灰色厚度常规女装2021年春季七分裤女士休闲裤,"[图文, 裤长]","[2.5088159061977444e-05, 1.8810140521654298e-0...",2.5e-05,"{'图文': 0, '裤长': 0}"
4,test004004,"[0.2508675, -0.0061487244, -0.037342805, -0.83...",低帮2021年夏季男士休闲鞋一脚蹬灰色,"[图文, 鞋帮高度]","[0.003223230822989491, 8.103681704613931e-05, ...",0.003223,"{'图文': 0, '鞋帮高度': 1}"


In [38]:
def count_key0_itm1(x):
    for i in x.keys():
        if i == '图文' and x[i]==1:
            flag = 1
        elif i == '图文' and x[i]==0:
            flag = 0
        else:
            if x[i] == 0 and flag == 1:
                return 1
    return 0

In [39]:
df['key0_itm1'] = df['match'].apply(lambda x:count_key0_itm1(x)) 

In [40]:
df['key0_itm1'].value_counts()

0    9967
1      33
Name: key0_itm1, dtype: int64

In [43]:
def image_item(x):
    ret = {}
    for i in x.keys():
        if i == '图文' and x[i]==1:
            ret['图文'] = 1
        elif i == '图文' and x[i]==0:
            ret['图文'] = 0
        else:
            if x[i] == 0 and ret['图文'] == 1:
                ret[i] = 1
            else:
                ret[i] = x[i]
    return ret

In [44]:
df['match1'] = df['match'].apply(lambda x:image_item(x))

In [47]:
# df[df.key0_itm1 == 1]

In [48]:
submit=[]
submit_sample={"img_name":"test000255","match":{"图文":0,"领型":1,"袖长":1,"穿着方式":0}}
for i, row in df.iterrows():
    submit_sample['img_name']=row['img_name']
    submit_sample['match']=row['match1']
    #print(submit_sample)
    submit.append(json.dumps(submit_sample, ensure_ascii=False)+'\n')

In [49]:
with open(os.path.join(CFG.out_dir, 'finally_sub_testb_tuwen_dai_guanjian.txt'), 'w') as f:
    f.writelines(submit)

In [34]:
4000*(1-0.4186)

2325.5999999999995

In [35]:
CFG.out_dir

'./output/nezha-cn-base_pretrained_6000_2_1.5_1.5_5_shuffle_0.3_fold5_修正数据集'