In [None]:
!pip install -q mlflow paramiko

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset
from random import sample
from torch.utils.data import DataLoader
from transformers import RobertaModel, RobertaTokenizer
from sklearn.utils import resample
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm
import mlflow
import time
import pandas as pd
import os
import paramiko

In [None]:
import sys
sys.path.append('/kaggle/input/liar-plus-utils')
print(sys.path)
from utils import (
    LABEL_MAPPING,
    ids2labels,
    save_checkpoint,
    load_checkpoint,
    save_best_model,
    load_best_model,
    save_model_remotely
)

In [None]:
one_hot_labels = {
    "sentiment": ['negative', 'neutral', 'positive'],
	"question": ['not_question', 'question'],
	"curse": ['curse', 'non-curse'],
	"emotion": ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'],
	"gibberish": ['clean', 'mild gibberish', 'word salad'],
	"offensiveness": ['non-offensive', 'offensive'],
	"political_bias": ['CENTER', 'LEFT', 'RIGHT']
}

label_to_index = {
    "sentiment": {label: idx for idx, label in enumerate(one_hot_labels["sentiment"])},
	"question": {label: idx for idx, label in enumerate(one_hot_labels["question"])},
	"curse": {label: idx for idx, label in enumerate(one_hot_labels["curse"])},
	"emotion": {label: idx for idx, label in enumerate(one_hot_labels["emotion"])},
	"gibberish": {label: idx for idx, label in enumerate(one_hot_labels["gibberish"])},
	"offensiveness": {label: idx for idx, label in enumerate(one_hot_labels["offensiveness"])},
	"political_bias": {label: idx for idx, label in enumerate(one_hot_labels["political_bias"])}
}

one_hot_metadata_size = sum([len(x) for x in one_hot_labels.values()])

In [None]:
class LiarPlusSingleRobertaDataset(Dataset):
    def __init__(
        self,
        filepath: str,
        tokenizer,
        str_metadata_cols: list[str],
        num_metadata_cols: list[str],
        one_hot_metadata_cols: list[str],
        max_length: int = 512,
    ):
        self.df = pd.read_csv(filepath)

        self.str_metadata_cols = str_metadata_cols
        self.num_metadata_cols = num_metadata_cols
        self.one_hot_metadata_cols = one_hot_metadata_cols

        for column in self.str_metadata_cols:
            self.df[column] = self.df[column].astype(str)

        self.df["statement"] = self.df["statement"].astype(str)
        #self.df["justification"] = self.df["justification"].astype(str)
        #self.df["articles"] = self.df["articles"].astype(str)

        self.statement_max_len = max_length // 4
        #self.justification_max_len = max_length // 4
        #self.article_max_len = max_length // 4
        self.str_metadata_max_len = max((
            max_length - self.statement_max_len# - self.justification_max_len - self.article_max_len
        ) // len(str_metadata_cols), 15)

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df.index)
        
    def limit_tokens(self, text, max_length=512):
        return self.tokenizer.convert_tokens_to_string(
            self.tokenizer.tokenize(text)[:max_length]
        )

    def __getitem__(self, index: int):
        item = self.df.iloc[index]

        input_text = self.limit_tokens(
            f"[STATEMENT] {item['statement']}",
            self.statement_max_len
        )
        #input_text += self.limit_tokens(
        #    f" [JUSTIFICATION] {item['justification']}",
        #    self.justification_max_len,
        #)
        #input_text += self.limit_tokens(
        #    f" [ARTICLE] {item['articles']}",
        #    self.article_max_len,
        #)

        for column in self.str_metadata_cols:
            input_text += self.limit_tokens(f" [{column.upper()}] {item[column]}", self.str_metadata_max_len)

        encoded = self.tokenizer(
            input_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        label = LABEL_MAPPING[item["label"]]

        num_metadata = [item[column] for column in self.num_metadata_cols]

        one_hot_metadata = []
        for column in self.one_hot_metadata_cols:
            value = item[column]
            possible_values = len(one_hot_labels[column])
            id_tensor = torch.tensor(label_to_index[column][value])
            one_hot_metadata.append(F.one_hot(id_tensor, possible_values))

        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "num_metadata": torch.tensor(num_metadata).float(),
            "one_hot_metadata": torch.cat(one_hot_metadata, dim=0).float(),
            "label": torch.tensor(label)
        }

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LiarPlusSingleFinetunedRoBERTasClassifier(nn.Module):
    def __init__(
        self, encoder_model, num_metadata_len, one_hot_metadata_size, num_hidden, num_classes
    ):
        super(LiarPlusSingleFinetunedRoBERTasClassifier, self).__init__()
        self.encoder = encoder_model
        self.hl = nn.Linear(
            self.encoder.config.hidden_size + num_metadata_len + one_hot_metadata_size, num_hidden
        )
        self.dropout = nn.Dropout(p=0.1)
        self.fc = nn.Linear(num_hidden, num_classes)

    def forward(self, input_ids, attention_mask, num_metadata, one_hot_metadata):
        outputs = self.encoder(
            input_ids=input_ids, attention_mask=attention_mask
        )

        cls_embedding = outputs.pooler_output
        concatted_inputs = torch.cat([cls_embedding, num_metadata, one_hot_metadata], dim=1)

        hl_output = F.gelu(self.hl(concatted_inputs))
        hl_output = self.dropout(hl_output)

        logits = self.fc(hl_output)
        return logits

    def roberta_trainable_state(self):
        return {
            name: param for name, param in self.encoder.named_parameters() if param.requires_grad
        }
    
    def load_roberta_trainable_state(self, state_dict):
        self.encoder.load_state_dict(state_dict, strict=False)

    # Zapisz tylko wagi warstw klasyfikatora
    def state_for_save(self):
        return {
            'hl_state_dict': self.hl.state_dict(),
            'fc_state_dict': self.fc.state_dict(),
            'roberta_trainable': self.roberta_trainable_state(),
        }
        
    # Ładowanie modelu (tylko wagi klasyfikatora)
    def load_state_from_save(self, state):
        self.hl.load_state_dict(state['hl_state_dict'])
        self.fc.load_state_dict(state['fc_state_dict'])
        if 'roberta_trainable' in state:
            self.load_roberta_trainable_state(state['roberta_trainable'])

