In [1]:
! pip install transformers[torch]
! pip install accelerate

Defaulting to user installation because normal site-packages is not writeable
^C


In [3]:
import pandas as pd
import numpy as np
import torch

from torch.utils.data.dataset import Dataset
from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split

In [4]:
df = pd.read_csv("arxiv_data.csv")

In [5]:
df

Unnamed: 0,titles,summaries,terms
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']"
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']"
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']"
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV']
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']"
...,...,...,...
51769,Hierarchically-coupled hidden Markov models fo...,We address the problem of analyzing sets of no...,"['stat.ML', 'physics.bio-ph', 'q-bio.QM']"
51770,Blinking Molecule Tracking,We discuss a method for tracking individual mo...,"['cs.CV', 'cs.DM']"
51771,Towards a Mathematical Foundation of Immunolog...,We attempt to set a mathematical foundation of...,"['stat.ML', 'cs.LG', 'q-bio.GN']"
51772,A Semi-Automatic Graph-Based Approach for Dete...,Diffusion Tensor Imaging (DTI) allows estimati...,['cs.CV']


In [6]:
label_df = pd.read_csv("category_taxonomy.csv")

In [7]:
label_df

Unnamed: 0,category,code,name
0,Computer Science,cs.AI,Artificial Intelligence
1,Computer Science,cs.AR,Hardware Architecture
2,Computer Science,cs.CC,Computational Complexity
3,Computer Science,cs.CE,"Computational Engineering, Finance, and Science"
4,Computer Science,cs.CG,Computational Geometry
...,...,...,...
149,Physics,physics.optics,Optics
150,Physics,physics.plasm-ph,Plasma Physics
151,Physics,physics.pop-ph,Popular Physics
152,Physics,physics.soc-ph,Physics and Society


In [8]:
codes = label_df["code"].to_list()

In [9]:
codes

