In [8]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [9]:
from config.config import *

In [10]:
texts_absa = pd.read_pickle(os.path.join(OBTAINED_DATA, 'texts_absa.pkl'))

In [11]:
texts_absa.head()

Unnamed: 0,text,targets,categories,polarities
0,The solution given here is fantastic and a bit...,[book],[BOOK#GENERAL],[positive]
1,Ahh [SEP] I wish I had never gotten this stupi...,"[Arthur, Mister Monday, Arthur, Mister Monday,...","[CONTENT#CHARACTERS, BOOK#TITLE, CONTENT#CHARA...","[neutral, neutral, neutral, negative, neutral,..."
2,I loved this book so much [SEP] I couldn't sto...,"[book, twins]","[BOOK#GENERAL, CONTENT#PLOT]","[positive, positive]"
3,"This book is very informative, describing in d...","[book, beadwork, loomwork]","[BOOK#GENERAL, CONTENT#PLOT, CONTENT#PLOT]","[positive, neutral, neutral]"
4,I recommend that anyone looking to have some m...,"[the Budapest, $5, The Budapest, Boryk, book, ...","[BOOK#TITLE, BOOK#PRICE, BOOK#TITLE, BOOK#AUTH...","[neutral, positive, neutral, neutral, positive..."


In [12]:
texts_absa.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 295 entries, 0 to 294
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   text        295 non-null    object
 1   targets     295 non-null    object
 2   categories  295 non-null    object
 3   polarities  295 non-null    object
dtypes: object(4)
memory usage: 9.3+ KB


In [13]:
texts_absa[['text', 'targets', 'polarities']]

Unnamed: 0,text,targets,polarities
0,The solution given here is fantastic and a bit...,[book],[positive]
1,Ahh [SEP] I wish I had never gotten this stupi...,"[Arthur, Mister Monday, Arthur, Mister Monday,...","[neutral, neutral, neutral, negative, neutral,..."
2,I loved this book so much [SEP] I couldn't sto...,"[book, twins]","[positive, positive]"
3,"This book is very informative, describing in d...","[book, beadwork, loomwork]","[positive, neutral, neutral]"
4,I recommend that anyone looking to have some m...,"[the Budapest, $5, The Budapest, Boryk, book, ...","[neutral, positive, neutral, neutral, positive..."
...,...,...,...
290,"I have read and used this book many, many time...","[book, book, children, Native American beadwor...","[positive, positive, positive, neutral]"
291,"Full of atmosphere, world war II feel to it [S...","[world war II, Churchill, novel, historical te...","[positive, positive, neutral, negative]"
292,I didn't like it AS much as some of the other ...,"[Marshall, read, protagonist, ""happy"" person, ...","[neutral, positive, negative, negative, neutra..."
293,Can Arthur save the world with the key as a yo...,"[Arthur, Arthur, key, asthma problems, asthma ...","[neutral, neutral, neutral, neutral, neutral, ..."


Like this book because it is just like my life in a lot of was [SEP] For one I'm a teenage and have a baby [SEP] This book teach me in a lot of way [SEP] Out of all the books I read this is the best [SEP] My darling,My hamburger is a good book if it didn't live you hang at the end what hope to whom and did they  go to collage or not

In [7]:
train_df, temp_df = train_test_split(absa_annotated, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)


In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import XLMRobertaTokenizer, XLMRobertaForTokenClassification, Trainer, TrainingArguments

class ABSA_Dataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, index):
        text = str(self.data.text[index])
        targets = self.data.targets[index]
        polarities = self.data.polarities[index]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'targets': torch.tensor(targets, dtype=torch.long),
            'polarities': torch.tensor(polarities, dtype=torch.long)
        }

    def __len__(self):
        return self.len

tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
model = XLMRobertaForTokenClassification.from_pretrained('xlm-roberta-large', num_labels=3).to('cuda')

train_dataset = ABSA_Dataset(train_df, tokenizer, max_len=512)
val_dataset = ABSA_Dataset(val_df, tokenizer, max_len=512)
test_dataset = ABSA_Dataset(test_df, tokenizer, max_len=512)

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Error while downloading from https://huggingface.co/xlm-roberta-large/resolve/main/tokenizer.json: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out.
Trying to resume download...


tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.huggingface.co/xlm-roberta-large/2dfa19f172412917cab174da04b46e2134811b723666965fd0aabd97caa6e23b?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1721650134&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTY1MDEzNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby94bG0tcm9iZXJ0YS1sYXJnZS8yZGZhMTlmMTcyNDEyOTE3Y2FiMTc0ZGEwNGI0NmUyMTM0ODExYjcyMzY2Njk2NWZkMGFhYmQ5N2NhYTZlMjNiP3Jlc3BvbnNlLWNvbnRlbnQtZGlzcG9zaXRpb249KiJ9XX0_&Signature=RA4fKFdJu974SS45BWtEJWHlnmtEjSa%7EHkjXrJTSYcDVAoyMKoZ1HJeWkIG9B5YVh%7EfF9UyMVNtyt01-rhG1Ki-A9jj2RvKD8tQ5FdbRIPS2sTBuUFAx2L5sf7AgYMv%7EPdPazbkwFBBCfxKxYjZGq6gdMZ8Y-HzIEsEcUULYp2hL-1JQ6fqKOnIS%7ECAxBGl2Qv7RLqdXeH5YfsPkBboo7JiPEXZdAQlfNXQe7ZlZ-4OD6Vhunv3r3Ztnf7xTMOPwa%7EEHDS89HShuFlZ1DkRt0uZA3hF8M9SU-c2c1YeOI13AG6U1xKHrbz5eK7oXTsbbpqQn-fRVO6zKWY0kH3OsDA__&Key-Pair-Id=K3ESJI6DHPFC7: HTTPS

ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'cdn-lfs.huggingface.co\', port=443): Max retries exceeded with url: /xlm-roberta-large/2dfa19f172412917cab174da04b46e2134811b723666965fd0aabd97caa6e23b?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1721650134&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTY1MDEzNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby94bG0tcm9iZXJ0YS1sYXJnZS8yZGZhMTlmMTcyNDEyOTE3Y2FiMTc0ZGEwNGI0NmUyMTM0ODExYjcyMzY2Njk2NWZkMGFhYmQ5N2NhYTZlMjNiP3Jlc3BvbnNlLWNvbnRlbnQtZGlzcG9zaXRpb249KiJ9XX0_&Signature=RA4fKFdJu974SS45BWtEJWHlnmtEjSa~HkjXrJTSYcDVAoyMKoZ1HJeWkIG9B5YVh~fF9UyMVNtyt01-rhG1Ki-A9jj2RvKD8tQ5FdbRIPS2sTBuUFAx2L5sf7AgYMv~PdPazbkwFBBCfxKxYjZGq6gdMZ8Y-HzIEsEcUULYp2hL-1JQ6fqKOnIS~CAxBGl2Qv7RLqdXeH5YfsPkBboo7JiPEXZdAQlfNXQe7ZlZ-4OD6Vhunv3r3Ztnf7xTMOPwa~EHDS89HShuFlZ1DkRt0uZA3hF8M9SU-c2c1YeOI13AG6U1xKHrbz5eK7oXTsbbpqQn-fRVO6zKWY0kH3OsDA__&Key-Pair-Id=K3ESJI6DHPFC7 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x296d32190>: Failed to resolve \'cdn-lfs.huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 5fbb24f1-ddc4-47d4-a8f7-a9c2e623c4f3)')