In [None]:
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import List, Union
from transformers import AutoTokenizer, AutoModel

class MyDataset(Dataset):
    def __init__(self, 
                ids: List[str], 
                speakers: List[str], 
                sexes: List[str], 
                texts: List[str], 
                texts_en: List[str], 
                labels: List[bool],
                device: torch.device = torch.device('cpu'),
                model_name: str = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english',
                max_length: int = 512
        ):
        assert len(ids) == len(speakers) == len(sexes) == len(texts) == len(texts_en) == len(labels)
        self.ids = []
        self.speakers = []
        self.sexes = []
        self.texts = []
        self.texts_en = []
        self.embeddings = []
        self.attention_masks = []
        self.labels = []
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        for i in range(len(ids)):
            text = texts[i]
            inputs = self.tokenizer(text, add_special_tokens=True, return_tensors='pt', padding='max_length',max_length=max_length)
            if inputs['input_ids'].shape[1] <= max_length:
                self.ids.append(ids[i])
                self.speakers.append(speakers[i])
                self.sexes.append(sexes[i])
                self.texts.append(texts[i])
                self.texts_en.append(texts_en[i])
                self.embeddings.append(inputs['input_ids'][0])
                self.attention_masks.append(inputs['attention_mask'])
                self.labels.append(torch.tensor((labels[i]), dtype=torch.long))
                
        print(f'Loaded {len(self.ids)}/{len(ids)} samples.')

    def __getitem__(self, index):
        return self.ids[index], self.speakers[index], self.sexes[index], self.texts[index], \
                self.texts_en[index], self.embeddings[index].to(self.device), self.attention_masks[index][0].to(self.device), self.labels[index]
            
    def __len__(self):
        return len(self.ids)

    def set_device(self, device: torch.device):
        '''
        Sets the device to the given device.
        '''
        self.device = device

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import List, Union
from transformers import AutoTokenizer, AutoModel, PreTrainedModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, BertForSequenceClassification
import pandas as pd

from transformers import DataCollatorWithPadding

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

def evaluate(dataset: Dataset, model: PreTrainedModel, device: torch.device = torch.device('cpu'), batch_size = 1):
    '''
    Evaluates the model on the given dataset.
    
    Parameters:
        dataset: Dataset
            The dataset to evaluate on.
        model: PreTrainedModel
            The model to evaluate.
        device: torch.device
            The device to use.
        plot: bool
    '''
    model.to(device)
    model.eval()

    #tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    #data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
    #loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
    correct_labels = []
    model_predictions = []
    with torch.no_grad():
        for batch in loader:
            id_, speaker, sex, text, text_en, embedding, attention_mask, label = batch
            embedding = embedding.to(device)
            attention_mask = attention_mask.to(device).squeeze(1)
            label = label.to(device)
            model_output = model(input_ids=embedding, labels=label, attention_mask=attention_mask)
            logits = model_output.logits
            predictions = torch.argmax(logits, dim=1)
            correct_labels.extend(label.cpu().numpy())
            model_predictions.extend(predictions.cpu().numpy())

    accuracy = accuracy_score(correct_labels, model_predictions)
    cls_dict = classification_report(correct_labels, model_predictions, zero_division = 0.0, output_dict=True)
    print(classification_report(correct_labels, model_predictions, zero_division = 0.0))
    print(f'Accuracy: {accuracy}')
    print(f'Confusion matrix:\n{confusion_matrix(correct_labels, model_predictions)}')
    return cls_dict

In [None]:
import types

dataset_valid = torch.load('/kaggle/input/orientation-dataset/val_dataset_all.pt')
dataset_train = torch.load('/kaggle/input/orientation-dataset/train_dataset_all.pt')
dataset_test = torch.load('/kaggle/input/orientation-dataset/test_dataset_all.pt')

In [None]:
left_v = 0
right_v = 0
for example in dataset_valid:
    if example[-1] == 0:
        left_v += 1
    else:
        right_v += 1

print(f'Validation set: {left_v} left, {right_v} right')

In [None]:
left_t = 0
right_t = 0
for example in dataset_train:
    if example[-1] == 0:
        left_t += 1
    else:
        right_t += 1
    
print(f'Training set: {left_t} left, {right_t} right')

In [None]:
left_tt = 0
right_tt = 0
for example in dataset_test:
    if example[-1] == 0:
        left_tt += 1
    else:
        right_tt += 1

print(f'Test set: {left_tt} left, {right_tt} right')

