In [None]:
!pip install -q mlflow paramiko

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import RobertaModel, RobertaTokenizer
from sklearn.utils import resample
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm
import mlflow
import time
import pandas as pd
import os
import paramiko

In [None]:
LABEL_MAPPING = {
    "pants-fire": 0,
    "false": 1,
    "barely-true": 2,
    "half-true": 3,
    "mostly-true": 4,
    "true": 5,
}

ids2labels = [
    "pants-fire",
    "false",
    "barely-true",
    "half-true",
    "mostly-true",
    "true",
]

In [None]:
def save_checkpoint(model, optimizer, epoch, val_acc, path="checkpoint.pth"):
    checkpoint = {
        "model_state_dict": model.state_for_save(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "val_acc": val_acc,
    }
    torch.save(checkpoint, path)
    print(
        f"Checkpoint saved at epoch {epoch} "
        f"with validation accuracy {val_acc:.4f}"
    )


def load_checkpoint(
    model,
    optimizer,
    path="checkpoint.pth",
    resume=False,
    reset_epoch=False
):
    if not resume:
        print("Resume is False. Starting from scratch.")
        return 0, 0  # Start fresh

    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_from_save(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        val_acc = checkpoint["val_acc"]
        if reset_epoch:
            print(
                f"Checkpoint loaded: Starting from initial"
                f"epoch, validation accuracy {val_acc:.4f}"
            )
            return 0, val_acc  # Start fresh with existing model
        else:
            print(
                f"Checkpoint loaded: Resuming from epoch "
                f"{epoch+1}, validation accuracy {val_acc:.4f}"
            )
            return epoch + 1, val_acc  # Next epoch to train
    else:
        print("No checkpoint found. Starting from scratch.")
        return 0, 0  # Start fresh


def save_best_model(model, optimizer, epoch, val_acc, path="best_model.pth"):
    best_model = {
        "model_state_dict": model.state_for_save(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "val_acc": val_acc,
    }
    torch.save(best_model, path)
    print(
        f"Best model saved at epoch {epoch} "
        f"with validation accuracy {val_acc:.4f}"
    )


def load_best_model(model, path="best_model.pth"):
    if os.path.exists(path):
        best_model = torch.load(path)
        model.load_state_from_save(best_model["model_state_dict"])
        print("Model loaded from best model checkpoint.")
    else:
        print("No best model checkpoint found.")


def save_model_remotely(local_path, remote_path, creds):
    # Ustawienia SSH
    hostname = creds['hostname']#"cimmerian.win"
    port = creds['port']#22
    username = creds['username']#"conan"
    password = creds['password']#"conan"

    # Połączenie SSH
    try:
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(hostname, port=port, username=username, password=password)
    
        # Pobierz rozmiar pliku lokalnego
        file_size = os.path.getsize(local_path)

        # Funkcja do aktualizacji paska postępu
        def progress_callback(transferred, total):
            progress_bar.update(transferred - progress_bar.n)
        
        # Inicjalizuj pasek postępu
        progress_bar = tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Uploading {local_path}")
    
        # SFTP transfer z callbackiem
        with ssh.open_sftp() as sftp:
            temp_remote_path = remote_path + os.path.basename(local_path) + ".tmp"
            final_remote_path = remote_path + os.path.basename(local_path)

            sftp.put(local_path, temp_remote_path, callback=progress_callback)

            try:
                sftp.remove(final_remote_path)
            except IOError:
                # Plik nie istnieje – można ignorować
                pass
            
            sftp.rename(temp_remote_path, final_remote_path)

    
        # Po zakończeniu
        progress_bar.close()
        print(f"Plik {os.path.basename(local_path)} został wysłany.")
    
    except Exception as e:
        print(f"Error: {e}")
    
    finally:
        # Zapewnia, że połączenie SSH zawsze zostanie zamknięte
        ssh.close()

In [None]:
class LiarPlusDataset(Dataset):
    def __init__(
        self,
        filepath: str,
        tokenizer,
        columns: list[str],
        num_metadata_cols: list[str],
        max_length: int = 128,
    ):
        self.df = pd.read_csv(filepath)

        self.columns = columns
        self.num_metadata_cols = num_metadata_cols

        for column in self.columns:
            self.df[column] = self.df[column].astype(str)

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, index: int):
        item = self.df.iloc[index]

        input_ids = []
        attention_mask = []

        for column in self.columns:
            encoded = self.tokenizer(
                item[column],
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            )
            input_ids.append(encoded["input_ids"])
            attention_mask.append(encoded["attention_mask"])

        input_ids = torch.cat(input_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)

        label = LABEL_MAPPING[item["label"]]

        metadata = [item[column] for column in self.num_metadata_cols]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "num_metadata": torch.tensor(metadata).float(),
            "label": torch.tensor(label),
        }


In [None]:
class LiarPlusDatasetSubset(Dataset):
    def __init__(
        self,
        total_size: int,
        filepath: str,
        tokenizer,
        columns: list[str],
        num_metadata_cols: list[str],
        random_state: int | None = None,
        max_length: int = 128,
    ):
        num_classes = 6
        df = pd.read_csv(filepath)

        if total_size != -1:
            desired_count = total_size // num_classes

            self.df = pd.concat(
                [
                    resample(
                        group,
                        replace=False,
                        n_samples=desired_count,
                        random_state=random_state,
                    )
                    for _, group in df.groupby("label")
                ]
            )
        else:
            self.df = df

        self.columns = columns
        self.num_metadata_cols = num_metadata_cols

        for column in self.columns:
            self.df[column] = self.df[column].astype(str)

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, index: int):
        item = self.df.iloc[index]

        input_ids = []
        attention_mask = []

        for column in self.columns:
            encoded = self.tokenizer(
                item[column],
                truncation=True,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
            )
            input_ids.append(encoded["input_ids"])
            attention_mask.append(encoded["attention_mask"])

        input_ids = torch.cat(input_ids, dim=0)
        attention_mask = torch.cat(attention_mask, dim=0)

        label = LABEL_MAPPING[item["label"]]

        metadata = [item[column] for column in self.num_metadata_cols]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "num_metadata": torch.tensor(metadata).float(),
            "label": torch.tensor(label),
        }