['cs.AI',
 'cs.AR',
 'cs.CC',
 'cs.CE',
 'cs.CG',
 'cs.CL',
 'cs.CR',
 'cs.CV',
 'cs.CY',
 'cs.DB',
 'cs.DC',
 'cs.DL',
 'cs.DM',
 'cs.DS',
 'cs.ET',
 'cs.FL',
 'cs.GL',
 'cs.GR',
 'cs.GT',
 'cs.HC',
 'cs.IR',
 'cs.IT',
 'cs.LG',
 'cs.LO',
 'cs.MA',
 'cs.MM',
 'cs.MS',
 'cs.NA',
 'cs.NE',
 'cs.NI',
 'cs.OH',
 'cs.OS',
 'cs.PF',
 'cs.PL',
 'cs.RO',
 'cs.SC',
 'cs.SD',
 'cs.SE',
 'cs.SI',
 'cs.SY',
 'econ.EM',
 'econ.GN',
 'econ.TH',
 'eess.AS',
 'eess.IV',
 'eess.SP',
 'eess.SY',
 'math.AC',
 'math.AG',
 'math.AP',
 'math.AT',
 'math.CA',
 'math.CO',
 'math.CT',
 'math.CV',
 'math.DG',
 'math.DS',
 'math.FA',
 'math.GM',
 'math.GN',
 'math.GR',
 'math.GT',
 'math.HO',
 'math.IT',
 'math.KT',
 'math.LO',
 'math.MG',
 'math.MP',
 'math.NA',
 'math.NT',
 'math.OA',
 'math.OC',
 'math.PR',
 'math.QA',
 'math.RA',
 'math.RT',
 'math.SG',
 'math.SP',
 'math.ST',
 'q-bio.BM',
 'q-bio.CB',
 'q-bio.GN',
 'q-bio.MN',
 'q-bio.NC',
 'q-bio.OT',
 'q-bio.PE',
 'q-bio.QM',
 'q-bio.SC',
 'q-bio.TO',
 '

In [10]:

new_df = df["terms"].str.replace("[",'').str.replace("]","").str.replace("'","").str.replace(" ","").str.get_dummies(",")
df

Unnamed: 0,titles,summaries,terms
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']"
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']"
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']"
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV']
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']"
...,...,...,...
51769,Hierarchically-coupled hidden Markov models fo...,We address the problem of analyzing sets of no...,"['stat.ML', 'physics.bio-ph', 'q-bio.QM']"
51770,Blinking Molecule Tracking,We discuss a method for tracking individual mo...,"['cs.CV', 'cs.DM']"
51771,Towards a Mathematical Foundation of Immunolog...,We attempt to set a mathematical foundation of...,"['stat.ML', 'cs.LG', 'q-bio.GN']"
51772,A Semi-Automatic Graph-Based Approach for Dete...,Diffusion Tensor Imaging (DTI) allows estimati...,['cs.CV']


In [11]:
final_df = pd.concat([df,new_df], axis=1)

final_df

Unnamed: 0,titles,summaries,terms,astro-ph.CO,astro-ph.EP,astro-ph.GA,astro-ph.HE,astro-ph.IM,astro-ph.SR,cond-mat.dis-nn,...,q-fin.PR,q-fin.RM,q-fin.ST,q-fin.TR,stat.AP,stat.CO,stat.ME,stat.ML,stat.OT,stat.TH
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV'],0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51769,Hierarchically-coupled hidden Markov models fo...,We address the problem of analyzing sets of no...,"['stat.ML', 'physics.bio-ph', 'q-bio.QM']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
51770,Blinking Molecule Tracking,We discuss a method for tracking individual mo...,"['cs.CV', 'cs.DM']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
51771,Towards a Mathematical Foundation of Immunolog...,We attempt to set a mathematical foundation of...,"['stat.ML', 'cs.LG', 'q-bio.GN']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
51772,A Semi-Automatic Graph-Based Approach for Dete...,Diffusion Tensor Imaging (DTI) allows estimati...,['cs.CV'],0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [12]:
def concatenate_title_summary(val1, val2):
    return f"Title: {val1}. Summary: {val2}"

final_df['title_and_summary'] = final_df.apply(lambda row: concatenate_title_summary(row['titles'], row['summaries']), axis=1)

final_df
final_df.to_csv("data_cleaned.csv", index=False)

In [13]:
final_df = pd.read_csv("data_cleaned.csv")
final_df




Unnamed: 0,titles,summaries,terms,astro-ph.CO,astro-ph.EP,astro-ph.GA,astro-ph.HE,astro-ph.IM,astro-ph.SR,cond-mat.dis-nn,...,q-fin.RM,q-fin.ST,q-fin.TR,stat.AP,stat.CO,stat.ME,stat.ML,stat.OT,stat.TH,title_and_summary
0,Survey on Semantic Stereo Matching / Semantic ...,Stereo matching is one of the widely used tech...,"['cs.CV', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: Survey on Semantic Stereo Matching / Se...
1,FUTURE-AI: Guiding Principles and Consensus Re...,The recent advancements in artificial intellig...,"['cs.CV', 'cs.AI', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: FUTURE-AI: Guiding Principles and Conse...
2,Enforcing Mutual Consistency of Hard Regions f...,"In this paper, we proposed a novel mutual cons...","['cs.CV', 'cs.AI']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: Enforcing Mutual Consistency of Hard Re...
3,Parameter Decoupling Strategy for Semi-supervi...,Consistency training has proven to be an advan...,['cs.CV'],0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: Parameter Decoupling Strategy for Semi-...
4,Background-Foreground Segmentation for Interio...,"To ensure safety in automated driving, the cor...","['cs.CV', 'cs.LG']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: Background-Foreground Segmentation for ...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51769,Hierarchically-coupled hidden Markov models fo...,We address the problem of analyzing sets of no...,"['stat.ML', 'physics.bio-ph', 'q-bio.QM']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,Title: Hierarchically-coupled hidden Markov mo...
51770,Blinking Molecule Tracking,We discuss a method for tracking individual mo...,"['cs.CV', 'cs.DM']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: Blinking Molecule Tracking. Summary: We...
51771,Towards a Mathematical Foundation of Immunolog...,We attempt to set a mathematical foundation of...,"['stat.ML', 'cs.LG', 'q-bio.GN']",0,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,Title: Towards a Mathematical Foundation of Im...
51772,A Semi-Automatic Graph-Based Approach for Dete...,Diffusion Tensor Imaging (DTI) allows estimati...,['cs.CV'],0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,Title: A Semi-Automatic Graph-Based Approach f...


In [15]:
# Efficiently sample 10,000 rows from final_df for training
if len(final_df) > 5000:
    final_df = final_df.sample(n=5000, random_state=42).reset_index(drop=True)

test_split = 0.2
train_df, test_df = train_test_split(
    final_df,
    test_size=test_split,
)
print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in test set: {len(test_df)}")

Number of rows in training set: 4000
Number of rows in test set: 1000


# Approach 1: Assume that every class is a unique class of it's own

In [16]:
# Assume 154 classes. So device a classification model with 154 classes

## Approach 1a: Using only Title

In [17]:
not_chosen_columns = ['titles', 'summaries','terms', "title_and_summary"]

In [18]:
label_columns = [col for col in final_df.columns if col not in not_chosen_columns]

In [19]:
label_columns

['astro-ph.CO',
 'astro-ph.EP',
 'astro-ph.GA',
 'astro-ph.HE',
 'astro-ph.IM',
 'astro-ph.SR',
 'cond-mat.dis-nn',
 'cond-mat.mtrl-sci',
 'cond-mat.soft',
 'cond-mat.stat-mech',
 'cond-mat.str-el',
 'cs.AI',
 'cs.AR',
 'cs.CC',
 'cs.CE',
 'cs.CG',
 'cs.CL',
 'cs.CR',
 'cs.CV',
 'cs.CY',
 'cs.DB',
 'cs.DC',
 'cs.DL',
 'cs.DM',
 'cs.DS',
 'cs.ET',
 'cs.FL',
 'cs.GR',
 'cs.GT',
 'cs.HC',
 'cs.IR',
 'cs.IT',
 'cs.LG',
 'cs.LO',
 'cs.MA',
 'cs.MM',
 'cs.MS',
 'cs.NA',
 'cs.NE',
 'cs.NI',
 'cs.OS',
 'cs.PF',
 'cs.PL',
 'cs.RO',
 'cs.SC',
 'cs.SD',
 'cs.SE',
 'cs.SI',
 'cs.SY',
 'econ.EM',
 'econ.GN',
 'econ.TH',
 'eess.AS',
 'eess.IV',
 'eess.SP',
 'eess.SY',
 'gr-qc',
 'hep-ex',
 'hep-ph',
 'hep-th',
 'math-ph',
 'math.AC',
 'math.AG',
 'math.AP',
 'math.AT',
 'math.CA',
 'math.CO',
 'math.CT',
 'math.CV',
 'math.DG',
 'math.DS',
 'math.FA',
 'math.GR',
 'math.GT',
 'math.HO',
 'math.IT',
 'math.LO',
 'math.MG',
 'math.MP',
 'math.NA',
 'math.NT',
 'math.OA',
 'math.OC',
 'math.PR',
 'math

In [20]:
df_labels_train = train_df[label_columns]
df_labels_test = test_df[label_columns]

In [21]:
df_labels_train["cs.CV"]

3560    1
4830    1
3751    1
4453    0
2276    0
       ..
2042    0
2054    0
4260    0
3529    0
4531    1
Name: cs.CV, Length: 4000, dtype: int64

In [22]:
labels_list_train = df_labels_train.values.tolist()
labels_list_test = df_labels_test.values.tolist()

In [23]:
len(labels_list_test), len(labels_list_test[0])

(1000, 135)

In [12]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

train_texts = train_df['titles'].tolist()
train_labels = labels_list_train

eval_texts = test_df['titles'].tolist()
eval_labels = labels_list_test

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

train_encodings = tokenizer(train_texts, padding="max_length", truncation=True, max_length=512)
eval_encodings = tokenizer(eval_texts, padding="max_length", truncation=True, max_length=512)


class TextClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).to(device)
        return item

train_dataset = TextClassifierDataset(train_encodings, train_labels)
eval_dataset = TextClassifierDataset(eval_encodings, eval_labels)

title_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    problem_type="multi_label_classification",
    num_labels=len(label_columns),
)

