In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import XLNetTokenizer, XLNetForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.utils.class_weight import compute_class_weight




In [2]:
# Clear CUDA cache
torch.cuda.empty_cache()

# Run garbage collection
import gc
gc.collect()

# Verify the GPU is cleared
torch.cuda.empty_cache()

In [3]:
# Load the dataset
df = pd.read_csv("combined_output.csv", on_bad_lines='skip')

In [4]:
df.head()

Unnamed: 0,original_id,source_name,title,url,body,date_published,language,date_modified,author_list,images,description,sentiment,emotions,entities,quotations,prValues,clipping,label,category
0,'5404c9d2fd24852afa122f2cc01cb3acba3c5d05b682d...,'kabargayo.com,'Kredit sepeda motor bisa terbayar jika Anda m...,'https://www.kabargayo.com/2024/09/19/kredit-s...,"'Jakarta, VIVA – Pembayaran kredit sepeda moto...",19/09/2024 22.32,'id,19/09/2024 22.32,'Aldi Hadad,'https://i1.wp.com/thumb.viva.co.id/media/fron...,,'positive,,"'Hari Pembayaran Berbayar atau Harcilnas 2024,...",,5250000,Adira,,
1,'e1e3f8d68b58568e8217b7562d48de634fceb0d837135...,'viva.co.id,'Kredit Motor Bisa Lunas Jika Bayar Cicilan Te...,'https://www.viva.co.id/otomotif/tips/1753596-...,"'Jakarta, VIVA – Cicilan kredit motor yang ser...",19/09/2024 22.30,'id,19/09/2024 22.30,'Krisna Wicaksono,'https://thumb.viva.co.id/media/frontend/thumb...,,'positive,,"'Harinya Cicilan Lunas,2024,PT Adira Dinamika ...",,5250000,Adira,,
2,'dca74b8fa4eabf60cebfa7b811ecb385872a0fd301eaf...,'kabarmegapolitan.pikiran-rakyat.com,'Adira Finance Umumkan Pemenang HARCILNAS 2024...,'https://kabarmegapolitan.pikiran-rakyat.com/b...,'KABARMEGAPOLITAN.com - PT Adira Dinamika Mult...,19/09/2024 21.45,'id,19/09/2024 21.45,'Yuliansyah,'https://assets.pikiran-rakyat.com/www/network...,,'positive,'HARCILNAS merupakan wujud apresiasi kami kepa...,"'PT Adira Dinamika,Cicilan Lunas HARCILNAS,12 ...","(Person :Tania Endah Budhi ,Quote : HARCILNAS ...",5250000,Adira,,
3,'56c73e2a6d254a17a5cc21dee7ed0b4660c3af70c093b...,'banggairaya.id,"'Dapatkan Promo Menarik, Yamaha Prima Motor Ra...",'https://banggairaya.id/dapatkan-promo-menarik...,'BANGGAI RAYA- Yamaha Prima Motor ramaikan pam...,19/09/2024 19.45,'id,19/09/2024 19.45,'Chikal Connect,'https://i0.wp.com/banggairaya.id/wp-content/u...,,'neutral,,"'RAYA- Yamaha Prima Motor,Banggai Goverment Ex...",,5250000,Adira,,
4,'1cd6c6db60224b6ee5f49cd5d6c62cd2850d9f3255721...,'jakarta.tribunnews.com,"'Sindikat Penipuan Leasing, Satu Bulan Ajukan ...",'https://jakarta.tribunnews.com/2024/09/19/sin...,'Laporan wartawan TribunJakarta.com Yusuf Bach...,19/09/2024 18.43,'id,19/09/2024 18.43,'Yusuf Bachtiar,'https://asset-2.tstatic.net/jakarta/foto/bank...,,'neutral,'Pelaku ini melakukan pembiayaan pembelian ken...,"'Yusuf Bachtiar TRIBUNJAKARTACOM,MEDAN,SATRIA,...","(Person :Dedi ,Quote : Pelaku ini melakukan pe...",5250000,Adira,,


In [5]:
# Clean extra characters from all columns in the DataFrame
for column in df.columns:
    if df[column].dtype == 'object':  # Check if the column is of string type
        df[column] = df[column].str.strip("'")  # Remove extra characters

# Filter relevant columns
df_filtered = df[['body', 'sentiment']].dropna()


In [6]:
df.head()

Unnamed: 0,original_id,source_name,title,url,body,date_published,language,date_modified,author_list,images,description,sentiment,emotions,entities,quotations,prValues,clipping,label,category
0,5404c9d2fd24852afa122f2cc01cb3acba3c5d05b682d4...,kabargayo.com,Kredit sepeda motor bisa terbayar jika Anda me...,https://www.kabargayo.com/2024/09/19/kredit-se...,"Jakarta, VIVA – Pembayaran kredit sepeda motor...",19/09/2024 22.32,id,19/09/2024 22.32,Aldi Hadad,https://i1.wp.com/thumb.viva.co.id/media/front...,,positive,,"Hari Pembayaran Berbayar atau Harcilnas 2024,P...",,5250000,Adira,,
1,e1e3f8d68b58568e8217b7562d48de634fceb0d8371356...,viva.co.id,Kredit Motor Bisa Lunas Jika Bayar Cicilan Tep...,https://www.viva.co.id/otomotif/tips/1753596-k...,"Jakarta, VIVA – Cicilan kredit motor yang seri...",19/09/2024 22.30,id,19/09/2024 22.30,Krisna Wicaksono,https://thumb.viva.co.id/media/frontend/thumbs...,,positive,,"Harinya Cicilan Lunas,2024,PT Adira Dinamika M...",,5250000,Adira,,
2,dca74b8fa4eabf60cebfa7b811ecb385872a0fd301eaf5...,kabarmegapolitan.pikiran-rakyat.com,Adira Finance Umumkan Pemenang HARCILNAS 2024:...,https://kabarmegapolitan.pikiran-rakyat.com/bi...,KABARMEGAPOLITAN.com - PT Adira Dinamika Multi...,19/09/2024 21.45,id,19/09/2024 21.45,Yuliansyah,https://assets.pikiran-rakyat.com/www/network/...,,positive,HARCILNAS merupakan wujud apresiasi kami kepad...,"PT Adira Dinamika,Cicilan Lunas HARCILNAS,12 p...","(Person :Tania Endah Budhi ,Quote : HARCILNAS ...",5250000,Adira,,
3,56c73e2a6d254a17a5cc21dee7ed0b4660c3af70c093b2...,banggairaya.id,"Dapatkan Promo Menarik, Yamaha Prima Motor Ram...",https://banggairaya.id/dapatkan-promo-menarik-...,BANGGAI RAYA- Yamaha Prima Motor ramaikan pame...,19/09/2024 19.45,id,19/09/2024 19.45,Chikal Connect,https://i0.wp.com/banggairaya.id/wp-content/up...,,neutral,,"RAYA- Yamaha Prima Motor,Banggai Goverment Exp...",,5250000,Adira,,
4,1cd6c6db60224b6ee5f49cd5d6c62cd2850d9f3255721d...,jakarta.tribunnews.com,"Sindikat Penipuan Leasing, Satu Bulan Ajukan K...",https://jakarta.tribunnews.com/2024/09/19/sind...,Laporan wartawan TribunJakarta.com Yusuf Bacht...,19/09/2024 18.43,id,19/09/2024 18.43,Yusuf Bachtiar,https://asset-2.tstatic.net/jakarta/foto/bank/...,,neutral,Pelaku ini melakukan pembiayaan pembelian kend...,"Yusuf Bachtiar TRIBUNJAKARTACOM,MEDAN,SATRIA,T...","(Person :Dedi ,Quote : Pelaku ini melakukan pe...",5250000,Adira,,


In [7]:
# Balance dataset classes
class_counts = df_filtered['sentiment'].value_counts()
min_class = class_counts.min()
df_balanced = df_filtered.groupby('sentiment').apply(lambda x: x.sample(min_class)).reset_index(drop=True)

In [8]:
# Ensure balanced classes during split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df_filtered['body'].tolist(), df_filtered['sentiment'].tolist(),
    test_size=0.2, random_state=42, stratify=df_filtered['sentiment']
)