In [None]:
class LiarPlusMultipleRoBERTasClassifier(nn.Module):
    def __init__(
        self, encoder_model, inputs, num_metadata_len, num_hidden, num_classes
    ):
        super(LiarPlusMultipleRoBERTasClassifier, self).__init__()
        self.encoder = encoder_model
        self.hl = nn.Linear(
            self.encoder.config.hidden_size * inputs + num_metadata_len,
            num_hidden,
        )
        self.fc = nn.Linear(num_hidden, num_classes)

    def forward(self, input_ids, attention_mask, num_metadata):
        batch_size, num_fields, max_length = input_ids.shape

        # reshape from (batch_size, num_fields, max_length) to (batch_size * num_fields, max_length)
        flat_input_ids = input_ids.view(batch_size * num_fields, max_length)
        flat_attention_mask = attention_mask.view(
            batch_size * num_fields, max_length
        )

        with torch.no_grad():  # Ensure encoder remains frozen
            outputs = self.encoder(
                input_ids=flat_input_ids, attention_mask=flat_attention_mask
            )

        # hidden_size should be 768 for RoBERTa
        # shape (batch_size * num_fields, hidden_size)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]

        # reshape (batch_size * num_fields, hidden_size) -> (batch_size, num_fields, hidden_size)
        cls_reshaped = cls_embeddings.view(batch_size, num_fields, -1)

        # reshape (batch_size, num_fields, hidden_size) -> (batch_size, num_fields * hidden_size)
        # which is concatenation along seperate fields' CLS token for following classification
        flattened_cls = torch.flatten(cls_reshaped, start_dim=1)

        concatted_inputs = torch.cat([flattened_cls, num_metadata], dim=1)

        # pass through hidden layer for better feature selection
        hl_output = F.gelu(self.hl(concatted_inputs))

        # pass through classification layer
        logits = self.fc(hl_output)

        return logits

    # Zapisz tylko wagi warstw klasyfikatora
    def state_for_save(self):
        return {
            'hl_state_dict': self.hl.state_dict(),
            'fc_state_dict': self.fc.state_dict(),
        }
        
    # Ładowanie modelu (tylko wagi klasyfikatora)
    def load_state_from_save(self, state):
        self.hl.load_state_dict(state['hl_state_dict'])
        self.fc.load_state_dict(state['fc_state_dict'])