title_model = title_model.to(device)

training_arguments = TrainingArguments(
    output_dir=".",
    dataloader_pin_memory=False,
    eval_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
)

trainer = Trainer(
    model=title_model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()




Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/2500 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 0.0906, 'grad_norm': 0.07519041001796722, 'learning_rate': 4e-05, 'epoch': 1.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.03673635795712471, 'eval_runtime': 8.8706, 'eval_samples_per_second': 112.732, 'eval_steps_per_second': 14.092, 'epoch': 1.0}
{'loss': 0.0328, 'grad_norm': 0.1309128701686859, 'learning_rate': 3e-05, 'epoch': 2.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.029773978516459465, 'eval_runtime': 8.6757, 'eval_samples_per_second': 115.264, 'eval_steps_per_second': 14.408, 'epoch': 2.0}
{'loss': 0.028, 'grad_norm': 0.10054031759500504, 'learning_rate': 2e-05, 'epoch': 3.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.028206994757056236, 'eval_runtime': 8.7041, 'eval_samples_per_second': 114.888, 'eval_steps_per_second': 14.361, 'epoch': 3.0}
{'loss': 0.026, 'grad_norm': 0.07997693866491318, 'learning_rate': 1e-05, 'epoch': 4.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.02796301431953907, 'eval_runtime': 8.6015, 'eval_samples_per_second': 116.258, 'eval_steps_per_second': 14.532, 'epoch': 4.0}
{'loss': 0.0248, 'grad_norm': 0.12335339933633804, 'learning_rate': 0.0, 'epoch': 5.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 0.027576344087719917, 'eval_runtime': 8.5862, 'eval_samples_per_second': 116.466, 'eval_steps_per_second': 14.558, 'epoch': 5.0}
{'train_runtime': 753.3915, 'train_samples_per_second': 26.547, 'train_steps_per_second': 3.318, 'train_loss': 0.04043272857666016, 'epoch': 5.0}


