## Setup

In [1]:
import numpy as np
import pandas as pd
import os
from transformers import pipeline, set_seed
import requests
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import time
from sklearn.model_selection import train_test_split

In [2]:
# random seed
# SEED = 1
# torch.backends.cudnn.deterministic = True
# torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.cuda.device_count()

2

In [4]:
DEVICE

device(type='cuda')

In [5]:
home_dir = '/g100/home/userexternal/mhabibi0/'
work_dir = '/g100_work/IscrC_mental'

hdata_dir = os.path.join(home_dir, 'Data')
wdata_dir = os.path.join(work_dir, 'data')
uc_dir = os.path.join(wdata_dir, 'user_classification')
model_dir = os.path.join(home_dir, 'Models', 'Gender')

In [6]:
# user age data
path  = os.path.join(uc_dir, 'data_for_models_train.pkl')
df = pd.read_pickle(path)
df['male'] = df['is_male'].astype(int)
df['text']  = 'bio: ' + df['masked_bio'] + '. ' + 'tweets: ' + df['long_text'] 
df['text'] = df['text'].str.replace('\r|\n', ' ', regex=True)
df_train, df_valid = train_test_split(df[['user_id', 'text', 'male']], test_size=0.1, random_state=42)

In [7]:
X_train = df_train['text'].values
y_train = df_train['male'].values

