In [None]:
import pandas as pd
import torch
from torch.nn import BCELoss
from sklearn.metrics import classification_report, f1_score, accuracy_score
from transformers import *
from tqdm.notebook import trange, tqdm

#### Get cuda if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
for index in range(n_gpu):
    print(torch.cuda.get_device_name(index))

### Get labels

In [None]:
from pathlib import Path
csv_path = Path("path_to_CSVs")
delimiter='\t' # default delimiter is comma
df_train = pd.read_csv(csv_path/'train.csv', delimiter=delimiter)
cols = df_train.columns
label_cols = list(cols[3:])
num_labels = len(label_cols)

### Get dataloaders

In [None]:
batch_size = 8
max_length = 512

In [None]:
train_dataloader = torch.load(f'dataloaders/train_data_loader-{batch_size}-{max_length}')
validation_dataloader = torch.load(f'dataloaders/validation_data_loader-{batch_size}-{max_length}')
test_dataloader = torch.load(f'dataloaders/test_data_loader-{batch_size}-{max_length}')

In [None]:
dataloaders = {
    'train': train_dataloader,
    'dev': validation_dataloader,
    'test': test_dataloader
}

#### helper functions

In [None]:
def best_threshold(true_labels, pred_labels_sigmoid):
    
    true_bools = [tl==1 for tl in true_labels]

    #range of thresholds to test
    micro_thresholds = (np.array(range(101))/100) 

    f1_results, acc_results = [], []
    
    for th in tqdm(micro_thresholds, desc="Best Threshold Calculation", leave=False):
        pred_bools = [pl>th for pl in pred_labels_sigmoid]
        
        micro_f1 = f1_score(true_bools,pred_bools,average='micro')*100
        f1_results.append(micro_f1)


    best_f1_idx = np.argmax(f1_results) #best threshold value

    # Printing and saving classification report
    best_threshold = micro_thresholds[best_f1_idx]
    best_f1_score = f1_results[best_f1_idx]

    best_pred_bools = [pl>best_threshold for pl in pred_labels_sigmoid]
    
    _, best_acc, precision, recall = get_metrics(true_bools, best_pred_bools)
    
    return best_threshold, best_f1_score, best_acc, precision, recall

In [None]:
def best_thresholds(true_labels, pred_labels_sigmoid):
    true_bools = [tl==1 for tl in true_labels]

    micro_thresholds = (np.array(range(101))/100) #calculating micro threshold values

    thresholds = []

    for i in trange(len(label_cols), desc="Best Thresholds Calculation", leave=False) :
        f1_results, flat_acc_results = [], []
        for th in micro_thresholds:
            pred_bools = [pl>th for pl in pred_labels_sigmoid]
            true_lab = [ labels[i] for labels in true_bools]
            pred_lab = [ labels[i] for labels in pred_bools]

            micro_f1 = f1_score(true_lab,pred_lab, average='micro', zero_division=1)
            f1_results.append(micro_f1)

        best_f1_idx = np.argmax(f1_results) #best threshold value

        thresholds.append(micro_thresholds[best_f1_idx])

    best_pred_bools = [ [e>thresh for e,thresh in zip(pl,thresholds)] for pl in pred_labels_sigmoid] 
    
    micro_f1_score, accuracy, precision, recall = get_metrics(true_bools, best_pred_bools)
    return thresholds, micro_f1_score, accuracy, precision, recall

### Metrics

In [None]:
def print_results(method, f1, acc, precision, recall):
    print('\n'+method+' :')
    print('Micro F1-Score =', f1)
    print('Accuracy =', acc)
    print('Micro Avg : precision =', precision, 'recall =', recall)

In [None]:
def get_metrics(true_bools, pred_bools):
    clf_report_optimized = classification_report(true_bools, pred_bools, target_names=label_cols, digits=5, zero_division=0, output_dict=True)
    micro_avg = clf_report_optimized['micro avg']
    f1 = f1_score(true_bools, pred_bools,average='micro')*100
    acc = accuracy_score(true_bools, pred_bools)*100
    precision = micro_avg['precision']*100
    recall = micro_avg['recall']*100
    
    return f1, acc, precision, recall

### Preparing the model

In [None]:
huggingFace_model_name = "flaubert/flaubert_base_cased"
dataset_name = "my dataset" # must be changed
model_name = "FlauBERT"

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(huggingFace_model_name, num_labels=num_labels)
model.cuda()

In [None]:
optimizer_classification = torch.optim.AdamW(model.parameters(),lr=2e-5)
classification_criterion = BCELoss()

### Metrics logging

In [None]:
epochs = 30

best_dev_f1_sgo = -1.0
best_dev_f1_si = -1.0

