In [1]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
from torch import nn

from transformers import BertTokenizer
from transformers import BertModel
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from catalyst import dl
from catalyst import dl, utils

import warnings
warnings.simplefilter('ignore')
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn import metrics
import transformers
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import DistilBertTokenizer, DistilBertModel
import logging

In [2]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    )

In [3]:
from util import *

In [4]:
training_set = pd.read_json("training_set.json.gz", lines=True, orient="records")
testing_set = pd.read_json("testing_set.json.gz", lines=True, orient="records")

In [5]:
subset = list(sorted(set(all_tiers_100)-set(["PersonalizedProduct"])))

In [6]:
model_name = "/home/martin/IdeaProjects/phenetics/bertForPatents/"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModel.from_pretrained(model_name, gradient_checkpointing=True)

In [7]:
training_set['labels']=training_set[subset].astype(int).values.tolist()
testing_set['labels']=testing_set[subset].astype(int).values.tolist()

In [8]:
nice_subset = [tier_translations[x] for x in subset]
nice_subset

['Analysis and Modeling',
 'Analysis and Modeling: 3D Modeling',
 'Anatomical Target',
 'Anatomical Target: Lower Extremity',
 'Anatomical Target: Lower Extremity - Hip',
 'Anatomical Target: Lower Extremity - Knee',
 'Anatomical Target: Torso',
 'Anatomical Target: Torso - Spine',
 'Anatomical Target: Upper Extremity',
 'Anatomical Target: Upper Extremity - Shoulder',
 'Imaging',
 'Imaging: CT',
 'Imaging: MRI',
 'Imaging: Ultrasound',
 'Manufacturing',
 'Manufacturing: Additive Manufacturing',
 'Personalized Product: Guide or Jig',
 'Personalized Product: Implant',
 'Specification of Use',
 'Specification of Use: Disease',
 'Specification of Use: Joint Replacement',
 'Surgical Method']

In [9]:
cpc_embeddings = np.fromfile("/home/martin/patentmark/cpc.node2vec.emb.32d.bin", dtype=np.float32).reshape((-1,32))

import joblib
cpc_labelizer = joblib.load('./node2id.joblib')
cpc_lookup = {c: n for n, c in enumerate(cpc_labelizer.classes_)}

@f.collecting
def convert_cpc_codes(codes):
    for code in codes:
        if code in cpc_lookup:
            yield cpc_lookup[code]
    
def embed_cpc_codes(codes):
    embedding = np.zeros(32)
    converted = convert_cpc_codes(codes)
    
    if not converted:
        return embedding
    
    for code_id in converted:
        embedding = embedding + cpc_embeddings[code_id]
        
    return embedding / len(converted)

training_set['embedded_cpc'] = training_set.cpc_codes.apply(embed_cpc_codes)
training_set.embedded_cpc

testing_set['embedded_cpc'] = testing_set.cpc_codes.apply(embed_cpc_codes)
testing_set.embedded_cpc

