In [None]:
from datasets import load_dataset, DatasetDict, Dataset, concatenate_datasets
import random
import string

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, DataCollatorWithPadding, AdamW, get_scheduler, MarianMTModel, MarianTokenizer
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output
from torch.nn.utils.rnn import pad_sequence
import gc

In [None]:
data_files = {
    "train": "../data/raw_dataset/train.csv",
    "test": "../data/raw_dataset/test.csv"
}
raw_dataset = load_dataset("csv", data_files=data_files)

Tang cuong va can bang du lieu

In [None]:
# Sử dụng thiết bị (GPU nếu có sẵn, nếu không thì sử dụng CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Danh sách các ngôn ngữ
languages = ['de', 'es', 'ru', 'fr']
translate_models = {}
translate_tokenizers = {}
# Tải mô hình và tokenizer, sau đó chuyển mô hình sang thiết bị
for language in languages:
    translate_models['to' + language] = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-" + language).to(device)
    translate_tokenizers['to' + language] = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-" + language)
    translate_models['from' + language] = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-" + language + "-en").to(device)
    translate_tokenizers['from' + language] = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-" + language + "-en")
languages.append('en')

In [None]:
# Hàm back-translate
def back_translate(texts):
    language = random.choice(languages)
    if language == 'en':
        return texts
    # Dịch từ English sang ngôn ngữ đích
    model = translate_models['to' + language]
    tokenizer = translate_tokenizers['to' + language]
    with torch.no_grad():
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
        translated = model.generate(**inputs)
    texts = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
    # Giải phóng bộ nhớ GPU cho model và inputs
    del inputs
    del translated
    torch.cuda.empty_cache()
    # Dịch từ ngôn ngữ đích quay lại English
    model = translate_models['from' + language]
    tokenizer = translate_tokenizers['from' + language]
    with torch.no_grad():
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(device)
        translated = model.generate(**inputs)
    texts = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
    # Giải phóng bộ nhớ GPU cho model và inputs
    del inputs
    del translated
    torch.cuda.empty_cache()
    del model
    del tokenizer
    del language
    gc.collect()
    return texts
def remove_last_punctuation(text):
    while text[-1] in string.punctuation:
        text = text[:-1]
    return text