In [9]:

# Load tokenizer and tokenize texts
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=512)
val_encodings = tokenizer(val_texts, truncation=True, padding=True, max_length=512)


In [10]:
# Map sentiment labels to integers
label_mapping = {'positive': 2, 'neutral': 1, 'negative': 0}
train_labels = [label_mapping[label] for label in train_labels]
val_labels = [label_mapping[label] for label in val_labels]

In [11]:
class SentimentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


In [12]:
# Create datasets
train_dataset = SentimentDataset(train_encodings, train_labels)
val_dataset = SentimentDataset(val_encodings, val_labels)

In [13]:
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [14]:
# Load XLNet model for sequence classification
model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=3)

Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.bias', 'sequence_summary.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Adjust model's loss function to include class weights
model.config.problem_type = "single_label_classification"
model.config.class_weights = class_weights.tolist()

In [16]:
from sklearn.metrics import accuracy_score
import numpy as np

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    return {'eval_accuracy': acc}


In [17]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",  # Evaluate more frequently
    eval_steps=200,              # Adjust based on dataset size
    save_strategy="steps",
    save_steps=200,
    per_device_train_batch_size=8,  # Lower batch size for better generalization
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,  # Simulates a larger batch size
    num_train_epochs=6,            # Train for more epochs
    learning_rate=1e-5,            # Lower learning rate for better fine-tuning
    weight_decay=0.01,
    warmup_steps=500,              # Gradual learning rate increase
    logging_dir="./logs",
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    greater_is_better=True,
    lr_scheduler_type="cosine_with_restarts",  # More effective for fine-tuning
    fp16=True,                   # Enable mixed precision training
)




