In [1]:
import torch
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import numpy as np
from sklearn.metrics import classification_report
from transformers import TOKENIZER_MAPPING, AutoModelForSequenceClassification, AutoTokenizer, AdamW, get_linear_schedule_with_warmup, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
import os
from dataset import Dataset
from util import create_output

In [2]:
TOKENIZER_NAME = "sentence-transformers/paraphrase-xlm-r-multilingual-v1"
MODEL_NAME = "sentence-transformers/paraphrase-xlm-r-multilingual-v1"
LEARNING_RATE = 3e-5

#OUTPUT_FILE = "NODUP-paraphrase-roberta-kan-pickle.md"

EPOCHS = 4
BATCH_SIZE = 24 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())    
    print(f'We will use the GPU:{torch.cuda.get_device_name()} ({device})')

else:
    print('NO GPU AVAILABLE ERROR')
    device = torch.device("cpu")
   
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForSequenceClassification.from_pretrained("../pickles_mixed/", num_labels=5, output_attentions=True)
model.to(device)
optimizer = AdamW(model.parameters(), lr = LEARNING_RATE, no_deprecation_warning=True)

data = Dataset()
tam_train_2022,_, _, _, _,_ = data.get_fire_2022_dataset(tokenizer, balance=False)

train_dataloader = DataLoader(
            tam_train_2022,
            sampler = RandomSampler(tam_train_2022),
            batch_size = BATCH_SIZE)

total_steps = len(train_dataloader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)


There are 1 GPU(s) available.
We will use the GPU:Tesla V100-SXM2-32GB (cuda)
Texts: 35575
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-Tamil', 'unknown_state'], dtype='object')
Texts: 3962
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-Tamil', 'unknown_state'], dtype='object')
Texts: 5951
Label names: Index(['Mixed feelings', 'Negative', 'Positive', 'not-Kannada',
       'unknown state'],
      dtype='object')
Texts: 691
Label names: Index(['Mixed feelings', 'Negative', 'Positive', 'not-Kannada',
       'unknown state'],
      dtype='object')
Texts: 15726
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-malayalam',
       'unknown_state'],
      dtype='object')
Texts: 1766
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-malayalam',
       'unknown_state'],
      dtype='object')


In [3]:
OUTPUT_FILE = "BIAOJSDOIJASD"
data.fire_validation(model, tokenizer, device, output_file=OUTPUT_FILE, year=2022, BS=BATCH_SIZE, dataset='tam')

Texts: 35575
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-Tamil', 'unknown_state'], dtype='object')
Texts: 3962
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-Tamil', 'unknown_state'], dtype='object')
Texts: 5951
Label names: Index(['Mixed feelings', 'Negative', 'Positive', 'not-Kannada',
       'unknown state'],
      dtype='object')
Texts: 691
Label names: Index(['Mixed feelings', 'Negative', 'Positive', 'not-Kannada',
       'unknown state'],
      dtype='object')
Texts: 15726
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-malayalam',
       'unknown_state'],
      dtype='object')
Texts: 1766
Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-malayalam',
       'unknown_state'],
      dtype='object')


tam validation:   0%|          | 0/166 [00:00<?, ?it/s]

tam validation: 3984


tam validation: 100%|██████████| 166/166 [00:38<00:00,  4.27it/s]

              precision    recall  f1-score   support

           0       0.22      0.33      0.26       289
           1       0.41      0.44      0.42       453
           2       0.82      0.75      0.78      2477
           3       0.57      0.63      0.60       161
           4       0.45      0.47      0.46       582

    accuracy                           0.64      3962
   macro avg       0.49      0.52      0.51      3962
weighted avg       0.66      0.64      0.65      3962