X_valid = df_valid['text'].values
y_valid = df_valid['male'].values

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def compute_metrics(model, data_loader, device):
    with torch.no_grad():
        all_predictions = []
        all_labels = []

        for batch_idx, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get the predicted class labels
            predicted_labels = torch.argmax(logits, dim=1)

            all_predictions.extend(predicted_labels.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        
        accuracy = accuracy_score(all_labels, all_predictions)
        precision = precision_score(all_labels, all_predictions)
        recall = recall_score(all_labels, all_predictions)
        f1 = f1_score(all_labels, all_predictions)

        metrics = {'accuracy': accuracy,  'f1': f1 }

        return metrics


In [9]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [10]:

def batch_tokenize(X_text, tokenizer, max_length=512, batch_size=64):

    # Dictionary to hold tokenized batches
    encodings = {}

    # Calculate the number of batches needed
    num_batches = len(X_text) // batch_size + int(len(X_text) % batch_size > 0)

    # Iterate over the data in batches
    for i in range(num_batches):
        batch_start = i * batch_size
        batch_end = min(len(X_text), (i + 1) * batch_size)

        # Tokenize the current batch of texts
        batch_encodings = tokenizer.batch_encode_plus(
            list(X_text[batch_start:batch_end]),
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

        # Merge the batch tokenizations into the main dictionary
        for key, val in batch_encodings.items():
            if key not in encodings:
                encodings[key] = []
            encodings[key].extend(val)

    return encodings


### Twitter XLM Roberta

In [11]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base")

train_encodings = batch_tokenize(X_train, tokenizer)
valid_encodings = batch_tokenize(X_valid, tokenizer)

train_dataset = TweetDataset(train_encodings, y_train)
valid_dataset = TweetDataset(valid_encodings, y_valid)

In [13]:
# data loaders
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

#### XLm Roberta + freezing + early stopping + linear schedule

In [15]:
from transformers import get_linear_schedule_with_warmup

MODEL = "cardiffnlp/twitter-xlm-roberta-base"
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model = torch.nn.DataParallel(model)
model.to(DEVICE)

NUM_EPOCHS = 100
LR = 2e-5

optim = torch.optim.Adam(model.parameters(), lr=LR)

# Create the learning rate scheduler.
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps = 0,
                                            num_training_steps = total_steps)

# Freeze all layers except the classifier for the first few epochs
freeze_steps = 2
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_f1 = 0.0
epochs_since_improvement = 0

Some weights of the model checkpoint at cardiffnlp/twitter-xlm-roberta-base were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.decoder.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-xlm-roberta-base and are newly initialized: ['classifier.den

In [16]:
# train
start_time = time.time()

for epoch in range(NUM_EPOCHS):

    model.train()
    # Unfreeze all layers after few epochs
    if epoch == freeze_steps:
        model.requires_grad_(True)
            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss = loss.mean()
        loss.backward()
        optim.step()
        scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):

        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_f1 = compute_metrics(model, valid_loader, DEVICE)['f1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            epochs_since_improvement = 0

            # Save the new best model
            path = os.path.join(model_dir ,'XLM_gender.pt')
            torch.save(model, path)

        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')




Epoch: 0001/0100 | Batch 0000/0540 | Loss: 0.6945
Epoch: 0001/0100 | Batch 0050/0540 | Loss: 0.6478
Epoch: 0001/0100 | Batch 0100/0540 | Loss: 0.6338
Epoch: 0001/0100 | Batch 0150/0540 | Loss: 0.7492
Epoch: 0001/0100 | Batch 0200/0540 | Loss: 0.6664
Epoch: 0001/0100 | Batch 0250/0540 | Loss: 0.6279
Epoch: 0001/0100 | Batch 0300/0540 | Loss: 0.6998
Epoch: 0001/0100 | Batch 0350/0540 | Loss: 0.7307
Epoch: 0001/0100 | Batch 0400/0540 | Loss: 0.6923
Epoch: 0001/0100 | Batch 0450/0540 | Loss: 0.5835
Epoch: 0001/0100 | Batch 0500/0540 | Loss: 0.6340
Training metrics: {'accuracy': 0.6279513888888889, 'f1': 0.7714458388140353}%
Valid metrics: {'accuracy': 0.6239583333333333, 'f1': 0.7684413085311097}%
Time elapsed: 6.67 min




Epoch: 0002/0100 | Batch 0000/0540 | Loss: 0.6620
Epoch: 0002/0100 | Batch 0050/0540 | Loss: 0.6861
Epoch: 0002/0100 | Batch 0100/0540 | Loss: 0.7286
Epoch: 0002/0100 | Batch 0150/0540 | Loss: 0.6604
Epoch: 0002/0100 | Batch 0200/0540 | Loss: 0.6608
Epoch: 0002/0100 | Batch 0250/0540 | Loss: 0.6044
Epoch: 0002/0100 | Batch 0300/0540 | Loss: 0.6467
Epoch: 0002/0100 | Batch 0350/0540 | Loss: 0.6007
Epoch: 0002/0100 | Batch 0400/0540 | Loss: 0.7214
Epoch: 0002/0100 | Batch 0450/0540 | Loss: 0.6424
Epoch: 0002/0100 | Batch 0500/0540 | Loss: 0.6268
Training metrics: {'accuracy': 0.6329282407407407, 'f1': 0.7727907726474907}%
Valid metrics: {'accuracy': 0.6317708333333333, 'f1': 0.7712714331931413}%
Time elapsed: 13.04 min




Epoch: 0003/0100 | Batch 0000/0540 | Loss: 0.5385
Epoch: 0003/0100 | Batch 0050/0540 | Loss: 0.6333
Epoch: 0003/0100 | Batch 0100/0540 | Loss: 0.6467
Epoch: 0003/0100 | Batch 0150/0540 | Loss: 0.6048
Epoch: 0003/0100 | Batch 0200/0540 | Loss: 0.4794
Epoch: 0003/0100 | Batch 0250/0540 | Loss: 0.5871
Epoch: 0003/0100 | Batch 0300/0540 | Loss: 0.5807
Epoch: 0003/0100 | Batch 0350/0540 | Loss: 0.4739
Epoch: 0003/0100 | Batch 0400/0540 | Loss: 0.5869
Epoch: 0003/0100 | Batch 0450/0540 | Loss: 0.3140
Epoch: 0003/0100 | Batch 0500/0540 | Loss: 0.3326
Training metrics: {'accuracy': 0.8575231481481481, 'f1': 0.892526628252139}%
Valid metrics: {'accuracy': 0.8322916666666667, 'f1': 0.8744149765990639}%
Time elapsed: 22.29 min




Epoch: 0004/0100 | Batch 0000/0540 | Loss: 0.2827
Epoch: 0004/0100 | Batch 0050/0540 | Loss: 0.2252
Epoch: 0004/0100 | Batch 0100/0540 | Loss: 0.4072
Epoch: 0004/0100 | Batch 0150/0540 | Loss: 0.5517
Epoch: 0004/0100 | Batch 0200/0540 | Loss: 0.1471
Epoch: 0004/0100 | Batch 0250/0540 | Loss: 0.5093
Epoch: 0004/0100 | Batch 0300/0540 | Loss: 0.2729
Epoch: 0004/0100 | Batch 0350/0540 | Loss: 0.2886
Epoch: 0004/0100 | Batch 0400/0540 | Loss: 0.2462
Epoch: 0004/0100 | Batch 0450/0540 | Loss: 0.2729
Epoch: 0004/0100 | Batch 0500/0540 | Loss: 0.3290
Training metrics: {'accuracy': 0.9111689814814815, 'f1': 0.9306026493060264}%
Valid metrics: {'accuracy': 0.8697916666666666, 'f1': 0.8981255093724532}%
Time elapsed: 31.55 min




Epoch: 0005/0100 | Batch 0000/0540 | Loss: 0.3039
Epoch: 0005/0100 | Batch 0050/0540 | Loss: 0.1552
Epoch: 0005/0100 | Batch 0100/0540 | Loss: 0.1184
Epoch: 0005/0100 | Batch 0150/0540 | Loss: 0.2535
Epoch: 0005/0100 | Batch 0200/0540 | Loss: 0.3498
Epoch: 0005/0100 | Batch 0250/0540 | Loss: 0.3193
Epoch: 0005/0100 | Batch 0300/0540 | Loss: 0.3738
Epoch: 0005/0100 | Batch 0350/0540 | Loss: 0.3839
Epoch: 0005/0100 | Batch 0400/0540 | Loss: 0.6283
Epoch: 0005/0100 | Batch 0450/0540 | Loss: 0.4832
Epoch: 0005/0100 | Batch 0500/0540 | Loss: 0.0721
Training metrics: {'accuracy': 0.937037037037037, 'f1': 0.9507246376811594}%
Valid metrics: {'accuracy': 0.8744791666666667, 'f1': 0.9034068136272544}%
Time elapsed: 40.79 min




Epoch: 0006/0100 | Batch 0000/0540 | Loss: 0.3165
Epoch: 0006/0100 | Batch 0050/0540 | Loss: 0.0856
Epoch: 0006/0100 | Batch 0100/0540 | Loss: 0.2454
Epoch: 0006/0100 | Batch 0150/0540 | Loss: 0.2680
Epoch: 0006/0100 | Batch 0200/0540 | Loss: 0.1836
Epoch: 0006/0100 | Batch 0250/0540 | Loss: 0.2434
Epoch: 0006/0100 | Batch 0300/0540 | Loss: 0.1057
Epoch: 0006/0100 | Batch 0350/0540 | Loss: 0.1160
Epoch: 0006/0100 | Batch 0400/0540 | Loss: 0.1054
Epoch: 0006/0100 | Batch 0450/0540 | Loss: 0.1174
Epoch: 0006/0100 | Batch 0500/0540 | Loss: 0.1456
Training metrics: {'accuracy': 0.9605324074074074, 'f1': 0.9688840222648052}%
Valid metrics: {'accuracy': 0.8776041666666666, 'f1': 0.9045879009338207}%
Time elapsed: 51.21 min




Epoch: 0007/0100 | Batch 0000/0540 | Loss: 0.2584
Epoch: 0007/0100 | Batch 0050/0540 | Loss: 0.2949
Epoch: 0007/0100 | Batch 0100/0540 | Loss: 0.1082
Epoch: 0007/0100 | Batch 0150/0540 | Loss: 0.2079
Epoch: 0007/0100 | Batch 0200/0540 | Loss: 0.0361
Epoch: 0007/0100 | Batch 0250/0540 | Loss: 0.2381
Epoch: 0007/0100 | Batch 0300/0540 | Loss: 0.0925
Epoch: 0007/0100 | Batch 0350/0540 | Loss: 0.0760
Epoch: 0007/0100 | Batch 0400/0540 | Loss: 0.0939
Epoch: 0007/0100 | Batch 0450/0540 | Loss: 0.1831
Epoch: 0007/0100 | Batch 0500/0540 | Loss: 0.3349
Training metrics: {'accuracy': 0.9706597222222222, 'f1': 0.9762829208962904}%
Valid metrics: {'accuracy': 0.8630208333333333, 'f1': 0.8876548483554036}%
Time elapsed: 60.43 min




Epoch: 0008/0100 | Batch 0000/0540 | Loss: 0.3770
Epoch: 0008/0100 | Batch 0050/0540 | Loss: 0.0071
Epoch: 0008/0100 | Batch 0100/0540 | Loss: 0.0620
Epoch: 0008/0100 | Batch 0150/0540 | Loss: 0.0521
Epoch: 0008/0100 | Batch 0200/0540 | Loss: 0.0796
Epoch: 0008/0100 | Batch 0250/0540 | Loss: 0.0817
Epoch: 0008/0100 | Batch 0300/0540 | Loss: 0.0209
Epoch: 0008/0100 | Batch 0350/0540 | Loss: 0.1196
Epoch: 0008/0100 | Batch 0400/0540 | Loss: 0.1203
Epoch: 0008/0100 | Batch 0450/0540 | Loss: 0.2236
Epoch: 0008/0100 | Batch 0500/0540 | Loss: 0.0767
Training metrics: {'accuracy': 0.9844328703703704, 'f1': 0.9875376418809358}%
Valid metrics: {'accuracy': 0.8682291666666667, 'f1': 0.8943632567849688}%
Time elapsed: 69.65 min




Epoch: 0009/0100 | Batch 0000/0540 | Loss: 0.0220
Epoch: 0009/0100 | Batch 0050/0540 | Loss: 0.1150
Epoch: 0009/0100 | Batch 0100/0540 | Loss: 0.0669
Epoch: 0009/0100 | Batch 0150/0540 | Loss: 0.1139
Epoch: 0009/0100 | Batch 0200/0540 | Loss: 0.0133
Epoch: 0009/0100 | Batch 0250/0540 | Loss: 0.0177
Epoch: 0009/0100 | Batch 0300/0540 | Loss: 0.1678
Epoch: 0009/0100 | Batch 0350/0540 | Loss: 0.0264
Epoch: 0009/0100 | Batch 0400/0540 | Loss: 0.0596
Epoch: 0009/0100 | Batch 0450/0540 | Loss: 0.2535
Epoch: 0009/0100 | Batch 0500/0540 | Loss: 0.0072
Training metrics: {'accuracy': 0.9867476851851852, 'f1': 0.9894852839891639}%
Valid metrics: {'accuracy': 0.8697916666666666, 'f1': 0.9004777070063695}%
Early stopping triggered.


In [18]:
from transformers import get_linear_schedule_with_warmup

MODEL = "dbmdz/bert-base-italian-xxl-cased"
num_labels = len(age_labels)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model.to(DEVICE)

NUM_EPOCHS = 100
LR = 2e-5

# Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=LR)

# Create the learning rate scheduler.
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps = 0,
                                            num_training_steps = total_steps)

# Freeze all layers except the classifier for the first few epochs
freeze_steps = 2
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_accuracy = 0.0
epochs_since_improvement = 0

Some weights of the model checkpoint at dbmdz/bert-base-italian-xxl-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

In [19]:
# train
start_time = time.time()

for epoch in range(NUM_EPOCHS):

    model.train()
    # Unfreeze all layers after few epochs
    if epoch == freeze_steps:
        for param in model.parameters():
            param.requires_grad = True
            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss.backward()
        optim.step()
        scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):

        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_accuracy = compute_metrics(model, valid_loader, DEVICE)['accuracy']
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            epochs_since_improvement = 0

            # Save the new best model
            path = os.path.join(model_dir ,'bert_italian_mod.pt')
            torch.save(model, path)

        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')


