In [2]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
import pandas as pd
import numpy as np
import random
from tqdm.auto import tqdm
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix


BERT_MODEL_NAME = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# torch.cuda.set_device()                            #　指定gpu1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


EPOCHS = 1
MAX_TOKEN_COUNT = 128
BATCH_SIZE = 32

In [3]:
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler

In [4]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)

In [5]:
%%time 
train_df=pd.read_feather("./autodl-nas/USPTO-2M_Training.feather")
val_df=pd.read_feather("./autodl-nas/USPTO-2M_Validation.feather")

LABEL_COLUMNS=train_df.columns[11:]
LABEL_COLUMNS

CPU times: user 37.9 s, sys: 2min 51s, total: 3min 29s
Wall time: 58.7 s


Index(['A41D', 'A62B', 'A41B', 'D06N', 'A42B', 'A43B', 'D06B', 'A41F', 'E03D',
       'A47K',
       ...
       'Y02D', 'F24V', 'H04T', 'G16B', 'G16C', 'G16Z', 'G21J', 'G16Y', 'G06J',
       'E99Z'],
      dtype='object', length=664)

In [6]:
class PatentDataset(Dataset):
    """
    Pass pandas dataframe, and tokeizer along with the max token length[128 default]
    
    Example: 
    -------
    train_dataset = ToxicCommentsDataset(
      train_df,
      tokenizer,
      max_token_len=MAX_TOKEN_COUNT
    )

    sample_item = train_dataset[0]
    
    """
    
    
    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: BertTokenizer,
        max_token_len: int = 512,
        test= False
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len
        self.test = test
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]
        comment_text = data_row.abstract

        if not self.test:
            labels = data_row[LABEL_COLUMNS]
        
        encoding = self.tokenizer.encode_plus(
            comment_text,
            max_length=self.max_token_len,
            padding="max_length",
            truncation=True,
            add_special_tokens=True, # [CLS] & [SEP]
            return_token_type_ids=False,
            return_attention_mask=True, #attention_mask
            return_tensors='pt',
        )
        
        if not self.test:
            return dict(
            comment_text=comment_text,
            input_ids = encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.FloatTensor(labels)
        )
        else:
            return dict(
                comment_text=comment_text,
                input_ids = encoding["input_ids"].flatten(),
                attention_mask=encoding["attention_mask"].flatten()
            )
        

In [7]:
train_dataset = PatentDataset(
  train_df,
  tokenizer,
  max_token_len=MAX_TOKEN_COUNT
)

val_dataset = PatentDataset(
  val_df,
  tokenizer,
  max_token_len=MAX_TOKEN_COUNT
)

In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,drop_last = True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last = True)

In [9]:
class PatentTagger(nn.Module):

    def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
        super().__init__()
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True) #load the pretrained bert model
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes) # add a linear layer to the bert
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps
        self.criterion = nn.BCEWithLogitsLoss()
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(self.dropout(output.last_hidden_state[:,0])) 
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

In [10]:
model = PatentTagger(len(LABEL_COLUMNS)).to(device)

N_EPOCHS = 1

steps_per_epoch=len(train_df) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS
warmup_steps = total_training_steps // 10
warmup_steps, total_training_steps

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(6376, 63760)

In [11]:
optimizer = AdamW(model.parameters(), lr=5e-5)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_training_steps
)



In [12]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score

# function for evaluating the model
def evaluate(mydataloader):

    print("\nEvaluating...")
    #t0 = time.time()
    # deactivate dropout layers
    model.eval()

    total_loss, total_accuracy = 0, 0

    # empty list to save the model predictions
    total_preds = []
    total_labels = []

    # iterate over batches
    for step,batch in tqdm(enumerate(mydataloader),total=len(mydataloader)):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)   

        with autocast():
            with torch.no_grad():
                loss, outputs = model(input_ids, attention_mask, labels)
                outputs = torch.sigmoid(outputs)
                
        total_loss = total_loss + loss.float().item()
        outputs = outputs.detach().float().cpu().numpy()
        labels = labels.detach().float().cpu().numpy()
        total_preds.append(outputs)
        total_labels.append(labels)

    avg_loss = total_loss / len(mydataloader)

    total_preds  = np.concatenate(total_preds, axis=0)
    total_labels = np.concatenate(total_labels, axis=0)

    print(f"Evaluate loss {total_loss / len(mydataloader)}")
    model.train()
    return avg_loss, total_preds, total_labels

