In [None]:
import os

os.environ["PATH"] = (
    "/opt/apps/rhel9/cuda-12.4/bin:/opt/apps/rhel9/cuda-12.4/" + os.environ["PATH"]
)
os.environ["LD_LIBRARY_PATH"] = (
    "/opt/apps/rhel9/cuda-12.4/bin:/opt/apps/rhel9/cuda-12.4"
    + os.environ.get("LD_LIBRARY_PATH", "")
)

from huggingface_hub import HfApi

import numpy as np
import numpy.random as npr
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (  # AdamW,
    BertForSequenceClassification,
    BertModel,
    BertTokenizer,
    get_linear_schedule_with_warmup,
)

pd.set_option("mode.copy_on_write", True)
repo_id = "nickeubank/leaa_grant_subjects"

In [None]:
assert torch.cuda.is_available()

In [3]:
# from google.colab import drive
# drive.mount('/content/gdrive')

In [4]:
dir = "/hpc/group/ssri/nce8/leaa_subj/"
# dir = "https://github.com/nickeubank/leaa_subj/raw/refs/heads/main/"
grants = pd.read_parquet(dir + "subj_text_and_labels.parquet")

#########
# Split into train test and for predict
#########
grants = grants.drop_duplicates("description")

labeled = grants[grants["label_1"].notnull()]


# Encode labels. For 1 digit codes, not important, but
# the two digits aren't sequential so let's just use.
label_encoder = LabelEncoder()
labeled["label_1_encoded"] = label_encoder.fit_transform(labeled["label_1"])

In [None]:
labeled = labeled.sort_values("description")
# labeled = labeled.sample(frac=0.1)


train_label, test_label, train_text, test_text = train_test_split(
    labeled["label_1_encoded"].values,
    labeled["description"].values,
    test_size=0.2,
    random_state=44,
    stratify=labeled["label_1"],
)
print(len(train_label))
print(len(test_label))

In [6]:
########
# Preprocess
########


class ClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long),
        }

In [None]:
# Hypertuning Parameters
# hypers = {"lr": [], "mlen": [], "batch_size": [], "accuracy": []}
# for mlen in [128, 256, 512]:
#     for batch_size in [8, 16, 32]:
#         for lr in [1e-6, 1e-5, 1e-4, 1e-3, 1e-2]:

MAX_LEN = 256
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.000010

#################
# Start fresh
#################
model = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model)
train_dataset = ClassificationDataset(train_text, train_label, tokenizer, MAX_LEN)
test_dataset = ClassificationDataset(test_text, test_label, tokenizer, MAX_LEN)

###########################
# Load Model and Tokenizer
###########################
# model = BertForSequenceClassification.from_pretrained(dir + "bert_grant_classifier_1digit")
# tokenizer = BertTokenizer.from_pretrained(dir + "bert_grant_classifier_1digit")
# label_encoder = torch.load(dir + "label_encoder_1digit.pth", weights_only=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model and Device Setup
assert torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BertForSequenceClassification.from_pretrained(
    model, num_labels=grants["label_1"].nunique()
)
model.to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, leave=True)

    if epoch > 0 and epoch % 2:

        time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")

        model.push_to_hub(repo_id, commit_message=f"checkpoint_{time}_epoch{epoch}")
        tokenizer.push_to_hub(
            repo_id,
            commit_message=f"checkpoint_{time}_epoch{epoch}",
        )

    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch} Loss: {total_loss / len(train_loader)}")

In [None]:
time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")

model.push_to_hub(repo_id, commit_message=f"trained_{time}_epoch{epoch}")
tokenizer.push_to_hub(
    repo_id,
    commit_message=f"trained_{time}_epoch{epoch}",
)

In [None]:
# Evaluation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in val_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=1)

        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Validation Accuracy: {accuracy:.4f}")

In [None]:
# Save
# dir = "/content/gdrive/MyDrive/leaa/"
model.save_pretrained(dir + "bert_grant_classifier_1digit")
tokenizer.save_pretrained(dir + "bert_grant_classifier_1digit")
torch.save(label_encoder, dir + "label_encoder_1digit.pth")

In [17]:
# Look at fine tuning
df = pd.DataFrame(hypers).sort_values("accuracy", ascending=False)
df.groupby("lr")["accuracy"].mean()