<a href="https://colab.research.google.com/github/ayyucedemirbas/TCGA_contrastive/blob/main/tcga_contrastive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://huggingface.co/datasets/AIBIC/MLOmics

Cloning into 'MLOmics'...
remote: Enumerating objects: 435, done.[K
remote: Total 435 (delta 0), reused 0 (delta 0), pack-reused 435 (from 1)[K
Receiving objects: 100% (435/435), 57.17 KiB | 5.72 MiB/s, done.
Resolving deltas: 100% (56/56), done.
Filtering content: 100% (235/235), 8.83 GiB | 53.76 MiB/s, done.


In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.7.0


In [3]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, Batch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

In [4]:
torch.manual_seed(42)
np.random.seed(42)

In [14]:
class Config:
    # Dataset paths
    dataset_root = 'MLOmics/Main_Dataset/Classification_datasets'
    cancer_type = 'GS-BRCA'  # GS-COAD, GS-GBM, GS-LGG, GS-OV
    data_version = 'Top'  # 'Original', 'Aligned', 'Top'


    hidden_dim = 256
    gat_heads = 4
    num_gat_layers = 3
    dropout = 0.3
    projection_dim = 128


    batch_size = 2
    learning_rate = 0.001
    num_epochs = 100
    patience = 15

    temperature = 0.07
    contrastive_weight = 0.5

    top_k_neighbors = 7


    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    save_dir = 'models'
    model_name = f'best_gat_contrastive_{cancer_type}.pth'



config = Config()
os.makedirs(config.save_dir, exist_ok=True)

In [15]:
class TCGADataLoader:
    def __init__(self, config):
        self.config = config
        self.data_path = os.path.join(
            config.dataset_root,
            config.cancer_type,
            config.data_version
        )
        os.makedirs(self.data_path, exist_ok=False) if not os.path.exists(self.data_path) else None

    def _read_csv_try(self, path, idx_col=0):
        try:
            df = pd.read_csv(path, index_col=idx_col)
            return df
        except Exception:
            df = pd.read_csv(path, header=None)
            return df

    def load_data(self):
        cancer_prefix = config.cancer_type.split('-')[1]
        print(f"Loading from: {self.data_path}")

        def path_for(name):
            return os.path.join(self.data_path, name)

        def _load_matrix(prefix, typ):
            candidates = [
                f"{prefix}_{typ}_{config.data_version.lower()}.csv",
                f"{prefix}_{typ}.csv",
                f"{typ}_{prefix}.csv",
                f"{prefix}_{typ}.CSV"
            ]
            for fname in candidates:
                p = path_for(fname)
                if os.path.exists(p):
                    try:
                        return pd.read_csv(p, index_col=0)
                    except Exception:
                        return pd.read_csv(p, header=None)
            raise FileNotFoundError(f"None of candidate files found for {prefix} {typ}: {candidates}")

        cnv = _load_matrix(cancer_prefix, 'CNV')
        methy = _load_matrix(cancer_prefix, 'Methy')
        mrna = _load_matrix(cancer_prefix, 'mRNA')
        mirna = _load_matrix(cancer_prefix, 'miRNA')

        label_path_candidates = [
            f"{cancer_prefix}_label_num.csv",
            f"{cancer_prefix}_label.csv",
            "label.csv",
            "labels.csv"
        ]
        labels = None
        for fname in label_path_candidates:
            p = path_for(fname)
            if os.path.exists(p):
                try:
                    labels = pd.read_csv(p, index_col=0)
                except Exception:
                    labels = pd.read_csv(p, header=None)
                break
        if labels is None:
            raise FileNotFoundError("Labels file not found in expected locations")

        print(f"Loaded - CNV: {cnv.shape}, Methy: {methy.shape}, mRNA: {mrna.shape}, miRNA: {mirna.shape}, Labels: {labels.shape}")

        label_count = None
        if isinstance(labels, pd.DataFrame) and labels.shape[1] >= 1:
            label_count = labels.shape[0]
        else:
            label_count = labels.shape[0]

        def ensure_samples_rows(df, name):
            if df.shape[0] == label_count:
                print(f"{name} already samples x features: {df.shape}")
                return df
            elif df.shape[1] == label_count:
                print(f"Transposing {name} from (features x samples) to (samples x features): {df.shape} -> {df.T.shape}")
                return df.T
            else:
                sample_index_candidates = set(df.index.astype(str))
                label_index_candidates = set()
                if isinstance(labels, pd.DataFrame):
                    for col in labels.columns:
                        label_index_candidates |= set(labels[col].astype(str))
                label_index_candidates |= set(labels.index.astype(str))

                if len(sample_index_candidates & label_index_candidates) > 0:
                    print(f"Detected sample ID overlap between {name} index and labels; using current orientation: {df.shape}")
                    return df
                if len(set(df.columns.astype(str)) & label_index_candidates) > 0:
                    print(f"Detected sample ID overlap between {name} columns and labels; transposing {name}")
                    return df.T

                if df.shape[0] > df.shape[1]:
                    print(f"Heuristic transpose for {name} (rows > cols). Transposing {df.shape} -> {df.T.shape}")
                    return df.T

                raise ValueError(f"Cannot determine sample axis for {name}. df.shape={df.shape}, label_count={label_count}")

        cnv = ensure_samples_rows(cnv, 'CNV')
        methy = ensure_samples_rows(methy, 'Methy')
        mrna = ensure_samples_rows(mrna, 'mRNA')
        mirna = ensure_samples_rows(mirna, 'miRNA')

        print(f"After ensure_samples_rows - CNV: {cnv.shape}, Methy: {methy.shape}, mRNA: {mrna.shape}, miRNA: {mirna.shape}")

        cnv.index = cnv.index.astype(str).str.strip()
        methy.index = methy.index.astype(str).str.strip()
        mrna.index = mrna.index.astype(str).str.strip()
        mirna.index = mirna.index.astype(str).str.strip()

        if isinstance(labels, pd.DataFrame) and labels.shape[1] == 0:
            labels = labels.copy()
            labels['label'] = labels.index.astype(str)
        if isinstance(labels, pd.DataFrame) and labels.shape[1] == 1:
            col = labels.columns[0]
            if set(labels[col].astype(str)) & set(cnv.index):
                labels = labels.set_index(col)
                if labels.shape[1] == 0:
                    labels['label'] = np.nan
            else:
                if len(labels) == len(cnv):
                    labels.index = cnv.index
                    labels = labels.rename(columns={col: 'label'})
                else:
                    if len(set(labels.index.astype(str)) & set(cnv.index)) == 0:
                        labels = pd.read_csv(os.path.join(self.data_path, f"{cancer_prefix}_label_num.csv"), header=None)
                        if labels.shape[0] == cnv.shape[0]:
                            labels.index = cnv.index
                            labels.columns = ['label']
                        else:
                            labels = pd.DataFrame({'label': np.zeros(len(cnv))}, index=cnv.index)
        else:
            if isinstance(labels, pd.DataFrame) and labels.shape[1] > 1:
                numeric_cols = [c for c in labels.columns if pd.api.types.is_numeric_dtype(labels[c])]
                if len(numeric_cols) >= 1:
                    labels = labels[[numeric_cols[0]]]
                    labels.columns = ['label']
                else:
                    if len(set(labels.index.astype(str)) & set(cnv.index)) == 0:
                        labels = pd.DataFrame({'label': np.zeros(len(cnv))}, index=cnv.index)

        labels.index = labels.index.astype(str).str.strip()
        labels['label'] = labels.iloc[:, 0]
        labels = labels[['label']]

        print(f"Labels processed: {labels.shape}")

        common_samples = list(set(cnv.index) & set(methy.index) & set(mrna.index) & set(mirna.index) & set(labels.index))

        if len(common_samples) == 0:
            print(list(cnv.index[:10]))
            print(list(methy.index[:10]))
            print(list(mrna.index[:10]))
            print(list(mirna.index[:10]))
            print(list(labels.index[:10]))
            raise ValueError(
                "No common samples found across modalities. "
                "Sample IDs don't match. Check the CSV file structure or sample naming conventions."
            )

        common_samples.sort()
        print(f"\nFound {len(common_samples)} common samples")

        cnv = cnv.loc[common_samples].copy()
        methy = methy.loc[common_samples].copy()
        mrna = mrna.loc[common_samples].copy()
        mirna = mirna.loc[common_samples].copy()
        labels = labels.loc[common_samples].copy()

        print(f"\nLabel types before encoding: {labels['label'].dtype}")
        print(f"Unique labels: {labels['label'].unique()}")

        if labels['label'].dtype == 'object' or labels['label'].dtype.name == 'string':
            print("Labels are categorical/strings. Encoding to numeric...")
            label_encoder = LabelEncoder()
            labels['label'] = label_encoder.fit_transform(labels['label'])
            print(f"Label mapping: {dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))}")
            self.label_encoder = label_encoder
        else:
            labels['label'] = labels['label'].astype(int)
            self.label_encoder = None

        print(f"Labels after encoding: {labels['label'].dtype}, unique values: {labels['label'].unique()}")

        cnv = cnv.reset_index(drop=True)
        methy = methy.reset_index(drop=True)
        mrna = mrna.reset_index(drop=True)
        mirna = mirna.reset_index(drop=True)
        labels = labels.reset_index(drop=True)

        cnv = cnv.fillna(cnv.mean())
        methy = methy.fillna(methy.mean())
        mrna = mrna.fillna(mrna.mean())
        mirna = mirna.fillna(mirna.mean())

        cnv = cnv.fillna(0)
        methy = methy.fillna(0)
        mrna = mrna.fillna(0)
        mirna = mirna.fillna(0)

        print(f"Final data shapes - CNV: {cnv.shape}, Methy: {methy.shape}, mRNA: {mrna.shape}, miRNA: {mirna.shape}, Labels: {labels.shape}")

        return cnv, methy, mrna, mirna, labels

    def create_train_test_split(self, cnv, methy, mrna, mirna, labels, test_size=0.2):
        indices = np.arange(len(cnv))
        y = labels['label'].values.ravel()
        try:
            train_idx, test_idx = train_test_split(
                indices, test_size=test_size, random_state=42, stratify=y
            )
        except Exception:
            train_idx, test_idx = train_test_split(
                indices, test_size=test_size, random_state=42, stratify=None
            )
        return train_idx, test_idx

In [16]:
class PatientGraphBuilder:

    def __init__(self, top_k=10):
        self.top_k = top_k

    def build_graph(self, cnv_feat, methy_feat, mrna_feat, mirna_feat):
        all_features = np.concatenate([cnv_feat, methy_feat, mrna_feat, mirna_feat])
        num_nodes = len(all_features)

        cnv_size = len(cnv_feat)
        methy_size = len(methy_feat)
        mrna_size = len(mrna_feat)
        mirna_size = len(mirna_feat)

        edge_index = self._build_knn_edges(all_features, self.top_k)

        cross_edges = self._build_cross_omics_edges(
            cnv_size, methy_size, mrna_size, mirna_size
        )
        if cross_edges.size > 0:
            edge_index = np.concatenate([edge_index, cross_edges], axis=1)

        x = torch.FloatTensor(all_features.reshape(-1, 1))
        edge_index = torch.LongTensor(edge_index)

        return x, edge_index

    def _build_knn_edges(self, features, k):
        num_nodes = len(features)
        features_reshaped = features.reshape(-1, 1)

        distances = np.abs(features_reshaped - features_reshaped.T)

        edges = []
        for i in range(num_nodes):
            nearest = np.argsort(distances[i])[:k+1]
            for j in nearest:
                if i != j:
                    edges.append([i, j])

        if len(edges) == 0:
            edges = [[i, i+1] for i in range(num_nodes-1)]

        return np.array(edges).T

    def _build_cross_omics_edges(self, cnv_size, methy_size, mrna_size, mirna_size):
        edges = []

        offsets = [0, cnv_size, cnv_size + methy_size, cnv_size + methy_size + mrna_size]
        sizes = [cnv_size, methy_size, mrna_size, mirna_size]

        for i in range(len(offsets)):
            for j in range(i+1, len(offsets)):
                num_cross_edges = min(5, sizes[i], sizes[j])
                if num_cross_edges <= 0:
                    continue
                for k in range(num_cross_edges):
                    src = offsets[i] + int(k * (sizes[i] / num_cross_edges))
                    dst = offsets[j] + int(k * (sizes[j] / num_cross_edges))
                    edges.append([src, dst])
                    edges.append([dst, src])

        if len(edges) == 0:
            edges = [[0, 1]]

        return np.array(edges).T

In [17]:
class MultiOmicsGraphDataset(Dataset):
    def __init__(self, cnv, methy, mrna, mirna, labels, indices, graph_builder, scalers=None):
        self.cnv = cnv.values[indices]
        self.methy = methy.values[indices]
        self.mrna = mrna.values[indices]
        self.mirna = mirna.values[indices]
        self.labels = labels.values[indices].ravel()
        self.graph_builder = graph_builder

        self.labels = self.labels.astype(int)

        if scalers is None:
            self.scaler_cnv = StandardScaler().fit(self.cnv)
            self.scaler_methy = StandardScaler().fit(self.methy)
            self.scaler_mrna = StandardScaler().fit(self.mrna)
            self.scaler_mirna = StandardScaler().fit(self.mirna)
        else:
            self.scaler_cnv, self.scaler_methy, self.scaler_mrna, self.scaler_mirna = scalers

        self.cnv = self.scaler_cnv.transform(self.cnv)
        self.methy = self.scaler_methy.transform(self.methy)
        self.mrna = self.scaler_mrna.transform(self.mrna)
        self.mirna = self.scaler_mirna.transform(self.mirna)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x, edge_index = self.graph_builder.build_graph(
            self.cnv[idx],
            self.methy[idx],
            self.mrna[idx],
            self.mirna[idx]
        )

        label = int(self.labels[idx])
        label = torch.LongTensor([label])[0]

        data = Data(x=x, edge_index=edge_index, y=label)

        return data

    def get_scalers(self):
        return (self.scaler_cnv, self.scaler_methy, self.scaler_mrna, self.scaler_mirna)

def collate_fn(batch):
    return Batch.from_data_list(batch)

In [18]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return F.normalize(self.net(x), dim=1)

class MultiOmicsGAT(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_classes, num_gat_layers=3,
                 heads=4, dropout=0.3, projection_dim=128):
        super().__init__()

        self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.gat_layers = nn.ModuleList()
        for i in range(num_gat_layers):
            self.gat_layers.append(
                GATConv(hidden_dim, hidden_dim // heads, heads=heads, dropout=dropout)
            )

        self.dropout = nn.Dropout(dropout)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

        self.projection = ProjectionHead(hidden_dim * 2, hidden_dim, projection_dim)

    def forward(self, data, return_embedding=False):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.input_proj(x)
        x = F.relu(x)
        for gat in self.gat_layers:
            x = gat(x, edge_index)
            x = F.elu(x)
            x = self.dropout(x)

        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)

        if return_embedding:
            return x

        out = self.classifier(x)

        z = self.projection(x)

        return out, z


In [19]:
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, z):
        batch_size = z.shape[0]

        sim_matrix = torch.matmul(z, z.T) / self.temperature
        mask = torch.eye(batch_size, device=z.device).bool()
        sim_matrix.masked_fill_(mask, -9e15)

        exp_sim = torch.exp(sim_matrix)
        loss = -torch.log(exp_sim / exp_sim.sum(dim=1, keepdim=True)).mean()

        return loss


In [20]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model.to(config.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )

        self.ce_loss = nn.CrossEntropyLoss()
        self.contrastive_loss = NTXentLoss(temperature=config.temperature)

        self.best_val_loss = float('inf')
        self.patience_counter = 0

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []

        for batch in tqdm(self.train_loader, desc='Training'):
            batch = batch.to(self.config.device)

            self.optimizer.zero_grad()
            logits, z = self.model(batch)

            ce_loss = self.ce_loss(logits, batch.y)
            cont_loss = self.contrastive_loss(z)

            loss = ce_loss + self.config.contrastive_weight * cont_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()

            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(batch.y.cpu().numpy())

        avg_loss = total_loss / len(self.train_loader)
        accuracy = accuracy_score(all_labels, all_preds)

        return avg_loss, accuracy

    def validate(self):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for batch in self.val_loader:
                batch = batch.to(self.config.device)

                logits, z = self.model(batch)

                ce_loss = self.ce_loss(logits, batch.y)
                cont_loss = self.contrastive_loss(z)
                loss = ce_loss + self.config.contrastive_weight * cont_loss

                total_loss += loss.item()

                probs = F.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1).cpu().numpy()

                all_preds.extend(preds)
                all_labels.extend(batch.y.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

        avg_loss = total_loss / len(self.val_loader)
        accuracy = accuracy_score(all_labels, all_preds)

        f1 = f1_score(all_labels, all_preds, average='weighted')

        try:
            all_probs = np.array(all_probs)
            if all_probs.shape[1] == 2:
                auc = roc_auc_score(all_labels, all_probs[:, 1])
            else:
                auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
        except:
            auc = 0.0

        return avg_loss, accuracy, f1, auc

    def train(self):

        for epoch in range(self.config.num_epochs):
            train_loss, train_acc = self.train_epoch()

            val_loss, val_acc, val_f1, val_auc = self.validate()

            self.scheduler.step(val_loss)

            print(f"Epoch {epoch+1}/{self.config.num_epochs}")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f} | Val AUC: {val_auc:.4f}")

            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.patience_counter = 0
                self.save_model()
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.config.patience:
                    print(f"\nEarly stopping triggered after {epoch+1} epochs")
                    break

            print("-" * 70)

        print(f"Best validation loss: {self.best_val_loss:.4f}")

    def save_model(self):
        save_path = os.path.join(self.config.save_dir, self.config.model_name)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'config': self.config,
        }, save_path)