In [13]:
# function to train the model
def train():
    
    now=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    best_valid_loss = float('inf')
    model.train()

    total_loss, total_accuracy = 0, 0
    avg_loss = 0
    scaler = GradScaler()

    # iterate over batches
    for step,batch in tqdm(enumerate(train_dataloader),total=len(train_dataloader),desc="Train"):
        
        if step%5000 == 0 and step!=0:
            valid_loss,_,_ = evaluate(val_dataloader)
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(model.state_dict(), f"./model/Classfication_abstract_model{now}.pt")

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)     

        model.zero_grad() 
        with autocast():
            loss, _ = model(input_ids, attention_mask, labels)
            
        scaler.scale(loss).backward()
        
        if step%200 == 0 :
            print(f"step: {step} loss: {loss}")
        # add on to the total loss
        total_loss = total_loss + loss.float().item()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scheduler.step()
        scaler.update()

    # compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)

    return avg_loss

In [14]:
import time
from tqdm.auto import tqdm

now=time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
train_loss = train()

Train:   0%|          | 0/63760 [00:00<?, ?it/s]

step: 0 loss: 0.013012336567044258
step: 200 loss: 0.011891783215105534
step: 400 loss: 0.009343056008219719
step: 600 loss: 0.009999590925872326
step: 800 loss: 0.011931704357266426
step: 1000 loss: 0.01122819259762764
step: 1200 loss: 0.009375101886689663
step: 1400 loss: 0.00942588783800602
step: 1600 loss: 0.009285036474466324
step: 1800 loss: 0.010173646733164787
step: 2000 loss: 0.009351005777716637
step: 2200 loss: 0.007503061089664698
step: 2400 loss: 0.006873880047351122
step: 2600 loss: 0.0077779246494174
step: 2800 loss: 0.006556471809744835
step: 3000 loss: 0.008434687741100788
step: 3200 loss: 0.006792544387280941
step: 3400 loss: 0.007407904136925936
step: 3600 loss: 0.008498935960233212
step: 3800 loss: 0.0066587477922439575
step: 4000 loss: 0.00875947531312704
step: 4200 loss: 0.007214851677417755
step: 4400 loss: 0.008076674304902554
step: 4600 loss: 0.006552125792950392
step: 4800 loss: 0.005467524752020836

Evaluating...


  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.006815069757323288
step: 5000 loss: 0.008045199327170849
step: 5200 loss: 0.0067377216182649136
step: 5400 loss: 0.006507044658064842
step: 5600 loss: 0.008261529728770256
step: 5800 loss: 0.005452746991068125
step: 6000 loss: 0.007953728549182415
step: 6200 loss: 0.005271097645163536
step: 6400 loss: 0.006013098638504744
step: 6600 loss: 0.007203992921859026
step: 6800 loss: 0.005347018130123615
step: 7000 loss: 0.006658430211246014
step: 7200 loss: 0.00477381469681859
step: 7400 loss: 0.00764570664614439
step: 7600 loss: 0.006752817425876856
step: 7800 loss: 0.006136501207947731
step: 8000 loss: 0.00664351275190711
step: 8200 loss: 0.007451686076819897
step: 8400 loss: 0.005945117678493261
step: 8600 loss: 0.005248998291790485
step: 8800 loss: 0.0047660102136433125
step: 9000 loss: 0.005224619060754776
step: 9200 loss: 0.008080413565039635
step: 9400 loss: 0.006035952363163233
step: 9600 loss: 0.005716169252991676
step: 9800 loss: 0.005177508573979139

