In [None]:
import pandas as pd
import numpy as np
import re
from cleantext import clean
import torch
from transformers import TrainingArguments, Trainer, pipeline
from transformers import XLNetTokenizer, XLNetForSequenceClassification
import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import random
import evaluate

test_data=pd.read_csv("./emotions_data/emotion-labels-test.csv")
train_data=pd.read_csv("./emotions_data/emotion-labels-train.csv")
val_data=pd.read_csv("./emotions_data/emotion-labels-val.csv")

data=pd.concat([test_data, train_data, val_data])


data.head()

data["clean-text"]=data["text"].apply(lambda x: clean(x, no_emoji=True))

data["clean-text"]=data["clean-text"].apply(lambda x: re.sub('[^\w\s]',' ',x))

data["label"].value_counts().plot(kind="bar")

min_label=data.groupby("label").size().min()
data=data.groupby("label").apply(lambda x: x.sample(min_label)).reset_index(drop=True)

data["label"].value_counts().plot(kind="bar")

encode=LabelEncoder()

data["label-int"]=encode.fit_transform(data["label"])

data.head()

train_data,test_data=train_test_split(data, test_size=0.2, random_state=42)
train_data, val_data=train_test_split(train_data, test_size=0.1, random_state=42)


print(len(train_data)) 
print(len(test_data))
print(len(val_data))

train_df=pd.DataFrame({"text" : train_data["clean-text"], "label" : train_data["label-int"]})
test_df=pd.DataFrame({"text" : test_data["clean-text"], "label" : test_data["label-int"]})
val_df=pd.DataFrame({"text" : val_data["clean-text"], "label" : val_data["label-int"]})

from datasets import Dataset, DatasetDict

train_dataset=Dataset.from_dict(train_df)
test_dataset=Dataset.from_dict(test_df)
val_dataset=Dataset.from_dict(val_df)

final_dict=DatasetDict({'train': train_dataset,'test' : test_dataset, 'validation': val_dataset})
final_dict

tokenizer=XLNetTokenizer.from_pretrained("xlnet-base-cased")


def tokenization_function(example):
    return tokenizer(
        example["text"], padding='max_length',truncation=True, max_length=128)
    

tokenizer_set=final_dict.map(tokenization_function, batched=True)

sample_text=tokenizer_set["train"][0]["text"]
sample_text

print(tokenizer_set["train"][0]["input_ids"])
print(tokenizer_set["train"][0]["token_type_ids"])
print(tokenizer_set["train"][0]["attention_mask"])

small_train_dataset=tokenizer_set["train"].shuffle(seed=42).select(range(100))
small_eval_dataset=tokenizer_set["test"].shuffle(seed=42).select(range(100))

print(small_train_dataset)
print(small_eval_dataset)

from transformers import XLNetForSequenceClassification
model=XLNetForSequenceClassification.from_pretrained("xlnet-base-cased",
                                                        num_labels=4,
                                                        id2label={0:"fear",1:"anger",
                                                                  2:"joy",3:"sadness"},
                                                    label2id={"fear": 0, "anger": 1, "joy": 2, "sadness": 3})


import evaluate
metric=evaluate.load("accuracy")

import numpy as np
def compute_pred(eval_pred):
    logits,labels=eval_pred
    predictions=np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

from transformers import TrainingArguments
trainee=TrainingArguments(output_dir="./fresh_xlnet_base_model",
                          evaluation_strategy="epoch",
                          num_train_epochs=3)

from transformers import Trainer
trainer=Trainer(
    model=model,
    args=trainee,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_pred)

trainer.train()

trainer.evaluate()

model.save_pretrained("my_finetuned_model", safe_serialization=False)


fine_tuned_model=XLNetForSequenceClassification.from_pretrained("my_finetuned_model")


from transformers import pipeline
clf=pipeline(task="text-classification",
    model=fine_tuned_model,
    tokenizer=tokenizer    ) 

clf("This is a random test sentence", top_k=None)