In [21]:
def main():
    data_loader = TCGADataLoader(config)
    cnv, methy, mrna, mirna, labels = data_loader.load_data()

    train_idx, test_idx = data_loader.create_train_test_split(
        cnv, methy, mrna, mirna, labels
    )

    print(f"\nTrain samples: {len(train_idx)}, Test samples: {len(test_idx)}")

    graph_builder = PatientGraphBuilder(top_k=config.top_k_neighbors)

    train_dataset = MultiOmicsGraphDataset(
        cnv, methy, mrna, mirna, labels, train_idx, graph_builder
    )

    scalers = train_dataset.get_scalers()
    test_dataset = MultiOmicsGraphDataset(
        cnv, methy, mrna, mirna, labels, test_idx, graph_builder, scalers=scalers
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    num_classes = len(np.unique(labels.values))
    print(f"Number of classes: {num_classes}")

    model = MultiOmicsGAT(
        input_dim=1,  # Each node has 1 feature (the omics value)
        hidden_dim=config.hidden_dim,
        num_classes=num_classes,
        num_gat_layers=config.num_gat_layers,
        heads=config.gat_heads,
        dropout=config.dropout,
        projection_dim=config.projection_dim
    )

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total model parameters: {total_params:,}")

    trainer = Trainer(model, train_loader, test_loader, config)
    trainer.train()

In [None]:
if __name__ == "__main__":
    main()

Loading from: MLOmics/Main_Dataset/Classification_datasets/GS-BRCA/Top
Loaded - CNV: (5000, 671), Methy: (5000, 671), mRNA: (5000, 671), miRNA: (366, 671), Labels: (671, 0)
Transposing CNV from (features x samples) to (samples x features): (5000, 671) -> (671, 5000)
Transposing Methy from (features x samples) to (samples x features): (5000, 671) -> (671, 5000)
Transposing mRNA from (features x samples) to (samples x features): (5000, 671) -> (671, 5000)
Transposing miRNA from (features x samples) to (samples x features): (366, 671) -> (671, 366)
After ensure_samples_rows - CNV: (671, 5000), Methy: (671, 5000), mRNA: (671, 5000), miRNA: (671, 366)
Labels processed: (671, 1)

Found 671 common samples

Label types before encoding: object
Unique labels: ['0' '1' '2' '3' '4']
Labels are categorical/strings. Encoding to numeric...
Label mapping: {'0': np.int64(0), '1': np.int64(1), '2': np.int64(2), '3': np.int64(3), '4': np.int64(4)}
Labels after encoding: int64, unique values: [0 1 2 3 4]


Training: 100%|██████████| 268/268 [56:58<00:00, 12.75s/it]


Epoch 1/100
  Train Loss: nan | Train Acc: 0.5261
  Val Loss: nan | Val Acc: 0.5259 | Val F1: 0.3625 | Val AUC: 0.0000
----------------------------------------------------------------------


Training: 100%|██████████| 268/268 [55:40<00:00, 12.46s/it]


Epoch 2/100
  Train Loss: nan | Train Acc: 0.5261
  Val Loss: nan | Val Acc: 0.5259 | Val F1: 0.3625 | Val AUC: 0.0000
----------------------------------------------------------------------


Training:  62%|██████▏   | 167/268 [34:52<20:40, 12.28s/it]