In [1]:
import os, time, pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

# Append the configuration path
import bert_functions as bf

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Uncomment to preprocess datasets
# encoded_datasets = {name: preprocess_function(data, tokenizer) for name, data in all_datasets.items()}
# for name, dataset in encoded_datasets.items():
#     torch.save(dataset, f"encoded/{name}.pt")

# device = torch.device('mps')

encoded_dir = 'encoded'
loaded_datasets = bf.load_encoded_datasets(encoded_dir, bf.all_datasets)

for dataset_name, dataset in loaded_datasets.items():
    loaded_datasets[dataset_name] = bf.CustomDataset(dataset)

combined_train_dataset = bf.combine_datasets(loaded_datasets, "_train")
combined_val_dataset = bf.combine_datasets(loaded_datasets, "_val")
combined_test_dataset = bf.combine_datasets(loaded_datasets, "_test")

train_size = 16000
train_dataset = bf.CustomDataset(combined_train_dataset.data, size=train_size)
val_dataset = bf.CustomDataset(combined_val_dataset.data, size=train_size)
test_dataset = bf.CustomDataset(combined_test_dataset.data, size=train_size)

model = BertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=bf.compute_metrics,
)

trainer.train()

time_now = time.strftime("%H:%M:%S", time.localtime())
output_dir = os.path.join("models", f"size_{train_size}_{time_now}")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

test_result = trainer.evaluate(eval_dataset=bf.CustomDataset(combined_test_dataset.data))
test_result = pd.DataFrame(test_result, index=[0])
print(test_result.T)

  from pandas.core import (


Dataset structured_Abt-Buy does not exist
Dataset textual_DBLP-ACM does not exist
Dataset textual_Amazon-Google does not exist
Dataset textual_Walmart-Amazon does not exist
Dataset textual_DBLP-GoogleScholar does not exist
Dataset textual_Fodors-Zagats does not exist
Dataset textual_Beer does not exist
Dataset textual_iTunes-Amazon does not exist
Dataset dirty_Abt-Buy does not exist
Dataset dirty_Amazon-Google does not exist
Dataset dirty_Fodors-Zagats does not exist
Dataset dirty_Beer does not exist


You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'enc

  0%|          | 0/21 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.data.items()}


{'train_runtime': 53.2719, 'train_samples_per_second': 5.631, 'train_steps_per_second': 0.394, 'train_loss': 0.7007445380801246, 'epoch': 3.0}


  0%|          | 0/124 [00:00<?, ?it/s]

                                  0
eval_loss                  0.692539
eval_accuracy              0.499619
eval_f1                    0.002024
eval_precision             0.500000
eval_recall                0.001014
eval_runtime             601.912900
eval_samples_per_second   13.095000
eval_steps_per_second      0.206000
epoch                      3.000000