In [18]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,  # Include custom metrics
)


In [19]:
import torch

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

In [20]:
# Train the model
trainer.train()

  0%|          | 0/2700 [00:00<?, ?it/s]

{'loss': 1.1042, 'grad_norm': 23.382308959960938, 'learning_rate': 9.200000000000001e-07, 'epoch': 0.11}
{'loss': 1.0486, 'grad_norm': 32.777462005615234, 'learning_rate': 1.9000000000000002e-06, 'epoch': 0.22}
{'loss': 1.0547, 'grad_norm': 19.98434829711914, 'learning_rate': 2.9e-06, 'epoch': 0.33}
{'loss': 1.0455, 'grad_norm': 20.777698516845703, 'learning_rate': 3.900000000000001e-06, 'epoch': 0.44}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.5008328706274292, 'eval_loss': 1.0246728658676147, 'eval_runtime': 32.9545, 'eval_samples_per_second': 54.651, 'eval_steps_per_second': 3.429, 'epoch': 0.44}
{'loss': 1.0223, 'grad_norm': 15.300000190734863, 'learning_rate': 4.9000000000000005e-06, 'epoch': 0.56}
{'loss': 1.0248, 'grad_norm': 18.869728088378906, 'learning_rate': 5.9e-06, 'epoch': 0.67}
{'loss': 1.0459, 'grad_norm': 19.955629348754883, 'learning_rate': 6.9e-06, 'epoch': 0.78}
{'loss': 1.0, 'grad_norm': 41.35318374633789, 'learning_rate': 7.9e-06, 'epoch': 0.89}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.47695724597445865, 'eval_loss': 0.9954295754432678, 'eval_runtime': 32.501, 'eval_samples_per_second': 55.414, 'eval_steps_per_second': 3.477, 'epoch': 0.89}
{'loss': 0.9753, 'grad_norm': 40.34638214111328, 'learning_rate': 8.900000000000001e-06, 'epoch': 1.0}
{'loss': 0.9325, 'grad_norm': 25.36160659790039, 'learning_rate': 9.9e-06, 'epoch': 1.11}
{'loss': 0.9153, 'grad_norm': 30.87482261657715, 'learning_rate': 9.989680231156983e-06, 'epoch': 1.22}
{'loss': 0.8616, 'grad_norm': 20.45682144165039, 'learning_rate': 9.954061643455524e-06, 'epoch': 1.33}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.6141032759578012, 'eval_loss': 0.8205902576446533, 'eval_runtime': 32.2781, 'eval_samples_per_second': 55.796, 'eval_steps_per_second': 3.501, 'epoch': 1.33}
{'loss': 0.7681, 'grad_norm': 15.865717887878418, 'learning_rate': 9.893198293191539e-06, 'epoch': 1.44}
{'loss': 0.799, 'grad_norm': 26.567745208740234, 'learning_rate': 9.807400326046843e-06, 'epoch': 1.56}
{'loss': 0.7504, 'grad_norm': 21.043394088745117, 'learning_rate': 9.697104948795353e-06, 'epoch': 1.67}
{'loss': 0.6939, 'grad_norm': 36.569339752197266, 'learning_rate': 9.56287420139758e-06, 'epoch': 1.78}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.6873958911715713, 'eval_loss': 0.7148331999778748, 'eval_runtime': 32.3717, 'eval_samples_per_second': 55.635, 'eval_steps_per_second': 3.491, 'epoch': 1.78}
{'loss': 0.7048, 'grad_norm': 12.172266960144043, 'learning_rate': 9.405392092973824e-06, 'epoch': 1.89}
{'loss': 0.6875, 'grad_norm': 17.792810440063477, 'learning_rate': 9.225461116250483e-06, 'epoch': 2.0}
{'loss': 0.5852, 'grad_norm': 59.38424301147461, 'learning_rate': 9.023998158241067e-06, 'epoch': 2.11}
{'loss': 0.5923, 'grad_norm': 21.02001953125, 'learning_rate': 8.802029828000157e-06, 'epoch': 2.22}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7523598001110494, 'eval_loss': 0.6578989028930664, 'eval_runtime': 32.343, 'eval_samples_per_second': 55.684, 'eval_steps_per_second': 3.494, 'epoch': 2.22}
{'loss': 0.5767, 'grad_norm': 10.104955673217773, 'learning_rate': 8.56068722525896e-06, 'epoch': 2.33}
{'loss': 0.5785, 'grad_norm': 15.702975273132324, 'learning_rate': 8.301200176600375e-06, 'epoch': 2.44}
{'loss': 0.586, 'grad_norm': 29.188926696777344, 'learning_rate': 8.024890968544614e-06, 'epoch': 2.56}
{'loss': 0.6245, 'grad_norm': 24.02370262145996, 'learning_rate': 7.73316760948019e-06, 'epoch': 2.67}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7640199888950583, 'eval_loss': 0.5926749110221863, 'eval_runtime': 32.2066, 'eval_samples_per_second': 55.92, 'eval_steps_per_second': 3.509, 'epoch': 2.67}
{'loss': 0.579, 'grad_norm': 16.111665725708008, 'learning_rate': 7.427516654775921e-06, 'epoch': 2.78}
{'loss': 0.5794, 'grad_norm': 21.263164520263672, 'learning_rate': 7.109495631635512e-06, 'epoch': 2.89}
{'loss': 0.6066, 'grad_norm': 17.195037841796875, 'learning_rate': 6.780725102295949e-06, 'epoch': 3.0}
{'loss': 0.4771, 'grad_norm': 24.058551788330078, 'learning_rate': 6.442880406013795e-06, 'epoch': 3.11}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7684619655746807, 'eval_loss': 0.5953294038772583, 'eval_runtime': 32.3215, 'eval_samples_per_second': 55.721, 'eval_steps_per_second': 3.496, 'epoch': 3.11}
{'loss': 0.4997, 'grad_norm': 16.495521545410156, 'learning_rate': 6.097683121920373e-06, 'epoch': 3.22}
{'loss': 0.4951, 'grad_norm': 23.68714141845703, 'learning_rate': 5.746892296249126e-06, 'epoch': 3.33}
{'loss': 0.5417, 'grad_norm': 25.39212989807129, 'learning_rate': 5.392295478639226e-06, 'epoch': 3.44}
{'loss': 0.4782, 'grad_norm': 19.695693969726562, 'learning_rate': 5.035699613192348e-06, 'epoch': 3.56}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7773459189339256, 'eval_loss': 0.6148065328598022, 'eval_runtime': 32.1335, 'eval_samples_per_second': 56.047, 'eval_steps_per_second': 3.517, 'epoch': 3.56}
{'loss': 0.4779, 'grad_norm': 20.864213943481445, 'learning_rate': 4.678921830699724e-06, 'epoch': 3.67}
{'loss': 0.4783, 'grad_norm': 11.84264850616455, 'learning_rate': 4.323780188960156e-06, 'epoch': 3.78}
{'loss': 0.5104, 'grad_norm': 22.134048461914062, 'learning_rate': 3.972084408374198e-06, 'epoch': 3.89}
{'loss': 0.4491, 'grad_norm': 19.352046966552734, 'learning_rate': 3.6256266500238312e-06, 'epoch': 4.0}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7656857301499167, 'eval_loss': 0.6262772083282471, 'eval_runtime': 32.2565, 'eval_samples_per_second': 55.834, 'eval_steps_per_second': 3.503, 'epoch': 4.0}
{'loss': 0.4101, 'grad_norm': 30.803600311279297, 'learning_rate': 3.286172383230388e-06, 'epoch': 4.11}
{'loss': 0.4047, 'grad_norm': 21.316173553466797, 'learning_rate': 2.955451389127567e-06, 'epoch': 4.22}
{'loss': 0.3955, 'grad_norm': 22.751741409301758, 'learning_rate': 2.6351489460932815e-06, 'epoch': 4.33}
{'loss': 0.4079, 'grad_norm': 18.1988525390625, 'learning_rate': 2.326897241957348e-06, 'epoch': 4.44}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.77623542476402, 'eval_loss': 0.6572760939598083, 'eval_runtime': 32.2918, 'eval_samples_per_second': 55.773, 'eval_steps_per_second': 3.499, 'epoch': 4.44}
{'loss': 0.4461, 'grad_norm': 23.991914749145508, 'learning_rate': 2.0322670567464304e-06, 'epoch': 4.56}
{'loss': 0.3862, 'grad_norm': 12.776251792907715, 'learning_rate': 1.7527597583490825e-06, 'epoch': 4.67}
{'loss': 0.4158, 'grad_norm': 27.843671798706055, 'learning_rate': 1.4897996518891328e-06, 'epoch': 4.78}
{'loss': 0.3869, 'grad_norm': 14.365665435791016, 'learning_rate': 1.2447267217932508e-06, 'epoch': 4.89}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7695724597445863, 'eval_loss': 0.6653987169265747, 'eval_runtime': 32.3748, 'eval_samples_per_second': 55.63, 'eval_steps_per_second': 3.49, 'epoch': 4.89}
{'loss': 0.3835, 'grad_norm': 19.2109375, 'learning_rate': 1.0187898035374683e-06, 'epoch': 5.0}
{'loss': 0.413, 'grad_norm': 11.97502613067627, 'learning_rate': 8.131402198678373e-07, 'epoch': 5.11}
{'loss': 0.3191, 'grad_norm': 33.33523941040039, 'learning_rate': 6.28825913923638e-07, 'epoch': 5.22}
{'loss': 0.3366, 'grad_norm': 25.26031494140625, 'learning_rate': 4.667861091593434e-07, 'epoch': 5.33}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7729039422543031, 'eval_loss': 0.6425918340682983, 'eval_runtime': 32.3772, 'eval_samples_per_second': 55.626, 'eval_steps_per_second': 3.49, 'epoch': 5.33}
{'loss': 0.3254, 'grad_norm': 16.377967834472656, 'learning_rate': 3.2784652327718544e-07, 'epoch': 5.44}
{'loss': 0.3603, 'grad_norm': 10.130207061767578, 'learning_rate': 2.1271516055901774e-07, 'epoch': 5.56}
{'loss': 0.335, 'grad_norm': 19.347028732299805, 'learning_rate': 1.2197870403878375e-07, 'epoch': 5.67}
{'loss': 0.3595, 'grad_norm': 29.1643009185791, 'learning_rate': 5.6099525900252805e-08, 'epoch': 5.78}


  0%|          | 0/113 [00:00<?, ?it/s]