TrainOutput(global_step=2500, training_loss=0.04043272857666016, metrics={'train_runtime': 753.3915, 'train_samples_per_second': 26.547, 'train_steps_per_second': 3.318, 'total_flos': 5268505006080000.0, 'train_loss': 0.04043272857666016, 'epoch': 5.0})

In [None]:

from transformers.convert_graph_to_onnx import convert

title_model.save_pretrained("./models/title_model")
convert(
    framework="pt",
    model="./models/title_model",
    output="./models/title_model.onnx",
    opset=12,
    tokenizer=tokenizer,
)

In [30]:
title_model = AutoModelForSequenceClassification.from_pretrained(
    "./models/title_model",
    problem_type="multi_label_classification",
    num_labels=len(label_columns),
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

train_texts = train_df['titles'].tolist()
train_labels = labels_list_train

eval_texts = test_df['titles'].tolist()
eval_labels = labels_list_test
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
title_model = title_model.to(device)


In [40]:
from torch.utils.data import DataLoader, Dataset
import torch
import numpy as np

class InferenceDataset(Dataset):
    """A simple Dataset that wraps a list of raw strings."""
    def __init__(self, texts, tokenizer, max_length=512):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # tokenize on the fly
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            padding="max_length"
        )
        # squeeze out the batch dim
        return {k: v.squeeze(0) for k, v in encoding.items()}


def collate_fn(batch):
    # Stack dicts of tensors into a single dict of batched tensors
    return {k: torch.stack([d[k] for d in batch]) for k in batch[0]}