def create_train_dataset(raw_datasets, batch_size=48, use_back_translate=True):
    datasets = DatasetDict({'train': raw_datasets['train'].shuffle()})
    if use_back_translate:
        en_text = datasets['train']['en_text']
        new_en_text = []
        for i in range(0, len(en_text), batch_size):
            new_en_text.extend(back_translate(en_text[i : i + batch_size]))
        datasets['train'] = datasets['train'].remove_columns('en_text').add_column('en_text', new_en_text)
    #################################################################################################
    for name in datasets:
        data_dict = {"en_text": [], "labels": [], "type": []}
        for row in datasets[name]:
            text = row['en_text']
            label_aspects = {'all': 0, 'amn': 0, 'ch': 0, 'ppl': 0, 'mgt': 0, 'nat': 0,}
            labels = row['labels'].split()
            for label in labels:
                try:
                    key, value = label.split('-')
                except:
                    print("Unknown label with text:" + text)
                if(key not in label_aspects or value not in ['0', '1', '2', '3']):
                    raise Exception("Unknown label:", label)
                label_aspects[key] = int(value)
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['all']);  data_dict["type"].append('all-' + str(label_aspects['all']))
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['amn']);  data_dict["type"].append('amn-' + str(label_aspects['amn']))
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['ch']);   data_dict["type"].append('ch-' +  str(label_aspects['ch']))
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['ppl']);  data_dict["type"].append('ppl-' + str(label_aspects['ppl']))
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['mgt']);  data_dict["type"].append('mgt-' + str(label_aspects['mgt']))
            data_dict["en_text"].append(text); data_dict["labels"].append(label_aspects['nat']);  data_dict["type"].append('nat-' + str(label_aspects['nat']))
        datasets[name] = Dataset.from_dict(DatasetDict(data_dict))
    #################################################################################################
    new_datasets = {"en_text": [], "labels": [], "type": []}
    Max = 0
    aspect_categories = ['all', 'amn', 'ch', 'ppl', 'mgt', 'nat']
    positive_short2full = {
        'all': 'positive all: ',
        'amn': 'positive amenities: ',
        'ch':  'positive cultural heritage: ',
        'ppl': 'positive people: ',
        'mgt': 'positive management: ',
        'nat': 'positive nature: ',
    }
    negative_short2full = {
        'all': 'negative all: ',
        'amn': 'negative amenities: ',
        'ch':  'negative cultural heritage: ',
        'ppl': 'negative people: ',
        'mgt': 'negative management: ',
        'nat': 'negative nature: ',
    }
    # 2 cau hoa positive-*-1#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-1'
        polarity_1_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        for i in range(polarity_1_datasets.num_rows):
            # 1 + any
            first_sentence = polarity_1_datasets[i]["en_text"]
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
        type = aspect_category + '-3'
        polarity_3_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        for i in range(polarity_3_datasets.num_rows):
            # 3 + any
            first_sentence = polarity_3_datasets[i]["en_text"]
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
        Max = max(Max, polarity_1_datasets.num_rows + polarity_3_datasets.num_rows)
    # 2 cau hoa negative-*-1#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-2'
        polarity_2_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        for i in range(polarity_2_datasets.num_rows):
            # 2 + any
            first_sentence = polarity_2_datasets[i]["en_text"]
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
        type = aspect_category + '-3'
        polarity_3_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        for i in range(polarity_3_datasets.num_rows):
            # 3 + any
            first_sentence = polarity_3_datasets[i]["en_text"]
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
        Max = max(Max, polarity_1_datasets.num_rows + polarity_3_datasets.num_rows)
    # Can bang du lieu positive-*-1#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-1'
        polarity_1_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_1 = polarity_1_datasets.num_rows

        type = aspect_category + '-3'
        polarity_3_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_3 = polarity_3_datasets.num_rows
        
        probability = num_rows_1/(num_rows_1 + num_rows_3)
        for i in range(num_rows_1 + num_rows_3, Max):
            first_sentence = ''
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if (random.random() < probability):
                first_sentence = polarity_1_datasets[random.randint(0, num_rows_1 - 1)]["en_text"]
            else:
                first_sentence = polarity_3_datasets[random.randint(0, num_rows_3 - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('positive-' + aspect_category + '-1')
    # Can bang du lieu negative-*-1#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-2'
        polarity_2_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_2 = polarity_2_datasets.num_rows

        type = aspect_category + '-3'
        polarity_3_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_3 = polarity_3_datasets.num_rows
        
        probability = num_rows_2/(num_rows_2 + num_rows_3)
        for i in range(num_rows_2 + num_rows_3, Max):
            first_sentence = ''
            second_sentence = datasets['train'][random.randint(0, datasets['train'].num_rows - 1)]["en_text"]
            if (random.random() < probability):
                first_sentence = polarity_2_datasets[random.randint(0, num_rows_2 - 1)]["en_text"]
            else:
                first_sentence = polarity_3_datasets[random.randint(0, num_rows_3 - 1)]["en_text"]
            if random.randint(0, 1):
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
            else:
                new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(second_sentence) + ", " + first_sentence)
                new_datasets['labels'].append(1)
                new_datasets['type'].append('negative-' + aspect_category + '-1')
    # Tao du lieu positive-*-0#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-0'
        polarity_0_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_0 = polarity_0_datasets.num_rows

        type = aspect_category + '-2'
        polarity_2_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_2 = polarity_2_datasets.num_rows
        for i in range(Max):
            first_sentence = ''
            if random.randint(0, 1):
                first_sentence = polarity_0_datasets[random.randint(0, num_rows_0 - 1)]["en_text"]
            else:
                first_sentence = polarity_2_datasets[random.randint(0, num_rows_2 - 1)]["en_text"]

            second_sentence = ''
            if random.randint(0, 1):
                second_sentence = polarity_0_datasets[random.randint(0, num_rows_0 - 1)]["en_text"]
            else:
                second_sentence = polarity_2_datasets[random.randint(0, num_rows_2 - 1)]["en_text"]
            
            new_datasets['en_text'].append(positive_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
            new_datasets['labels'].append(0)
            new_datasets['type'].append('positive-' + aspect_category + '-0')
    # Tao du lieu negative-*-0#################################################################################################
    for aspect_category in aspect_categories:
        type = aspect_category + '-0'
        polarity_0_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_0 = polarity_0_datasets.num_rows

        type = aspect_category + '-1'
        polarity_1_datasets = datasets['train'].filter(lambda x: x['type'] == type)
        num_rows_1 = polarity_1_datasets.num_rows
        for i in range(Max):
            first_sentence = ''
            if random.randint(0, 1):
                first_sentence = polarity_0_datasets[random.randint(0, num_rows_0 - 1)]["en_text"]
            else:
                first_sentence = polarity_1_datasets[random.randint(0, num_rows_1 - 1)]["en_text"]

            second_sentence = ''
            if random.randint(0, 1):
                second_sentence = polarity_0_datasets[random.randint(0, num_rows_0 - 1)]["en_text"]
            else:
                second_sentence = polarity_1_datasets[random.randint(0, num_rows_1 - 1)]["en_text"]
            
            new_datasets['en_text'].append(negative_short2full[aspect_category] + remove_last_punctuation(first_sentence) + ", " + second_sentence)
            new_datasets['labels'].append(0)
            new_datasets['type'].append('negative-' + aspect_category + '-0')
    clear_output()
    return DatasetDict({"train": Dataset.from_dict(DatasetDict(new_datasets))})

def polarity2label(check_polarity, polarity):
    if check_polarity == 1:
        if polarity == 1 or polarity == 3:
            return 1
        else:
            return 0
    if check_polarity == 2:
        if polarity == 2 or polarity == 3:
            return 1
        else:
            return 0

def create_test_dataset(raw_datasets):
    datasets = DatasetDict({
        'train': raw_datasets['train'],
        'test': raw_datasets['test']})
    for name in datasets:
        data_dict = {"en_text": [], "labels": [], "type": []}
        for item in datasets[name]:
            text = item['en_text']
            aspect2label = {'all': 0, 'amn': 0, 'ch': 0, 'ppl': 0, 'mgt': 0, 'nat': 0,}
            labels = item['labels'].split()
            for label in labels:
                try:
                    key, value = label.split('-')
                except:
                    print("Unknown label with text:" + text)
                if(key not in aspect2label or value not in ['0', '1', '2', '3']):
                    raise Exception("Unknown label:", label)
                aspect2label[key] = int(value)
            data_dict["en_text"].append("positive all: " + text);              data_dict["labels"].append(polarity2label(1, aspect2label['all']));data_dict["type"].append('positive-all-' + str(polarity2label(1, aspect2label['all'])))
            data_dict["en_text"].append("positive amenities: " + text);        data_dict["labels"].append(polarity2label(1, aspect2label['amn']));data_dict["type"].append('positive-amn-' + str(polarity2label(1, aspect2label['amn'])))
            data_dict["en_text"].append("positive cultural heritage: " + text);data_dict["labels"].append(polarity2label(1, aspect2label['ch'])); data_dict["type"].append('positive-ch-'  + str(polarity2label(1, aspect2label['ch'])))
            data_dict["en_text"].append("positive people: " + text);           data_dict["labels"].append(polarity2label(1, aspect2label['ppl']));data_dict["type"].append('positive-ppl-' + str(polarity2label(1, aspect2label['ppl'])))
            data_dict["en_text"].append("positive management: " + text);       data_dict["labels"].append(polarity2label(1, aspect2label['mgt']));data_dict["type"].append('positive-mgt-' + str(polarity2label(1, aspect2label['mgt'])))
            data_dict["en_text"].append("positive nature: " + text);           data_dict["labels"].append(polarity2label(1, aspect2label['nat']));data_dict["type"].append('positive-nat-' + str(polarity2label(1, aspect2label['nat'])))

            data_dict["en_text"].append("negative all: " + text);              data_dict["labels"].append(polarity2label(2, aspect2label['all']));data_dict["type"].append('negative-all-' + str(polarity2label(2, aspect2label['all'])))
            data_dict["en_text"].append("negative amenities: " + text);        data_dict["labels"].append(polarity2label(2, aspect2label['amn']));data_dict["type"].append('negative-amn-' + str(polarity2label(2, aspect2label['amn'])))
            data_dict["en_text"].append("negative cultural heritage: " + text);data_dict["labels"].append(polarity2label(2, aspect2label['ch'])); data_dict["type"].append('negative-ch-'  + str(polarity2label(2, aspect2label['ch'])))
            data_dict["en_text"].append("negative people: " + text);           data_dict["labels"].append(polarity2label(2, aspect2label['ppl']));data_dict["type"].append('negative-ppl-' + str(polarity2label(2, aspect2label['ppl'])))
            data_dict["en_text"].append("negative management: " + text);       data_dict["labels"].append(polarity2label(2, aspect2label['mgt']));data_dict["type"].append('negative-mgt-' + str(polarity2label(2, aspect2label['mgt'])))
            data_dict["en_text"].append("negative nature: " + text);           data_dict["labels"].append(polarity2label(2, aspect2label['nat']));data_dict["type"].append('negative-nat-' + str(polarity2label(2, aspect2label['nat'])))
        datasets[name] = Dataset.from_dict(DatasetDict(data_dict))
    return datasets

In [None]:
# load model, set freeze, set device
tokenizer = AutoTokenizer.from_pretrained("khanhtq2802/thesis-model")
config = AutoConfig.from_pretrained("khanhtq2802/thesis-model")
model = AutoModelForSequenceClassification.from_pretrained("../best_weights", num_labels=2, ignore_mismatched_sizes=True)

In [None]:
# Freeze all layers
for name, param in model.named_parameters():
    param.requires_grad = False
# unFreeze
for name, param in model.named_parameters():
    if name.startswith("classifier.out_proj"):
        param.requires_grad = True
for name, param in model.named_parameters():
    if name.startswith("classifier.dense"):
        param.requires_grad = True
for name, param in model.named_parameters():
    if name.startswith("roberta.encoder.layer.11"):
        param.requires_grad = True
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
print(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Trainable Parameters:", trainable_params)

In [None]:
# Danh sách các checkpolarity-aspect
aspects = ["positive all: ", "positive amenities: ", "positive cultural heritage: ", "positive people: ", "positive management: ", "positive nature: ",
           "negative all: ", "negative amenities: ", "negative cultural heritage: ", "negative people: ", "negative management: ", "negative nature: "]
skip_tokens = 0
# Chuyển các aspect sang token và đếm số lượng token
for aspect in aspects:
    tokens = tokenizer.tokenize(aspect)
    skip_tokens = max(skip_tokens, len(tokens))
print(skip_tokens)
class CustomDataCollator:
    def __init__(self, tokenizer, mlm_probability=0.15, skip_tokens=4):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
        self.skip_tokens = skip_tokens
    def __call__(self, examples):
        input_ids = [example['input_ids'] for example in examples]
        attention_mask = [example['attention_mask'] for example in examples]
        labels = [example['labels'] for example in examples]
        # Mask tokens with probability self.mlm_probability
        for i in range(len(input_ids)):
            for j in range(self.skip_tokens + 1, len(input_ids[i])):
                if random.random() < self.mlm_probability:
                    input_ids[i][j] = self.tokenizer.mask_token_id
        # Convert input_ids to tensor
        input_ids = pad_sequence([torch.tensor(sublist) for sublist in input_ids], batch_first=True, padding_value=1)
        # Pad attention_mask and convert to tensor
        attention_mask = pad_sequence([torch.tensor(sublist) for sublist in attention_mask], batch_first=True, padding_value=1)
        return {
            'labels': torch.tensor(labels),
            'input_ids': input_ids,
            'attention_mask': attention_mask,}
def tokenize_function(example):
    return tokenizer(example["en_text"], truncation=True)
data_collator_train = CustomDataCollator(
    tokenizer=tokenizer,
    mlm_probability=0.15,  # 15% masking rate
    skip_tokens=skip_tokens
)
data_collator_test = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# load loss history from file
training_losses = []
test_losses = []
try:
    with open("losses_2.txt", "r") as f:
        for line in f:
            if line.strip():  # Check if line is not empty
                training_loss, test_loss = line.split(",")
                training_losses.append(float(training_loss))
                test_losses.append(float(test_loss))
except:
    print("error when load training history")

In [None]:
test_dataloader = DataLoader(
    create_test_dataset(raw_dataset).map(tokenize_function, batched=True).remove_columns(['en_text', 'type'])['test'], 
    shuffle=False, 
    batch_size=24, 
    collate_fn=data_collator_test)
batch_size = 12
train_dataloader = DataLoader(
    create_train_dataset(raw_dataset, use_back_translate=True).map(tokenize_function, batched=True).remove_columns(["en_text", "type"])["train"], 
    shuffle=True, 
    batch_size=batch_size, 
    collate_fn=data_collator_train)

optimizer = AdamW(model.parameters(), lr=3e-5) #before 1e-5
num_epochs = 100
num_training_steps = num_epochs * len(train_dataloader)
decay = "cosine" #constant cosine linear

lr_scheduler = get_scheduler(
    decay,
    optimizer=optimizer,
    num_warmup_steps=len(train_dataloader),
    num_training_steps=num_training_steps)
for epoch in range(num_epochs):
    # train
    model.train()
    epoch_losses = []
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        loss = model(**batch).loss
        loss.backward()
        epoch_losses.append(loss.item())
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
    training_losses.append(sum(epoch_losses) / len(epoch_losses))
    # evaluation
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    with torch.no_grad():
        for batch in test_dataloader:
            total_loss += torch.nn.CrossEntropyLoss()(model(**{k: v.to(device) for k, v in batch.items() if k != 'labels'}).logits, batch['labels'].to(device)).item()
    test_losses.append(total_loss / len(test_dataloader))
    # Plotting
    clear_output(wait=True)
    plt.plot(training_losses, label="Training Loss")
    plt.plot(test_losses, label="test Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    # save loss history to file
    with open("losses.txt", "w") as f:
        for i in range(len(training_losses)):
            f.write(f"{training_losses[i]},{test_losses[i]}\n")
    # save model
    model.save_pretrained("../late_weights")
    if test_losses[-1] == min(test_losses):
        model.save_pretrained("../best_weights")
    # re-create train_dataloader
    if len(test_losses) >= 2:
        if test_losses[-1] >= test_losses[-2]:
            torch.cuda.empty_cache()
            train_dataloader = DataLoader(
                create_train_dataset(raw_dataset, use_back_translate=True).map(tokenize_function, batched=True).remove_columns(["en_text", "type"])["train"], 
                shuffle=True,
                batch_size=batch_size, 
                collate_fn=data_collator_train)

In [None]:
def evaluate_model(model, raw_datasets, device, name='train'):
    dataset = create_test_dataset(raw_datasets)

    dataloader = DataLoader(
    dataset.map(tokenize_function, batched=True).remove_columns(['en_text', 'type'])[name], 
    shuffle=False,
    batch_size=16, 
    collate_fn=data_collator_test)

    model.eval()  # Set the model to evaluation mode
    total_acc, total_loss = 0, 0
    fail_count = {
        "false-positive": {"all-0": 0, "amn-0": 0, "ch-0": 0, "hist-0": 0, "ppl-0": 0, "mgt-0": 0, "nat-0": 0,
        "all-1": 0, "amn-1": 0, "ch-1": 0, "hist-1": 0, "ppl-1": 0, "mgt-1": 0, "nat-1": 0,
        "all-2": 0, "amn-2": 0, "ch-2": 0, "hist-2": 0, "ppl-2": 0, "mgt-2": 0, "nat-2": 0,
        "all-3": 0, "amn-3": 0, "ch-3": 0, "hist-3": 0, "ppl-3": 0, "mgt-3": 0, "nat-3": 0,},

        "false-negative": {"all-0": 0, "amn-0": 0, "ch-0": 0, "hist-0": 0, "ppl-0": 0, "mgt-0": 0, "nat-0": 0,
        "all-1": 0, "amn-1": 0, "ch-1": 0, "hist-1": 0, "ppl-1": 0, "mgt-1": 0, "nat-1": 0,
        "all-2": 0, "amn-2": 0, "ch-2": 0, "hist-2": 0, "ppl-2": 0, "mgt-2": 0, "nat-2": 0,
        "all-3": 0, "amn-3": 0, "ch-3": 0, "hist-3": 0, "ppl-3": 0, "mgt-3": 0, "nat-3": 0,},
    }
    type_counts = {}
    for type_ in dataset[name]['type']:
        if type_ in type_counts:
            type_counts[type_] += 1
        else:
            type_counts[type_] = 1

    with torch.no_grad():
        for batch in dataloader:
            # Move data to the specified device
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)

            outputs = model(**inputs)
            logits = outputs.logits  # Access the model's output logits

            # Calculate loss (optional, for reference)
            loss_fn = torch.nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            total_loss += loss.item()

            # Calculate accuracy
            pred_labels = torch.argmax(logits, dim=1)  # Get predicted labels
            match_labels = pred_labels == labels
            for i in range(len(match_labels)):
                if match_labels[i] == True:
                    total_acc+=1
                else:
                    text = tokenizer.decode(batch['input_ids'][i])
                    if text.startswith("<s>all: "):
                        fail_count["false-positive"]["all-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["all-" + str(labels[i].item())] += 1
                    if text.startswith("<s>amenities: "):
                        fail_count["false-positive"]["amn-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["amn-" + str(labels[i].item())] += 1
                    if text.startswith("<s>cultural heritage: "):
                        fail_count["false-positive"]["ch-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["ch-" + str(labels[i].item())] += 1
                    if text.startswith("<s>history: "):
                        fail_count["false-positive"]["hist-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["hist-" + str(labels[i].item())] += 1
                    if text.startswith("<s>people: "):
                        fail_count["false-positive"]["ppl-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["ppl-" + str(labels[i].item())] += 1
                    if text.startswith("<s>management: "):
                        fail_count["false-positive"]["mgt-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["mgt-" + str(labels[i].item())] += 1
                    if text.startswith("<s>nature: "):
                        fail_count["false-positive"]["nat-" + str(pred_labels[i].item())] += 1
                        fail_count["false-negative"]["nat-" + str(labels[i].item())] += 1
                    print(text)
                    print("pred: ", pred_labels[i])
                    print("labels: ", labels[i])
                    print()
    recalls = []; precisions = []; f1s = []
    # recall, precision, f1
    for key in type_counts:
        if type_counts[key] > 0:
            true_positive = type_counts[key] - fail_count["false-negative"][key]
            recall = 0; precision = 0; f1 = 0
            if true_positive != 0:
                recall = true_positive/type_counts[key]
                precision = true_positive/(true_positive + fail_count["false-positive"][key])
                f1 = 2*recall*precision/(recall + precision)
            print(key, "recall=", round(recall, 4), "precision=", round(precision, 4), "f1=", round(f1, 4))
            recalls.append(recall); precisions.append(precision); f1s.append(f1)
    # Accuracy
    print(total_acc, dataloader.dataset.num_rows)
    print("Accuracy:", round(total_acc / dataloader.dataset.num_rows, 5))
    # Loss
    print("Loss:", total_loss / len(dataloader))
    # Marco recall, precision, f1
    print("Marco-recall:", round(sum(recalls)/len(recalls), 5))
    print("Marco-precision:", round(sum(precisions)/len(precisions), 5))
    print("Marco-f1:", round(sum(f1s)/len(f1s), 5))

In [None]:
evaluate_model(model, raw_dataset, device, name='test')