In [54]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from torch.optim import AdamW
from tqdm import tqdm
from tokenizers import Tokenizer 
from sklearn.metrics import accuracy_score
import os
import wandb
import toml

In [68]:
 # --- Configuration ---
config_file_path = '/Users/nishitha/Desktop/Learn/NLP/Clinical trials eligibility/config.toml'
if not os.path.exists(config_file_path):
    print(f"Error: Config file not found at {config_file_path}")
else:
    config = toml.load(config_file_path)

    # general parameters
    BATCH_SIZE = config['training']['batch_size']
    NUM_LABELS = config['data']['num_labels']
    LEARNING_RATE = config['training']['learning_rate']

    # For BioBert Transformer 
    EPOCHS_TRANSFORMER = config['training']['transformer']['epochs']
    MAX_LEN_TRANSFORMER = config['training']['transformer']['max_len']

    # For RNN model parameters
    MAX_LEN_RNN = config['data']['max_len']
    EPOCHS_RNN = config['training']['rnn']['epochs']
    EMBEDDING_DIM = config['model']['embedding_dim']
    RNN_HIDDEN_SIZE = config['model']['rnn']['rnn_hidden_size'] 
    RNN_NUM_LAYERS = config['model']['rnn']['rnn_num_layers']   
    RNN_DROPOUT = config['model']['rnn']['rnn_dropout']        

    # device configuration
    DEVICE = torch.device(config['general']['device'])



In [69]:
class DataCreator_RNN(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.patient_texts = list(df['patient'])
        self.criteria_texts = list(df['criteria'])
        self.labels = list(df['label'])
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.pad_token_id = self.tokenizer.token_to_id("[PAD]")
        if self.pad_token_id is None:
            print("Warning: [PAD] token not found in tokenizer. Using 0 for padding token ID in DataCreator.")
            self.pad_token_id = 0

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        patient_text = str(self.patient_texts[index])
        criteria_text = str(self.criteria_texts[index])
        label = torch.tensor(self.labels[index], dtype=torch.long)

        encoding = self.tokenizer.encode(criteria_text, patient_text)

        ids = encoding.ids
        attention_mask = encoding.attention_mask

        if len(ids) > self.max_len:
            ids = ids[:self.max_len]
            attention_mask = attention_mask[:self.max_len]
        else:
            padding_length = self.max_len - len(ids)
            ids += [self.pad_token_id] * padding_length
            attention_mask += [0] * padding_length

        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "label": label,
        }


In [70]:
class DataCreator_Transformer(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.patient_texts = list(df['patient'])
        self.criteria_texts = list(df['criteria'])
        self.labels = list(df['label'])
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        self.pad_token_id = self.tokenizer.pad_token_id
        if self.pad_token_id is None:
            print("Warning: [PAD] token not found in tokenizer. Using 0 for padding token ID.")
            self.pad_token_id = 0

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        patient_text = str(self.patient_texts[index])
        criteria_text = str(self.criteria_texts[index])
        label = torch.tensor(self.labels[index], dtype=torch.long)

        encoding = self.tokenizer(
            criteria_text,
            patient_text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": label,
        }

In [71]:
class TransformerClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)
        return logits

In [72]:
class RNNClassifierFromScratch(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_size: int, num_layers: int, num_labels: int, dropout_rate: float):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.lstm = nn.LSTM(embedding_dim, hidden_size, num_layers, 
                            batch_first=True, dropout=dropout_rate if num_layers > 1 else 0)
        
        self.dropout_classifier = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(hidden_size, num_labels)
        
        self.num_layers = num_layers
        self.hidden_size = hidden_size

    def forward(self, input_ids, attention_mask):
        embedded = self.embedding(input_ids) 

        lengths = attention_mask.sum(dim=1)
        
        lengths = lengths.cpu().clamp(min=1) 

        packed_embedded = nn.utils.rnn.pack_padded_sequence(
            embedded, lengths, batch_first=True, enforce_sorted=False 
        )

        packed_output, (hidden, cell) = self.lstm(packed_embedded)
        
        final_hidden_state = hidden[-1, :, :]
        
        pooled_output = self.dropout_classifier(final_hidden_state)
        
        logits = self.classifier(pooled_output)
        
        return logits