Evaluating...


  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.005953713519825866
step: 10000 loss: 0.007268458139151335
step: 10200 loss: 0.0066675590351223946
step: 10400 loss: 0.005758641287684441
step: 10600 loss: 0.007122973911464214
step: 10800 loss: 0.0056346808560192585
step: 11000 loss: 0.005984003655612469
step: 11200 loss: 0.005813052412122488
step: 11400 loss: 0.006619095336645842
step: 11600 loss: 0.0058076344430446625
step: 11800 loss: 0.005591823719441891
step: 12000 loss: 0.005130874924361706
step: 12200 loss: 0.004638967104256153
step: 12400 loss: 0.005352185107767582
step: 12600 loss: 0.004926243796944618
step: 12800 loss: 0.0036805050913244486
step: 13000 loss: 0.006126330234110355
step: 13200 loss: 0.003931519575417042
step: 13400 loss: 0.006870080716907978
step: 13600 loss: 0.005966171156615019
step: 13800 loss: 0.005540113430470228
step: 14000 loss: 0.00439738342538476
step: 14200 loss: 0.007802925538271666
step: 14400 loss: 0.004566485062241554
step: 14600 loss: 0.007029811851680279
step: 14800 loss: 0.006176

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.005603812352242347
step: 15000 loss: 0.004332813899964094
step: 15200 loss: 0.006021414417773485
step: 15400 loss: 0.007142563816159964
step: 15600 loss: 0.006049643736332655
step: 15800 loss: 0.0063386945985257626
step: 16000 loss: 0.008851079270243645
step: 16200 loss: 0.0046615912579
step: 16400 loss: 0.0046090297400951385
step: 16600 loss: 0.007102504372596741
step: 16800 loss: 0.0035947379656136036
step: 17000 loss: 0.006364994682371616
step: 17200 loss: 0.005194767378270626
step: 17400 loss: 0.0051585836336016655
step: 17600 loss: 0.004730544053018093
step: 17800 loss: 0.004860556684434414
step: 18000 loss: 0.00579141266644001
step: 18200 loss: 0.0035311204846948385
step: 18400 loss: 0.0062245214357972145
step: 18600 loss: 0.003741051536053419
step: 18800 loss: 0.0049375686794519424
step: 19000 loss: 0.00640818290412426
step: 19200 loss: 0.005645744036883116
step: 19400 loss: 0.005596559029072523
step: 19600 loss: 0.0067101591266691685
step: 19800 loss: 0.00567932

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.005389747303748684
step: 20000 loss: 0.006330818869173527
step: 20200 loss: 0.004385828971862793
step: 20400 loss: 0.007502546068280935
step: 20600 loss: 0.0041683632880449295
step: 20800 loss: 0.0033431784249842167
step: 21000 loss: 0.004914254881441593
step: 21200 loss: 0.004825596231967211
step: 21400 loss: 0.0045168763026595116
step: 21600 loss: 0.005233236588537693
step: 21800 loss: 0.005451945587992668
step: 22000 loss: 0.0057573216035962105
step: 22200 loss: 0.005049029365181923
step: 22400 loss: 0.004306865390390158
step: 22600 loss: 0.0040792734362185
step: 22800 loss: 0.004799045622348785
step: 23000 loss: 0.005034009926021099
step: 23200 loss: 0.006072778720408678
step: 23400 loss: 0.004694879520684481
step: 23600 loss: 0.005839328747242689
step: 23800 loss: 0.004436340648680925
step: 24000 loss: 0.004902059677988291
step: 24200 loss: 0.006690016482025385
step: 24400 loss: 0.0039156232960522175
step: 24600 loss: 0.004624576773494482
step: 24800 loss: 0.004340

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.005251171837807057
step: 25000 loss: 0.00492758909240365
step: 25200 loss: 0.0044488622806966305
step: 25400 loss: 0.006442800164222717
step: 25600 loss: 0.005363112781196833
step: 25800 loss: 0.005274486728012562
step: 26000 loss: 0.005050510633736849
step: 26200 loss: 0.007096807472407818
step: 26400 loss: 0.005813184659928083
step: 26600 loss: 0.005261649377644062
step: 26800 loss: 0.006126993801444769
step: 27000 loss: 0.006089768838137388
step: 27200 loss: 0.006043876986950636
step: 27400 loss: 0.007188380230218172
step: 27600 loss: 0.004124113358557224
step: 27800 loss: 0.004456973634660244
step: 28000 loss: 0.004427823703736067
step: 28200 loss: 0.005609849467873573
step: 28400 loss: 0.005571182817220688
step: 28600 loss: 0.0059096794575452805
step: 28800 loss: 0.003748571965843439
step: 29000 loss: 0.005998808890581131
step: 29200 loss: 0.004879074636846781
step: 29400 loss: 0.005019654519855976
step: 29600 loss: 0.006490442901849747
step: 29800 loss: 0.00506683

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.0051149421430225946
step: 30000 loss: 0.004640503786504269
step: 30200 loss: 0.004172858782112598
step: 30400 loss: 0.00574825145304203
step: 30600 loss: 0.005326497834175825
step: 30800 loss: 0.006814407650381327
step: 31000 loss: 0.005302960518747568
step: 31200 loss: 0.004806051030755043
step: 31400 loss: 0.004399290308356285
step: 31600 loss: 0.0050444286316633224
step: 31800 loss: 0.004456539172679186
step: 32000 loss: 0.005154454614967108
step: 32200 loss: 0.005295547656714916
step: 32400 loss: 0.0042585572227835655
step: 32600 loss: 0.003636320121586323
step: 32800 loss: 0.005324679426848888
step: 33000 loss: 0.0051211402751505375
step: 33200 loss: 0.004130370449274778
step: 33400 loss: 0.005866975523531437
step: 33600 loss: 0.0037599392235279083
step: 33800 loss: 0.005915570538491011
step: 34000 loss: 0.00476113660261035
step: 34200 loss: 0.003948139492422342
step: 34400 loss: 0.005471607204526663
step: 34600 loss: 0.004420747514814138
step: 34800 loss: 0.003851

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.005039491854560182
step: 35000 loss: 0.00545640429481864
step: 35200 loss: 0.005337847862392664
step: 35400 loss: 0.0046603321097791195
step: 35600 loss: 0.006296766456216574
step: 35800 loss: 0.0037575361784547567
step: 36000 loss: 0.004595241975039244
step: 36200 loss: 0.0038945460692048073
step: 36400 loss: 0.005975418724119663
step: 36600 loss: 0.005524842534214258
step: 36800 loss: 0.005490224342793226
step: 37000 loss: 0.004979349672794342
step: 37200 loss: 0.0043941340409219265
step: 37400 loss: 0.0036892457865178585
step: 37600 loss: 0.00527940271422267
step: 37800 loss: 0.005663914140313864
step: 38000 loss: 0.0066391220316290855
step: 38200 loss: 0.003922826610505581
step: 38400 loss: 0.00482386676594615
step: 38600 loss: 0.004781036172062159
step: 38800 loss: 0.005157562904059887
step: 39000 loss: 0.0039750817231833935
step: 39200 loss: 0.004949135705828667
step: 39400 loss: 0.006349786184728146
step: 39600 loss: 0.004510027356445789
step: 39800 loss: 0.00511

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.004956234713436094
step: 40000 loss: 0.005112936720252037
step: 40200 loss: 0.005008650477975607
step: 40400 loss: 0.004512897226959467
step: 40600 loss: 0.005562123376876116
step: 40800 loss: 0.005592006258666515
step: 41000 loss: 0.00521506555378437
step: 41200 loss: 0.005838092882186174
step: 41400 loss: 0.002935917815193534
step: 41600 loss: 0.005793129559606314
step: 41800 loss: 0.004956494551151991
step: 42000 loss: 0.004780573304742575
step: 42200 loss: 0.0054644192568957806
step: 42400 loss: 0.004579138942062855
step: 42600 loss: 0.003436122788116336
step: 42800 loss: 0.00514821894466877
step: 43000 loss: 0.00443305354565382
step: 43200 loss: 0.004611053038388491
step: 43400 loss: 0.0035223786253482103
step: 43600 loss: 0.00443147961050272
step: 43800 loss: 0.003545920131728053
step: 44000 loss: 0.005077302921563387
step: 44200 loss: 0.005531689152121544
step: 44400 loss: 0.004457431845366955
step: 44600 loss: 0.004503726493567228
step: 44800 loss: 0.00464484049

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.00488225732501961
step: 45000 loss: 0.005269747693091631
step: 45200 loss: 0.004355178214609623
step: 45400 loss: 0.006859550718218088
step: 45600 loss: 0.006239450071007013
step: 45800 loss: 0.005122290924191475
step: 46000 loss: 0.005141776520758867
step: 46200 loss: 0.004941432736814022
step: 46400 loss: 0.004204892087727785
step: 46600 loss: 0.005522429943084717
step: 46800 loss: 0.00398195581510663
step: 47000 loss: 0.0048233442939817905
step: 47200 loss: 0.006000581197440624
step: 47400 loss: 0.004472291562706232
step: 47600 loss: 0.004456717986613512
step: 47800 loss: 0.005124782212078571
step: 48000 loss: 0.00549608888104558
step: 48200 loss: 0.006880708038806915
step: 48400 loss: 0.005162759684026241
step: 48600 loss: 0.004704514052718878
step: 48800 loss: 0.006212571635842323
step: 49000 loss: 0.005357827991247177
step: 49200 loss: 0.005283455364406109
step: 49400 loss: 0.0057135834358632565
step: 49600 loss: 0.0058306255377829075
step: 49800 loss: 0.004563144

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.004831094429155668
step: 50000 loss: 0.004313376732170582
step: 50200 loss: 0.004688977729529142
step: 50400 loss: 0.005839629098773003
step: 50600 loss: 0.004835675936192274
step: 50800 loss: 0.005160741042345762
step: 51000 loss: 0.004102484788745642
step: 51200 loss: 0.005151880439370871
step: 51400 loss: 0.004015315789729357
step: 51600 loss: 0.0031323526054620743
step: 51800 loss: 0.003652641549706459
step: 52000 loss: 0.005571655463427305
step: 52200 loss: 0.004910406656563282
step: 52400 loss: 0.0049752127379179
step: 52600 loss: 0.004263041540980339
step: 52800 loss: 0.005450797732919455
step: 53000 loss: 0.0041834949515759945
step: 53200 loss: 0.0061322348192334175
step: 53400 loss: 0.004951458889991045
step: 53600 loss: 0.004839623346924782
step: 53800 loss: 0.004623217508196831
step: 54000 loss: 0.007104769814759493
step: 54200 loss: 0.003943278919905424
step: 54400 loss: 0.005263890605419874
step: 54600 loss: 0.0047859870828688145
step: 54800 loss: 0.0054214

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.004786799380395019
step: 55000 loss: 0.004281916189938784
step: 55200 loss: 0.004682584665715694
step: 55400 loss: 0.005313456058502197
step: 55600 loss: 0.00568461325019598
step: 55800 loss: 0.004198024515062571
step: 56000 loss: 0.005564683582633734
step: 56200 loss: 0.004909052513539791
step: 56400 loss: 0.004340255167335272
step: 56600 loss: 0.005233125761151314
step: 56800 loss: 0.005554497241973877
step: 57000 loss: 0.004464512690901756
step: 57200 loss: 0.004200706258416176
step: 57400 loss: 0.00333963381126523
step: 57600 loss: 0.004042338114231825
step: 57800 loss: 0.0035102758556604385
step: 58000 loss: 0.005341269541531801
step: 58200 loss: 0.0037936479784548283
step: 58400 loss: 0.00474276440218091
step: 58600 loss: 0.005074033979326487
step: 58800 loss: 0.003639831906184554
step: 59000 loss: 0.004549422767013311
step: 59200 loss: 0.005566252861171961
step: 59400 loss: 0.004764218349009752
step: 59600 loss: 0.0042147389613091946
step: 59800 loss: 0.005770911

  0%|          | 0/1634 [00:00<?, ?it/s]