In [None]:
def test(
    model: nn.Module,
    best_model_path: str,
    dataloader: DataLoader
) -> None:
    # Define loss function
    criterion = nn.CrossEntropyLoss()

    load_best_model(model, best_model_path)
    
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    f1 = MulticlassF1Score(num_classes, average=None).to(device)
    precision = MulticlassPrecision(num_classes, average=None).to(device)
    recall = MulticlassRecall(num_classes, average=None).to(device)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            num_metadata = batch["num_metadata"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask, num_metadata)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * input_ids.size(0)

            preds = torch.argmax(outputs, dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += input_ids.size(0)

            f1.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples

    f1_res = f1.compute()
    precision_res = precision.compute()
    recall_res = recall.compute()

    mlflow.log_metric("test_acc", accuracy)
    mlflow.log_metric("test_loss", accuracy)

    for i in range(num_classes):
        mlflow.log_metric(f"test_f1_{ids2labels[i]}", f1_res[i])
        mlflow.log_metric(f"test_precision_{ids2labels[i]}", precision_res[i])
        mlflow.log_metric(f"test_recall_{ids2labels[i]}", recall_res[i])
    
    macro_f1 = f1_res.mean()
    macro_precision = precision_res.mean()
    macro_recall = recall_res.mean()

    mlflow.log_metric("test_f1", macro_f1)
    mlflow.log_metric("test_precision", macro_precision)
    mlflow.log_metric("test_recall", macro_recall)

    print(
        f"Test Loss: {avg_loss:.4f}, "
        f"Test Accuracy: {accuracy:.4f}, "
        f"Test F1: {f1_res} (marcro = {macro_f1:.4f}), "
        f"Test Precision: {precision_res} (marcro = {macro_precision:.4f}), "
        f"Test Recall: {recall_res} (marcro = {macro_recall:.4f}), "
    )

In [None]:
def train(
    creds: dict,
    model: nn.Module,
    save_path: str,
    remote_models_path: str,
    best_model_path: str,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    batch_size: int,
    num_classes: int,
    lr=1e-3,
    epochs=30,
    patience=5,
    resume: bool = False,
    reset_epoch: bool = False,
) -> None:
    dev_name = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device {dev_name}")
    device = torch.device(dev_name)

    # Define optimizer and loss function
    # Train only the classifier
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Checkpoint Path
    checkpoint_path = f"{save_path}/checkpoint_{patience}.pth"

    checkpoint_send_interval = 2

    # Track best loss for model saving
    # Load Checkpoint (Decide if you want to continue)
    start_epoch, best_val_accuracy = load_checkpoint(
        model,
        optimizer,
        checkpoint_path,
        resume,
        reset_epoch
    )

    patience_counter = 0

    f1 = MulticlassF1Score(num_classes, average=None).to(device)
    precision = MulticlassPrecision(num_classes, average=None).to(device)
    recall = MulticlassRecall(num_classes, average=None).to(device)

    # Training loop
    for epoch in range(start_epoch, epochs):
        model.train()
        epoch_loss = 0

        train_accuracy = 0

        for batch in tqdm(
            train_loader, desc=f"Epoch {epoch+1}", leave=False
        ):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            num_metadata = batch["num_metadata"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask, num_metadata)
            loss = criterion(
                outputs, labels
            )  # można spróbować to logować jako osobny wykres do debugowania
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=-1)
            train_accuracy += (preds == labels).sum().item()

            f1.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)

        avg_loss = epoch_loss / len(train_loader)
        avg_train_accuracy = train_accuracy / len(train_loader.dataset)
        mlflow.log_metric("train_loss", avg_loss, step=epoch)
        mlflow.log_metric("train_acc", avg_train_accuracy, step=epoch)

        f1_res = f1.compute()
        precision_res = precision.compute()
        recall_res = recall.compute()

        for i in range(num_classes):
            mlflow.log_metric(
                f"train_f1_{ids2labels[i]}", f1_res[i], step=epoch
            )
            mlflow.log_metric(
                f"train_precision_{ids2labels[i]}",
                precision_res[i],
                step=epoch,
            )
            mlflow.log_metric(
                f"train_recall_{ids2labels[i]}", recall_res[i], step=epoch
            )

        macro_f1 = f1_res.mean()
        macro_precision = precision_res.mean()
        macro_recall = recall_res.mean()

        mlflow.log_metric("train_f1", macro_f1, step=epoch)
        mlflow.log_metric("train_precision", macro_precision, step=epoch)
        mlflow.log_metric("train_recall", macro_recall, step=epoch)

        tqdm.write(
            f"Epoch {epoch+1}: "
            f"Training Loss: {avg_loss}, "
            f"Training Accuracy: {avg_train_accuracy}, "
            f"Training F1: {macro_f1}, "
            f"Training Precision: {macro_precision}, "
            f"Training Recall: {macro_recall}"
        )

        # Validation step
        model.eval()  # Switch to evaluation mode
        val_loss = 0
        val_accuracy = 0

        f1.reset()
        precision.reset()
        recall.reset()

        with torch.no_grad():
            for batch in tqdm(
                val_loader,
                desc=f"Validation of epoch {epoch + 1}",
                leave=False,
            ):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                num_metadata = batch["num_metadata"].to(device)
                labels = batch["label"].to(device)

                outputs = model(input_ids, attention_mask, num_metadata)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                preds = torch.argmax(outputs, dim=-1)
                val_accuracy += (preds == labels).sum().item()
                f1.update(preds, labels)
                precision.update(preds, labels)
                recall.update(preds, labels)

        avg_val_loss = val_loss / len(val_loader)
        avg_val_accuracy = val_accuracy / len(val_loader.dataset)
        mlflow.log_metric("val_loss", avg_val_loss, step=epoch)
        mlflow.log_metric("val_acc", avg_val_accuracy, step=epoch)

        f1_res = f1.compute()
        precision_res = precision.compute()
        recall_res = recall.compute()

        for i in range(num_classes):
            mlflow.log_metric(
                f"val_f1_{ids2labels[i]}", f1_res[i], step=epoch
            )
            mlflow.log_metric(
                f"val_precision_{ids2labels[i]}",
                precision_res[i],
                step=epoch,
            )
            mlflow.log_metric(
                f"val_recall_{ids2labels[i]}", recall_res[i], step=epoch
            )

        macro_f1 = f1_res.mean()
        macro_precision = precision_res.mean()
        macro_recall = recall_res.mean()

        mlflow.log_metric("val_f1", macro_f1, step=epoch)
        mlflow.log_metric("val_precision", macro_precision, step=epoch)
        mlflow.log_metric("val_recall", macro_recall, step=epoch)

        print(
            f"Epoch {epoch+1}: "
            f"Validation Loss: {avg_val_loss}, "
            f"Validation Accuracy: {avg_val_accuracy}, "
            f"Validation F1: {macro_f1}, "
            f"Validation Precision: {macro_precision}, "
            f"Validation Recall: {macro_recall}"
        )

        save_checkpoint(
            model, optimizer, epoch, avg_val_accuracy, checkpoint_path
        )
        if (epoch + 1) % checkpoint_send_interval == 0:# and epoch != 0:
            save_model_remotely(checkpoint_path, remote_models_path, creds)

        # Check for early stopping
        if avg_val_accuracy > best_val_accuracy:
            best_val_accuracy = avg_val_accuracy
            patience_counter = 0
            # Save the best model
            save_best_model(
                model,
                optimizer,
                epoch,
                best_val_accuracy,
                best_model_path
            )
            save_model_remotely(best_model_path, remote_models_path, creds)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Log final checkpoint
    save_model_remotely(checkpoint_path, remote_models_path, creds)

