In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from random import sample
from torch.utils.data import DataLoader
from transformers import RobertaModel, RobertaTokenizer
from sklearn.utils import resample
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
)
from tqdm import tqdm
import mlflow
import time
import pandas as pd
import os
import paramiko

In [2]:
import os
import torch
import paramiko  # type: ignore
from tqdm import tqdm


LABEL_MAPPING = {
    "pants-fire": 0,
    "false": 1,
    "barely-true": 2,
    "half-true": 3,
    "mostly-true": 4,
    "true": 5,
}

ids2labels = [
    "pants-fire",
    "false",
    "barely-true",
    "half-true",
    "mostly-true",
    "true",
]


def save_checkpoint(model, optimizer, epoch, val_acc, path="checkpoint.pth"):
    checkpoint = {
        "model_state_dict": model.state_for_save(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "val_acc": val_acc,
    }
    torch.save(checkpoint, path)
    print(
        f"Checkpoint saved at epoch {epoch} "
        f"with validation accuracy {val_acc:.4f}"
    )


def load_checkpoint(
    model, optimizer, path="checkpoint.pth", resume=False, reset_epoch=False
):
    if not resume:
        print("Resume is False. Starting from scratch.")
        return 0, 0  # Start fresh

    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_from_save(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        val_acc = checkpoint["val_acc"]
        if reset_epoch:
            print(
                f"Checkpoint loaded: Starting from initial"
                f"epoch, validation accuracy {val_acc:.4f}"
            )
            return 0, val_acc  # Start fresh with existing model
        else:
            print(
                f"Checkpoint loaded: Resuming from epoch "
                f"{epoch+1}, validation accuracy {val_acc:.4f}"
            )
            return epoch + 1, val_acc  # Next epoch to train
    else:
        print("No checkpoint found. Starting from scratch.")
        return 0, 0  # Start fresh


def save_best_model(model, optimizer, epoch, val_acc, path="best_model.pth"):
    best_model = {
        "model_state_dict": model.state_for_save(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "val_acc": val_acc,
    }
    torch.save(best_model, path)
    print(
        f"Best model saved at epoch {epoch} "
        f"with validation accuracy {val_acc:.4f}"
    )


def load_best_model(model, path="best_model.pth"):
    if os.path.exists(path):
        best_model = torch.load(path)
        model.load_state_from_save(best_model["model_state_dict"])
        print("Model loaded from best model checkpoint.")
    else:
        print("No best model checkpoint found.")


def save_model_remotely(local_path, remote_path, creds):
    # Ustawienia SSH
    hostname = creds["hostname"]
    port = creds["port"]
    username = creds["username"]
    password = creds["password"]

    # Połączenie SSH
    try:
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(hostname, port=port, username=username, password=password)

        # Pobierz rozmiar pliku lokalnego
        file_size = os.path.getsize(local_path)

        # Funkcja do aktualizacji paska postępu
        def progress_callback(transferred, total):
            progress_bar.update(transferred - progress_bar.n)

        # Inicjalizuj pasek postępu
        progress_bar = tqdm(
            total=file_size,
            unit="B",
            unit_scale=True,
            desc=f"Uploading {local_path}",
        )

        # SFTP transfer z callbackiem
        with ssh.open_sftp() as sftp:
            temp_remote_path = (
                remote_path + os.path.basename(local_path) + ".tmp"
            )
            final_remote_path = remote_path + os.path.basename(local_path)

            sftp.put(local_path, temp_remote_path, callback=progress_callback)

            try:
                sftp.remove(final_remote_path)
            except IOError:
                # Plik nie istnieje – można ignorować
                pass

            sftp.rename(temp_remote_path, final_remote_path)

        # Po zakończeniu
        progress_bar.close()
        print(f"Plik {os.path.basename(local_path)} został wysłany.")

    except Exception as e:
        print(f"Error: {e}")

    finally:
        # Zapewnia, że połączenie SSH zawsze zostanie zamknięte
        ssh.close()


In [3]:
class LiarPlusBaseRobertaDataset(Dataset):
    def __init__(
        self,
        filepath: str,
        tokenizer,
        max_length: int = 512,
    ):
        self.df = pd.read_csv(filepath)
        self.df["statement"] = self.df["statement"].astype(str)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df.index)

    def __getitem__(self, index: int):
        statement = self.df.iloc[index]["statement"]
        label_str = self.df.iloc[index]["label"]

        encoded = self.tokenizer(
            statement,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        label = LABEL_MAPPING[label_str]
        
        return {
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "label": torch.tensor(label),
        }

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LiarPlusBaseRoBERTasClassifier(nn.Module):
    def __init__(
        self, encoder_model, num_classes
    ):
        super(LiarPlusBaseRoBERTasClassifier, self).__init__()
        self.encoder = encoder_model
        self.fc = nn.Linear(self.encoder.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        logits = self.fc(cls_embedding)
        return logits

    # Zapisz tylko wagi warstw klasyfikatora
    def state_for_save(self):
        return {
            'fc_state_dict': self.fc.state_dict(),
        }
        
    # Ładowanie modelu (tylko wagi klasyfikatora)
    def load_state_from_save(self, state):
        self.fc.load_state_dict(state['fc_state_dict'])

In [5]:
# Load RoBERTa tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta = RobertaModel.from_pretrained("roberta-base")

for param in roberta.parameters():
    param.requires_grad = False  # Freeze all layers

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
training_data = LiarPlusBaseRobertaDataset(
    "data/normalized/train2.csv",
    tokenizer
)
batch_size = 64

train_dataloader = DataLoader(
    training_data, batch_size=batch_size, shuffle=True
)

In [7]:
num_classes = 6
model = LiarPlusBaseRoBERTasClassifier(
    roberta,
    num_classes,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

LiarPlusBaseRoBERTasClassifier(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
           

In [9]:
load_best_model(model, 'results/Base/best_model_10.pth')
model.eval()
with torch.no_grad():
    for batch in tqdm(train_dataloader):
        res = model(input_ids=batch['input_ids'].to(device), attention_mask=batch['attention_mask'].to(device))
        if (res.argmax(dim=1) == 3).any():
            print(res)
            break

  best_model = torch.load(path)


Model loaded from best model checkpoint.


  attn_output = torch.nn.functional.scaled_dot_product_attention(
  0%|                                                                                                                                                                           | 0/161 [00:02<?, ?it/s]

tensor([[-8.9492e-01,  1.0884e-01,  4.4853e-02,  3.1243e-01,  2.1663e-01,
         -2.0351e-01],
        [-5.7822e-01,  1.5024e-01, -1.4315e-01,  1.0387e-01,  3.3188e-01,
         -1.0929e-01],
        [-1.0152e+00, -8.1179e-03, -6.8929e-02,  4.0695e-01,  3.5767e-01,
         -1.8672e-01],
        [-7.5284e-01,  1.0135e-01, -2.3083e-01,  3.0743e-01,  2.3144e-01,
          6.0803e-02],
        [-5.3404e-01,  2.0390e-01, -6.4482e-02,  1.1892e-01,  4.5022e-02,
         -2.2283e-02],
        [-9.8042e-01,  1.5690e-01, -1.8976e-01,  2.8061e-01,  3.2305e-01,
         -1.1606e-03],
        [-8.4892e-01,  2.5314e-01,  9.3451e-02,  3.1589e-01,  1.4447e-01,
         -3.6117e-01],
        [-6.6785e-01,  2.5050e-01,  2.2881e-01,  2.2148e-01,  7.1997e-02,
         -3.9809e-01],
        [-1.0230e+00,  2.0469e-01, -1.1851e-01,  1.5587e-01,  3.6168e-01,
         -2.3021e-03],
        [-8.4438e-01,  3.1092e-02, -1.5821e-01,  5.1605e-01,  2.7108e-01,
         -9.3113e-02],
        [-1.0193e+00,  4.9156e




In [10]:
res = torch.tensor([[-8.9492e-01,  1.0884e-01,  4.4853e-02,  3.1243e-01,  2.1663e-01,
         -2.0351e-01],
        [-5.7822e-01,  1.5024e-01, -1.4315e-01,  1.0387e-01,  3.3188e-01,
         -1.0929e-01],
        [-1.0152e+00, -8.1179e-03, -6.8929e-02,  4.0695e-01,  3.5767e-01,
         -1.8672e-01],
        [-7.5284e-01,  1.0135e-01, -2.3083e-01,  3.0743e-01,  2.3144e-01,
          6.0803e-02],
        [-5.3404e-01,  2.0390e-01, -6.4482e-02,  1.1892e-01,  4.5022e-02,
         -2.2283e-02],
        [-9.8042e-01,  1.5690e-01, -1.8976e-01,  2.8061e-01,  3.2305e-01,
         -1.1606e-03],
        [-8.4892e-01,  2.5314e-01,  9.3451e-02,  3.1589e-01,  1.4447e-01,
         -3.6117e-01],
        [-6.6785e-01,  2.5050e-01,  2.2881e-01,  2.2148e-01,  7.1997e-02,
         -3.9809e-01],
        [-1.0230e+00,  2.0469e-01, -1.1851e-01,  1.5587e-01,  3.6168e-01,
         -2.3021e-03],
        [-8.4438e-01,  3.1092e-02, -1.5821e-01,  5.1605e-01,  2.7108e-01,
         -9.3113e-02],
        [-1.0193e+00,  4.9156e-01,  2.7303e-01,  2.7127e-01,  5.4285e-02,
         -4.6673e-01],
        [-7.5651e-01,  2.7133e-01, -1.8897e-01,  2.0437e-01,  7.8001e-02,
          5.5627e-02],
        [-3.1671e-01,  3.5415e-01,  2.2951e-01,  2.2657e-01, -8.8703e-02,
         -5.1328e-01],
        [-1.1096e+00,  2.2036e-01, -3.3765e-01,  6.6647e-02,  4.4673e-01,
          2.7703e-01],
        [-7.8437e-01,  2.5349e-01, -7.9421e-02,  6.8218e-02,  2.2604e-01,
          4.6675e-04],
        [-9.2899e-01,  9.2818e-02, -1.9228e-01,  2.8877e-01,  3.4585e-01,
          6.9131e-03],
        [ 6.1579e-03,  2.8974e-01, -1.4598e-02,  5.9168e-02, -1.1034e-01,
         -2.8899e-01],
        [-1.2656e+00,  1.5415e-01, -3.6471e-02,  4.1501e-01,  4.1370e-01,
         -1.8361e-01],
        [-8.7667e-01,  1.5819e-01, -7.0469e-02,  3.1123e-01,  2.8564e-01,
         -2.0431e-01],
        [-9.8399e-01,  1.7995e-01,  1.0869e-01,  2.7581e-01,  2.6005e-01,
         -2.1152e-01],
        [-9.2883e-01,  1.6767e-01, -1.3804e-01,  2.6834e-01,  2.8097e-01,
         -2.0826e-02],
        [-1.1424e+00, -2.0425e-02, -2.2306e-01,  1.2804e-01,  5.2800e-01,
          3.3787e-01],
        [-7.1579e-01,  2.2190e-01, -1.2868e-01,  4.2523e-02,  2.3369e-01,
          8.0986e-02],
        [-7.3994e-01,  2.8848e-01,  6.1744e-02,  2.8877e-01,  1.0221e-01,
         -3.5819e-01],
        [-3.2972e-01,  2.6573e-01,  1.0854e-01,  2.3736e-01, -4.3108e-02,
         -4.3046e-01],
        [-6.0433e-01,  4.1965e-01,  8.3801e-02,  5.6269e-02,  8.1244e-02,
         -3.5641e-01],
        [-1.0576e+00,  1.3449e-01,  1.8229e-01,  2.0605e-01,  1.5300e-01,
         -6.2429e-02],
        [-6.8446e-01,  2.6023e-01,  1.0897e-01,  3.8950e-01,  1.2426e-01,
         -4.3114e-01],
        [-4.9597e-01,  2.6612e-01,  1.0198e-01,  3.9256e-01,  2.3430e-02,
         -5.3688e-01],
        [-4.0498e-01,  3.1188e-01,  6.1521e-02,  1.5537e-01, -6.9880e-02,
         -1.7717e-01],
        [-1.3986e-01,  2.0157e-01,  7.4124e-02,  1.7165e-01, -1.2255e-01,
         -2.7473e-01],
        [-1.0370e+00,  2.7313e-02, -2.5669e-01,  2.4463e-01,  4.4240e-01,
          1.9330e-01],
        [-1.1657e+00,  2.4522e-01,  1.2501e-01,  2.4754e-01,  2.3645e-01,
         -1.5934e-01],
        [-1.0584e+00,  6.0217e-02,  2.9484e-01,  3.1466e-01,  2.0304e-01,
         -2.6140e-01],
        [-9.2228e-01,  1.3813e-01,  6.3419e-03,  2.1075e-01,  1.9076e-01,
         -2.7842e-02],
        [-7.7608e-01,  5.7938e-02,  1.7765e-01,  3.5973e-01,  1.3194e-01,
         -2.2783e-01],
        [-7.9885e-01,  1.5760e-01,  1.7169e-01,  3.6905e-01,  1.7800e-01,
         -4.5364e-01],
        [-8.0261e-01,  2.9052e-01, -5.6968e-02,  3.9695e-01,  7.3366e-02,
         -1.6737e-01],
        [-9.7667e-01,  6.1660e-02, -1.0220e-02,  3.7519e-01,  2.8191e-01,
         -1.4378e-01],
        [-7.9555e-01,  2.8409e-01, -9.4091e-02,  2.7423e-01,  1.9731e-01,
         -1.7822e-01],
        [-5.3845e-01,  2.0001e-01,  3.4476e-02,  1.9040e-01,  3.3668e-02,
         -1.6421e-01],
        [-1.1210e+00,  2.4885e-02,  2.1462e-01,  3.6780e-01,  2.1189e-01,
         -1.7013e-01],
        [-8.2161e-01,  2.6396e-01,  1.6614e-01,  2.7604e-01,  1.0296e-01,
         -3.5786e-01],
        [-4.1460e-01,  3.5185e-01, -1.3056e-01, -8.7982e-02,  4.1132e-02,
          2.5787e-02],
        [-9.7758e-01,  1.8817e-01, -2.7801e-01,  2.9838e-01,  2.9411e-01,
          1.6623e-01],
        [-7.8810e-01,  3.5444e-01,  2.2991e-02, -6.2273e-02,  1.1073e-01,
          1.0677e-02],
        [-7.1181e-01,  1.9409e-01, -1.6406e-01,  1.9402e-01,  2.1783e-01,
          3.2102e-03],
        [-7.9644e-01,  3.9023e-02,  1.1749e-01,  3.8761e-01,  1.1017e-01,
         -2.4511e-01],
        [-1.3845e+00, -3.4270e-01, -3.0989e-01,  6.4894e-01,  6.7696e-01,
          2.1074e-01],
        [-9.8077e-01,  1.9867e-01, -1.1164e-01,  2.7985e-01,  2.8635e-01,
         -4.0705e-02],
        [-1.4884e+00,  2.2595e-01, -3.4491e-01,  2.7378e-01,  4.8179e-01,
          3.0043e-01],
        [-9.8696e-01,  4.0794e-01, -7.3404e-02,  2.6487e-01,  1.8649e-01,
         -1.9794e-01],
        [-9.7030e-01,  2.8129e-01,  1.3456e-01,  2.2819e-01,  2.2721e-01,
         -3.3626e-01],
        [-6.8581e-01,  1.7882e-01, -1.5964e-01,  3.0582e-01,  2.2853e-01,
         -1.0833e-01],
        [-1.0758e+00,  9.0568e-02, -1.0060e-01,  2.5066e-01,  4.0034e-01,
         -9.9203e-03],
        [-1.6355e-01,  4.0793e-01,  1.9132e-01,  1.7303e-01, -1.9363e-01,
         -6.2665e-01],
        [-3.3660e-01,  2.3119e-01,  2.2071e-02,  1.4100e-01,  3.8306e-02,
         -3.4399e-01],
        [-7.7878e-01,  2.6083e-01,  2.2556e-02,  1.2581e-01,  3.1620e-02,
         -4.4233e-02],
        [-5.9574e-01,  1.8101e-01,  2.3738e-01,  3.0071e-01, -1.6549e-02,
         -4.0756e-01],
        [-4.4696e-01,  3.0733e-01, -1.4029e-01,  1.0085e-01,  1.3581e-01,
         -1.9050e-01],
        [-6.2806e-01,  8.1569e-02,  1.2505e-02,  7.6638e-02,  1.7126e-01,
         -1.3664e-01],
        [-6.9420e-01,  2.4602e-01,  1.2632e-01,  3.3528e-01,  8.3587e-02,
         -3.3860e-01],
        [-1.1653e+00,  1.9088e-01, -2.9227e-01,  1.5721e-01,  3.6096e-01,
          2.3897e-01],
        [-3.4847e-01,  3.6728e-01,  1.5456e-01,  2.6220e-01, -8.3316e-02,
         -4.6575e-01]], device='cuda:0')

In [12]:
res.argmax(dim=1)

tensor([3, 4, 3, 3, 1, 4, 3, 1, 4, 3, 1, 1, 1, 4, 1, 4, 1, 3, 3, 3, 4, 4, 4, 3,
        1, 1, 3, 3, 3, 1, 1, 4, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 3, 1, 4, 3,
        4, 4, 4, 1, 1, 3, 4, 1, 1, 1, 3, 1, 4, 3, 4, 1], device='cuda:0')

In [15]:
res[0]

tensor([-0.8949,  0.1088,  0.0449,  0.3124,  0.2166, -0.2035], device='cuda:0')