Evaluate loss 0.004777440300121256
step: 60000 loss: 0.006069455295801163
step: 60200 loss: 0.00526946596801281
step: 60400 loss: 0.00422582495957613
step: 60600 loss: 0.005483999382704496
step: 60800 loss: 0.0050404430367052555
step: 61000 loss: 0.004714380484074354
step: 61200 loss: 0.007539224810898304
step: 61400 loss: 0.004355337470769882
step: 61600 loss: 0.004614986479282379
step: 61800 loss: 0.003914312459528446
step: 62000 loss: 0.005472263321280479
step: 62200 loss: 0.0037036275025457144
step: 62400 loss: 0.005551773123443127
step: 62600 loss: 0.005782519467175007
step: 62800 loss: 0.003914413973689079
step: 63000 loss: 0.005111378617584705
step: 63200 loss: 0.00584035087376833
step: 63400 loss: 0.005093981046229601
step: 63600 loss: 0.0047060418874025345


## 测试

In [14]:
model.load_state_dict(torch.load("./model/Classfication_abstract_model2022-04-23 01:39:03.pt"))

<All keys matched successfully>

In [1]:
valid_loss,total_preds,total_labels = evaluate(val_dataloader)

NameError: name 'evaluate' is not defined

In [15]:
from sklearn.metrics import accuracy_score, roc_curve, auc,precision_score
import matplotlib.pyplot as plt

