In [None]:
import pandas as pd
import torch
from torch.nn import BCELoss
from sklearn.metrics import classification_report, f1_score, accuracy_score
from transformers import *
from tqdm.notebook import trange, tqdm

#### Get cuda if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
for index in range(n_gpu):
    print(torch.cuda.get_device_name(index))

### Get labels

In [None]:
from pathlib import Path
csv_path = Path("path_to_CSVs")
delimiter='\t' # default delimiter is comma
df_train = pd.read_csv(csv_path/'train.csv', delimiter=delimiter) 
cols = df_train.columns
label_cols = list(cols[3:])
num_labels = len(label_cols)

### Get dataloaders

In [None]:
batch_size = 8
max_length = 512

In [None]:
train_dataloader = torch.load(f'dataloaders/train_data_loader-{batch_size}-{max_length}')
validation_dataloader = torch.load(f'dataloaders/validation_data_loader-{batch_size}-{max_length}')
test_dataloader = torch.load(f'dataloaders/test_data_loader-{batch_size}-{max_length}')

In [None]:
dataloaders = {
    'train': train_dataloader,
    'dev': validation_dataloader,
    'test': test_dataloader
}

In [None]:
huggingFace_model_name = "flaubert/flaubert_base_cased"
dataset_name = "my dataset"
model_name = "FlauBERT"

### Metrics

In [None]:
def print_results(method, f1, acc, precision, recall):
    print('\n'+method+' :')
    print('Micro F1-Score =', f1)
    print('Accuracy =', acc)
    print('Micro Avg : precision =', precision, 'recall =', recall)

In [None]:
def get_metrics(true_bools, pred_bools):
    clf_report_optimized = classification_report(true_bools, pred_bools, target_names=label_cols, digits=5, zero_division=1, output_dict=True)
    micro_avg = clf_report_optimized['micro avg']
    f1 = f1_score(true_bools, pred_bools,average='micro')*100
    acc = accuracy_score(true_bools, pred_bools)*100
    precision = micro_avg['precision']*100
    recall = micro_avg['recall']*100
    
    return f1, acc, precision, recall

### Preparing the model

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(huggingFace_model_name, num_labels=num_labels)
model.cuda()

In [None]:
class Classifier(torch.nn.Module):

    def __init__(self, input_size, out_features):
        super(Classifier, self).__init__()
        self.dense_1 = torch.nn.Linear(in_features=input_size, out_features=out_features, bias=True)
        self.dense_2 = torch.nn.Linear(in_features=out_features, out_features=out_features, bias=True)
        self.activation_2 = torch.nn.Sigmoid()
        
        
    def forward(self, x):
        out = self.dense_1(x)
        out2 = self.dense_2(out)
        out2 = self.activation_2(out2)
        return out, out2

In [None]:
model.sequence_summary.summary = Classifier(input_size=768, out_features=num_labels)
model.cuda()

In [None]:
# setting custom optimization parameters. You may implement a scheduler here as well.
param_optimizer = list(model.named_parameters())

# every layer except the last one
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not 'dense_2' in n]}
]

#isolate the last layer
optimizer_grouped_parameters_thresh = [
    {'params': [p for n, p in param_optimizer if 'dense_2' in n]}
]

In [None]:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=1e-5)
optimizer_thresh = torch.optim.AdamW(optimizer_grouped_parameters_thresh, lr=1e-5)

classification_criterion = torch.nn.BCEWithLogitsLoss() 
threshold_criterion = torch.nn.BCELoss() 

### Metrics logging

In [None]:
epochs = 30
best_dev_f1 = -1
threshold = 0.5

In [None]:
import numpy as np
import wandb
# a wandb account is needed to log training metrics
wandb.login()
config = {"epochs": epochs, "batch_size": batch_size, "seq_max_length": max_length,
          "lr_cls": 2e-5, "threshold": threshold,
         "optimizer": "AdamW"}
config.update({"dataset": dataset_name})

mode = "disabled"
wandb.init(project="myProject", entity="myEntity", name="RunName", config=config) # project info must be modified

In [None]:
# trange is a tqdm wrapper around the normal python range
for epoch_num in trange(epochs, desc="Epoch", position=0):
    
    for phase in tqdm(['train', 'dev'], leave=False, desc='Phases', position=0):

        # Tracking variables
        true_labels,pred_labels = [], [] 
        epoch_loss, cls_loss, thresh_loss = 0, 0, 0 #running losses
        epoch_steps = 0
        
        if phase == 'train': 
            model.train()
            
        if phase == 'dev':
            model.eval()
            
        for step, batch in enumerate(tqdm(dataloaders[phase], leave=False, desc=f"{phase.capitalize()} Dataloader", position=0)):

            # Add batch to device
            batch = tuple(t.to(device) for t in batch)

            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch

            # Forward pass for multilabel classification
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(b_input_ids, attention_mask=b_input_mask)[0]
                classification_logits, thresholding_logits = outputs     
                
            del b_input_ids, b_input_mask, outputs
            torch.cuda.empty_cache()

            #loss calculation
            classification_loss = classification_criterion(classification_logits, b_labels.type_as(classification_logits))
            thresholding_loss = threshold_criterion(thresholding_logits, b_labels.type_as(thresholding_logits))
            loss = classification_loss + thresholding_loss
            
            if phase == 'train': 

                # Clear out the gradients 
                optimizer.zero_grad()
                optimizer_thresh.zero_grad()
                
                # Backward pass
                loss.backward()
                    
                # Update parameters and take a step using the computed gradient
                optimizer.step()
                optimizer_thresh.step()

            # Update tracking variables
            epoch_loss += loss.item()
            cls_loss += classification_loss.item()
            thresh_loss += thresholding_loss.item()
            epoch_steps += 1
            
            # Batch Predictions/true_labels
            pred_label = thresholding_logits.detach().to('cpu').numpy()
            b_labels = b_labels.to('cpu').numpy()

            true_labels.append(b_labels)
            pred_labels.append(pred_label)
            

        # Get Epoch Loss
        epoch_loss = epoch_loss/epoch_steps
        cls_loss = cls_loss/epoch_steps
        thresh_loss = thresh_loss/epoch_steps
        
        # Get Epoch Metrics
        # Make list out of all the batch predictions/true_labels
        pred_labels = [item for sublist in pred_labels for item in sublist]
        true_labels = [item for sublist in true_labels for item in sublist]
        true_bools = true_labels 

        # TL Metrics 
        pred_bools = [pl>threshold for pl in pred_labels] 
        f1_accuracy, acc, precision, recall = get_metrics(true_bools, pred_bools)

        # Log Epoch Metrics
        metrics = {
            'F1_score': f1_accuracy,
            'Accuracy': acc,
            'Precision': precision,
            'Recall': recall,
            'epoch_loss': epoch_loss,
            'cls_loss': cls_loss,
            'thresh_loss': thresh_loss
        }
        wandb.log({f'{phase.capitalize()}': metrics}, commit=False)
        # print_results("TL", f1_sgo, acc_sgo, precision_sgo, recall_sgo)

        # Save model if valid performances are better
        if phase == 'dev':
            if f1_accuracy > best_dev_f1 :
                best_dev_f1 = f1_accuracy
                torch.save(model.state_dict(), 'state_dicts/'+model_name+'_'+dataset_name+'_best_model_weights_tl.pt')

    wandb.log(data={}, commit=True)
    
# save last model
torch.save(model.state_dict(), 'state_dicts/last_'+ model_name +'.pt')