Epoch: 0001/0100 | Batch 0000/0170 | Loss: 1.3928
Epoch: 0001/0100 | Batch 0050/0170 | Loss: 1.3165
Epoch: 0001/0100 | Batch 0100/0170 | Loss: 1.3359
Epoch: 0001/0100 | Batch 0150/0170 | Loss: 1.3451
Training metrics: {'accuracy': 0.3507834101382489, 'macro_f1': 0.1875464430350761}%
Valid metrics: {'accuracy': 0.3544973544973545, 'macro_f1': 0.19339588251402448}%
Time elapsed: 2.18 min
Epoch: 0002/0100 | Batch 0000/0170 | Loss: 1.3295
Epoch: 0002/0100 | Batch 0050/0170 | Loss: 1.3283
Epoch: 0002/0100 | Batch 0100/0170 | Loss: 1.2494
Epoch: 0002/0100 | Batch 0150/0170 | Loss: 1.3359
Training metrics: {'accuracy': 0.35889400921658987, 'macro_f1': 0.21016802056656747}%
Valid metrics: {'accuracy': 0.3637566137566138, 'macro_f1': 0.21248051048822164}%
Time elapsed: 4.13 min
Epoch: 0003/0100 | Batch 0000/0170 | Loss: 1.2263
Epoch: 0003/0100 | Batch 0050/0170 | Loss: 1.3815
Epoch: 0003/0100 | Batch 0100/0170 | Loss: 1.3420
Epoch: 0003/0100 | Batch 0150/0170 | Loss: 1.0971
Training metrics: {'

##   bert-tweet-base-italian-uncased

In [21]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("osiria/bert-tweet-base-italian-uncased")

train_encodings = batch_tokenize(X_train, tokenizer)
valid_encodings = batch_tokenize(X_valid, tokenizer)

train_dataset = TweetDataset(train_encodings, y_train)
valid_dataset = TweetDataset(valid_encodings, y_valid)

In [22]:
# data loaders
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

#### Bertweet + freeze + early stopping + linear schedule

In [25]:
from transformers import get_linear_schedule_with_warmup

MODEL = "osiria/bert-tweet-base-italian-uncased"
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model = torch.nn.DataParallel(model)
model.to(DEVICE)

NUM_EPOCHS = 100
LR = 2e-5

# Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=LR)

# Create the learning rate scheduler.
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps = 0,
                                            num_training_steps = total_steps)