def evaluate_roc(probs, y_true):
    """
    - Print AUC and accuracy on the test set
    - Plot ROC
    @params    probs (np.array): an array of predicted probabilities with shape (len(y_true), 2)
    @params    y_true (np.array): an array of the true values with shape (len(y_true),)
    """
    preds = probs#[:, 1]
    fpr, tpr, threshold = roc_curve(y_true, preds)
    roc_auc = auc(fpr, tpr)
    print(f'AUC: {roc_auc:.4f}')
       
    # Get accuracy over the test set
    y_pred = np.where(preds >= 0.5, 1, 0)
    accuracy = accuracy_score(y_true, y_pred)
    print(f'Accuracy: {accuracy*100:.2f}%')
    
    precision = precision_score(y_true,y_pred)
    print(f'Precision: {precision*100:.2f}%')
    # Plot ROC AUC
    plt.title('Receiver Operating Characteristic')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()

In [16]:
from sklearn.metrics import roc_auc_score
# function for evaluating the model
def evaluate(mydataloader):

    print("\nEvaluating...")
    #t0 = time.time()
    # deactivate dropout layers
    model.eval()

    total_loss, total_accuracy = 0, 0

    # empty list to save the model predictions
    total_preds = []
    total_labels = []

    # iterate over batches
    for step,batch in tqdm(enumerate(mydataloader),total=len(mydataloader)):

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)   
        # deactivate autograd
        with torch.no_grad():
            
            loss, outputs = model(input_ids, attention_mask, labels)

            total_loss = total_loss + loss.item()

            outputs = outputs.detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
            total_preds.append(outputs)
            total_labels.append(labels)


    # compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader)
    print(f"{step}: {avg_loss}")



    # reshape the predictions in form of (number of samples, no. of classes)
    total_preds  = np.concatenate(total_preds, axis=0)
    total_labels = np.concatenate(total_labels, axis=0)
    true = np.array(total_labels)
    pred = np.array(total_preds>0.5)
    #print(true)
    #print(pred)
    for i, name in enumerate(LABEL_COLUMNS):
        try:
            print(f"{name} roc_auc {roc_auc_score(true[:, i], pred[:, i])}")
        except Exception as e:
            print(e)
            pass
    print(f"Evaluate loss {total_loss / len(val_dataloader)}")
    
    
    total_patent=0
    acc_patent=0
    for pp,pr in zip(pred,true):
        total_patent+=1
        if all(pp==pr):
            acc_patent+=1
    print(f"Predict accuracy num: {acc_patent},total Patent num: {total_patent}, Accuracy: {acc_patent/total_patent*100:.2f}%")
    
    return avg_loss, total_preds, total_labels