In [None]:
left_total = left_v+left_t+left_tt
right_total = right_v+right_t+right_tt

In [None]:
print(f'Left: {left_total}, Right: {right_total}')

In [None]:
print(f'Left: {left_total/(left_total+right_total)*100:.2f}%, Right: {right_total/(left_total+right_total)*100:.2f}%')

In [None]:
import os
DATASET_DIR = '/kaggle/input/orientation-dataset'
c_dict = {'train': {}, 'val': {}, 'test': {}}
for filename in os.listdir(DATASET_DIR):
    if filename.endswith(".pt"):
        file_path = os.path.join(DATASET_DIR, filename)
        dataset = torch.load(file_path)
        if 'train' in filename:
            c_dict['train'][filename] = dataset

        elif 'val' in filename:
            c_dict['val'][filename] = dataset

        elif 'test' in filename:
            c_dict['test'][filename] = dataset

In [None]:
for key in c_dict:
    print(key)
    for filename in c_dict[key]:
        print(f'    {filename}: {len(c_dict[key][filename])}')

In [None]:
from transformers import AutoModel
model = torch.load('/kaggle/input/distilbert/pytorch/default/1/distilbert_cased_en_novi.pt', map_location=torch.device('cuda:0'))

In [None]:
model

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import torch
from typing import List, Union
from transformers import AutoTokenizer, AutoModel, PreTrainedModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, BertForSequenceClassification
import pandas as pd

from transformers import DataCollatorWithPadding

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