In [None]:
mlflow_uri = "http://cimmerian.win:5000"
resume = False
reset_epoch = False

creds = {
    'hostname': "cimmerian.win",
    'port': 22,
    'username': "conan",
    'password': "conan"
}

mlflow.set_tracking_uri(uri=mlflow_uri)

# MLflow experiment setup
mlflow.set_experiment("LiarPlusMultipleRoBERTasClassifier")

# Load RoBERTa tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta = RobertaModel.from_pretrained("roberta-base")

for param in roberta.parameters():
    param.requires_grad = False  # Freeze all layers

# Hyperparameters
num_classes = 6
lr = 1e-3
epochs = 30
hidden_size = 128
# Number of epochs to wait before stopping if no improvement
patience = 10

# Save path
save_path = "/kaggle/working/"
# Remote models path
remote_models_path = "/home/conan/models/multiple_robertas/"
# Best model path
best_model_path = f"{save_path}/best_model_{patience}.pth"

# można przetestować zachłannie
# dodajemy kolumnę jak poprawia i nie dodajemy jak nie poprawia
text_columns = [
    "statement",
    "subject",
    "speaker",
    "job_title",
    "state",
    "party_affiliation",
    "context",
    "justification",
]
num_metadata_cols = [
    "barely_true_counts",
    "false_counts",
    "half_true_counts",
    "mostly_true_counts",
    "pants_on_fire_counts",
]

