# loading pretrained tokenizer and text classification model

In [1]:
import torch
from transformers import BertTokenizerFast, BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained(
    'model/BertForSequenceClassification/bert-base-uncased-emotion')
tokenizer = BertTokenizerFast.from_pretrained(
    'tokenizer/BertTokenizerFast/bert-base-uncased')
device = torch.device('cuda:0')


# defining custom Dataset Class

In [2]:
import torch
from transformers import PreTrainedTokenizerBase
import pandas as pd


class TokenizedDataset(torch.utils.data.Dataset):
    def __init__(self, path: str = None, df: pd.DataFrame = None, tokenizer: PreTrainedTokenizerBase = None):
        if path != None:
            df = pd.read_csv(path)
        else:
            if df.empty:
                raise ValueError(df)
        self._encodings = tokenizer(df['text'].to_list(), padding=True, truncation=True, return_tensors='pt')
        self._encodings['labels'] = torch.full(
            (len(self._encodings['input_ids']),), model.num_labels-1)
        self._labels = torch.tensor(df['label'])
    
    def __len__(self):
        return len(self._labels)

    def __getitem__(self, index):
        item = {key: val[index] for key, val in self._encodings.items()}
        item['label'] = self._labels[index]
        return item



In [10]:
from torch.utils.data import DataLoader
train = TokenizedDataset(
    df=pd.read_csv('dataset/dataset.csv'), tokenizer=tokenizer)

train_loader = DataLoader(train, batch_size=64, shuffle=True)
print(len(train), train[0].keys())

19135 dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels', 'label'])


# using Trainer/TrainerArguments

In [None]:
import os
from transformers import TrainingArguments, Trainer
import torch_ort

ort_model = torch_ort.ORTModule(model)
ort_model.to(device=device)
ort_model.train()

# os.environ["TOKENIZERS_PARALLELISM"] = "false"

training_args = TrainingArguments(output_dir="trainer", num_train_epochs=50)

trainer = Trainer(
    model=ort_model, args=training_args, train_dataset=train, eval_dataset=train
)

trainer.train()


# custom trainer

In [10]:
# import torch_ort

# ort_model = torch_ort.ORTModule(model)
# ort_model.to(device=device)
# ort_model.train()

# criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(ort_model.parameters(), lr=1e-4)

# max_epochs = 20

# for epoch in range(max_epochs):
#     for data in train_loader:
#         label = data['label']
#         data.pop('label')
#         inputs = data
#         optimizer.zero_grad()
#         outputs = ort_model(**data)
#         outputs.loss.backward()
#         optimizer.step()
#         print(f'loss : {float(outputs.loss)}')