In [73]:
def train_and_evaluate_transformer_model(
    model_name: str,
    df: pd.DataFrame,
    num_labels: int,
    max_len: int,
    batch_size: int,
    epochs: int,
    learning_rate: float,
    device: torch.device
):
    print(f"\n--- Starting training for Transformer model: {model_name} ---")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print(f"Tokenizer for {model_name} loaded.")

    df['label'] = pd.to_numeric(df['label'], errors='coerce')
    df.dropna(subset=['label'], inplace=True)
    df['label'] = df['label'].astype(int)

    train_val_df, test_df = train_test_split(
        df, test_size=0.2, stratify=df['label'], random_state=42
    )
    train_df, val_df = train_test_split(
        train_val_df, test_size=0.25, stratify=train_val_df['label'], random_state=42
    )

    print(f"Dataset split: Train={len(train_df)} | Val={len(val_df)} | Test={len(test_df)}")

    train_dataset = DataCreator_Transformer(df=train_df, tokenizer=tokenizer, max_len=max_len)
    val_dataset = DataCreator_Transformer(df=val_df, tokenizer=tokenizer, max_len=max_len)
    test_dataset = DataCreator_Transformer(df=test_df, tokenizer=tokenizer, max_len=max_len)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = TransformerClassifier(model_name, num_labels)
    model.to(device)
    print(f"Model {model_name} initialized")

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    best_val_accuracy = 0.0
    model_save_name = model_name.replace('/', '_').replace('-', '_')
    model_save_path = f"{model_save_name}_clinical_model.pt"

    
    wandb.init(project='NLP_Project_Clinical_Trials', config=config)

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0

        print(f"\nEpoch {epoch + 1}/{epochs} (Model: {model_name})")
        loop = tqdm(train_dataloader, leave=True)

        for batch in loop:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            loop.set_description(f"Epoch {epoch + 1}")
            loop.set_postfix(loss=loss.item())

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        model.eval()
        total_val_loss = 0
        correct_val_preds = 0

        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)

                outputs = model(input_ids, attention_mask)
                loss = loss_fn(outputs, labels)
                total_val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                correct_val_preds += (preds == labels).sum().item()

        avg_val_loss = total_val_loss / len(val_dataloader)
        val_accuracy = correct_val_preds / len(val_dataset)*100

        print(f"Validation loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")


        wandb.log(data={
            "epoch": epoch+1,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy
        })

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved to {model_save_path} (Best validation accuracy: {best_val_accuracy:.4f})")
    
    print(f"\n--- Finished training for {model_name} ---")

    print(f"\n--- Evaluating {model_name} on the TEST SET ---")
    model.eval()
    total_test_loss = 0
    correct_test_preds = 0

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Test Evaluation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            total_test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            correct_test_preds += (preds == labels).sum().item()

    avg_test_loss = total_test_loss / len(test_dataloader)
    test_accuracy = correct_test_preds / len(test_dataset)

    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


    return {
        "model_name": model_name,
        "final_val_loss": avg_val_loss,
        "final_val_accuracy": val_accuracy,
        "best_val_accuracy": best_val_accuracy,
        "final_test_loss": avg_test_loss,
        "final_test_accuracy": test_accuracy,
        "saved_model_path": model_save_path
    }

