In [1]:
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import List, Union
from transformers import AutoTokenizer, AutoModel

class MyDataset(Dataset):
    def __init__(self, 
                ids: List[str], 
                speakers: List[str], 
                sexes: List[str], 
                texts: List[str], 
                texts_en: List[str], 
                labels: List[bool],
                device: torch.device = torch.device('cpu'),
                model_name: str = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english',
                max_length: int = 512
        ):
        assert len(ids) == len(speakers) == len(sexes) == len(texts) == len(texts_en) == len(labels)
        self.ids = []
        self.speakers = []
        self.sexes = []
        self.texts = []
        self.texts_en = []
        self.embeddings = []
        self.attention_masks = []
        self.labels = []
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        for i in range(len(ids)):
            text = texts[i]
            inputs = self.tokenizer(text, add_special_tokens=True, return_tensors='pt', padding='max_length',max_length=max_length)
            if inputs['input_ids'].shape[1] <= max_length:
                self.ids.append(ids[i])
                self.speakers.append(speakers[i])
                self.sexes.append(sexes[i])
                self.texts.append(texts[i])
                self.texts_en.append(texts_en[i])
                self.embeddings.append(inputs['input_ids'][0])
                self.attention_masks.append(inputs['attention_mask'])
                self.labels.append(torch.tensor((labels[i]), dtype=torch.long))
                
        print(f'Loaded {len(self.ids)}/{len(ids)} samples.')

    def __getitem__(self, index):
        return self.ids[index], self.speakers[index], self.sexes[index], self.texts[index], \
                self.texts_en[index], self.embeddings[index][:512].to(self.device), self.attention_masks[index][0][:512].to(self.device), self.labels[index]
            
    def __len__(self):
        return len(self.ids)

    def set_device(self, device: torch.device):
        '''
        Sets the device to the given device.
        '''
        self.device = device

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import types
import sys

dataset_valid = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/val_dataset_all.pt')
dataset_train = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/train_dataset_all.pt')
dataset_test = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/test_dataset_all.pt')

  dataset_valid = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/val_dataset_all.pt')
  dataset_train = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/train_dataset_all.pt')
  dataset_test = torch.load('D:/fer/9.sem/OPJ/data/torch/orientation/test_dataset_all.pt')


In [3]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import List, Union
from transformers import AutoTokenizer, AutoModel, PreTrainedModel
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix

def evaluate(dataset: Dataset, model: PreTrainedModel, device: torch.device = torch.device('cpu'), plot: bool = False):
    '''
    Evaluates the model on the given dataset.
    
    Parameters:
        dataset: Dataset
            The dataset to evaluate on.
        model: PreTrainedModel
            The model to evaluate.
        device: torch.device
            The device to use.
        plot: bool
    '''
    model.to(device)
    model.eval()
    loader = DataLoader(dataset, batch_size=16, shuffle=False)
    correct_labels = []
    model_predictions = []
    probs = []
    attentions = []
    embeddings = []
    texts = []
    with torch.no_grad():
        for batch in loader:
            id_, speaker, sex, text, text_en, embedding, attention_mask, label = batch
            texts.extend(text_en)
            embedding = embedding.to(device)
            attention_mask = attention_mask.to(device).squeeze(1)
            label = label.to(device)
            model_output = model(input_ids=embedding, attention_mask=attention_mask, output_attentions=True)
            embeddings.extend(embedding.cpu())
            
            # Prosječni attention skorovi za posljednji sloj
            attention = torch.mean(model_output.attentions[-1], dim=1).squeeze()[:, 0]
            attentions.extend(attention.cpu().numpy())
            
            logits = model_output.logits
            prob = torch.max(torch.softmax(logits, dim=1), dim=1)
            probs.extend(prob.values.cpu())
            predictions = torch.argmax(logits, dim=1)
            correct_labels.extend(label.cpu().numpy())
            model_predictions.extend(predictions.cpu().numpy())

    accuracy = accuracy_score(correct_labels, model_predictions)
    print(f'Accuracy: {accuracy}')
    print(f'Confusion matrix:\n{confusion_matrix(correct_labels, model_predictions)}')
    
    return correct_labels, model_predictions, probs, attentions, embeddings, texts

In [4]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = torch.load('D:/fer/9.sem/OPJ/roberta_base_en.pt', map_location=torch.device('cuda:0'))

  model = torch.load('D:/fer/9.sem/OPJ/roberta_base_en.pt', map_location=torch.device('cuda:0'))


In [5]:
output = evaluate(dataset_train, model, 'cuda:0')



Accuracy: 0.8556500231875097
Confusion matrix:
[[10621  2636]
 [ 2033 17055]]


In [6]:
labels, predictions, probs, attentions, embeddings, old_texts = output

In [7]:
labels = np.array(labels)
predictions = np.array(predictions)
probs = np.array(probs)
attentions = np.array(attentions)
embeddings = np.array(embeddings)

In [8]:
# Filtriranje samo točno klasificiranih primjera
ind = labels == predictions
labels = labels[ind]
predictions = predictions[ind]
probs = probs[ind]
attentions = attentions[ind]
embeddings = embeddings[ind]
texts = [old_texts[i] for i, cond in enumerate(ind) if cond]

