In [None]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    BitsAndBytesConfig,
    TFAutoModel
)
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Connect to GPU (make sure you've selected GPU in Colab)
assert torch.cuda.is_available(), "Change runtime type to GPU!"
device = torch.device("cuda")
print(f"Using device: {device}")

In [None]:
model = TFAutoModel.from_pretrained("bert-base-uncased") #BERT Base model with 110M Parameter and 12 encoder layers

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
inputs = tokenizer('Hello world')
inputs

In [None]:
# Load the dataset
dataset = load_dataset("Jinyan1/COLING_2025_MGT_en")
print(dataset)

In [None]:
dataset

In [None]:
def tokenize(batch):
    return tokenizer(
        batch['text'],
        padding='max_length',
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )

In [None]:
# Process with maximum efficiency
dataset_encoded = dataset.map(
    tokenize,
    batched=True,
    batch_size=1000,
    num_proc=4,
    remove_columns=['text']
)

In [None]:
dataset_encoded

In [None]:
train_tf_dataset = dataset_encoded['train'].to_tf_dataset(
    columns=['input_ids', 'attention_mask', 'token_type_ids'],
    label_cols=['label'],
    shuffle=True,
    batch_size=32,
)

val_tf_dataset = dataset_encoded['dev'].to_tf_dataset(
    columns=['input_ids', 'attention_mask', 'token_type_ids'],
    label_cols=['label'],
    shuffle=False,
    batch_size=32,
)

In [None]:
import tensorflow as tf
class BERTForClassification(tf.keras.Model):  # Fixed class name typo
    def __init__(self, bert_model, num_classes):  # Fixed __init__ syntax
        super().__init__()
        self.bert = bert_model
        self.fc = tf.keras.layers.Dense(num_classes, activation='softmax')  # Fixed to layers (not layer)
    def call(self, inputs):
        # Assuming inputs is a dict containing input_ids, attention_mask, and token_type_ids
        x = self.bert(inputs)[1]  # This might need adjustment based on the actual output structure
        return self.fc(x)

In [None]:
# Create and compile the classifier
classifier = BERTForClassification(model, num_classes=2)
classifier.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                   metrics=['accuracy'])

In [None]:
# Train the model
history = classifier.fit(
    train_tf_dataset,
    validation_data=val_tf_dataset,
    epochs=3,  # You may want to adjust this
    callbacks=[
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3),
        tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
    ]
)

In [None]:
plt.figure(figsize=(12, 4))

In [None]:
history = classifier.fit(
    train_tf_dataset,
    validation_data=val_tf_dataset,
    epochs=5,  # You may want to adjust this
    callbacks=[
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3),
        tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
    ]
)

# Plot training history
plt.figure(figsize=(12, 4))

# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show()

In [None]:
# # Evaluate on validation set
# results = classifier.evaluate(val_tf_dataset)
# print(f"Validation Loss: {results[0]:.4f}")
# print(f"Validation Accuracy: {results[1]:.4f}")