In [19]:
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score)


true = np.array(total_labels)
pred = np.array(total_preds>0.5)

dic = {
    "Accuracy" : accuracy_score(true,pred),
    "Precision-micro" : precision_score(true,pred,average='micro'),
    "Precision-macro" : precision_score(true,pred,average='macro'),
    "recall-micro" : recall_score(true,pred,average='micro'),
    "recall-macro" : recall_score(true,pred,average='macro'),
    "f1_micro" : f1_score(true,pred,average='micro'),
    "f1-macro" : f1_score(true,pred,average='macro')
}

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


In [20]:
# Test scores 
# learning rate 5e-5
# batch_size 32
dic

{'Accuracy': 0.3850360576923077,
 'Precision-micro': 0.8007889877455094,
 'Precision-macro': 0.48546576079778736,
 'recall-micro': 0.5182517436933707,
 'recall-macro': 0.2494437992538163,
 'f1_micro': 0.6292607640354579,
 'f1-macro': 0.31188530500687855}

In [16]:
# Validation scores
# learning_rate 5e-5
dic

{'Accuracy': 0.3974908200734394,
 'Precision-micro': 0.8567705276010238,
 'Precision-macro': 0.3543756250077216,
 'recall-micro': 0.42521033863938285,
 'recall-macro': 0.1348312852097599,
 'f1_micro': 0.5683512067475059,
 'f1-macro': 0.18028400221502028}