In [74]:
def train_and_evaluate_rnn_model(
    df: pd.DataFrame,
    num_labels: int,
    max_len: int,
    batch_size: int,
    epochs: int,
    learning_rate: float,
    embedding_dim: int,
    hidden_size: int,
    num_rnn_layers: int,
    dropout: float,
    device: torch.device
):
    print(f"\n--- Starting training for Simple RNN Model ---")

    tokenizer = Tokenizer.from_file("/Users/nishitha/Desktop/Learn/NLP/Clinical trials eligibility/BPE/bpe_tokenizer.json")
    vocab_size = tokenizer.get_vocab_size()
    print(f"Custom tokenizer loaded. Vocabulary size: {vocab_size}")

    df['label'] = pd.to_numeric(df['label'], errors='coerce')
    df.dropna(subset=['label'], inplace=True)
    df['label'] = df['label'].astype(int)

    train_val_df, test_df = train_test_split(
        df, test_size=0.2, stratify=df['label'], random_state=42
    )
    train_df, val_df = train_test_split(
        train_val_df, test_size=0.25, stratify=train_val_df['label'], random_state=42
    )

    print(f"Dataset split: Train={len(train_df)} | Val={len(val_df)} | Test={len(test_df)}")

    train_dataset = DataCreator_RNN(df=train_df, tokenizer=tokenizer, max_len=max_len)
    val_dataset = DataCreator_RNN(df=val_df, tokenizer=tokenizer, max_len=max_len)
    test_dataset = DataCreator_RNN(df=test_df, tokenizer=tokenizer, max_len=max_len)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = RNNClassifierFromScratch(vocab_size, embedding_dim, hidden_size, num_rnn_layers, num_labels, dropout)
    model.to(device)
    print(f"RNNClassifierFromScratch initialized on {device}.")

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    best_val_accuracy = 0.0
    model_save_path = "scratch_rnn_clinical_model.pt"

    wandb.init(project='NLP_Project_Clinical_Trials', config=config)

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0

        print(f"\nEpoch {epoch + 1}/{epochs} (Model: Simple RNN)")
        loop = tqdm(train_dataloader, leave=True)

        for batch in loop:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)

            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            loop.set_description(f"Epoch {epoch + 1}")
            loop.set_postfix(loss=loss.item())

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        model.eval()
        total_val_loss = 0
        correct_val_preds = 0

        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)

                outputs = model(input_ids, attention_mask)
                loss = loss_fn(outputs, labels)
                total_val_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                correct_val_preds += (preds == labels).sum().item()

        avg_val_loss = total_val_loss / len(val_dataloader)
        val_accuracy = correct_val_preds / len(val_dataset)*100

        print(f"Validation loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")

        wandb.log(data={
            "epoch": epoch+1,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy
        })


        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), model_save_path)
            print(f"Model saved to {model_save_path} (Best validation accuracy: {best_val_accuracy:.4f})")
    
    print(f"\n--- Finished training for Simple RNN Model ---")

    print(f"\n--- Evaluating Simple RNN Model on the TEST SET ---")
    model.eval()
    total_test_loss = 0
    correct_test_preds = 0

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Test Evaluation"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids, attention_mask)
            loss = loss_fn(outputs, labels)
            total_test_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)
            correct_test_preds += (preds == labels).sum().item()

    avg_test_loss = total_test_loss / len(test_dataloader)
    test_accuracy = correct_test_preds / len(test_dataset)

    print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


    return {
        "model_name": "Simple RNN Classifier",
        "final_val_loss": avg_val_loss,
        "final_val_accuracy": val_accuracy,
        "best_val_accuracy": best_val_accuracy,
        "final_test_loss": avg_test_loss,
        "final_test_accuracy": test_accuracy,
        "saved_model_path": model_save_path
    }



In [75]:
if __name__ == "__main__":
    df_full = pd.read_csv('../Dataset/cleaned_data_3.csv')

    all_results = []


    model_name_to_test = "dmis-lab/biobert-v1.1"
    result = train_and_evaluate_transformer_model(
        model_name=model_name_to_test,
        df=df_full.copy(),
        num_labels=NUM_LABELS,
        max_len=MAX_LEN_TRANSFORMER,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS_TRANSFORMER,
        learning_rate=LEARNING_RATE,
        device=DEVICE
    )
    all_results.append(result)

    rnn_result = train_and_evaluate_rnn_model(
        df=df_full.copy(),
        num_labels=NUM_LABELS,
        max_len=MAX_LEN_RNN,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS_RNN,
        learning_rate=LEARNING_RATE,
        embedding_dim=EMBEDDING_DIM,
        hidden_size=RNN_HIDDEN_SIZE,
        num_rnn_layers=RNN_NUM_LAYERS,
        dropout=RNN_DROPOUT,
        device=DEVICE
    )
    all_results.append(rnn_result)


    print("\n--- Final Comparative Study Summary ---")
    for res in all_results:
        print(f"Model: {res['model_name']}")
        print(f"  Best Val Accuracy: {res['best_val_accuracy']:.4f}")
        print(f"  Final Test Accuracy: {res['final_test_accuracy']:.4f}")
        print(f"  Final Test Loss: {res['final_test_loss']:.4f}")
        print(f"  Saved Model: {res['saved_model_path']}")
        print("-" * 30)