# Freeze all layers except the classifier for the first few epochs
freeze_steps = 2
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_f1 = 0.0
epochs_since_improvement = 0

Some weights of the model checkpoint at osiria/bert-tweet-base-italian-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at osiria/bert-tweet-base-italian-uncased and are newly initialized: ['clas

In [26]:
# train
start_time = time.time()

for epoch in range(NUM_EPOCHS):

    model.train()
    # Unfreeze all layers after few epochs
    if epoch == freeze_steps:
        model.requires_grad_(True)
            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss = loss.mean()
        loss.backward()
        optim.step()
        scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):

        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_f1 = compute_metrics(model, valid_loader, DEVICE)['f1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            epochs_since_improvement = 0

            # Save the new best model
            path = os.path.join(model_dir ,'bertweet_italian_mod.pt')
            torch.save(model, path)

        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')


Epoch: 0001/0100 | Batch 0000/0540 | Loss: 0.7959
Epoch: 0001/0100 | Batch 0050/0540 | Loss: 0.6919
Epoch: 0001/0100 | Batch 0100/0540 | Loss: 0.6879
Epoch: 0001/0100 | Batch 0150/0540 | Loss: 0.6742
Epoch: 0001/0100 | Batch 0200/0540 | Loss: 0.6685
Epoch: 0001/0100 | Batch 0250/0540 | Loss: 0.7083
Epoch: 0001/0100 | Batch 0300/0540 | Loss: 0.6494
Epoch: 0001/0100 | Batch 0350/0540 | Loss: 0.6638
Epoch: 0001/0100 | Batch 0400/0540 | Loss: 0.6609
Epoch: 0001/0100 | Batch 0450/0540 | Loss: 0.7390
Epoch: 0001/0100 | Batch 0500/0540 | Loss: 0.6568
Training metrics: {'accuracy': 0.6261574074074074, 'f1': 0.7698446629613795}%
Valid metrics: {'accuracy': 0.6234375, 'f1': 0.7675988428158148}%
Time elapsed: 4.91 min




Epoch: 0002/0100 | Batch 0000/0540 | Loss: 0.6754
Epoch: 0002/0100 | Batch 0050/0540 | Loss: 0.6671
Epoch: 0002/0100 | Batch 0100/0540 | Loss: 0.5582
Epoch: 0002/0100 | Batch 0150/0540 | Loss: 0.6466
Epoch: 0002/0100 | Batch 0200/0540 | Loss: 0.8088
Epoch: 0002/0100 | Batch 0250/0540 | Loss: 0.6261
Epoch: 0002/0100 | Batch 0300/0540 | Loss: 0.6633
Epoch: 0002/0100 | Batch 0350/0540 | Loss: 0.7375
Epoch: 0002/0100 | Batch 0400/0540 | Loss: 0.6848
Epoch: 0002/0100 | Batch 0450/0540 | Loss: 0.6685
Epoch: 0002/0100 | Batch 0500/0540 | Loss: 0.7270
Training metrics: {'accuracy': 0.6274305555555556, 'f1': 0.7709711846318036}%
Valid metrics: {'accuracy': 0.6239583333333333, 'f1': 0.7682926829268294}%
Time elapsed: 9.81 min




Epoch: 0003/0100 | Batch 0000/0540 | Loss: 0.5864
Epoch: 0003/0100 | Batch 0050/0540 | Loss: 0.6693
Epoch: 0003/0100 | Batch 0100/0540 | Loss: 0.6063
Epoch: 0003/0100 | Batch 0150/0540 | Loss: 0.6274
Epoch: 0003/0100 | Batch 0200/0540 | Loss: 0.7238
Epoch: 0003/0100 | Batch 0250/0540 | Loss: 0.7661
Epoch: 0003/0100 | Batch 0300/0540 | Loss: 0.6638
Epoch: 0003/0100 | Batch 0350/0540 | Loss: 0.6677
Epoch: 0003/0100 | Batch 0400/0540 | Loss: 0.7035
Epoch: 0003/0100 | Batch 0450/0540 | Loss: 0.6897
Epoch: 0003/0100 | Batch 0500/0540 | Loss: 0.6378
Training metrics: {'accuracy': 0.6280671296296296, 'f1': 0.771533184031851}%
Valid metrics: {'accuracy': 0.6244791666666667, 'f1': 0.7688361654376403}%
Time elapsed: 17.71 min




Epoch: 0004/0100 | Batch 0000/0540 | Loss: 0.7346
Epoch: 0004/0100 | Batch 0050/0540 | Loss: 0.6220
Epoch: 0004/0100 | Batch 0100/0540 | Loss: 0.6634
Epoch: 0004/0100 | Batch 0150/0540 | Loss: 0.7231
Epoch: 0004/0100 | Batch 0200/0540 | Loss: 0.4587
Epoch: 0004/0100 | Batch 0250/0540 | Loss: 0.6083
Epoch: 0004/0100 | Batch 0300/0540 | Loss: 0.5420
Epoch: 0004/0100 | Batch 0350/0540 | Loss: 0.5853
Epoch: 0004/0100 | Batch 0400/0540 | Loss: 0.4371
Epoch: 0004/0100 | Batch 0450/0540 | Loss: 0.6375
Epoch: 0004/0100 | Batch 0500/0540 | Loss: 0.5388
Training metrics: {'accuracy': 0.7006365740740741, 'f1': 0.7283801522709372}%
Valid metrics: {'accuracy': 0.7083333333333334, 'f1': 0.7388059701492536}%
Time elapsed: 25.59 min




Epoch: 0005/0100 | Batch 0000/0540 | Loss: 0.5784
Epoch: 0005/0100 | Batch 0050/0540 | Loss: 0.6426
Epoch: 0005/0100 | Batch 0100/0540 | Loss: 0.4216
Epoch: 0005/0100 | Batch 0150/0540 | Loss: 0.5253
Epoch: 0005/0100 | Batch 0200/0540 | Loss: 0.5357
Epoch: 0005/0100 | Batch 0250/0540 | Loss: 0.3689
Epoch: 0005/0100 | Batch 0300/0540 | Loss: 0.2988
Epoch: 0005/0100 | Batch 0350/0540 | Loss: 0.5290
Epoch: 0005/0100 | Batch 0400/0540 | Loss: 0.3966
Epoch: 0005/0100 | Batch 0450/0540 | Loss: 0.5680
Epoch: 0005/0100 | Batch 0500/0540 | Loss: 0.3909
Training metrics: {'accuracy': 0.8372106481481482, 'f1': 0.8690348712696122}%
Valid metrics: {'accuracy': 0.7973958333333333, 'f1': 0.8374425407438362}%
Time elapsed: 33.56 min




Epoch: 0006/0100 | Batch 0000/0540 | Loss: 0.3559
Epoch: 0006/0100 | Batch 0050/0540 | Loss: 0.4581
Epoch: 0006/0100 | Batch 0100/0540 | Loss: 0.3671
Epoch: 0006/0100 | Batch 0150/0540 | Loss: 0.4718
Epoch: 0006/0100 | Batch 0200/0540 | Loss: 0.5333
Epoch: 0006/0100 | Batch 0250/0540 | Loss: 0.3611
Epoch: 0006/0100 | Batch 0300/0540 | Loss: 0.3494
Epoch: 0006/0100 | Batch 0350/0540 | Loss: 0.2886
Epoch: 0006/0100 | Batch 0400/0540 | Loss: 0.3979
Epoch: 0006/0100 | Batch 0450/0540 | Loss: 0.3670
Epoch: 0006/0100 | Batch 0500/0540 | Loss: 0.4053
Training metrics: {'accuracy': 0.8926504629629629, 'f1': 0.9162641628673318}%
Valid metrics: {'accuracy': 0.825, 'f1': 0.8650602409638554}%
Time elapsed: 41.47 min




Epoch: 0007/0100 | Batch 0000/0540 | Loss: 0.4923
Epoch: 0007/0100 | Batch 0050/0540 | Loss: 0.4560
Epoch: 0007/0100 | Batch 0100/0540 | Loss: 0.1444
Epoch: 0007/0100 | Batch 0150/0540 | Loss: 0.2568
Epoch: 0007/0100 | Batch 0200/0540 | Loss: 0.3323
Epoch: 0007/0100 | Batch 0250/0540 | Loss: 0.1916
Epoch: 0007/0100 | Batch 0300/0540 | Loss: 0.3878
Epoch: 0007/0100 | Batch 0350/0540 | Loss: 0.2723
Epoch: 0007/0100 | Batch 0400/0540 | Loss: 0.3427
Epoch: 0007/0100 | Batch 0450/0540 | Loss: 0.4365
Epoch: 0007/0100 | Batch 0500/0540 | Loss: 0.3216
Training metrics: {'accuracy': 0.9177662037037037, 'f1': 0.935820423648435}%
Valid metrics: {'accuracy': 0.8291666666666667, 'f1': 0.8692185007974481}%
Time elapsed: 49.37 min




Epoch: 0008/0100 | Batch 0000/0540 | Loss: 0.1603
Epoch: 0008/0100 | Batch 0050/0540 | Loss: 0.1259
Epoch: 0008/0100 | Batch 0100/0540 | Loss: 0.3004
Epoch: 0008/0100 | Batch 0150/0540 | Loss: 0.1821
Epoch: 0008/0100 | Batch 0200/0540 | Loss: 0.2238
Epoch: 0008/0100 | Batch 0250/0540 | Loss: 0.2234
Epoch: 0008/0100 | Batch 0300/0540 | Loss: 0.4191
Epoch: 0008/0100 | Batch 0350/0540 | Loss: 0.2302
Epoch: 0008/0100 | Batch 0400/0540 | Loss: 0.2550
Epoch: 0008/0100 | Batch 0450/0540 | Loss: 0.4305
Epoch: 0008/0100 | Batch 0500/0540 | Loss: 0.2485
Training metrics: {'accuracy': 0.9487847222222222, 'f1': 0.95871623827961}%
Valid metrics: {'accuracy': 0.8458333333333333, 'f1': 0.8756302521008403}%
Time elapsed: 57.27 min




Epoch: 0009/0100 | Batch 0000/0540 | Loss: 0.2650
Epoch: 0009/0100 | Batch 0050/0540 | Loss: 0.1095
Epoch: 0009/0100 | Batch 0100/0540 | Loss: 0.1863
Epoch: 0009/0100 | Batch 0150/0540 | Loss: 0.1272
Epoch: 0009/0100 | Batch 0200/0540 | Loss: 0.2077
Epoch: 0009/0100 | Batch 0250/0540 | Loss: 0.2080
Epoch: 0009/0100 | Batch 0300/0540 | Loss: 0.3144
Epoch: 0009/0100 | Batch 0350/0540 | Loss: 0.5061
Epoch: 0009/0100 | Batch 0400/0540 | Loss: 0.1703
Epoch: 0009/0100 | Batch 0450/0540 | Loss: 0.1008
Epoch: 0009/0100 | Batch 0500/0540 | Loss: 0.1985
Training metrics: {'accuracy': 0.9491319444444445, 'f1': 0.9583708264267108}%
Valid metrics: {'accuracy': 0.81875, 'f1': 0.8477690288713912}%
Time elapsed: 65.16 min




Epoch: 0010/0100 | Batch 0000/0540 | Loss: 0.0369
Epoch: 0010/0100 | Batch 0050/0540 | Loss: 0.1665
Epoch: 0010/0100 | Batch 0100/0540 | Loss: 0.2317
Epoch: 0010/0100 | Batch 0150/0540 | Loss: 0.1235
Epoch: 0010/0100 | Batch 0200/0540 | Loss: 0.0907
Epoch: 0010/0100 | Batch 0250/0540 | Loss: 0.0995
Epoch: 0010/0100 | Batch 0300/0540 | Loss: 0.0924
Epoch: 0010/0100 | Batch 0350/0540 | Loss: 0.0880
Epoch: 0010/0100 | Batch 0400/0540 | Loss: 0.0983
Epoch: 0010/0100 | Batch 0450/0540 | Loss: 0.2983
Epoch: 0010/0100 | Batch 0500/0540 | Loss: 0.2550
Training metrics: {'accuracy': 0.9758680555555556, 'f1': 0.9809597735263229}%
Valid metrics: {'accuracy': 0.834375, 'f1': 0.8746056782334386}%
Time elapsed: 73.05 min




Epoch: 0011/0100 | Batch 0000/0540 | Loss: 0.1784
Epoch: 0011/0100 | Batch 0050/0540 | Loss: 0.1247
Epoch: 0011/0100 | Batch 0100/0540 | Loss: 0.0853
Epoch: 0011/0100 | Batch 0150/0540 | Loss: 0.1776
Epoch: 0011/0100 | Batch 0200/0540 | Loss: 0.1725
Epoch: 0011/0100 | Batch 0250/0540 | Loss: 0.2205
Epoch: 0011/0100 | Batch 0300/0540 | Loss: 0.0680
Epoch: 0011/0100 | Batch 0350/0540 | Loss: 0.0268
Epoch: 0011/0100 | Batch 0400/0540 | Loss: 0.0997
Epoch: 0011/0100 | Batch 0450/0540 | Loss: 0.0396
Epoch: 0011/0100 | Batch 0500/0540 | Loss: 0.1452
Training metrics: {'accuracy': 0.9824652777777778, 'f1': 0.9860670437301696}%
Valid metrics: {'accuracy': 0.8385416666666666, 'f1': 0.8759007205764612}%
Time elapsed: 80.95 min




Epoch: 0012/0100 | Batch 0000/0540 | Loss: 0.0498
Epoch: 0012/0100 | Batch 0050/0540 | Loss: 0.0256
Epoch: 0012/0100 | Batch 0100/0540 | Loss: 0.0495
Epoch: 0012/0100 | Batch 0150/0540 | Loss: 0.0497
Epoch: 0012/0100 | Batch 0200/0540 | Loss: 0.1172
Epoch: 0012/0100 | Batch 0250/0540 | Loss: 0.0461
Epoch: 0012/0100 | Batch 0300/0540 | Loss: 0.2662
Epoch: 0012/0100 | Batch 0350/0540 | Loss: 0.0915
Epoch: 0012/0100 | Batch 0400/0540 | Loss: 0.0552
Epoch: 0012/0100 | Batch 0450/0540 | Loss: 0.0431
Epoch: 0012/0100 | Batch 0500/0540 | Loss: 0.2111
Training metrics: {'accuracy': 0.9813078703703704, 'f1': 0.9849578540492712}%
Valid metrics: {'accuracy': 0.8203125, 'f1': 0.8514851485148516}%
Time elapsed: 88.84 min




Epoch: 0013/0100 | Batch 0000/0540 | Loss: 0.0172
Epoch: 0013/0100 | Batch 0050/0540 | Loss: 0.0169
Epoch: 0013/0100 | Batch 0100/0540 | Loss: 0.1370
Epoch: 0013/0100 | Batch 0150/0540 | Loss: 0.0096
Epoch: 0013/0100 | Batch 0200/0540 | Loss: 0.0661
Epoch: 0013/0100 | Batch 0250/0540 | Loss: 0.0183
Epoch: 0013/0100 | Batch 0300/0540 | Loss: 0.0323
Epoch: 0013/0100 | Batch 0350/0540 | Loss: 0.0353
Epoch: 0013/0100 | Batch 0400/0540 | Loss: 0.1551
Epoch: 0013/0100 | Batch 0450/0540 | Loss: 0.0150
Epoch: 0013/0100 | Batch 0500/0540 | Loss: 0.0171
Training metrics: {'accuracy': 0.990625, 'f1': 0.9925318089618292}%
Valid metrics: {'accuracy': 0.8333333333333334, 'f1': 0.8686371100164204}%
Time elapsed: 96.93 min




Epoch: 0014/0100 | Batch 0000/0540 | Loss: 0.0462
Epoch: 0014/0100 | Batch 0050/0540 | Loss: 0.1521
Epoch: 0014/0100 | Batch 0100/0540 | Loss: 0.1544
Epoch: 0014/0100 | Batch 0150/0540 | Loss: 0.0408
Epoch: 0014/0100 | Batch 0200/0540 | Loss: 0.1000
Epoch: 0014/0100 | Batch 0250/0540 | Loss: 0.0293
Epoch: 0014/0100 | Batch 0300/0540 | Loss: 0.0048
Epoch: 0014/0100 | Batch 0350/0540 | Loss: 0.1256
Epoch: 0014/0100 | Batch 0400/0540 | Loss: 0.0114
Epoch: 0014/0100 | Batch 0450/0540 | Loss: 0.0062
Epoch: 0014/0100 | Batch 0500/0540 | Loss: 0.0364
Training metrics: {'accuracy': 0.9939814814814815, 'f1': 0.9952082565425728}%
Valid metrics: {'accuracy': 0.8333333333333334, 'f1': 0.8697068403908794}%
Early stopping triggered.