{'eval_accuracy': 0.7712382009994447, 'eval_loss': 0.6537501811981201, 'eval_runtime': 32.3606, 'eval_samples_per_second': 55.654, 'eval_steps_per_second': 3.492, 'epoch': 5.78}
{'loss': 0.3643, 'grad_norm': 10.634465217590332, 'learning_rate': 1.541333133436018e-08, 'epoch': 5.89}
{'loss': 0.3826, 'grad_norm': 15.367070198059082, 'learning_rate': 1.2744786250407092e-10, 'epoch': 6.0}
{'train_runtime': 2740.192, 'train_samples_per_second': 15.765, 'train_steps_per_second': 0.985, 'train_loss': 0.6107909446292453, 'epoch': 6.0}


TrainOutput(global_step=2700, training_loss=0.6107909446292453, metrics={'train_runtime': 2740.192, 'train_samples_per_second': 15.765, 'train_steps_per_second': 0.985, 'total_flos': 1.2306930130944e+16, 'train_loss': 0.6107909446292453, 'epoch': 6.0})

In [21]:
print(f"Number of training examples: {len(train_dataset)}")
print(f"Batch size: {training_args.per_device_train_batch_size}")


Number of training examples: 7200
Batch size: 8


In [22]:
# Save the trained model and tokenizer
model.save_pretrained('./xlnet_model3')
tokenizer.save_pretrained('./xlnet_model3')