inference_dataset = InferenceDataset(eval_texts, tokenizer)
dataloader = DataLoader(
        inference_dataset,
        batch_size=8,          # tune this to fill your GPU without OOM
        collate_fn=collate_fn,  # so varying lengths get padded properly
        pin_memory=False
    )
all_preds = []
with torch.no_grad():
    for batch in dataloader:
        # move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = title_model(**batch)
        preds = outputs.logits.sigmoid()
        preds = preds.cpu().numpy()
        all_preds.extend(preds)
probability = all_preds
from sklearn.metrics import f1_score,  recall_score, precision_score, multilabel_confusion_matrix
threshold = 0.5
targets = eval_labels
outputs = np.array(probability) >= threshold
recall_micro = recall_score(targets,outputs, average="micro", zero_division=np.nan)
precision_micro = precision_score(targets, outputs, average="micro")


conf_matrix = multilabel_confusion_matrix(targets, outputs)


print("Recall", recall_micro)

print("Precision", precision_micro)

Recall 0.5873810716074112
Precision 0.8555798687089715


## Approach 1b: Using both Title and Summaries

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

train_texts = train_df['title_and_summary'].tolist()
train_labels = labels_list_train

eval_texts = test_df['title_and_summary'].tolist()
eval_labels = labels_list_test

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

train_encodings = tokenizer(train_texts, padding="max_length", truncation=True, max_length=512)
eval_encodings = tokenizer(eval_texts, padding="max_length", truncation=True, max_length=512)


class TextClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).to(device)
        return item

train_dataset = TextClassifierDataset(train_encodings, train_labels)
eval_dataset = TextClassifierDataset(eval_encodings, eval_labels)

title_summary_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    problem_type="multi_label_classification",
    num_labels=len(label_columns),
)

title_summary_model = title_summary_model.to(device)

training_arguments = TrainingArguments(
    output_dir=".",
    dataloader_pin_memory=False,
    eval_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
)

trainer = Trainer(
    model=title_summary_model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/2500 [00:00<?, ?it/s]

In [None]:
from torch.utils.data import DataLoader, Dataset
import torch
import numpy as np

class InferenceDataset(Dataset):
    """A simple Dataset that wraps a list of raw strings."""
    def __init__(self, texts, tokenizer, max_length=512):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # tokenize on the fly
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            padding="max_length"
        )
        # squeeze out the batch dim
        return {k: v.squeeze(0) for k, v in encoding.items()}


def collate_fn(batch):
    # Stack dicts of tensors into a single dict of batched tensors
    return {k: torch.stack([d[k] for d in batch]) for k in batch[0]}

inference_dataset = InferenceDataset(eval_texts, tokenizer)
dataloader = DataLoader(
        inference_dataset,
        batch_size=8,          # tune this to fill your GPU without OOM
        collate_fn=collate_fn,  # so varying lengths get padded properly
        pin_memory=False
    )
all_preds = []
with torch.no_grad():
    for batch in dataloader:
        # move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = title_summary_model(**batch)
        preds = outputs.logits.sigmoid()
        preds = preds.cpu().numpy()
        all_preds.extend(preds)
probability = all_preds
from sklearn.metrics import f1_score,  recall_score, precision_score, multilabel_confusion_matrix
threshold = 0.5
targets = eval_labels
outputs = np.array(probability) >= threshold
recall_micro = recall_score(targets,outputs, average="micro", zero_division=np.nan)
precision_micro = precision_score(targets, outputs, average="micro")


conf_matrix = multilabel_confusion_matrix(targets, outputs)


print("Recall", recall_micro)

print("Precision", precision_micro)

In [None]:
# Inference

In [None]:
from transformers.convert_graph_to_onnx import convert

model.save_pretrained("./bert_imdb")
convert(
    framework="pt",
    model="./bert_imdb",
    output="./bert_imdb.onnx",
    opset=12,
    tokenizer=tokenizer,
)

# Approach 2: Feed multilabel classification data from parent to child

# Approach 3: Have 2 models separate

# Approach 4: GNNs?