def evaluate(dataset: Dataset, model: PreTrainedModel, device: torch.device = torch.device('cpu'), batch_size=1):
    """
    Evaluates the model on the given dataset, ignoring sequences longer than 512 tokens.

    Parameters:
        dataset: Dataset
            The dataset to evaluate on.
        model: PreTrainedModel
            The model to evaluate.
        device: torch.device
            The device to use.
        batch_size: int
            The batch size for evaluation.
    """
    model.to(device)
    model.eval()

    def filter_and_collate(batch):
        """
        Custom collate function that filters out inputs longer than 512 tokens.
        """
        filtered_batch = []
        for item in batch:
            embedding = item[5]  # Assuming 'embedding' or 'input_ids' is at index 5
            if len(embedding) <= 512:
                filtered_batch.append(item)

        if len(filtered_batch) == 0:
            return None  # Return None if no valid examples remain in the batch

        # Convert filtered items to tensors
        ids = [item[0] for item in filtered_batch]
        speakers = [item[1] for item in filtered_batch]
        sexes = [item[2] for item in filtered_batch]
        texts = [item[3] for item in filtered_batch]
        texts_en = [item[4] for item in filtered_batch]
        embeddings = torch.stack([torch.tensor(item[5]) for item in filtered_batch])
        attention_masks = torch.stack([torch.tensor(item[6]) for item in filtered_batch])
        labels = torch.tensor([item[7] for item in filtered_batch])

        return {
            'ids': ids,
            'speakers': speakers,
            'sexes': sexes,
            'texts': texts,
            'texts_en': texts_en,
            'input_ids': embeddings,
            'attention_mask': attention_masks,
            'labels': labels,
        }

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=filter_and_collate)

    correct_labels = []
    model_predictions = []

    with torch.no_grad():
        for batch in loader:
            if batch is None:
                continue  # Skip empty batches

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            model_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            logits = model_output.logits
            predictions = torch.argmax(logits, dim=1)

            correct_labels.extend(labels.cpu().numpy())
            model_predictions.extend(predictions.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(correct_labels, model_predictions)
    cls_dict = classification_report(correct_labels, model_predictions, zero_division=0.0, output_dict=True)
    print(classification_report(correct_labels, model_predictions, zero_division=0.0))
    print(f'Accuracy: {accuracy}')
    print(f'Confusion matrix:\n{confusion_matrix(correct_labels, model_predictions)}')

    return cls_dict


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
performance = {}
for filename in c_dict['test']:
    dataset_ = c_dict['test'][filename]
    print(f'{filename}: {len(dataset_)} examples')
    report = evaluate(dataset_, model, device=device,batch_size=16)
    performance[filename] = report['weighted avg']['f1-score']

In [None]:
import re
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# Ensure inline plotting for Jupyter notebooks
%matplotlib inline

def filter_and_collate(batch):
    """
    Custom collate function that filters out inputs longer than 512 tokens.
    """
    filtered_batch = []
    for item in batch:
        embedding = item[5]  # Assuming 'embedding' or 'input_ids' is at index 5
    if len(embedding) <= 512:
        filtered_batch.append(item)

    if len(filtered_batch) == 0:
        return None  # Return None if no valid examples remain in the batch

    # Convert filtered items to tensors
    ids = [item[0] for item in filtered_batch]
    speakers = [item[1] for item in filtered_batch]
    sexes = [item[2] for item in filtered_batch]
    texts = [item[3] for item in filtered_batch]
    texts_en = [item[4] for item in filtered_batch]
    embeddings = torch.stack([torch.tensor(item[5]) for item in filtered_batch])
    attention_masks = torch.stack([torch.tensor(item[6]) for item in filtered_batch])
    labels = torch.tensor([item[7] for item in filtered_batch])

    return {
        'ids': ids,
        'speakers': speakers,
        'sexes': sexes,
        'texts': texts,
        'texts_en': texts_en,
        'input_ids': embeddings,
        'attention_mask': attention_masks,
        'labels': labels,
    }

# Extract country code from file path
def extract_country_code(file_path):
    if 'all' in file_path:
        return 'all_combined'
    if 'train' in file_path:
        match = re.search(r'orientation-([a-z]{2}(?:-[a-z]{2})?).pt', file_path)
        if match:
            return match.group(1)
    
    if 'val' in file_path:
        match = re.search(r'val_dataset_orientation-([a-z]{2}(?:-[a-z]{2})?).pt', file_path)
        if match:
            return match.group(1)
        
    if 'test' in file_path:
        match = re.search(r'orientation-([a-z]{2}(?:-[a-z]{2})?).pt', file_path)
        if match:
            return match.group(1)
        
    print(f"No match found for: {file_path}")
    return None

def get_bert_representations(model, dataset_name, dataset, batch_size=1):
    representations = []
    labels = []
    countries = []
    sexes = []
    logits = []
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=filter_and_collate)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for ind, batch in enumerate(dataloader):
        if batch is None:
            continue
        
        #id_, speaker, sex, text, text_en, embedding, attention_mask, label = batch
        embedding = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        label = batch['labels']
        sex = batch['sexes']
        # Get the last hidden state
        outputs = model(input_ids=embedding, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states
        cls_representation = hidden_states[-1][:,0,:]
        logits_ = outputs.logits
        representations.extend(cls_representation.detach().cpu().numpy())
        labels.extend(label.detach().cpu().numpy())
        countries.extend([extract_country_code(dataset_name) for _ in range(len(label))])
        sexes.extend(sex)
        logits.extend(logits_.detach().cpu().numpy())
            
    sexes = list(map(lambda x: 0 if x == 'F' else 1, sexes))
    return np.array(representations), np.array(labels), np.array(countries), np.array(sexes), np.array(logits)

def extract_representations(dataset_dict):
    representations = np.array([])
    labels = np.array([])
    countries = np.array([])
    sexes = np.array([])
    logits = np.array([])
    for file_path, dataset in dataset_dict.items():
        reps_, labels_, countries_, sexes_, logits_ = get_bert_representations(model, file_path, dataset, 16)
        print(reps_)
        print("shape: ", reps_.shape)
        representations = np.concatenate([representations, reps_]) if representations.size else reps_
        labels = np.concatenate([labels, labels_]) if labels.size else labels_
        countries = np.concatenate([countries, countries_]) if countries.size else countries_
        sexes = np.concatenate([sexes, sexes_]) if sexes.size else sexes_
        logits = np.concatenate([logits, logits_]) if logits.size else logits_
        print(f"Processed {file_path}")
        
    return representations, labels, countries, sexes, logits

def calculate_mean_by_country(countries, labels, representations, threshold=1.0):
    mean_representations = {}
    for country in np.unique(countries):
        if representations[(countries == country)].shape[0] < threshold:
            print(country)
            continue
        for label in np.unique(labels):
            reps = representations[(countries == country) & (labels == label)]
            mean_representations[(country, label)] = np.mean(reps, axis=0)
    
    return mean_representations

def calculate_mean_by_country_sex(countries, labels, representations, sexes, threshold=1.0):
    mean_representations = {}
    for country in np.unique(countries):
        if representations[(countries == country)].shape[0] < threshold:
            continue
        for sex in np.unique(sexes):
            for label in np.unique(labels):
                reps = representations[(countries == country) & (sexes == sex) & (labels == label)]
                if reps.size:
                    mean_representations[(country, sex, label)] = np.mean(reps, axis=0)
                else:
                    mean_representations[(country, sex, label)] = np.zeros([768])
                
    return mean_representations

def perform_PCA(vectors):
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(vectors)
    return pca_result 
    
def plot_country(pca_result, countries_labels):
    plt.figure(figsize=(25, 14))

    for i, ((country, label), pca_coord) in enumerate(zip(countries_labels, pca_result)):
        color = 'red' if label == 0 else 'blue'
        plt.scatter(pca_coord[0], pca_coord[1], color=color)

    # Connect the points with lines and add country codes
    for country in np.unique([c for c, l in countries_labels]):
        left_coords = pca_result[[i for i, (c, l) in enumerate(countries_labels) if c == country and l == 0]]
        right_coords = pca_result[[i for i, (c, l) in enumerate(countries_labels) if c == country and l == 1]]
        if len(left_coords) > 0 and len(right_coords) > 0:
            plt.plot([left_coords[0][0], right_coords[0][0]], [left_coords[0][1], right_coords[0][1]], 'k-')
            mid_x = (left_coords[0][0] + right_coords[0][0]) / 2
            mid_y = (left_coords[0][1] + right_coords[0][1]) / 2
            plt.text(mid_x, mid_y, country, fontsize=12, color='black', ha='center', va='center')

    plt.xlabel('PCA Component 1')
    plt.ylabel('PCA Component 2')
    plt.title('PCA of Mean Representations for Left-Wing and Right-Wing Speeches by Country')
    plt.grid(True)
    plt.show()

def plot_country_sex(pca_result, cs_labels):
    plt.figure(figsize=(25, 14))

    for i, ((country, sex, label), pca_coord) in enumerate(zip(cs_labels, pca_result)):
        color = 'red' if label == 0 else 'blue'
        marker = 'v' if sex == 1 else 'o'
        plt.scatter(pca_coord[0], pca_coord[1], color=color, marker=marker)

    # Connect the points with lines and add country codes
    for country in np.unique([c for (c, s, l) in cs_labels]):
        for sex in np.unique([s for (c, s, l) in cs_labels]):
            if (country, sex, 0) not in cs_labels or (country, sex, 1) not in cs_labels:
                continue
            left_coords = pca_result[[i for i, (c, s, l) in enumerate(cs_labels) if c == country and l == 0 and s == sex]]
            right_coords = pca_result[[i for i, (c, s, l) in enumerate(cs_labels) if c == country and l == 1 and s == sex]]
            if len(left_coords) > 0 and len(right_coords) > 0:
                plt.plot([left_coords[0][0], right_coords[0][0]], [left_coords[0][1], right_coords[0][1]], 'k-')
                mid_x = (left_coords[0][0] + right_coords[0][0]) / 2
                mid_y = (left_coords[0][1] + right_coords[0][1]) / 2
                sex_text = "Female" if sex == 0 else "Male"
                text = country + " - " + sex_text
                plt.text(mid_x, mid_y, text, fontsize=12, color='black', ha='center', va='center')

    plt.xlabel('PCA Component 1')
    plt.ylabel('PCA Component 2')
    plt.title('PCA of Mean Representations for Left-Wing and Right-Wing Speeches by Country')
    plt.grid(True)
    plt.show()

In [None]:
torch.cuda.empty_cache()

In [None]:
import matplotlib
font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 15}