('./xlnet_model3\\tokenizer_config.json',
 './xlnet_model3\\special_tokens_map.json',
 './xlnet_model3\\spiece.model',
 './xlnet_model3\\added_tokens.json')

In [23]:
# Evaluate the model
results = trainer.evaluate()

  0%|          | 0/113 [00:00<?, ?it/s]

In [24]:
print(results)

{'eval_accuracy': 0.7773459189339256, 'eval_loss': 0.6148065328598022, 'eval_runtime': 32.4119, 'eval_samples_per_second': 55.566, 'eval_steps_per_second': 3.486, 'epoch': 6.0}


In [25]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report

# After training, get predictions on the validation dataset
predictions = trainer.predict(val_dataset)

# Get the predicted logits
logits = predictions.predictions

# Convert logits to predicted class labels
predicted_labels = np.argmax(logits, axis=1)

# Get the true labels
true_labels = val_labels

# Calculate evaluation metrics
accuracy = accuracy_score(true_labels, predicted_labels)
f1 = f1_score(true_labels, predicted_labels, average='weighted')
precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')

# Display metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

# Detailed classification report
print("\nClassification Report:")
print(classification_report(true_labels, predicted_labels))


  0%|          | 0/113 [00:00<?, ?it/s]

Accuracy: 0.7773
F1 Score: 0.7757
Precision: 0.7781
Recall: 0.7773

Classification Report:
              precision    recall  f1-score   support

           0       0.79      0.63      0.70       338
           1       0.79      0.81      0.80       836
           2       0.76      0.81      0.78       627

    accuracy                           0.78      1801
   macro avg       0.78      0.75      0.76      1801
weighted avg       0.78      0.78      0.78      1801



In [26]:
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda
