In [13]:
import torch
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import numpy as np
from sklearn.metrics import classification_report
from transformers import TOKENIZER_MAPPING, AutoModelForSequenceClassification, AutoTokenizer, AdamW, get_linear_schedule_with_warmup, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
import os
from dataset import Dataset

In [14]:
TOKENIZER_NAME = "sentence-transformers/paraphrase-xlm-r-multilingual-v1"
MODEL_NAME = "sentence-transformers/paraphrase-xlm-r-multilingual-v1"
LEARNING_RATE = 3e-5

#OUTPUT_FILE = "NODUP-paraphrase-roberta-kan-pickle.md"

EPOCHS = 4
BATCH_SIZE = 24 
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())    
    print(f'We will use the GPU:{torch.cuda.get_device_name()} ({device})')

else:
    print('NO GPU AVAILABLE ERROR')
    device = torch.device("cpu")
   
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
model = AutoModelForSequenceClassification.from_pretrained("../../task_a/pickles_mal/", num_labels=5, output_attentions=True)
model_phobia = AutoModelForSequenceClassification.from_pretrained("../pickles_mal/", num_labels=3, output_attentions=True) 
model.to(device)
model_phobia.to(device)
optimizer = AdamW(model.parameters(), lr = LEARNING_RATE, no_deprecation_warning=True)

data = Dataset()
_, _, _, _, datatrain, _, _, _ = data.get_phobia_dataset(tokenizer, balance=False)
#_,_, kan_train_2022, _, _,_ = data.get_fire_2022_dataset(tokenizer, balance=False)

train_dataloader = DataLoader(
            datatrain,
            sampler = RandomSampler(datatrain),
            batch_size = BATCH_SIZE)

total_steps = len(train_dataloader) * EPOCHS

scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)


There are 2 GPU(s) available.
We will use the GPU:Tesla V100-SXM2-32GB (cuda)
Texts: 3128
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 790
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 1750
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 621
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 2521
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 837
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 3807
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 962
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')


In [11]:
def inference_validation(model, tokenizer, device, output_file, dataset, BS=16):
    _, eng_val, _, tam_val, _, mal_val, _, eng_tam_val = data.get_phobia_dataset(tokenizer, balance=False)

    if dataset == 'tam':
        loader = DataLoader(tam_val, sampler = SequentialSampler(tam_val), batch_size=BS)
    elif dataset == 'eng':
        loader = DataLoader(eng_val, sampler = SequentialSampler(eng_val), batch_size=BS) 
    elif dataset == 'mal':
        loader = DataLoader(mal_val, sampler = SequentialSampler(mal_val), batch_size=BS) 
    elif dataset == 'eng_tam':
        loader = DataLoader(eng_tam_val, sampler = SequentialSampler(eng_tam_val), batch_size=BS) 

    print(f"{dataset} validation: {len(loader) * BS}")
    
    vbar = tqdm(enumerate(loader), total=len(loader), desc= dataset + " validation")

    model.eval()
    
    true_labels = []
    pred_labels = []
    #total_eval_loss = 0
    
    # Label names: Index(['Mixed_feelings', 'Negative', 'Positive', 'not-Tamil', 'unknown_state']
    # Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic']
    
    for step, batch in vbar:
        b_input_ids = batch[0].to(device)
        b_masks = batch[1].to(device)
        b_labels = batch[2].to(device)

        with torch.no_grad(): 
            outputs = model(input_ids=b_input_ids, attention_mask=b_masks,
                                            labels=b_labels)
            outputs_two = model_phobia(input_ids=b_input_ids, attention_mask=b_masks,
                                            labels=b_labels)
            
            #total_eval_loss += outputs.loss.item()
            logits = outputs.logits.detach().cpu().numpy().tolist()
            logits_two = outputs_two.logits.detach().cpu().numpy().tolist() 
            label_ids = b_labels.to('cpu').numpy().tolist()

            true_labels.extend(label_ids)
            for i in logits:
                # If it is more negative than positive
                if i[1] > i[2]:
                    pred_labels.append(0)
                else:
                    pred_labels.append(1)
            #print(f"Predicted: {pred_labels[-1]}, {np.argmax(i)}")
            #print(f"{pred_labels[-24:]},{np.argmax(logits_two, axis=1)}")
            #pred_labels.extend(np.argmax(logits,axis=1))
        
        #f = open("../outputs/sentiment", 'a')
        #f.write(str(pred_labels))
    print(classification_report(pred_labels, true_labels))
    
    model.train()

In [12]:
OUTPUT_FILE = "BIAOJSDOIJASD"
inference_validation(model=model, tokenizer=tokenizer, device=device, output_file=OUTPUT_FILE, BS=BATCH_SIZE, dataset='mal')

Texts: 3128
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 790
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 1750
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 621
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 2521
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 837
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 3807
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')
Texts: 962
Label names: Index(['Homophobic', 'Non-anti-LGBT+ content', 'Transphobic'], dtype='object')


mal validation:   0%|          | 0/35 [00:00<?, ?it/s]

mal validation: 840


mal validation:   3%|▎         | 1/35 [00:00<00:17,  2.00it/s]

[1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0],[1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]


mal validation:   6%|▌         | 2/35 [00:00<00:16,  2.04it/s]

[0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0],[1 1 1 1 1 0 0 1 0 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1]


mal validation:   9%|▊         | 3/35 [00:01<00:15,  2.07it/s]

[1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0],[1 1 1 1 1 1 1 1 1 1 0 0 1 0 1 1 1 1 1 1 1 1 1 1]


mal validation:  11%|█▏        | 4/35 [00:01<00:14,  2.09it/s]

[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1]


mal validation:  14%|█▍        | 5/35 [00:02<00:14,  2.10it/s]

[1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1],[1 1 1 0 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1]


mal validation:  17%|█▋        | 6/35 [00:02<00:13,  2.12it/s]

[1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0],[1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  20%|██        | 7/35 [00:03<00:13,  2.12it/s]

[0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1],[1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 1 1 1 1 1 1 1]


mal validation:  23%|██▎       | 8/35 [00:03<00:12,  2.12it/s]

[0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0],[0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1]


mal validation:  26%|██▌       | 9/35 [00:04<00:12,  2.13it/s]

[0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1],[1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  29%|██▊       | 10/35 [00:04<00:11,  2.13it/s]

[1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1],[1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1]


mal validation:  31%|███▏      | 11/35 [00:05<00:11,  2.13it/s]

[1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1]


mal validation:  34%|███▍      | 12/35 [00:05<00:10,  2.14it/s]

[1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1],[1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  37%|███▋      | 13/35 [00:06<00:10,  2.14it/s]

[1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],[1 1 1 1 1 1 1 1 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 1]


mal validation:  40%|████      | 14/35 [00:06<00:09,  2.14it/s]

[0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1]


mal validation:  43%|████▎     | 15/35 [00:07<00:09,  2.14it/s]

[0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1]


mal validation:  46%|████▌     | 16/35 [00:07<00:08,  2.14it/s]

[1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1],[1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  49%|████▊     | 17/35 [00:07<00:08,  2.13it/s]

[1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0],[1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  51%|█████▏    | 18/35 [00:08<00:07,  2.13it/s]

[0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 0 1]


mal validation:  54%|█████▍    | 19/35 [00:08<00:07,  2.13it/s]

[1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1]


mal validation:  57%|█████▋    | 20/35 [00:09<00:07,  2.14it/s]

[0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0],[1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  60%|██████    | 21/35 [00:09<00:06,  2.13it/s]

[0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1],[1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1]


mal validation:  63%|██████▎   | 22/35 [00:10<00:06,  2.13it/s]

[1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1],[1 1 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 0 1 1]


mal validation:  66%|██████▌   | 23/35 [00:10<00:05,  2.14it/s]

[1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  69%|██████▊   | 24/35 [00:11<00:05,  2.14it/s]

[1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1],[1 0 1 0 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  71%|███████▏  | 25/35 [00:11<00:04,  2.13it/s]

[0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1],[1 1 1 0 0 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  74%|███████▍  | 26/35 [00:12<00:04,  2.13it/s]

[1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0],[1 0 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  77%|███████▋  | 27/35 [00:12<00:03,  2.14it/s]

[0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0],[0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 0 1 1 1]


mal validation:  80%|████████  | 28/35 [00:13<00:03,  2.13it/s]

[0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0],[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  83%|████████▎ | 29/35 [00:13<00:02,  2.13it/s]

[0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0],[1 1 0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1]


mal validation:  86%|████████▌ | 30/35 [00:14<00:02,  2.13it/s]

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1]


mal validation:  89%|████████▊ | 31/35 [00:14<00:01,  2.13it/s]

[0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0],[1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1 1]


mal validation:  91%|█████████▏| 32/35 [00:15<00:01,  2.13it/s]

[1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1],[1 1 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1]


mal validation:  94%|█████████▍| 33/35 [00:15<00:00,  2.12it/s]

[1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1],[1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]


mal validation:  97%|█████████▋| 34/35 [00:15<00:00,  2.13it/s]

[1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0],[1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1]


mal validation: 100%|██████████| 35/35 [00:16<00:00,  2.14it/s]

[0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1],[1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 0 1 1 1 1]
              precision    recall  f1-score   support

           0       0.77      0.20      0.31       452
           1       0.50      0.89      0.64       385
           2       0.00      0.00      0.00         0

    accuracy                           0.51       837
   macro avg       0.42      0.36      0.32       837
weighted avg       0.65      0.51      0.46       837




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