matplotlib.rc('font', **font)

In [None]:
# Get BERT representations for the validation data
val_representations, val_labels, val_countries, val_sexes, val_logits = extract_representations(c_dict['val'])
#print(val_countries)

In [None]:
# Calculate mean representations for each country and label
mean_representations_cval = calculate_mean_by_country(val_countries, val_labels, val_representations, threshold=4)
print("Calculated all representations!")
#print(mean_representations_cval)

# Prepare data for PCA
mean_reps_cval = np.array(list(mean_representations_cval.values()))
countries_labels_cval = list(mean_representations_cval.keys())

# Perform PCA
pca_result_cval = perform_PCA(mean_reps_cval)

#Plot results
plot_country(pca_result_cval, countries_labels_cval)

In [None]:
# Calculate mean representations for each country and label
mean_representations_csval = calculate_mean_by_country_sex(val_countries, val_labels, val_representations, val_sexes, threshold=100)
print("Calculated all representations!")
print(mean_representations_csval)

# Prepare data for PCA
mean_reps_csval = np.array(list(mean_representations_csval.values()))
countries_labels_csval = list(mean_representations_csval.keys())

# Perform PCA
pca_result_csval = perform_PCA(mean_reps_csval)

#Plot results
plot_country_sex(pca_result_csval, countries_labels_csval)