--- Starting training for Transformer model: dmis-lab/biobert-v1.1 ---
Tokenizer for dmis-lab/biobert-v1.1 loaded.
Dataset split: Train=619 | Val=207 | Test=207
Model dmis-lab/biobert-v1.1 initialized


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



Epoch 1/1 (Model: dmis-lab/biobert-v1.1)


  return forward_call(*args, **kwargs)
Epoch 1:   5%|▌         | 2/39 [04:13<1:06:39, 108.09s/it, loss=nan] wandb-core(73149) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(73158) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:   8%|▊         | 3/39 [04:48<44:41, 74.48s/it, loss=nan]   wandb-core(73170) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(73183) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(73198) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  10%|█         | 4/39 [05:42<38:49, 66.54s/it, loss=nan]wandb-core(73215) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(73237) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(73241) MallocStackLogging: can't t

Average Training Loss: nan


wandb-core(74354) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74360) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74367) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74368) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74375) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Validation loss: nan, Accuracy: 27.5362
Model saved to dmis_lab_biobert_v1.1_clinical_model.pt (Best validation accuracy: 27.5362)

--- Finished training for dmis-lab/biobert-v1.1 ---

--- Evaluating dmis-lab/biobert-v1.1 on the TEST SET ---


Test Evaluation:   8%|▊         | 1/13 [00:09<01:49,  9.11s/it]wandb-core(74391) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Test Evaluation:  15%|█▌        | 2/13 [00:16<01:28,  8.04s/it]wandb-core(74405) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Test Evaluation:  23%|██▎       | 3/13 [13:08<59:29, 356.96s/it]wandb-core(74414) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Test Evaluation:  46%|████▌     | 6/13 [53:17<1:46:07, 909.62s/it]wandb-core(74435) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Test Evaluation:  54%|█████▍    | 7/13 [53:24<1:01:27, 614.51s/it]wandb-core(74440) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Test Evaluation:  62%|██████▏   | 8/13 [53:47<35:30, 426.09s/it]  wandb-core(74442) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Test Loss: nan, Test Accuracy: 0.2754

--- Starting training for Simple RNN Model ---
Custom tokenizer loaded. Vocabulary size: 15552
Dataset split: Train=619 | Val=207 | Test=207
RNNClassifierFromScratch initialized on mps.


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁
val_accuracy,▁

0,1
epoch,1.0
train_loss,
val_accuracy,27.53623
val_loss,


wandb-core(74484) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74485) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(74499) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



Epoch 1/25 (Model: Simple RNN)


Epoch 1:   3%|▎         | 1/39 [00:11<06:58, 11.01s/it, loss=1.1]wandb-core(74515) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  10%|█         | 4/39 [05:46<1:17:20, 132.59s/it, loss=1.07]wandb-core(74578) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  21%|██        | 8/39 [06:04<14:56, 28.92s/it, loss=1.05]   wandb-core(74582) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  31%|███       | 12/39 [06:20<04:16,  9.51s/it, loss=1.06]wandb-core(74590) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  41%|████      | 16/39 [06:32<01:49,  4.78s/it, loss=1.07]wandb-core(74594) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Epoch 1:  49%|████▊     | 19/39 [06:49<07:11, 21.57s/it, loss=1.07]


KeyboardInterrupt: 

wandb-core(74605) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74639) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74650) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74676) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74707) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74711) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74730) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74749) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74796) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
wandb-core(74845) MallocStackLogging: can't turn off malloc stack logging because 