In [None]:
!pip install simpletransformers transformers==4.40.2

In [None]:

# Load the required packages

# Dataframes
import pandas as pd, numpy as np

# Regular expressions
import re

# Unidecoder
import unicodedata

# Timestamp / time measurment
import time

# for train/test data preparation
from sklearn.model_selection import train_test_split

# Label encode
from sklearn.preprocessing import LabelEncoder

# Class weights
from sklearn.utils.class_weight import compute_class_weight

# Model performance scores
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

# Simpletransformers classifier
from simpletransformers.classification import ClassificationModel, ClassificationArgs

# PyTorch: enable GPU access
import torch

# For logging and wandb
import logging
import wandb

from google.colab import drive
drive.mount('/content/drive')


In [None]:
# required functions
def f1_class(labels, preds):
    return f1_score(labels, preds, average='binary')
def precision(labels, preds):
    return precision_score(labels, preds, average='binary')
def recall(labels, preds):
    return recall_score(labels, preds, average='binary')


In [None]:
cd /content/drive/MyDrive/your_working_directory

## load training data

In [None]:
dat = pd.read_csv('training_data.csv')

dat['final_climate']=dat['final_climate'].astype(int)

In [None]:
# set qs_id as index
dat.set_index("qs_new", drop = False, inplace = True, verify_integrity = True)

In [None]:
# make numeric labels
dat["label"] = dat["final_climate"].astype("category").cat.codes
dat["label"].value_counts()

In [None]:
# make language-class stratification variable
# from https://stackoverflow.com/a/62918682
dat["strata_"] = dat.set_index(['language','label']).index.factorize()[0]


## make data splits

In [None]:

#  train/test split
train_ids, test_ids = train_test_split(dat.index.values, test_size = .25, stratify = dat.strata_.values)
train_ids, val_ids = train_test_split(train_ids, test_size = .3, stratify = dat.loc[train_ids].strata_.values)

print(len(train_ids), "training samples")
print(len(test_ids), "test samples")
print(len(val_ids), "val samples")


In [None]:
# # create train, val, test dfs
train_df = pd.DataFrame(zip(train_ids,dat.loc[train_ids]['original_text'].values,dat.loc[train_ids]['label'].values),columns=['qs_new','text','labels'])
test_df = pd.DataFrame(zip(test_ids,dat.loc[test_ids]['original_text'].values,dat.loc[test_ids]['label'].values),columns=['qs_new','text','labels'])
val_df = pd.DataFrame(zip(val_ids,dat.loc[val_ids]['original_text'].values,dat.loc[val_ids]['label'].values),columns=['qs_new','text','labels'])


In [None]:
# set qs_id as index b
train_df.set_index("qs_new", drop = False, inplace = True, verify_integrity = True)
test_df.set_index("qs_new", drop = False, inplace = True, verify_integrity = True)
val_df.set_index("qs_new", drop = False, inplace = True, verify_integrity = True)

In [None]:
# Load the label encoder
label_encoder = LabelEncoder()

# Encode the labels
train_df['labels'] = label_encoder.fit_transform(train_df.labels)
test_df['labels'] = label_encoder.fit_transform(test_df.labels)
val_df['labels'] = label_encoder.fit_transform(val_df.labels)

In [None]:
train_df = pd.read_csv(f'train_ft.csv')
test_df = pd.read_csv(f'test_ft.csv')
val_df = pd.read_csv(f'val_ft.csv')

In [None]:
# set up configuration with the hyperparameter ranges you want to check

sweep_config = {
    "method": "bayes",  # grid, random
    "metric": {"name": "f1_eval", "goal": "maximize"},# eval_loss?
    "parameters": {
        "num_train_epochs": {"values": [2]},
        "train_batch_size":{"values":[8,16]},
        "learning_rate": {"min": 1e-5, "max": 9e-4},
        "weight_decay":{"min":0.0,"max":0.15},
        'use_class_weights':{'values':[0,1]},
        'stride':{'min':0.0, 'max':1.0},
        'hidden_dropout_prob':{'min':0.1, 'max':0.3},
        'attention_probs_dropout_prob':{'min':0.1, 'max':0.3}
    },
}


In [None]:
# create sweep -- need wandb API key
sweep_id = wandb.sweep(sweep_config, project="roberta_sweep")


In [None]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

In [None]:
model_type = "xlmroberta"
model_name = "xlm-roberta-base"

model_args = ClassificationArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.evaluate_during_training = True
evaluate_during_training_verbose=True
model_args.manual_seed = 4
model_args.num_train_epochs = 2
model_args.use_multiprocessing = True
model_args.learning_rate=1e-05
model_args.train_batch_size = 8
#model_args.eval_batch_size = 8
model_args.max_seq_length=256
model_args.labels_list = [0,1] # UPDATE
model_args.sliding_window = True
#model_args.stride = 0.6
model_args.no_save = True
#model_args.save_model_every_epoch=False
model_args.save_optimizer_and_scheduler=False
model_args.wandb_project = "roberta_sweep"

In [None]:
def train():
    # Initialize a new wandb run
    wandb.init()

    # Create a TransformerModel
    model = ClassificationModel(
        model_type,
        model_name,
        num_labels = 2,
        weight = weights,
        use_cuda=True,
        args=model_args,
        sweep_config=wandb.config,
    )

    # Train the model, specify metrics
    model.train_model(train_df,
                      eval_df=test_df,
                      accuracy=accuracy_score,
                      f1_train=f1_class)

    # Evaluate the model
    eval_res,_,_ = model.eval_model(val_df,
                     accuracy_eval=accuracy_score,
                     recall_eval=recall,
                     precision_eval=precision,
                     f1_eval=f1_class)

    # add metrics to evaluate
    wandb.log({'f1_eval':eval_res['f1_eval'],
               'fp_eval':eval_res['fp'],
               'fn_eval':eval_res['fn'],
               'tn_eval':eval_res['tn'],
               'tp_eval':eval_res['tp']
               })

    # Sync wandb
    wandb.join()



In [None]:
wandb.agent(sweep_id, train)

The above will send all the results to the respective wandb sweep page where you can view and export results to determine the highest values depending on what you want to see.