In [None]:
# Get BERT representations for the validation data
tr_representations, tr_labels, tr_countries, tr_sexes, tr_logits = extract_representations(c_dict['train'])

In [None]:
# Calculate mean representations for each country and label
mean_representations_ctr = calculate_mean_by_country(tr_countries, tr_labels, tr_representations, threshold=30)
print("Calculated all representations!")

# Prepare data for PCA
mean_reps_ctr = np.array(list(mean_representations_ctr.values()))
countries_labels_ctr = list(mean_representations_ctr.keys())

# Perform PCA
pca_result_ctr = perform_PCA(mean_reps_ctr)

#Plot results
plot_country(pca_result_ctr, countries_labels_ctr)

In [None]:
# Calculate mean representations for each country and label
mean_representations_cstr = calculate_mean_by_country_sex(tr_countries, tr_labels, tr_representations, tr_sexes, threshold=500)
print("Calculated all representations!")

# Prepare data for PCA
mean_reps_cstr = np.array(list(mean_representations_cstr.values()))
countries_labels_cstr = list(mean_representations_cstr.keys())

# Perform PCA
pca_result_cstr = perform_PCA(mean_reps_cstr)

#Plot results
plot_country_sex(pca_result_cstr, countries_labels_cstr)

In [None]:
# Get BERT representations for the validation data
ts_representations, ts_labels, ts_countries, ts_sexes, ts_logits = extract_representations(c_dict['test'])

In [None]:
# Calculate mean representations for each country and label
mean_representations_cts = calculate_mean_by_country(ts_countries, ts_labels, ts_representations, threshold=100)
print("Calculated all representations!")

# Prepare data for PCA
mean_reps_cts = np.array(list(mean_representations_cts.values()))
countries_labels_cts = list(mean_representations_cts.keys())

# Perform PCA
pca_result_cts = perform_PCA(mean_reps_cts)

#Plot results
plot_country(pca_result_cts, countries_labels_cts)

In [None]:
# Calculate mean representations for each country and label
mean_representations_csts = calculate_mean_by_country_sex(ts_countries, ts_labels, ts_representations, ts_sexes, threshold=200)
print("Calculated all representations!")

# Prepare data for PCA
mean_reps_csts = np.array(list(mean_representations_csts.values()))
countries_labels_csts = list(mean_representations_csts.keys())

# Perform PCA
pca_result_csts = perform_PCA(mean_reps_csts)

#Plot results
plot_country_sex(pca_result_csts, countries_labels_csts)

In [None]:
plt.figure(figsize=(15, 9))
c_names = list(map(extract_country_code, list(performance.keys())))
plt.barh(c_names, list(performance.values()))
plt.xlabel('Weighted F1')
plt.title('Weighted F1 by Country')

In [None]:
def calculate_mean_probabilities(countries, predictions, labels, threshold = 1):
    country_means = {}
    for country in np.unique(countries):
        #print(f"Processing country: {country}")
        country_indices = countries == country
        country_labels = labels[country_indices]
        if country_labels.shape[0] < threshold:
            continue
        country_predictions = predictions[country_indices]
        country_predictions = np.exp(country_predictions)
        country_predictions = country_predictions / np.sum(country_predictions, axis=1)[:, np.newaxis]
        if country_labels.size > 0:  
            right_mean = np.mean(country_predictions[country_labels == 1][:, 1])
            left_mean = np.mean(country_predictions[country_labels == 0][:, 0])
            country_means[country] = (right_mean, left_mean)
            #print(f"Country: {country}, Right Mean: {right_mean}, Left Mean: {left_mean}")
        else:
            print(f"No data for country: {country}")
    return country_means

In [None]:
country_means = calculate_mean_probabilities(ts_countries, ts_logits, ts_labels, threshold=7)
countries = list(country_means.keys())
right_means = [country_means[c][0] for c in countries]
left_means = [country_means[c][1] for c in countries]

In [None]:
plt.figure(figsize=(15, 9))
plt.barh(countries, right_means, color='blue')
plt.xlabel('Mean Probability')
plt.title('Right-Wing Mean Probabilities by Country')

In [None]:
plt.figure(figsize=(15, 9))
plt.barh(countries, left_means, color='red')
plt.xlabel('Mean Probability')
plt.title('Left-Wing Mean Probabilities by Country')

plt.tight_layout()
plt.show()