0      [0.09129103335241477, -0.8074875394503276, -0....
1      [-0.0626441298850945, -0.8264780470303127, -0....
2      [-0.2087969978650411, -0.8326806823412577, -0....
3      [0.020394775830209256, -0.8215901732444764, -0...
4      [-0.26043402403593063, -0.6891247034072876, -0...
                             ...                        
238    [-0.23802674313386282, -0.628900408744812, -0....
239    [-0.3754243354002635, -0.6894144614537557, -0....
240    [-0.12913421913981438, -0.6960149183869362, -0...
241    [-0.3880331997688, -0.702021429171929, -0.2717...
242    [-0.2977850042283535, -0.7015813589096069, -0....
Name: embedded_cpc, Length: 243, dtype: object

In [10]:
MAX_LEN_CLAIMS = 512
MAX_LEN_ABSTRACT = 160
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4
EPOCHS = 100
LEARNING_RATE = 1e-5
SEED = 17
PRED_THRES = 0.4
ACCUM_STEPS = 8

In [11]:
from sklearn.feature_extraction.text import CountVectorizer
#cpc_coder = CountVectorizer(analyzer=cpc_split, min_df=5)

In [12]:
training_set['citing'] = training_set[['citations', 'cited_by']].apply(
        lambda row: list(set(row['citations']+row['cited_by'])), axis=1)

In [13]:
testing_set['citing'] = testing_set[['citations', 'cited_by']].apply(
        lambda row: list(set(row['citations']+row['cited_by'])), axis=1)

In [14]:
training_set['people'] = training_set[['assignees', 'inventors']].apply(lambda row: list(set(row['assignees']+row['inventors'])), axis=1)

In [15]:
testing_set['people'] = testing_set[['assignees', 'inventors']].apply(lambda row: list(set(row['assignees']+row['inventors'])), axis=1)

In [16]:
def format(t):
    CORP_TYPES = set(
        [
            "INC",
            "LLC" "CORP",
            "KK",
            "SA",
            "SRL",
            "LTD",
            "NL",
            "PTY",
            "AG",
            "GMBH",
            "KG",
            "OG",
            "LIMITED",
            "SARL",
            "BM",
            "PLC",
            "LP",
            "IP",
            "DBA",
            "CORP",
            "CO",
        ]
    )

    tokenized = strip_non_alphanum(strip_punctuation(t)).upper().split(" ")
    cleaned = [t for t in tokenized if t not in CORP_TYPES]
    return "".join(cleaned)


people_coder = CountVectorizer(analyzer=lambda x: map(format, x), min_df=2)

In [17]:
citing_coder = CountVectorizer(analyzer=lambda x: x, min_df=4)

In [18]:
citing_coder.fit(training_set.citing)
len(citing_coder.vocabulary_)

2076

In [19]:
people_coder.fit(training_set.people)

CountVectorizer(analyzer=<function <lambda> at 0x7f8000ea6310>, min_df=2)

In [20]:
len(people_coder.vocabulary_)

506

In [21]:
#training_set['cpc_vec'] = list(cpc_coder.transform(training_set.cpc_codes).todense())
#testing_set['cpc_vec'] = list(cpc_coder.transform(testing_set.cpc_codes).todense())
training_set['people_vec'] = list(np.array(people_coder.transform(training_set.people).todense()))
testing_set['people_vec'] = list(np.array(people_coder.transform(testing_set.people).todense()))
training_set['citing_vec'] = list(np.array(citing_coder.transform(training_set.citing).todense()))
testing_set['citing_vec'] = list(np.array(citing_coder.transform(testing_set.citing).todense()))

In [22]:
training_set.labels

0      [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, ...
1      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2      [1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
3      [1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, ...
4      [0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
                             ...                        
967    [1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, ...
968    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...
969    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, ...
970    [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, ...
971    [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, ...
Name: labels, Length: 972, dtype: object

In [23]:
# training_set = training_set.explode('labels').reset_index()
# testing_set = testing_set.explode('labels').reset_index()

In [24]:
#training_set.labels.str.len().describe()

In [25]:
possible_labels = set([tier_translations[x] for x in subset])

In [26]:
possible_labels

{'Analysis and Modeling',
 'Analysis and Modeling: 3D Modeling',
 'Anatomical Target',
 'Anatomical Target: Lower Extremity',
 'Anatomical Target: Lower Extremity - Hip',
 'Anatomical Target: Lower Extremity - Knee',
 'Anatomical Target: Torso',
 'Anatomical Target: Torso - Spine',
 'Anatomical Target: Upper Extremity',
 'Anatomical Target: Upper Extremity - Shoulder',
 'Imaging',
 'Imaging: CT',
 'Imaging: MRI',
 'Imaging: Ultrasound',
 'Manufacturing',
 'Manufacturing: Additive Manufacturing',
 'Personalized Product: Guide or Jig',
 'Personalized Product: Implant',
 'Specification of Use',
 'Specification of Use: Disease',
 'Specification of Use: Joint Replacement',
 'Surgical Method'}

In [27]:
def tokenize(tokenizer, text, max_len):
        text = str(text)
        text = " ".join(text.split())

        inputs = tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {
            f"input_ids": torch.tensor(ids, dtype=torch.long),
            f"attention_mask": torch.tensor(mask, dtype=torch.long),
            f"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
        }
    
def tokenize_list(tokenizer, text_list, max_len):
        
        inputs = tokenizer.batch_encode_plus(
            text_list,
            #None,
            #add_special_tokens=True,
            max_length=max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {
            f"input_ids": torch.tensor(ids, dtype=torch.long),
            f"attention_mask": torch.tensor(mask, dtype=torch.long),
            f"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
        }

In [28]:
subset_tokenized = tokenize_list(tokenizer, nice_subset, 56)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [29]:
subset_tokenized['input_ids']

tensor([[    2,  3771,  1663,  ...,     0,     0,     0],
        [    2,  3771,  1663,  ...,     0,     0,     0],
        [    2, 27806,  4204,  ...,     0,     0,     0],
        ...,
        [    2, 12492,  1662,  ...,     0,     0,     0],
        [    2, 12492,  1662,  ...,     0,     0,     0],
        [    2, 11372,  3783,  ...,     0,     0,     0]])

In [30]:
# import fse
# import gensim.downloader as api
# glove = api.load("glove-wiki-gigaword-100")

In [31]:
# from fse import IndexedList, SplitIndexedList
# from fse.models import uSIF
# s = SplitIndexedList(nice_subset)
# label_model = uSIF(glove, workers=32, lang_freq="en")
# label_model.train(s)

In [32]:
#subset_embeddings = label_model.infer(s)

In [33]:
#subset_embeddings.shape

In [34]:
class MultiLabelDataset(Dataset):

    def __init__(self, dataframe, tokenizer):
        self.tokenizer = tokenizer
        self.data = dataframe
        
        self.claims = dataframe.claims
        self.abstracts = dataframe.abstract
        
        self.labels = dataframe.labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        
        abstract = tokenize(tokenizer, self.abstracts[index], max_len=MAX_LEN_ABSTRACT)
        claims = tokenize(tokenizer, self.claims[index], MAX_LEN_CLAIMS)
        
        labels = torch.tensor(np.array(self.labels[index]), dtype=torch.float)
        
        people = torch.tensor(np.array(self.data.people_vec[index]), dtype=torch.float)
        citing = torch.tensor(np.array(self.data.citing_vec[index]), dtype=torch.float)
        embedded_cpc = torch.tensor(np.array(self.data.embedded_cpc[index]), dtype=torch.float)        
                
        return {"abstract": abstract, 
                "claims": claims, 
                
                #'cpcs': cpcs,
                 'people': people,
                 'citing': citing,
                 'embedded_cpc': embedded_cpc,
                 'labels': labels}

In [35]:
training_dataset = MultiLabelDataset(training_set, tokenizer)
testing_dataset = MultiLabelDataset(testing_set, tokenizer)

In [36]:
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False,
                'num_workers': 0
                }

training_loader = DataLoader(training_dataset, **train_params)
testing_loader = DataLoader(testing_dataset, **test_params)

In [37]:
import torch.nn.functional as F

from pytorch_metric_learning import miners, losses

NUM_LABELS = len(nice_subset)
import catalyst.contrib as contrib
device = utils.get_device()
from datetime import datetime
logdir="/var/patentmark/logdir/fit2/" + datetime.now().strftime("%Y%m%d-%H%M%S")
#%load_ext tensorboard
#%tensorboard --logdir /var/patentmark/logdir/fit/ --bind_all

class PatentModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.text_embedder = AutoModel.from_pretrained(model_name, gradient_checkpointing=True)
        
        self.people_embedder = torch.nn.Linear(len(people_coder.vocabulary_), 64)
        self.citing_embedder = torch.nn.Linear(len(citing_coder.vocabulary_), 64)
        
        total_embedding_size = self.text_embedder.pooler.dense.out_features*2+32+64*2
        output_size = 256 #self.text_embedder.pooler.dense.out_features
        
        self.dropout1 = nn.Dropout(0.1)
        
        self.dense1 = nn.Linear(total_embedding_size, output_size)
        self.dense1label = nn.Linear(self.text_embedder.pooler.dense.out_features, output_size)
        
        self.categorizer = nn.Linear(output_size, NUM_LABELS)

    
    def encode_label(self, label):
        label_emb = self.text_embedder(input_ids=label["input_ids"].to(device), attention_mask=label["attention_mask"].to(device))
        label_emb = label_emb[0][:,0]
        
        x = self.dropout1(label_emb)
        x = F.elu(self.dense1label(x))
        return x
    
    def predict_classes(self, embeddings):
        x = self.dropout1(embeddings)
        x = self.categorizer(x)
        return x
        
        
    def encode_patent(self, abstract, claims, embedded_cpc, people, citing):
        
        abstract_emb = self.text_embedder(input_ids=abstract["input_ids"].to(device), 
                                          attention_mask=abstract["attention_mask"].to(device))
        abstract_emb = abstract_emb[0][:, 0]
        
        claim_emb = self.text_embedder(input_ids=claims["input_ids"].to(device), 
                                       attention_mask=claims["attention_mask"].to(device))
        claim_emb = claim_emb[0][:, 0]
        
        people_emb = F.elu(self.people_embedder(people.to(device)))
        citing_emb = F.elu(self.citing_embedder(citing.to(device)))
    
        x = torch.cat((abstract_emb, claim_emb, embedded_cpc.to(device), people_emb, citing_emb), 1)
        x = self.dropout1(x)
        x = F.elu(self.dense1(x))
        
        return x

model = PatentModel().to(device)

In [38]:
from sklearn.metrics import f1_score

In [39]:
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        acc_list.append(tmp_a)
    return np.mean(acc_list)

optimizer = torch.optim.AdamW(params =  model.parameters(), lr=LEARNING_RATE)
#loss_function = torch.nn.CosineEmbeddingLoss()

miner = miners.MultiSimilarityMiner()
loss_function = losses.TripletMarginLoss()
classifier_function = torch.nn.BCEWithLogitsLoss()


from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter(log_dir=logdir, flush_secs=30)
model.train()
optimizer.zero_grad()
step = 0
for epoch in range(EPOCHS):
    running_loss = 0
    for batch_no, batch in enumerate(training_loader):
        
        abstract = batch['abstract']#.to(device)
        claims = batch['claims']#.to(device)
        people = batch['people']#.to(device)
        citing = batch['citing']#.to(device)
        embedded_cpc = batch['embedded_cpc']#.to(device)
        labels = batch['labels'].to(device)

        patent_emb = model.encode_patent(abstract=abstract, claims=claims, embedded_cpc=embedded_cpc, people=people, citing=citing)
        label_embeddings = model.encode_label(subset_tokenized)
        
        #hard_pairs = miner(patent_emb, labels)
        #loss = loss_function(patent_emb, labels, hard_pairs)
    
        loss = 0
        for label_idx in range(NUM_LABELS):
            #label_emb = label_embeddings[label_idx].repeat(4,1)
            current_labels = labels[:, label_idx]
            hard_pairs = miner(patent_emb, current_labels)
            loss += loss_function(patent_emb, current_labels, hard_pairs)
            #current_labels[current_labels==0] = -1
            #loss += loss_function(patent_emb, label_emb, current_labels)
        
        running_loss += loss
        avg_loss = running_loss / (batch_no+1)
        
        print(f"step_no: {step}, epoch: {epoch}, batch_no: {batch_no}, loss: {loss}, avg_loss: {avg_loss}")
        writer.add_scalar('loss/embedding', loss, step)
        loss.backward()
        
#         optimizer.step()
#         optimizer.zero_grad()   
        
#         predictions = model.predict_classes(patent_emb)
#         classifier_loss = classifier_function(labels, predictions)
#         writer.add_scalar('loss/classifier', classifier_loss, step)
#         classifier_loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()       
        step = step + 1
        
        if step % 100 == 0:
            model.eval()
            f1s = []
            with torch.no_grad():
                for val_batch_no, val_batch in enumerate(testing_loader):
                    raw_predictions = torch.sigmoid(model.predict_classes(patent_emb)).cpu()
                    #print(f"raw_predictions: {raw_predictions.shape}")
                    predictions = raw_predictions > 0.5
                    #print(f"predictions: {predictions.shape}: {predictions}")
                    #print(f"val_batch: {val_batch['labels'].shape}: {val_batch['labels']}")
                    if val_batch['labels'].shape[0] == 3:
                        continue
                    f1s.append(f1_score(val_batch['labels'], predictions, average='samples'))
            f1_avg = np.mean(f1s)
            writer.add_scalar('f1/valid', f1_avg, step)
            print(f1_avg)
            model.train()
                    
                    
                    
                    
            
        
        

step_no: 0, epoch: 0, batch_no: 0, loss: 1.1900696754455566, avg_loss: 1.1900696754455566
step_no: 1, epoch: 0, batch_no: 1, loss: 0.7996355295181274, avg_loss: 0.994852602481842
step_no: 2, epoch: 0, batch_no: 2, loss: 1.1026958227157593, avg_loss: 1.0308003425598145
step_no: 3, epoch: 0, batch_no: 3, loss: 1.6710759401321411, avg_loss: 1.1908692121505737
step_no: 4, epoch: 0, batch_no: 4, loss: 1.3944175243377686, avg_loss: 1.2315788269042969
step_no: 5, epoch: 0, batch_no: 5, loss: 3.112457513809204, avg_loss: 1.5450586080551147
step_no: 6, epoch: 0, batch_no: 6, loss: 0.647013783454895, avg_loss: 1.4167665243148804
step_no: 7, epoch: 0, batch_no: 7, loss: 1.9238325357437134, avg_loss: 1.480149745941162
step_no: 8, epoch: 0, batch_no: 8, loss: 1.6977035999298096, avg_loss: 1.5043224096298218
step_no: 9, epoch: 0, batch_no: 9, loss: 0.6345551013946533, avg_loss: 1.417345643043518
step_no: 10, epoch: 0, batch_no: 10, loss: 1.1138719320297241, avg_loss: 1.3897571563720703
step_no: 11, 

step_no: 90, epoch: 0, batch_no: 90, loss: 1.1137853860855103, avg_loss: 0.9178000688552856
step_no: 91, epoch: 0, batch_no: 91, loss: 0.9466953873634338, avg_loss: 0.9181141257286072
step_no: 92, epoch: 0, batch_no: 92, loss: 0.31674885749816895, avg_loss: 0.9116478562355042
step_no: 93, epoch: 0, batch_no: 93, loss: 0.6594262719154358, avg_loss: 0.9089645743370056
step_no: 94, epoch: 0, batch_no: 94, loss: 0.7873609066009521, avg_loss: 0.9076846241950989
step_no: 95, epoch: 0, batch_no: 95, loss: 1.043910264968872, avg_loss: 0.9091035723686218
step_no: 96, epoch: 0, batch_no: 96, loss: 1.1328743696212769, avg_loss: 0.9114104509353638
step_no: 97, epoch: 0, batch_no: 97, loss: 0.7867861986160278, avg_loss: 0.9101387858390808
step_no: 98, epoch: 0, batch_no: 98, loss: 0.5760913491249084, avg_loss: 0.906764566898346
step_no: 99, epoch: 0, batch_no: 99, loss: 1.7689217329025269, avg_loss: 0.9153860807418823
0.38159776180678046
step_no: 100, epoch: 0, batch_no: 100, loss: 0.85255843400955

step_no: 178, epoch: 0, batch_no: 178, loss: 0.7836066484451294, avg_loss: 0.8977039456367493
step_no: 179, epoch: 0, batch_no: 179, loss: 0.7782843708992004, avg_loss: 0.897040605545044
step_no: 180, epoch: 0, batch_no: 180, loss: 0.8342980146408081, avg_loss: 0.896694004535675
step_no: 181, epoch: 0, batch_no: 181, loss: 0.8028872609138489, avg_loss: 0.8961785435676575
step_no: 182, epoch: 0, batch_no: 182, loss: 0.6592616438865662, avg_loss: 0.8948838710784912
step_no: 183, epoch: 0, batch_no: 183, loss: 0.8032845854759216, avg_loss: 0.8943860530853271
step_no: 184, epoch: 0, batch_no: 184, loss: 0.8286000490188599, avg_loss: 0.8940304517745972
step_no: 185, epoch: 0, batch_no: 185, loss: 0.7493975758552551, avg_loss: 0.8932528495788574
step_no: 186, epoch: 0, batch_no: 186, loss: 0.8990655541419983, avg_loss: 0.8932839632034302
step_no: 187, epoch: 0, batch_no: 187, loss: 0.7727659940719604, avg_loss: 0.8926428556442261
step_no: 188, epoch: 0, batch_no: 188, loss: 0.596378624439239

step_no: 266, epoch: 1, batch_no: 23, loss: 0.8488627672195435, avg_loss: 0.7101106643676758
step_no: 267, epoch: 1, batch_no: 24, loss: 0.7835418581962585, avg_loss: 0.7130479216575623
step_no: 268, epoch: 1, batch_no: 25, loss: 1.218063235282898, avg_loss: 0.7324716448783875
step_no: 269, epoch: 1, batch_no: 26, loss: 0.830669641494751, avg_loss: 0.7361086010932922
step_no: 270, epoch: 1, batch_no: 27, loss: 0.8893259763717651, avg_loss: 0.7415806651115417
step_no: 271, epoch: 1, batch_no: 28, loss: 0.4052910804748535, avg_loss: 0.7299844026565552
step_no: 272, epoch: 1, batch_no: 29, loss: 0.5410018563270569, avg_loss: 0.7236850261688232
step_no: 273, epoch: 1, batch_no: 30, loss: 0.7951177954673767, avg_loss: 0.7259892821311951
step_no: 274, epoch: 1, batch_no: 31, loss: 0.659998893737793, avg_loss: 0.7239271402359009
step_no: 275, epoch: 1, batch_no: 32, loss: 0.7608404755592346, avg_loss: 0.7250458002090454
step_no: 276, epoch: 1, batch_no: 33, loss: 0.41050446033477783, avg_loss

step_no: 355, epoch: 1, batch_no: 112, loss: 0.675605833530426, avg_loss: 0.7204902172088623
step_no: 356, epoch: 1, batch_no: 113, loss: 0.8341785073280334, avg_loss: 0.7214874625205994
step_no: 357, epoch: 1, batch_no: 114, loss: 0.891735315322876, avg_loss: 0.7229679226875305
step_no: 358, epoch: 1, batch_no: 115, loss: 0.9173758625984192, avg_loss: 0.7246438264846802
step_no: 359, epoch: 1, batch_no: 116, loss: 0.510428786277771, avg_loss: 0.7228130102157593
step_no: 360, epoch: 1, batch_no: 117, loss: 0.8179741501808167, avg_loss: 0.7236194014549255
step_no: 361, epoch: 1, batch_no: 118, loss: 0.9140125513076782, avg_loss: 0.7252193689346313
step_no: 362, epoch: 1, batch_no: 119, loss: 0.6838849186897278, avg_loss: 0.7248749136924744
step_no: 363, epoch: 1, batch_no: 120, loss: 0.6136227250099182, avg_loss: 0.7239554524421692
step_no: 364, epoch: 1, batch_no: 121, loss: 0.918286919593811, avg_loss: 0.7255483269691467
step_no: 365, epoch: 1, batch_no: 122, loss: 0.8455997109413147,

step_no: 443, epoch: 1, batch_no: 200, loss: 0.9538229703903198, avg_loss: 0.7402440905570984
step_no: 444, epoch: 1, batch_no: 201, loss: 0.7576296329498291, avg_loss: 0.740330159664154
step_no: 445, epoch: 1, batch_no: 202, loss: 0.6623415350914001, avg_loss: 0.7399459481239319
step_no: 446, epoch: 1, batch_no: 203, loss: 0.5099999308586121, avg_loss: 0.7388187646865845
step_no: 447, epoch: 1, batch_no: 204, loss: 0.6903181672096252, avg_loss: 0.738582193851471
step_no: 448, epoch: 1, batch_no: 205, loss: 0.5856930017471313, avg_loss: 0.7378399968147278
step_no: 449, epoch: 1, batch_no: 206, loss: 0.7413152456283569, avg_loss: 0.7378568053245544
step_no: 450, epoch: 1, batch_no: 207, loss: 0.6501511931419373, avg_loss: 0.7374351620674133
step_no: 451, epoch: 1, batch_no: 208, loss: 0.5072299242019653, avg_loss: 0.7363336682319641
step_no: 452, epoch: 1, batch_no: 209, loss: 0.6761978268623352, avg_loss: 0.7360473275184631
step_no: 453, epoch: 1, batch_no: 210, loss: 0.795041680335998

step_no: 531, epoch: 2, batch_no: 45, loss: 1.0367101430892944, avg_loss: 0.6810451745986938
step_no: 532, epoch: 2, batch_no: 46, loss: 0.7688047885894775, avg_loss: 0.6829123497009277
step_no: 533, epoch: 2, batch_no: 47, loss: 0.7732954025268555, avg_loss: 0.6847953796386719
step_no: 534, epoch: 2, batch_no: 48, loss: 0.8130615949630737, avg_loss: 0.6874130368232727
step_no: 535, epoch: 2, batch_no: 49, loss: 0.6091747283935547, avg_loss: 0.6858482360839844
step_no: 536, epoch: 2, batch_no: 50, loss: 0.8783462047576904, avg_loss: 0.689622700214386
step_no: 537, epoch: 2, batch_no: 51, loss: 0.5226906538009644, avg_loss: 0.6864124536514282
step_no: 538, epoch: 2, batch_no: 52, loss: 0.692175567150116, avg_loss: 0.6865212321281433
step_no: 539, epoch: 2, batch_no: 53, loss: 0.7111523151397705, avg_loss: 0.6869773268699646
step_no: 540, epoch: 2, batch_no: 54, loss: 0.48283851146698, avg_loss: 0.6832656860351562
step_no: 541, epoch: 2, batch_no: 55, loss: 0.5364618301391602, avg_loss: 

step_no: 619, epoch: 2, batch_no: 133, loss: 0.8936689496040344, avg_loss: 0.7165082693099976
step_no: 620, epoch: 2, batch_no: 134, loss: 0.9214319586753845, avg_loss: 0.7180262207984924
step_no: 621, epoch: 2, batch_no: 135, loss: 0.8089454174041748, avg_loss: 0.7186947464942932
step_no: 622, epoch: 2, batch_no: 136, loss: 0.49909889698028564, avg_loss: 0.7170918583869934
step_no: 623, epoch: 2, batch_no: 137, loss: 0.6097792983055115, avg_loss: 0.7163142561912537
step_no: 624, epoch: 2, batch_no: 138, loss: 0.5904964804649353, avg_loss: 0.7154091000556946
step_no: 625, epoch: 2, batch_no: 139, loss: 0.6403546333312988, avg_loss: 0.7148730158805847
step_no: 626, epoch: 2, batch_no: 140, loss: 0.8489046096801758, avg_loss: 0.7158235907554626
step_no: 627, epoch: 2, batch_no: 141, loss: 0.6174882650375366, avg_loss: 0.7151311039924622
step_no: 628, epoch: 2, batch_no: 142, loss: 0.6728795170783997, avg_loss: 0.7148356437683105
step_no: 629, epoch: 2, batch_no: 143, loss: 0.719224691390

step_no: 707, epoch: 2, batch_no: 221, loss: 0.6530042886734009, avg_loss: 0.7251921892166138
step_no: 708, epoch: 2, batch_no: 222, loss: 0.8682955503463745, avg_loss: 0.7258339524269104
step_no: 709, epoch: 2, batch_no: 223, loss: 0.8069874048233032, avg_loss: 0.7261962294578552
step_no: 710, epoch: 2, batch_no: 224, loss: 0.6614363789558411, avg_loss: 0.7259083986282349
step_no: 711, epoch: 2, batch_no: 225, loss: 0.692332923412323, avg_loss: 0.7257598638534546
step_no: 712, epoch: 2, batch_no: 226, loss: 0.7294690012931824, avg_loss: 0.7257761359214783
step_no: 713, epoch: 2, batch_no: 227, loss: 1.1299498081207275, avg_loss: 0.7275488376617432
step_no: 714, epoch: 2, batch_no: 228, loss: 0.5374091267585754, avg_loss: 0.726718544960022
step_no: 715, epoch: 2, batch_no: 229, loss: 0.4349885582923889, avg_loss: 0.7254500985145569
step_no: 716, epoch: 2, batch_no: 230, loss: 0.6222731471061707, avg_loss: 0.7250034809112549
step_no: 717, epoch: 2, batch_no: 231, loss: 0.743110239505767

step_no: 796, epoch: 3, batch_no: 67, loss: 0.3514384925365448, avg_loss: 0.7436053156852722
step_no: 797, epoch: 3, batch_no: 68, loss: 0.4727587401866913, avg_loss: 0.7396800518035889
step_no: 798, epoch: 3, batch_no: 69, loss: 0.7615422010421753, avg_loss: 0.7399923801422119
step_no: 799, epoch: 3, batch_no: 70, loss: 0.9693300127983093, avg_loss: 0.7432224750518799
0.40918319162620304
step_no: 800, epoch: 3, batch_no: 71, loss: 0.7591568827629089, avg_loss: 0.7434437870979309
step_no: 801, epoch: 3, batch_no: 72, loss: 1.0319501161575317, avg_loss: 0.747395932674408
step_no: 802, epoch: 3, batch_no: 73, loss: 0.6341413259506226, avg_loss: 0.7458654642105103
step_no: 803, epoch: 3, batch_no: 74, loss: 0.7233393788337708, avg_loss: 0.7455651164054871
step_no: 804, epoch: 3, batch_no: 75, loss: 0.84920334815979, avg_loss: 0.7469287514686584
step_no: 805, epoch: 3, batch_no: 76, loss: 0.8114935159683228, avg_loss: 0.7477672696113586
step_no: 806, epoch: 3, batch_no: 77, loss: 0.4632052

step_no: 884, epoch: 3, batch_no: 155, loss: 0.6801006197929382, avg_loss: 0.7418916821479797
step_no: 885, epoch: 3, batch_no: 156, loss: 0.5648625493049622, avg_loss: 0.7407641410827637
step_no: 886, epoch: 3, batch_no: 157, loss: 0.6691088676452637, avg_loss: 0.7403106093406677
step_no: 887, epoch: 3, batch_no: 158, loss: 0.7336923480033875, avg_loss: 0.7402689456939697
step_no: 888, epoch: 3, batch_no: 159, loss: 0.6505017876625061, avg_loss: 0.7397079467773438
step_no: 889, epoch: 3, batch_no: 160, loss: 0.6854408979415894, avg_loss: 0.7393708825111389
step_no: 890, epoch: 3, batch_no: 161, loss: 0.4531833529472351, avg_loss: 0.7376043200492859
step_no: 891, epoch: 3, batch_no: 162, loss: 0.3367767930030823, avg_loss: 0.7351452112197876
step_no: 892, epoch: 3, batch_no: 163, loss: 0.5806828141212463, avg_loss: 0.7342033386230469
step_no: 893, epoch: 3, batch_no: 164, loss: 0.8747270107269287, avg_loss: 0.7350550293922424
step_no: 894, epoch: 3, batch_no: 165, loss: 0.4885284304618

step_no: 972, epoch: 4, batch_no: 0, loss: 0.6469544768333435, avg_loss: 0.6469544768333435
step_no: 973, epoch: 4, batch_no: 1, loss: 0.548209011554718, avg_loss: 0.5975817441940308
step_no: 974, epoch: 4, batch_no: 2, loss: 0.5684092044830322, avg_loss: 0.5878576040267944
step_no: 975, epoch: 4, batch_no: 3, loss: 0.7276942729949951, avg_loss: 0.6228167414665222
step_no: 976, epoch: 4, batch_no: 4, loss: 1.0956902503967285, avg_loss: 0.7173914313316345
step_no: 977, epoch: 4, batch_no: 5, loss: 0.4638321101665497, avg_loss: 0.6751315593719482
step_no: 978, epoch: 4, batch_no: 6, loss: 0.6775146126747131, avg_loss: 0.6754720211029053
step_no: 979, epoch: 4, batch_no: 7, loss: 0.607276439666748, avg_loss: 0.6669475436210632
step_no: 980, epoch: 4, batch_no: 8, loss: 1.0398412942886353, avg_loss: 0.7083801627159119
step_no: 981, epoch: 4, batch_no: 9, loss: 1.0365979671478271, avg_loss: 0.7412019968032837
step_no: 982, epoch: 4, batch_no: 10, loss: 0.5283460021018982, avg_loss: 0.721851

step_no: 1060, epoch: 4, batch_no: 88, loss: 1.0750116109848022, avg_loss: 0.7374035120010376
step_no: 1061, epoch: 4, batch_no: 89, loss: 0.8467787504196167, avg_loss: 0.7386188507080078
step_no: 1062, epoch: 4, batch_no: 90, loss: 0.5917951464653015, avg_loss: 0.7370054125785828
step_no: 1063, epoch: 4, batch_no: 91, loss: 0.5788564085960388, avg_loss: 0.7352864146232605
step_no: 1064, epoch: 4, batch_no: 92, loss: 0.6679075956344604, avg_loss: 0.7345618605613708
step_no: 1065, epoch: 4, batch_no: 93, loss: 1.1160869598388672, avg_loss: 0.7386206388473511
step_no: 1066, epoch: 4, batch_no: 94, loss: 0.5661671161651611, avg_loss: 0.7368054389953613
step_no: 1067, epoch: 4, batch_no: 95, loss: 0.48950129747390747, avg_loss: 0.734229326248169
step_no: 1068, epoch: 4, batch_no: 96, loss: 0.5294654965400696, avg_loss: 0.7321183085441589
step_no: 1069, epoch: 4, batch_no: 97, loss: 0.5469132661819458, avg_loss: 0.7302284836769104
step_no: 1070, epoch: 4, batch_no: 98, loss: 0.4693747162818

step_no: 1147, epoch: 4, batch_no: 175, loss: 0.7188620567321777, avg_loss: 0.7268074154853821
step_no: 1148, epoch: 4, batch_no: 176, loss: 0.8472546935081482, avg_loss: 0.7274878621101379
step_no: 1149, epoch: 4, batch_no: 177, loss: 0.7318625450134277, avg_loss: 0.7275124192237854
step_no: 1150, epoch: 4, batch_no: 178, loss: 0.7517955303192139, avg_loss: 0.7276480793952942
step_no: 1151, epoch: 4, batch_no: 179, loss: 0.7554243206977844, avg_loss: 0.7278023958206177
step_no: 1152, epoch: 4, batch_no: 180, loss: 0.9979268908500671, avg_loss: 0.7292947769165039
step_no: 1153, epoch: 4, batch_no: 181, loss: 0.8716061115264893, avg_loss: 0.7300767302513123
step_no: 1154, epoch: 4, batch_no: 182, loss: 0.5345759391784668, avg_loss: 0.7290083765983582
step_no: 1155, epoch: 4, batch_no: 183, loss: 0.6644953489303589, avg_loss: 0.7286577820777893
step_no: 1156, epoch: 4, batch_no: 184, loss: 0.43473389744758606, avg_loss: 0.7270690202713013
step_no: 1157, epoch: 4, batch_no: 185, loss: 0.6

step_no: 1234, epoch: 5, batch_no: 19, loss: 0.7980104088783264, avg_loss: 0.7008013129234314
step_no: 1235, epoch: 5, batch_no: 20, loss: 0.4737129509449005, avg_loss: 0.6899875402450562
step_no: 1236, epoch: 5, batch_no: 21, loss: 0.8897343873977661, avg_loss: 0.6990669369697571
step_no: 1237, epoch: 5, batch_no: 22, loss: 0.8033981323242188, avg_loss: 0.7036030888557434
step_no: 1238, epoch: 5, batch_no: 23, loss: 1.037950873374939, avg_loss: 0.7175342440605164
step_no: 1239, epoch: 5, batch_no: 24, loss: 0.8391400575637817, avg_loss: 0.7223984599113464
step_no: 1240, epoch: 5, batch_no: 25, loss: 0.5159204006195068, avg_loss: 0.7144570350646973
step_no: 1241, epoch: 5, batch_no: 26, loss: 0.78443443775177, avg_loss: 0.7170487642288208
step_no: 1242, epoch: 5, batch_no: 27, loss: 0.5769082903862, avg_loss: 0.712043821811676
step_no: 1243, epoch: 5, batch_no: 28, loss: 0.8654413819313049, avg_loss: 0.7173333764076233
step_no: 1244, epoch: 5, batch_no: 29, loss: 0.6895055770874023, av

KeyboardInterrupt: 

In [None]:
batch["labels"]

In [None]:
model.eval()
f1s = []
with torch.no_grad():
    for batch_no, batch in enumerate(testing_loader):
        predictions = torch.sigmoid(model.predict_classes(patent_emb)).cpu() > 0.5
        f1s.append(f1_score(batch['labels'], predictions, average='samples'))
f1_avg = np.mean(f1)
print(f1_avg)
model.train()

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
#criterion = contrib.nn.criterion.LovaszLossMultiLabel()
#scheduler = contrib.nn.OneCycleLRWithWarmup(optimizer, num_steps=500, lr_range=(1e-4, 1e-8), init_lr=1e-9, warmup_fraction=0.2)
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [2])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
#lrfinder = dl.LRFinder(final_lr=1)

runner = dl.SupervisedRunner(input_key=("abstract", "claims", "embedded_cpc", "people", "citing"))
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=loaders,
    logdir=logdir,
    num_epochs=EPOCHS,
    callbacks=[
               dl.MultiLabelAccuracyCallback(threshold=PRED_THRES, activation="None"),
               dl.EarlyStoppingCallback(patience=3, metric="multi_label_accuracy", minimize=False),
               dl.TensorboardLogger(),
               #dl.CheckpointCallback(),
               dl.OptimizerCallback(accumulation_steps=ACCUM_STEPS),
               dl.ValidationManagerCallback(),
               ],
               #dl.MetricManagerCallback(num_classes=len(subset), )],
    
    fp16=True,
    verbose=True
)

In [None]:
logdir

In [None]:
predictions = np.vstack(list(map(
    lambda x: x["logits"].cpu().numpy(), 
    runner.predict_loader(loader=loaders["valid"], resume=f"{logdir}/checkpoints/best.pth" )
)))

In [None]:
binary_predictions = torch.sigmoid(torch.from_numpy(predictions)) > 0.5

In [None]:
from sklearn.metrics import *
print(classification_report(testing_set[subset].astype(int), binary_predictions, target_names=subset))

In [None]:
precision    recall  f1-score   support

     SpecificationofUse_JointReplacement       0.21      0.32      0.25        44
                                 Imaging       0.55      1.00      0.71       133
                          SurgicalMethod       0.00      0.00      0.00        40
     Manufacturing_AdditiveManufacturing       0.00      0.00      0.00        38
                      Imaging_Ultrasound       0.00      0.00      0.00        32
                             Imaging_MRI       0.34      0.20      0.26        59
AnatomicalTarget_UpperExtremity_Shoulder       0.00      0.00      0.00        23
              SpecificationofUse_Disease       0.00      0.00      0.00        30
             PersonalizedProduct_Implant       0.51      1.00      0.68       124
                           Manufacturing       0.34      0.90      0.49        83
         AnatomicalTarget_UpperExtremity       0.00      0.00      0.00        31
                     AnalysisAndModeling       0.36      0.96      0.52        84
         AnatomicalTarget_LowerExtremity       0.47      1.00      0.63       113
                      SpecificationofUse       0.34      0.99      0.50        79
                        AnatomicalTarget       0.67      1.00      0.81       164
           PersonalizedProduct_Guide/Jig       0.49      1.00      0.66       120
            AnatomicalTarget_Torso_Spine       0.00      0.00      0.00        21
                              Imaging_CT       0.29      0.31      0.30        59
          AnalysisAndModeling_3DModeling       0.30      0.93      0.46        71
    AnatomicalTarget_LowerExtremity_Knee       0.34      0.78      0.48        82
                  AnatomicalTarget_Torso       0.00      0.00      0.00        35
     AnatomicalTarget_LowerExtremity_Hip       0.00      0.00      0.00        40

                               micro avg       0.43      0.71      0.54      1505
                               macro avg       0.24      0.47      0.31      1505
                            weighted avg       0.36      0.71      0.47      1505
                             samples avg       0.43      0.74      0.52      1505

In [None]:
# Longformer base (claims + abstract)
                                            precision    recall  f1-score   support

                     AnalysisAndModeling       0.35      1.00      0.51        84
                        AnatomicalTarget       0.67      1.00      0.81       164
            AnatomicalTarget_Torso_Spine       0.00      0.00      0.00        21
     AnatomicalTarget_LowerExtremity_Hip       0.00      0.00      0.00        40
                             Imaging_MRI       0.00      0.00      0.00        59
                                 Imaging       0.55      1.00      0.71       133
                           Manufacturing       0.34      0.99      0.50        83
             PersonalizedProduct_Implant       0.51      1.00      0.68       124
              SpecificationofUse_Disease       0.00      0.00      0.00        30
                      SpecificationofUse       0.34      0.89      0.49        79
     SpecificationofUse_JointReplacement       0.00      0.00      0.00        44
                  AnatomicalTarget_Torso       0.00      0.00      0.00        35
         AnatomicalTarget_UpperExtremity       0.00      0.00      0.00        31
                      Imaging_Ultrasound       0.00      0.00      0.00        32
                              Imaging_CT       0.32      0.25      0.28        59
          AnalysisAndModeling_3DModeling       0.28      0.80      0.42        71
                          SurgicalMethod       0.00      0.00      0.00        40
AnatomicalTarget_UpperExtremity_Shoulder       0.00      0.00      0.00        23
    AnatomicalTarget_LowerExtremity_Knee       0.34      1.00      0.51        82
           PersonalizedProduct_Guide/Jig       0.49      1.00      0.66       120
         AnatomicalTarget_LowerExtremity       0.47      1.00      0.63       113
     Manufacturing_AdditiveManufacturing       0.00      0.00      0.00        38

                               micro avg       0.44      0.69      0.54      1505
                               macro avg       0.21      0.45      0.28      1505
                            weighted avg       0.34      0.69      0.45      1505
                             samples avg       0.44      0.73      0.52      1505

In [None]:
    #Albert base w/ 256 length sequences (claims + abstract)                
    
    precision    recall  f1-score   support

         AnatomicalTarget_LowerExtremity       0.47      1.00      0.63       113
     Manufacturing_AdditiveManufacturing       0.67      0.05      0.10        38
                                 Imaging       0.55      1.00      0.71       133
                          SurgicalMethod       0.00      0.00      0.00        40
AnatomicalTarget_UpperExtremity_Shoulder       0.18      0.13      0.15        23
              SpecificationofUse_Disease       0.00      0.00      0.00        30
    AnatomicalTarget_LowerExtremity_Knee       0.45      0.40      0.43        82
                      SpecificationofUse       0.35      0.95      0.52        79
         AnatomicalTarget_UpperExtremity       0.00      0.00      0.00        31
            AnatomicalTarget_Torso_Spine       0.00      0.00      0.00        21
             PersonalizedProduct_Implant       0.51      1.00      0.68       124
                     AnalysisAndModeling       0.38      0.65      0.48        84
          AnalysisAndModeling_3DModeling       0.33      0.68      0.44        71
                  AnatomicalTarget_Torso       0.00      0.00      0.00        35
     SpecificationofUse_JointReplacement       0.18      0.68      0.28        44
                        AnatomicalTarget       0.67      1.00      0.81       164
                           Manufacturing       0.32      0.87      0.47        83
                             Imaging_MRI       0.26      0.15      0.19        59
                      Imaging_Ultrasound       0.20      0.12      0.15        32
                              Imaging_CT       0.32      0.34      0.33        59
     AnatomicalTarget_LowerExtremity_Hip       0.14      0.03      0.04        40
           PersonalizedProduct_Guide/Jig       0.50      1.00      0.66       120

                               micro avg       0.43      0.67      0.52      1505
                               macro avg       0.29      0.46      0.32      1505
                            weighted avg       0.39      0.67      0.47      1505
                             samples avg       0.43      0.70      0.51      1505

In [None]:
hamming_loss(testing_set[subset], binary_predictions)