In [22]:
# Test 2021-A scores
# learning_rate 5e-5
dic

{'Accuracy': 0.34275610207100593,
 'Precision-micro': 0.8686674889223882,
 'Precision-macro': 0.37088179823084866,
 'recall-micro': 0.3982942216338944,
 'recall-macro': 0.1260824985554961,
 'f1_micro': 0.5461652684154142,
 'f1-macro': 0.17231818742954721}

In [26]:
# Test 2021-A scores
# learning_rate 5e-5
dic

{'Accuracy': 0.3426790557199211,
 'Precision-micro': 0.8687377221435424,
 'Precision-macro': 0.3707975747138592,
 'recall-micro': 0.3982917010518731,
 'recall-macro': 0.1260669178061862,
 'f1_micro': 0.5461767797749323,
 'f1-macro': 0.17230378630870097}

In [32]:
# Test 2021-A-50000 scores
# learning_rate 5e-5
dic

{'Accuracy': 0.3416666666666667,
 'Precision-micro': 0.8671591098021518,
 'Precision-macro': 0.3461082396608983,
 'recall-micro': 0.3967615735709628,
 'recall-macro': 0.12544875490870652,
 'f1_micro': 0.5444256391521215,
 'f1-macro': 0.17069473713464509}

In [21]:
torch.save(model.state_dict(), f"./baseline-Classification-USPTO-2M{now}.pt")

In [15]:
# Test 2021-A
test_df = pd.read_feather("./autodl-nas/2021-sample-50000.feather")

In [16]:
test_dataset = PatentDataset(
  test_df,
  tokenizer,
  max_token_len=128
)

In [17]:
test_dataloader = DataLoader(test_dataset, batch_size=8*BATCH_SIZE, shuffle=True,drop_last = True)

In [18]:
avg_loss, total_preds, total_labels = evaluate(test_dataloader)


Evaluating...


  0%|          | 0/195 [00:00<?, ?it/s]

Evaluate loss 0.005401214983505316


In [30]:
avg_loss, total_preds, total_labels = evaluate(test_dataloader)


Evaluating...


HBox(children=(FloatProgress(value=0.0, max=195.0), HTML(value='')))


Evaluate loss 0.005900688882535084