subset_size = 1000
random_state = 42

# speedup the experiments
# można ustawić epochs na 1 i sprawdzić czy w ramach jednej epoki val loss spada
training_data_subset = LiarPlusDatasetSubset(
    -1,
    "/kaggle/input/liar-plus-final-dataset/train2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
    random_state,
)
validation_data = LiarPlusDatasetSubset(
    -1,
    "/kaggle/input/liar-plus-final-dataset/val2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
)
test_data = LiarPlusDatasetSubset(
    -1,
    "/kaggle/input/liar-plus-final-dataset/test2.csv",
    tokenizer,
    text_columns,
    num_metadata_cols,
)

batch_size = 64

train_dataloader = DataLoader(
    training_data_subset, batch_size=batch_size, shuffle=True
)
val_dataloader = DataLoader(
    validation_data, batch_size=batch_size, shuffle=True
)
test_dataloader = DataLoader(
    test_data, batch_size=batch_size, shuffle=True
)

# Instantiate model
model = LiarPlusMultipleRoBERTasClassifier(
    roberta,
    len(text_columns),
    len(num_metadata_cols),
    hidden_size,
    num_classes,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

start = time.time()
with mlflow.start_run():
    mlflow.log_param("learning_rate", lr)
    mlflow.log_param("batch_size", batch_size)
    mlflow.log_param("epochs", epochs)
    mlflow.log_param("resume", resume)
    mlflow.log_param("reset_epoch", reset_epoch)
    mlflow.log_param("patience", patience)
    
    # Train the model
    train(
        creds,
        model,
        save_path,
        remote_models_path,
        best_model_path,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        batch_size,
        num_classes,
        lr,
        epochs,
        patience,
        resume,
        reset_epoch,
    )
    # Evaluate on test dataset
    test(model, best_model_path, test_dataloader)
end = time.time()
print(f"Total time took training: {end-start}s")