In [33]:
k = min(5000, len(probs))
most_left = torch.topk(torch.tensor(probs), k, largest=False).indices.numpy()
counter_l = {}
for ind in most_left:
    most_important = torch.topk(torch.tensor(attentions[ind]), 15).indices.numpy()
    words = tokenizer.decode(embeddings[ind][most_important]).lower().split(' ')
    for word in words:
        if '[SEP]' in word:
            word = word.replace('[SEP]', '')
        if '....' in word:
            word = word.replace('....', '')
        if '...' in word:
            word = word.replace('...', '')
        if '..' in word:
            word = word.replace('..', '')
        if '.' in word:
            word = word.replace('.', '')
        if ',' in word:
            word = word.replace(',', '')
        if '–' in word:
            word = word.replace('–', '')
        if '?' in word:
            word = word.replace('?', '')
        if '!' in word:
            word = word.replace('!', '')
        if '""' in word:
            word = word.replace('""', '')
        if '</s>' in word:
            word = word.replace('</s>', '')
        if '<s>' in word:
            word = word.replace('<s>', '')
        if word not in counter_l:
            counter_l[word] = 1
        else:
            counter_l[word] += 1

most_popular_left = sorted(counter_l.items(), key=lambda x: -x[1])
print("Most influential words for left predictions:")
print(most_popular_left[:100])

Most influential words for left predictions:
[('', 4660), ('the', 2102), ('minister', 1203), ('president', 1090), ('you', 588), ('government', 562), ('mr', 413), ('of', 340), ('this', 327), ('and', 262), ('state', 261), ('secretary', 225), ('house', 209), ('bill', 181), ('council', 176), ('we', 162), ('it', 148), ('he', 144), ('your', 137), ('right', 137), ('members', 131), ('our', 126), ('chairman', 122), ('to', 120), ('chamber', 116), ('today', 107), ('rapport', 107), ('representative', 104), ('a', 103), ('vice', 103), ('amendment', 102), ('prime', 99), ('party', 94), ('floor', 88), ('gentlemen', 85), ('they', 85), ('my', 79), ('in', 76), ('mad', 76), (')', 70), ('here', 69), ('hon', 68), ('—', 63), ('committee', 63), ('mrs', 61), ('gentleman', 60), ('speaker', 58), ('law', 57), ('she', 55), ('will', 54), ('is', 53), ('parliament', 53), ('(<', 52), ('i', 51), ('his', 51), ('but', 49), ('czech', 49), ('member', 49), ('her', 48), ('dear', 48), ('commission', 47), ('on', 47), ('these', 

In [34]:
most_right = torch.topk(torch.tensor(probs), k, largest=True).indices.numpy()
counter_r = {}
for ind in most_right:
    most_important = torch.topk(torch.tensor(attentions[ind]), 15).indices.numpy()
    words = tokenizer.decode(embeddings[ind][most_important]).lower().split(' ')
    for word in words:
        if '[SEP]' in word:
            word = word.replace('[SEP]', '')
        if '....' in word:
            word = word.replace('....', '')
        if '...' in word:
            word = word.replace('...', '')
        if '..' in word:
            word = word.replace('..', '')
        if '.' in word:
            word = word.replace('.', '')
        if ',' in word:
            word = word.replace(',', '')
        if '–' in word:
            word = word.replace('–', '')
        if '?' in word:
            word = word.replace('?', '')
        if '!' in word:
            word = word.replace('!', '')
        if '""' in word:
            word = word.replace('""', '')
        if '</s>' in word:
            word = word.replace('</s>', '')
        if '<s>' in word:
            word = word.replace('<s>', '')
        if word not in counter_r:
            counter_r[word] = 1
        else:
            counter_r[word] += 1

most_popular_right = sorted(counter_r.items(), key=lambda x: -x[1])
print("Most influential words for right predictions:")
print(most_popular_right[:100])

Most influential words for right predictions:
[('the', 6712), ('', 4667), ('our', 2439), ('my', 2405), ('we', 2298), ('government', 1809), ('friend', 1798), ('hon', 1743), ('minister', 1235), ('right', 884), ('of', 819), ('noble', 710), ('this', 683), ('i', 596), ('he', 592), ('are', 535), ('to', 515), ('will', 499), ('secretary', 492), ('is', 447), ('uk', 363), ('state', 359), ('prime', 350), ('his', 344), ('a', 282), ('president', 276), ('she', 266), ('that', 266), ('you', 264), ('her', 264), ('have', 251), ('lady', 247), ('house', 247), ('lords', 246), ('they', 231), ('nation', 218), ('chancellor', 218), ('and', 216), ('department', 208), ('it', 198), ('in', 195), ('high', 193), ('country', 183), ('part', 180), ('bill', 170), ('assembly', 166), ('lord', 165), ('with', 157), ('has', 152), ('turkey', 152), ('their', 137), ('by', 134), ('for', 133), ('there', 132), ('does', 129), ('salute', 128), ('your', 127), ('general', 126), ('thank', 125), ('marshal', 123), ('kingdom', 98), ('mr',

In [18]:
w = 'treasury'
print(counter_l[w])
print(counter_r[w])

5
25


In [19]:
w = 'climate'
print(counter_l[w])
print(counter_r[w])

32
2


In [35]:
w = 'amendment'
print(counter_l[w])
print(counter_r[w])

102
7


In [37]:
w = 'nation'
print(counter_l[w])
print(counter_r[w])

8
218


In [38]:
w = 'school'
print(counter_l[w])
print(counter_r[w])

31
4


In [41]:
w = 'support'
print(counter_l[w])
print(counter_r[w])

3
47