In [None]:
import numpy as np
import wandb
wandb.login()
config = {"epochs": epochs, "batch_size": batch_size, "seq_max_length": max_length,
          "lr_cls": 2e-5,
         "optimizer": "AdamW"}
config.update({"dataset": dataset_name})

# mode = "disabled"
wandb.init(project="myProject", entity="myEntity", name="RunName", config=config) # project info must be modified

### Training the model

In [None]:
# trange is a tqdm wrapper around the normal python range
for epoch_num in trange(epochs, desc="Epoch", position=0):
    
    for phase in tqdm(['train', 'dev'], leave=False, desc='Phases', position=0):

        # Tracking variables
        true_labels,pred_labels = [], [] 
        epoch_loss = 0 # running loss
        epoch_steps = 0
        
        if phase == 'train': 
            model.train()
            
        if phase == 'dev':
            model.eval()
            
        for step, batch in enumerate(tqdm(dataloaders[phase], leave=False, desc=f"{phase.capitalize()} Dataloader", position=0)):

            # Add batch to device
            batch = tuple(t.to(device) for t in batch)

            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch

            # Forward pass for multilabel classification
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(b_input_ids, attention_mask=b_input_mask)[0]
                classification_logits = outputs
                classification_logits = torch.sigmoid(classification_logits) # apply sigmoid activation function to get classification probabilities
                
            del b_input_ids, b_input_mask, outputs
            torch.cuda.empty_cache()

            #loss calculation
            classification_loss = classification_criterion(classification_logits, b_labels.type_as(classification_logits))
            
            if phase == 'train': 

                # Clear out the gradients 
                optimizer_classification.zero_grad()
                
                # Backward pass
                classification_loss.backward()
                    
                # Update parameters and take a step using the computed gradient
                optimizer_classification.step()

            # Update tracking variables
            epoch_loss += classification_loss.item()
            epoch_steps += 1
            
            # Update Epoch Metrics
            pred_label = classification_logits.detach().to('cpu').numpy()
            b_labels = b_labels.to('cpu').numpy()

            true_labels.append(b_labels)
            pred_labels.append(pred_label)
            


        # Get Epoch Loss
        epoch_loss = epoch_loss/epoch_steps
        loss = {
            'epoch_loss': epoch_loss
        }
        wandb.log({f'{phase.capitalize()}': loss}, commit=False)
        
        # Get Epoch Metrics
        
        # Flatten outputs
        pred_labels = [item for sublist in pred_labels for item in sublist]
        true_labels = [item for sublist in true_labels for item in sublist]
        
        train_threshold, th_micro_f1, th_acc, th_precision, th_recall = best_threshold(true_labels, pred_labels)
        train_thresholds, ths_micro_f1, ths_acc, ths_precision , ths_recall = best_thresholds(true_labels, pred_labels)

        true_bools = true_labels 

        # SGO Metrics 
        pred_bools = [pl>train_threshold for pl in pred_labels] 
        f1_sgo, acc_sgo, precision_sgo, recall_sgo = get_metrics(true_bools, pred_bools)

        
        # SI Metrics
        pred_bools = [ [e>thresh for e,thresh in zip(pl,train_thresholds)] for pl in pred_labels] 
        f1_si, acc_si, precision_si, recall_si = get_metrics(true_bools, pred_bools)


        # Log Epoch Metrics
        metrics_sgo = {
            'F1_score': f1_sgo,
            'Accuracy': acc_sgo,
            'Precision': precision_sgo,
            'Recall': recall_sgo
        }
        wandb.log({f'{phase.capitalize()} SGO': metrics_sgo}, commit=False)
        # print_results("SGO", f1_sgo, acc_sgo, precision_sgo, recall_sgo)
        
        metrics_si = {
            'F1_score': f1_si,
            'Accuracy': acc_si,
            'Precision': precision_si,
            'Recall': recall_si
        }
        wandb.log({f'{phase.capitalize()} SI': metrics_si}, commit=False)
        # print_results("SI", f1_si, acc_si, precision_si, recall_si)
        
        
        # Save model if valid performances are better
        if phase == 'dev':
            if f1_sgo > best_dev_f1_sgo :
                best_dev_f1_sgo = f1_sgo
                torch.save(model.state_dict(), 'state_dicts/'+model_name+'_'+dataset_name+'_best_model_weights_sgo.pt')

            if f1_si > best_dev_f1_si :
                best_dev_f1_si = f1_si
                torch.save(model.state_dict(), 'state_dicts/'+model_name+'_'+dataset_name+'_best_model_weights_si.pt')


    wandb.log(data={}, commit=True)
    
# save last model
torch.save(model.state_dict(), 'state_dicts/last_'+ model_name +'.pt')