In [None]:
def test(
    model: nn.Module,
    best_model_path: str,
    dataloader: DataLoader
) -> None:
    # Define loss function
    criterion = nn.CrossEntropyLoss()

    load_best_model(model, best_model_path)
    
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    f1 = MulticlassF1Score(num_classes, average=None).to(device)
    precision = MulticlassPrecision(num_classes, average=None).to(device)
    recall = MulticlassRecall(num_classes, average=None).to(device)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            num_metadata = batch["num_metadata"].to(device)
            one_hot_metadata = batch["one_hot_metadata"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask, num_metadata, one_hot_metadata)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * input_ids.size(0)

            preds = torch.argmax(outputs, dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += input_ids.size(0)

            f1.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    f1_res = f1.compute()
    precision_res = precision.compute()
    recall_res = recall.compute()

    mlflow.log_metric("test_acc", accuracy)
    mlflow.log_metric("test_loss", accuracy)

    for i in range(num_classes):
        mlflow.log_metric(f"test_f1_{ids2labels[i]}", f1_res[i])
        mlflow.log_metric(f"test_precision_{ids2labels[i]}", precision_res[i])
        mlflow.log_metric(f"test_recall_{ids2labels[i]}", recall_res[i])
    
    macro_f1 = f1_res.mean()
    macro_precision = precision_res.mean()
    macro_recall = recall_res.mean()

    mlflow.log_metric("test_f1", macro_f1)
    mlflow.log_metric("test_precision", macro_precision)
    mlflow.log_metric("test_recall", macro_recall)

    print(
        f"Test Loss: {avg_loss:.4f}, "
        f"Test Accuracy: {accuracy:.4f}, "
        f"Test F1: {f1_res} (marcro = {macro_f1:.4f}), "
        f"Test Precision: {precision_res} (marcro = {macro_precision:.4f}), "
        f"Test Recall: {recall_res} (marcro = {macro_recall:.4f}), "
    )

In [None]:
def train(
    creds: dict,
    model: nn.Module,
    save_path: str,
    remote_models_path: str,
    best_model_path: str,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    batch_size: int,
    num_classes: int,
    lr=1e-3,
    encoder_lr=1e-5,
    epochs=30,
    patience=5,
    resume: bool = False,
    reset_epoch: bool = False,
) -> None:
    dev_name = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device {dev_name}")
    device = torch.device(dev_name)

    # Define optimizer and loss function
    # Train only the classifier
    optimizer = torch.optim.AdamW([
        {'params': model.encoder.parameters(), 'lr': encoder_lr},  # niższe LR dla encodera
        {'params': model.hl.parameters(), 'lr': lr},
        {'params': model.fc.parameters(), 'lr': lr},
    ])
    criterion = nn.CrossEntropyLoss()

    # Checkpoint Path
    checkpoint_path = f"checkpoint_{patience}.pth"

    checkpoint_send_interval = 5

    # Track best loss for model saving
    # Load Checkpoint (Decide if you want to continue)
    start_epoch, best_val_accuracy = load_checkpoint(
        model,
        optimizer,
        checkpoint_path,
        resume,
        reset_epoch
    )

    patience_counter = 0

    f1 = MulticlassF1Score(num_classes, average=None).to(device)
    precision = MulticlassPrecision(num_classes, average=None).to(device)
    recall = MulticlassRecall(num_classes, average=None).to(device)

    # Training loop
    for epoch in range(start_epoch, epochs):
        model.train()
        epoch_loss = 0

        train_accuracy = 0

        for batch in tqdm(
            train_loader, desc=f"Epoch {epoch+1}", leave=False
        ):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            num_metadata = batch["num_metadata"].to(device)
            one_hot_metadata = batch["one_hot_metadata"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, num_metadata, one_hot_metadata)
            loss = criterion(
                outputs, labels
            )  # można spróbować to logować jako osobny wykres do debugowania
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=-1)
            train_accuracy += (preds == labels).sum().item()

            f1.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)

        avg_loss = epoch_loss / len(train_loader)
        avg_train_accuracy = train_accuracy / len(train_loader.dataset)
        mlflow.log_metric("train_loss", avg_loss, step=epoch)
        mlflow.log_metric("train_acc", avg_train_accuracy, step=epoch)

        f1_res = f1.compute()
        precision_res = precision.compute()
        recall_res = recall.compute()

        for i in range(num_classes):
            mlflow.log_metric(
                f"train_f1_{ids2labels[i]}", f1_res[i], step=epoch
            )
            mlflow.log_metric(
                f"train_precision_{ids2labels[i]}",
                precision_res[i],
                step=epoch,
            )
            mlflow.log_metric(
                f"train_recall_{ids2labels[i]}", recall_res[i], step=epoch
            )

        macro_f1 = f1_res.mean()
        macro_precision = precision_res.mean()
        macro_recall = recall_res.mean()

        mlflow.log_metric("train_f1", macro_f1, step=epoch)
        mlflow.log_metric("train_precision", macro_precision, step=epoch)
        mlflow.log_metric("train_recall", macro_recall, step=epoch)

        tqdm.write(
            f"Epoch {epoch+1}: "
            f"Training Loss: {avg_loss}, "
            f"Training Accuracy: {avg_train_accuracy}, "
            f"Training F1: {macro_f1}, "
            f"Training Precision: {macro_precision}, "
            f"Training Recall: {macro_recall}"
        )

        # Validation step
        model.eval()  # Switch to evaluation mode
        val_loss = 0
        val_accuracy = 0

        f1.reset()
        precision.reset()
        recall.reset()

        with torch.no_grad():
            for batch in tqdm(
                val_loader,
                desc=f"Validation of epoch {epoch + 1}",
                leave=False,
            ):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                num_metadata = batch["num_metadata"].to(device)
                one_hot_metadata = batch["one_hot_metadata"].to(device)
                labels = batch["label"].to(device)

                outputs = model(input_ids, attention_mask, num_metadata, one_hot_metadata)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                preds = torch.argmax(outputs, dim=-1)
                val_accuracy += (preds == labels).sum().item()
                f1.update(preds, labels)
                precision.update(preds, labels)
                recall.update(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_accuracy = val_accuracy / len(val_loader.dataset)
        mlflow.log_metric("val_loss", avg_val_loss, step=epoch)
        mlflow.log_metric("val_acc", avg_val_accuracy, step=epoch)

        f1_res = f1.compute()
        precision_res = precision.compute()
        recall_res = recall.compute()

        for i in range(num_classes):
            mlflow.log_metric(
                f"val_f1_{ids2labels[i]}", f1_res[i], step=epoch
            )
            mlflow.log_metric(
                f"val_precision_{ids2labels[i]}",
                precision_res[i],
                step=epoch,
            )
            mlflow.log_metric(
                f"val_recall_{ids2labels[i]}", recall_res[i], step=epoch
            )

        macro_f1 = f1_res.mean()
        macro_precision = precision_res.mean()
        macro_recall = recall_res.mean()

        mlflow.log_metric("val_f1", macro_f1, step=epoch)
        mlflow.log_metric("val_precision", macro_precision, step=epoch)
        mlflow.log_metric("val_recall", macro_recall, step=epoch)

        print(
            f"Epoch {epoch+1}: "
            f"Validation Loss: {avg_val_loss}, "
            f"Validation Accuracy: {avg_val_accuracy}, "
            f"Validation F1: {macro_f1}, "
            f"Validation Precision: {macro_precision}, "
            f"Validation Recall: {macro_recall}"
        )

        save_checkpoint(
            model, optimizer, epoch, avg_val_accuracy, checkpoint_path
        )
        if (epoch + 1) % checkpoint_send_interval == 0:# and epoch != 0:
            save_model_remotely(checkpoint_path, remote_models_path, creds)

        # Check for early stopping
        if avg_val_accuracy > best_val_accuracy:
            best_val_accuracy = avg_val_accuracy
            patience_counter = 0
            # Save the best model
            save_best_model(
                model,
                optimizer,
                epoch,
                best_val_accuracy,
                best_model_path
            )
            save_model_remotely(best_model_path, remote_models_path, creds)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Log final checkpoint
    save_model_remotely(checkpoint_path, remote_models_path, creds)

In [None]:
mlflow_uri = "http://cimmerian.win:5000"
resume = False
reset_epoch = False

creds = {
    'hostname': "cimmerian.win",
    'port': 22,
    'username': "conan",
    'password': "conan"
}

mlflow.set_tracking_uri(uri=mlflow_uri)

# MLflow experiment setup
mlflow.set_experiment("RoBERTaSM")

# Load RoBERTa tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta = RobertaModel.from_pretrained("roberta-base")

# trenuje 2 ostatnie warstwy
for name, param in roberta.named_parameters():
    if name.startswith("encoder.layer.11") or name.startswith("pooler"):
        param.requires_grad = True
    else:
        param.requires_grad = False


# Hyperparameters
num_classes = 6
lr = 1e-3
encoder_lr = 1e-5
epochs = 30
hidden_size = 128
# Number of epochs to wait before stopping if no improvement
patience = 6

# Save path
save_path = "/kaggle/working"
# Remote models path
remote_models_path = "/home/conan/models/final_sm/"
# Best model path
best_model_path = f"{save_path}/best_model_{patience}.pth"

# można przetestować zachłannie
# dodajemy kolumnę jak poprawia i nie dodajemy jak nie poprawia
text_columns = [
    "subject",
    "speaker",
    "job_title",
    "state",
    "party_affiliation",
    "context"
]
num_metadata_cols = [
    "barely_true_counts",
    "false_counts",
    "half_true_counts",
    "mostly_true_counts",
    "pants_on_fire_counts",
    "grammar_errors",
    "ratio_of_capital_letters"
]
one_hot_cols = [
    "sentiment",
    "question",
    "curse",
    "emotion",
    "gibberish",
    "offensiveness",
    "political_bias"
]

#subset_size = 1000
#random_state = 42

# speedup the experiments
# można ustawić epochs na 1 i sprawdzić czy w ramach jednej epoki val loss spada
training_data = LiarPlusSingleRobertaDataset(
    "/kaggle/input/liar-plus-final-dataset/train2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
    one_hot_cols
)
validation_data = LiarPlusSingleRobertaDataset(
    "/kaggle/input/liar-plus-final-dataset/val2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
    one_hot_cols
)
test_data = LiarPlusSingleRobertaDataset(
    "/kaggle/input/liar-plus-final-dataset/test2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
    one_hot_cols
)

batch_size = 64

#training_data_subset = Subset(training_data, sample(range(len(training_data)), k=1000))

train_dataloader = DataLoader(
    training_data, batch_size=batch_size, shuffle=True
)
val_dataloader = DataLoader(
    validation_data, batch_size=batch_size, shuffle=True
)
test_dataloader = DataLoader(
    test_data, batch_size=batch_size, shuffle=True
)

# Instantiate model
model = LiarPlusSingleFinetunedRoBERTasClassifier(
    roberta,
    len(num_metadata_cols),
    one_hot_metadata_size,
    hidden_size,
    num_classes,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

start = time.time()
with mlflow.start_run():
    mlflow.log_param("learning_rate", lr)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("resume", resume)
    mlflow.log_param("reset_epoch", reset_epoch)
    mlflow.log_param("patience", patience)
    
    # Train the model
    train(
        creds,
        model,
        save_path,
        remote_models_path,
        best_model_path,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        batch_size,
        num_classes,
        lr,
        encoder_lr,
        epochs,
        patience,
        resume,
        reset_epoch,
    )
    # Evaluate on test dataset
    test(model, best_model_path, test_dataloader)
end = time.time()
print(f"Total time took training: {end-start}s")