## Setup

In [27]:
import numpy as np
import pandas as pd
import os
from transformers import pipeline, set_seed
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import time
from sklearn.model_selection import train_test_split

In [2]:
# random seed
# SEED = 1
# torch.backends.cudnn.deterministic = True
# torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.cuda.device_count()

2

In [4]:
DEVICE

device(type='cuda')

In [5]:
home_dir = '/g100/home/userexternal/mhabibi0/'
work_dir = '/g100_work/IscrC_mental'

hdata_dir = os.path.join(home_dir, 'Data')
wdata_dir = os.path.join(work_dir, 'data')
uc_dir = os.path.join(wdata_dir, 'user_classification')
model_dir = os.path.join(home_dir, 'Models', 'Age')

In [6]:
# user age data
path  = os.path.join(uc_dir, 'data_for_models_train.pkl')
df = pd.read_pickle(path)

In [7]:
# user age data
path  = os.path.join(uc_dir, 'data_for_models_train.pkl')
df = pd.read_pickle(path)


# Discretize the 'age' column into four classes
age_intervals = [0, 19, 30, 40, 100]
age_labels = [0, 1, 2, 3]
df['age_class'] = pd.cut(df['age'], bins=age_intervals, labels=age_labels, right=False)

# create input text
df['text']  = 'bio: ' + df['masked_bio'] + '. ' + 'tweets: ' + df['long_text'] 
df['text'] = df['text'].str.replace('\r|\n', ' ', regex=True)

# train valid split
df_train, df_valid = train_test_split(df[['user_id', 'text', 'age_class']], test_size=0.1, random_state=42)

In [8]:
X_train = df_train['text'].values
y_train = df_train['age_class'].values

X_valid = df_valid['text'].values
y_valid = df_valid['age_class'].values

In [9]:
from sklearn.metrics import f1_score, accuracy_score

def compute_metrics(model, data_loader, device):
    with torch.no_grad():
        all_predictions = []
        all_labels = []

        for batch_idx, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']

            predicted_labels = torch.argmax(logits, 1)

            all_predictions.extend(predicted_labels.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_predictions)
        macro_f1 = f1_score(all_labels, all_predictions, average='macro')

        metrics = {
            'accuracy': accuracy,
            'f1': macro_f1
        }

        return metrics


In [10]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [11]:

def batch_tokenize(X_text, tokenizer, max_length=512, batch_size=64):

    # Dictionary to hold tokenized batches
    encodings = {}

    # Calculate the number of batches needed
    num_batches = len(X_text) // batch_size + int(len(X_text) % batch_size > 0)

    # Iterate over the data in batches
    for i in range(num_batches):
        batch_start = i * batch_size
        batch_end = min(len(X_text), (i + 1) * batch_size)

        # Tokenize the current batch of texts
        batch_encodings = tokenizer.batch_encode_plus(
            list(X_text[batch_start:batch_end]),
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

        # Merge the batch tokenizations into the main dictionary
        for key, val in batch_encodings.items():
            if key not in encodings:
                encodings[key] = []
            encodings[key].extend(val)

    return encodings


### Twitter XLM Roberta

In [12]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base")

train_encodings = batch_tokenize(X_train, tokenizer)
valid_encodings = batch_tokenize(X_valid, tokenizer)

train_dataset = TweetDataset(train_encodings, y_train)
valid_dataset = TweetDataset(valid_encodings, y_valid)

In [13]:
# data loaders
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

#### XLm Roberta + freezing + early stopping + linear schedule

In [14]:
from transformers import get_linear_schedule_with_warmup

MODEL = "cardiffnlp/twitter-xlm-roberta-base"
num_labels = 4
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model = torch.nn.DataParallel(model)
model.to(DEVICE)

NUM_EPOCHS = 100
LR = 2e-5

optim = torch.optim.Adam(model.parameters(), lr=LR)

# Create the learning rate scheduler.
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps = 0,
                                            num_training_steps = total_steps)

# Freeze all layers except the classifier for the first few epochs
freeze_steps = 2
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_f1 = 0.0
epochs_since_improvement = 0

Some weights of the model checkpoint at cardiffnlp/twitter-xlm-roberta-base were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-xlm-roberta-base and are newly initialized: ['classifier.den

In [15]:
# train
start_time = time.time()

for epoch in range(NUM_EPOCHS):

    model.train()
    # Unfreeze all layers after few epochs
    if epoch == freeze_steps:
        model.requires_grad_(True)
            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss = loss.mean()
        loss.backward()
        optim.step()
        scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):

        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_f1 = compute_metrics(model, valid_loader, DEVICE)['f1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            epochs_since_improvement = 0

            # Save the new best model
            path = os.path.join(model_dir ,'XLM_age.pt')
            torch.save(model, path)

        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')




Epoch: 0001/0100 | Batch 0000/0540 | Loss: 1.3829
Epoch: 0001/0100 | Batch 0050/0540 | Loss: 1.3620
Epoch: 0001/0100 | Batch 0100/0540 | Loss: 1.1555
Epoch: 0001/0100 | Batch 0150/0540 | Loss: 1.3557
Epoch: 0001/0100 | Batch 0200/0540 | Loss: 1.2825
Epoch: 0001/0100 | Batch 0250/0540 | Loss: 1.3009
Epoch: 0001/0100 | Batch 0300/0540 | Loss: 1.3325
Epoch: 0001/0100 | Batch 0350/0540 | Loss: 1.1226
Epoch: 0001/0100 | Batch 0400/0540 | Loss: 1.2715
Epoch: 0001/0100 | Batch 0450/0540 | Loss: 1.3482
Epoch: 0001/0100 | Batch 0500/0540 | Loss: 1.2894
Training metrics: {'accuracy': 0.4927662037037037, 'f1': 0.16596942822605834}%
Valid metrics: {'accuracy': 0.4796875, 'f1': 0.1631158399313904}%
Time elapsed: 6.64 min




Epoch: 0002/0100 | Batch 0000/0540 | Loss: 1.0419
Epoch: 0002/0100 | Batch 0050/0540 | Loss: 1.0933
Epoch: 0002/0100 | Batch 0100/0540 | Loss: 1.1582
Epoch: 0002/0100 | Batch 0150/0540 | Loss: 1.0571
Epoch: 0002/0100 | Batch 0200/0540 | Loss: 1.3001
Epoch: 0002/0100 | Batch 0250/0540 | Loss: 1.1779
Epoch: 0002/0100 | Batch 0300/0540 | Loss: 1.2524
Epoch: 0002/0100 | Batch 0350/0540 | Loss: 1.1405
Epoch: 0002/0100 | Batch 0400/0540 | Loss: 1.1088
Epoch: 0002/0100 | Batch 0450/0540 | Loss: 1.1620
Epoch: 0002/0100 | Batch 0500/0540 | Loss: 1.0243
Training metrics: {'accuracy': 0.49693287037037037, 'f1': 0.18529401540661372}%
Valid metrics: {'accuracy': 0.48333333333333334, 'f1': 0.18066017400122505}%
Time elapsed: 12.99 min




Epoch: 0003/0100 | Batch 0000/0540 | Loss: 1.1573
Epoch: 0003/0100 | Batch 0050/0540 | Loss: 0.9898
Epoch: 0003/0100 | Batch 0100/0540 | Loss: 0.9920
Epoch: 0003/0100 | Batch 0150/0540 | Loss: 1.1018
Epoch: 0003/0100 | Batch 0200/0540 | Loss: 0.9470
Epoch: 0003/0100 | Batch 0250/0540 | Loss: 0.9528
Epoch: 0003/0100 | Batch 0300/0540 | Loss: 1.5733
Epoch: 0003/0100 | Batch 0350/0540 | Loss: 0.8567
Epoch: 0003/0100 | Batch 0400/0540 | Loss: 0.9177
Epoch: 0003/0100 | Batch 0450/0540 | Loss: 1.1659
Epoch: 0003/0100 | Batch 0500/0540 | Loss: 0.9696
Training metrics: {'accuracy': 0.6236111111111111, 'f1': 0.4334317300468655}%
Valid metrics: {'accuracy': 0.61875, 'f1': 0.4342393293117124}%
Time elapsed: 22.23 min




Epoch: 0004/0100 | Batch 0000/0540 | Loss: 1.1180
Epoch: 0004/0100 | Batch 0050/0540 | Loss: 1.1963
Epoch: 0004/0100 | Batch 0100/0540 | Loss: 0.8322
Epoch: 0004/0100 | Batch 0150/0540 | Loss: 1.2137
Epoch: 0004/0100 | Batch 0200/0540 | Loss: 1.1202
Epoch: 0004/0100 | Batch 0250/0540 | Loss: 0.9887
Epoch: 0004/0100 | Batch 0300/0540 | Loss: 0.8698
Epoch: 0004/0100 | Batch 0350/0540 | Loss: 1.3320
Epoch: 0004/0100 | Batch 0400/0540 | Loss: 0.9983
Epoch: 0004/0100 | Batch 0450/0540 | Loss: 0.7749
Epoch: 0004/0100 | Batch 0500/0540 | Loss: 1.1106
Training metrics: {'accuracy': 0.6421296296296296, 'f1': 0.47908522410674415}%
Valid metrics: {'accuracy': 0.6083333333333333, 'f1': 0.45220331212912634}%
Time elapsed: 31.47 min




Epoch: 0005/0100 | Batch 0000/0540 | Loss: 0.8789
Epoch: 0005/0100 | Batch 0050/0540 | Loss: 0.9823
Epoch: 0005/0100 | Batch 0100/0540 | Loss: 0.9081
Epoch: 0005/0100 | Batch 0150/0540 | Loss: 0.7682
Epoch: 0005/0100 | Batch 0200/0540 | Loss: 0.6690
Epoch: 0005/0100 | Batch 0250/0540 | Loss: 0.8114
Epoch: 0005/0100 | Batch 0300/0540 | Loss: 0.8254
Epoch: 0005/0100 | Batch 0350/0540 | Loss: 1.1351
Epoch: 0005/0100 | Batch 0400/0540 | Loss: 0.9305
Epoch: 0005/0100 | Batch 0450/0540 | Loss: 0.8068
Epoch: 0005/0100 | Batch 0500/0540 | Loss: 0.8358
Training metrics: {'accuracy': 0.692650462962963, 'f1': 0.5952034335671245}%
Valid metrics: {'accuracy': 0.6213541666666667, 'f1': 0.5090097437781405}%
Time elapsed: 40.71 min




Epoch: 0006/0100 | Batch 0000/0540 | Loss: 0.9637
Epoch: 0006/0100 | Batch 0050/0540 | Loss: 0.7158
Epoch: 0006/0100 | Batch 0100/0540 | Loss: 0.8872
Epoch: 0006/0100 | Batch 0150/0540 | Loss: 0.9394
Epoch: 0006/0100 | Batch 0200/0540 | Loss: 1.1758
Epoch: 0006/0100 | Batch 0250/0540 | Loss: 0.6760
Epoch: 0006/0100 | Batch 0300/0540 | Loss: 0.8018
Epoch: 0006/0100 | Batch 0350/0540 | Loss: 0.8894
Epoch: 0006/0100 | Batch 0400/0540 | Loss: 0.7711
Epoch: 0006/0100 | Batch 0450/0540 | Loss: 0.6981
Epoch: 0006/0100 | Batch 0500/0540 | Loss: 0.6864
Training metrics: {'accuracy': 0.7444444444444445, 'f1': 0.6253151997321698}%
Valid metrics: {'accuracy': 0.6421875, 'f1': 0.5103898457479631}%
Time elapsed: 49.95 min




Epoch: 0007/0100 | Batch 0000/0540 | Loss: 0.7955
Epoch: 0007/0100 | Batch 0050/0540 | Loss: 0.7709
Epoch: 0007/0100 | Batch 0100/0540 | Loss: 0.5515
Epoch: 0007/0100 | Batch 0150/0540 | Loss: 0.7021
Epoch: 0007/0100 | Batch 0200/0540 | Loss: 0.6487
Epoch: 0007/0100 | Batch 0250/0540 | Loss: 0.6602
Epoch: 0007/0100 | Batch 0300/0540 | Loss: 0.6832
Epoch: 0007/0100 | Batch 0350/0540 | Loss: 0.6995
Epoch: 0007/0100 | Batch 0400/0540 | Loss: 0.5912
Epoch: 0007/0100 | Batch 0450/0540 | Loss: 0.5852
Epoch: 0007/0100 | Batch 0500/0540 | Loss: 0.8041
Training metrics: {'accuracy': 0.7733796296296296, 'f1': 0.6756024303078131}%
Valid metrics: {'accuracy': 0.6354166666666666, 'f1': 0.5030266903295681}%
Time elapsed: 59.16 min




Epoch: 0008/0100 | Batch 0000/0540 | Loss: 0.6836
Epoch: 0008/0100 | Batch 0050/0540 | Loss: 0.5062
Epoch: 0008/0100 | Batch 0100/0540 | Loss: 0.5453
Epoch: 0008/0100 | Batch 0150/0540 | Loss: 0.7729
Epoch: 0008/0100 | Batch 0200/0540 | Loss: 0.9313
Epoch: 0008/0100 | Batch 0250/0540 | Loss: 0.4673
Epoch: 0008/0100 | Batch 0300/0540 | Loss: 0.6362
Epoch: 0008/0100 | Batch 0350/0540 | Loss: 0.4056
Epoch: 0008/0100 | Batch 0400/0540 | Loss: 0.6429
Epoch: 0008/0100 | Batch 0450/0540 | Loss: 0.9291
Epoch: 0008/0100 | Batch 0500/0540 | Loss: 0.8379
Training metrics: {'accuracy': 0.8168402777777778, 'f1': 0.7506255269804499}%
Valid metrics: {'accuracy': 0.61875, 'f1': 0.5260190040909731}%
Time elapsed: 68.40 min




Epoch: 0009/0100 | Batch 0000/0540 | Loss: 0.5439
Epoch: 0009/0100 | Batch 0050/0540 | Loss: 0.5876
Epoch: 0009/0100 | Batch 0100/0540 | Loss: 0.4113
Epoch: 0009/0100 | Batch 0150/0540 | Loss: 0.9630
Epoch: 0009/0100 | Batch 0200/0540 | Loss: 0.4217
Epoch: 0009/0100 | Batch 0250/0540 | Loss: 0.4495
Epoch: 0009/0100 | Batch 0300/0540 | Loss: 0.6637
Epoch: 0009/0100 | Batch 0350/0540 | Loss: 0.3469
Epoch: 0009/0100 | Batch 0400/0540 | Loss: 0.7618
Epoch: 0009/0100 | Batch 0450/0540 | Loss: 0.5045
Epoch: 0009/0100 | Batch 0500/0540 | Loss: 0.6551
Training metrics: {'accuracy': 0.8526041666666667, 'f1': 0.796701066213249}%
Valid metrics: {'accuracy': 0.5989583333333334, 'f1': 0.5161281453485982}%
Time elapsed: 77.61 min




Epoch: 0010/0100 | Batch 0000/0540 | Loss: 0.4720
Epoch: 0010/0100 | Batch 0050/0540 | Loss: 0.5126
Epoch: 0010/0100 | Batch 0100/0540 | Loss: 0.5282
Epoch: 0010/0100 | Batch 0150/0540 | Loss: 0.2724
Epoch: 0010/0100 | Batch 0200/0540 | Loss: 0.3552
Epoch: 0010/0100 | Batch 0250/0540 | Loss: 0.2772
Epoch: 0010/0100 | Batch 0300/0540 | Loss: 0.4141
Epoch: 0010/0100 | Batch 0350/0540 | Loss: 0.3837
Epoch: 0010/0100 | Batch 0400/0540 | Loss: 0.5272
Epoch: 0010/0100 | Batch 0450/0540 | Loss: 0.7653
Epoch: 0010/0100 | Batch 0500/0540 | Loss: 0.5062
Training metrics: {'accuracy': 0.907349537037037, 'f1': 0.8730322845044426}%
Valid metrics: {'accuracy': 0.6114583333333333, 'f1': 0.5252348915746545}%
Time elapsed: 87.01 min




Epoch: 0011/0100 | Batch 0000/0540 | Loss: 0.3879
Epoch: 0011/0100 | Batch 0050/0540 | Loss: 0.7963
Epoch: 0011/0100 | Batch 0100/0540 | Loss: 0.3133
Epoch: 0011/0100 | Batch 0150/0540 | Loss: 0.3248
Epoch: 0011/0100 | Batch 0200/0540 | Loss: 0.3674
Epoch: 0011/0100 | Batch 0250/0540 | Loss: 0.4980
Epoch: 0011/0100 | Batch 0300/0540 | Loss: 0.3030
Epoch: 0011/0100 | Batch 0350/0540 | Loss: 0.2211
Epoch: 0011/0100 | Batch 0400/0540 | Loss: 0.2472
Epoch: 0011/0100 | Batch 0450/0540 | Loss: 0.3385
Epoch: 0011/0100 | Batch 0500/0540 | Loss: 0.3872
Training metrics: {'accuracy': 0.9189814814814815, 'f1': 0.8917001868070116}%
Valid metrics: {'accuracy': 0.5817708333333333, 'f1': 0.5118746096345851}%
Early stopping triggered.


In [16]:
# save models in work directory
path = os.path.join(uc_dir, 'trained_models')
if not os.path.exists(path):
    os.mkdir(path)
    
sub_path = os.path.join(path, 'age')
if not os.path.exists(sub_path):
        os.mkdir(sub_path)

# XLM
import shutil
source_path =  os.path.join(model_dir ,'XLM_age.pt')
dest_path = os.path.join(sub_path, 'XLM_age.pt')
shutil.copy(source_path, dest_path)

'/g100_work/IscrC_mental/data/user_classification/trained_models/age/XLM_age.pt'

##   bert-tweet-base-italian-uncased

In [24]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("osiria/bert-tweet-base-italian-uncased")

train_encodings = batch_tokenize(X_train, tokenizer)
valid_encodings = batch_tokenize(X_valid, tokenizer)

train_dataset = TweetDataset(train_encodings, y_train)
valid_dataset = TweetDataset(valid_encodings, y_valid)

In [25]:
# data loaders
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

#### Bertweet + freeze + early stopping + linear schedule

In [27]:
from transformers import get_linear_schedule_with_warmup

MODEL = "osiria/bert-tweet-base-italian-uncased"
num_labels = 4
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model = torch.nn.DataParallel(model)
model.to(DEVICE)

NUM_EPOCHS = 100
LR = 2e-5

# Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=LR)

# Create the learning rate scheduler.
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps = 0,
                                            num_training_steps = total_steps)

# Freeze all layers except the classifier for the first few epochs
freeze_steps = 2
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_f1 = 0.0
epochs_since_improvement = 0

Some weights of the model checkpoint at osiria/bert-tweet-base-italian-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at osiria/bert-tweet-base-italian-uncased and are newly initialized: ['bert

In [29]:
# train
start_time = time.time()

for epoch in range(NUM_EPOCHS):

    model.train()
    # Unfreeze all layers after few epochs
    if epoch == freeze_steps:
        model.requires_grad_(True)
            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss = loss.mean()
        loss.backward()
        optim.step()
        scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):

        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_f1 = compute_metrics(model, valid_loader, DEVICE)['f1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            epochs_since_improvement = 0

            # Save the new best model
            path = os.path.join(model_dir ,'bertweet_italian_age.pt')
            torch.save(model, path)

        else:
            epochs_since_improvement += 1

        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')




Epoch: 0001/0100 | Batch 0000/0540 | Loss: 1.4175
Epoch: 0001/0100 | Batch 0050/0540 | Loss: 1.3475
Epoch: 0001/0100 | Batch 0100/0540 | Loss: 1.3008
Epoch: 0001/0100 | Batch 0150/0540 | Loss: 1.2773
Epoch: 0001/0100 | Batch 0200/0540 | Loss: 1.1173
Epoch: 0001/0100 | Batch 0250/0540 | Loss: 1.2485
Epoch: 0001/0100 | Batch 0300/0540 | Loss: 1.3703
Epoch: 0001/0100 | Batch 0350/0540 | Loss: 1.2290
Epoch: 0001/0100 | Batch 0400/0540 | Loss: 1.1605
Epoch: 0001/0100 | Batch 0450/0540 | Loss: 1.4217
Epoch: 0001/0100 | Batch 0500/0540 | Loss: 1.1619
Training metrics: {'accuracy': 0.4925347222222222, 'f1': 0.16499941840176804}%
Valid metrics: {'accuracy': 0.4791666666666667, 'f1': 0.1619718309859155}%
Time elapsed: 5.03 min




Epoch: 0002/0100 | Batch 0000/0540 | Loss: 1.1757
Epoch: 0002/0100 | Batch 0050/0540 | Loss: 1.0646
Epoch: 0002/0100 | Batch 0100/0540 | Loss: 1.1214
Epoch: 0002/0100 | Batch 0150/0540 | Loss: 1.4042
Epoch: 0002/0100 | Batch 0200/0540 | Loss: 1.2427
Epoch: 0002/0100 | Batch 0250/0540 | Loss: 1.2560
Epoch: 0002/0100 | Batch 0300/0540 | Loss: 1.4770
Epoch: 0002/0100 | Batch 0350/0540 | Loss: 1.3571
Epoch: 0002/0100 | Batch 0400/0540 | Loss: 1.1915
Epoch: 0002/0100 | Batch 0450/0540 | Loss: 1.1563
Epoch: 0002/0100 | Batch 0500/0540 | Loss: 1.2304
Training metrics: {'accuracy': 0.4925347222222222, 'f1': 0.1650058162078325}%
Valid metrics: {'accuracy': 0.4791666666666667, 'f1': 0.1619718309859155}%
Time elapsed: 9.92 min




Epoch: 0003/0100 | Batch 0000/0540 | Loss: 1.0302
Epoch: 0003/0100 | Batch 0050/0540 | Loss: 1.0997
Epoch: 0003/0100 | Batch 0100/0540 | Loss: 1.0758
Epoch: 0003/0100 | Batch 0150/0540 | Loss: 1.0697
Epoch: 0003/0100 | Batch 0200/0540 | Loss: 1.0629
Epoch: 0003/0100 | Batch 0250/0540 | Loss: 0.8222
Epoch: 0003/0100 | Batch 0300/0540 | Loss: 1.1140
Epoch: 0003/0100 | Batch 0350/0540 | Loss: 1.0684
Epoch: 0003/0100 | Batch 0400/0540 | Loss: 0.9221
Epoch: 0003/0100 | Batch 0450/0540 | Loss: 1.0059
Epoch: 0003/0100 | Batch 0500/0540 | Loss: 1.0371
Training metrics: {'accuracy': 0.6186921296296296, 'f1': 0.4319826400562383}%
Valid metrics: {'accuracy': 0.5953125, 'f1': 0.4125443041992446}%
Time elapsed: 17.87 min




Epoch: 0004/0100 | Batch 0000/0540 | Loss: 0.7283
Epoch: 0004/0100 | Batch 0050/0540 | Loss: 0.9104
Epoch: 0004/0100 | Batch 0100/0540 | Loss: 1.0668
Epoch: 0004/0100 | Batch 0150/0540 | Loss: 0.8832
Epoch: 0004/0100 | Batch 0200/0540 | Loss: 0.9238
Epoch: 0004/0100 | Batch 0250/0540 | Loss: 0.9687
Epoch: 0004/0100 | Batch 0300/0540 | Loss: 1.0565
Epoch: 0004/0100 | Batch 0350/0540 | Loss: 0.6881
Epoch: 0004/0100 | Batch 0400/0540 | Loss: 0.9972
Epoch: 0004/0100 | Batch 0450/0540 | Loss: 0.7962
Epoch: 0004/0100 | Batch 0500/0540 | Loss: 1.0730
Training metrics: {'accuracy': 0.6798032407407407, 'f1': 0.5195867016162029}%
Valid metrics: {'accuracy': 0.640625, 'f1': 0.4740931779618318}%
Time elapsed: 25.79 min




Epoch: 0005/0100 | Batch 0000/0540 | Loss: 0.9594
Epoch: 0005/0100 | Batch 0050/0540 | Loss: 0.6779
Epoch: 0005/0100 | Batch 0100/0540 | Loss: 0.8352
Epoch: 0005/0100 | Batch 0150/0540 | Loss: 0.9399
Epoch: 0005/0100 | Batch 0200/0540 | Loss: 0.7773
Epoch: 0005/0100 | Batch 0250/0540 | Loss: 0.6127
Epoch: 0005/0100 | Batch 0300/0540 | Loss: 1.0567
Epoch: 0005/0100 | Batch 0350/0540 | Loss: 0.7110
Epoch: 0005/0100 | Batch 0400/0540 | Loss: 0.7260
Epoch: 0005/0100 | Batch 0450/0540 | Loss: 0.7137
Epoch: 0005/0100 | Batch 0500/0540 | Loss: 0.9893
Training metrics: {'accuracy': 0.7183449074074074, 'f1': 0.5681179676382347}%
Valid metrics: {'accuracy': 0.6473958333333333, 'f1': 0.48953533400125804}%
Time elapsed: 33.71 min




Epoch: 0006/0100 | Batch 0000/0540 | Loss: 0.9362
Epoch: 0006/0100 | Batch 0050/0540 | Loss: 0.4671
Epoch: 0006/0100 | Batch 0100/0540 | Loss: 0.8211
Epoch: 0006/0100 | Batch 0150/0540 | Loss: 0.6166
Epoch: 0006/0100 | Batch 0200/0540 | Loss: 0.6039
Epoch: 0006/0100 | Batch 0250/0540 | Loss: 0.9591
Epoch: 0006/0100 | Batch 0300/0540 | Loss: 0.8554
Epoch: 0006/0100 | Batch 0350/0540 | Loss: 0.8246
Epoch: 0006/0100 | Batch 0400/0540 | Loss: 0.6421
Epoch: 0006/0100 | Batch 0450/0540 | Loss: 0.9264
Epoch: 0006/0100 | Batch 0500/0540 | Loss: 0.8052
Training metrics: {'accuracy': 0.7668402777777777, 'f1': 0.6636790528764708}%
Valid metrics: {'accuracy': 0.6354166666666666, 'f1': 0.5024553474457967}%
Time elapsed: 41.63 min




Epoch: 0007/0100 | Batch 0000/0540 | Loss: 0.9467
Epoch: 0007/0100 | Batch 0050/0540 | Loss: 0.7749
Epoch: 0007/0100 | Batch 0100/0540 | Loss: 0.7063
Epoch: 0007/0100 | Batch 0150/0540 | Loss: 0.5497
Epoch: 0007/0100 | Batch 0200/0540 | Loss: 0.9245
Epoch: 0007/0100 | Batch 0250/0540 | Loss: 0.7558
Epoch: 0007/0100 | Batch 0300/0540 | Loss: 0.5222
Epoch: 0007/0100 | Batch 0350/0540 | Loss: 0.8637
Epoch: 0007/0100 | Batch 0400/0540 | Loss: 0.6743
Epoch: 0007/0100 | Batch 0450/0540 | Loss: 0.8855
Epoch: 0007/0100 | Batch 0500/0540 | Loss: 0.6149
Training metrics: {'accuracy': 0.8184606481481481, 'f1': 0.746813287203581}%
Valid metrics: {'accuracy': 0.6322916666666667, 'f1': 0.5364795008912656}%
Time elapsed: 49.55 min




Epoch: 0008/0100 | Batch 0000/0540 | Loss: 1.1604
Epoch: 0008/0100 | Batch 0050/0540 | Loss: 0.5597
Epoch: 0008/0100 | Batch 0100/0540 | Loss: 0.4237
Epoch: 0008/0100 | Batch 0150/0540 | Loss: 0.6647
Epoch: 0008/0100 | Batch 0200/0540 | Loss: 0.6939
Epoch: 0008/0100 | Batch 0250/0540 | Loss: 0.4937
Epoch: 0008/0100 | Batch 0300/0540 | Loss: 0.7225
Epoch: 0008/0100 | Batch 0350/0540 | Loss: 0.6767
Epoch: 0008/0100 | Batch 0400/0540 | Loss: 0.2811
Epoch: 0008/0100 | Batch 0450/0540 | Loss: 0.5824
Epoch: 0008/0100 | Batch 0500/0540 | Loss: 0.4015
Training metrics: {'accuracy': 0.8623842592592592, 'f1': 0.8157309846582538}%
Valid metrics: {'accuracy': 0.6057291666666667, 'f1': 0.5330890378702301}%
Time elapsed: 57.46 min




Epoch: 0009/0100 | Batch 0000/0540 | Loss: 0.5141
Epoch: 0009/0100 | Batch 0050/0540 | Loss: 0.2400
Epoch: 0009/0100 | Batch 0100/0540 | Loss: 0.5066
Epoch: 0009/0100 | Batch 0150/0540 | Loss: 0.5469
Epoch: 0009/0100 | Batch 0200/0540 | Loss: 0.3639
Epoch: 0009/0100 | Batch 0250/0540 | Loss: 0.4560
Epoch: 0009/0100 | Batch 0300/0540 | Loss: 0.5901
Epoch: 0009/0100 | Batch 0350/0540 | Loss: 0.4849
Epoch: 0009/0100 | Batch 0400/0540 | Loss: 0.5011
Epoch: 0009/0100 | Batch 0450/0540 | Loss: 0.2755
Epoch: 0009/0100 | Batch 0500/0540 | Loss: 0.2781
Training metrics: {'accuracy': 0.8854166666666666, 'f1': 0.8432749079863538}%
Valid metrics: {'accuracy': 0.6057291666666667, 'f1': 0.5251013036546821}%
Time elapsed: 65.38 min




Epoch: 0010/0100 | Batch 0000/0540 | Loss: 0.3254
Epoch: 0010/0100 | Batch 0050/0540 | Loss: 0.5289
Epoch: 0010/0100 | Batch 0100/0540 | Loss: 0.3251
Epoch: 0010/0100 | Batch 0150/0540 | Loss: 0.3055
Epoch: 0010/0100 | Batch 0200/0540 | Loss: 0.3582
Epoch: 0010/0100 | Batch 0250/0540 | Loss: 0.1760
Epoch: 0010/0100 | Batch 0300/0540 | Loss: 0.2200
Epoch: 0010/0100 | Batch 0350/0540 | Loss: 0.4280
Epoch: 0010/0100 | Batch 0400/0540 | Loss: 0.1992
Epoch: 0010/0100 | Batch 0450/0540 | Loss: 0.2806
Epoch: 0010/0100 | Batch 0500/0540 | Loss: 0.2380
Training metrics: {'accuracy': 0.9251157407407408, 'f1': 0.8956377862939804}%
Valid metrics: {'accuracy': 0.6125, 'f1': 0.5190090215460894}%
Early stopping triggered.


In [None]:
# save models in work directory
path = os.path.join(uc_dir, 'trained_models')
if not os.path.exists(path):
    os.mkdir(path)
    
sub_path = os.path.join(path, 'age')
if not os.path.exists(sub_path):
        os.mkdir(sub_path)

# XLM
import shutil
source_path =  os.path.join(model_dir ,'bertweet_italian_age.pt')
dest_path = os.path.join(sub_path, 'bertweet_italian_age.pt')
shutil.copy(source_path, dest_path)

## Test Set Performance

In [12]:
# user age data
path  = os.path.join(uc_dir, 'data_for_models_test.pkl')
df_test = pd.read_pickle(path)

# Discretize the 'age' column into four classes
age_intervals = [0, 19, 30, 40, 100]
age_labels = [0, 1, 2, 3]
df_test['age_class'] = pd.cut(df_test['age'], bins=age_intervals, labels=age_labels, right=False)


df_test['text']  = 'bio: ' + df_test['masked_bio'] + '. ' + 'tweets: ' + df_test['long_text'] 
df_test['text'] = df_test['text'].str.replace('\r|\n', ' ', regex=True)

X_test = df_test['text'].values
y_test = df_test['age_class'].values

In [13]:
# setup tokenizer
xlm_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base")


# setup tokenizer BerTweet
btwt_tokenizer = AutoTokenizer.from_pretrained("osiria/bert-tweet-base-italian-uncased")

# test encodings and dataset
xlm_test_encodings = batch_tokenize(X_test, xlm_tokenizer)
btwt_test_encodings = batch_tokenize(X_test, btwt_tokenizer)

xlm_test_dataset = TweetDataset(xlm_test_encodings, y_test)
btwt_test_dataset = TweetDataset(btwt_test_encodings, y_test)

In [14]:
# data loaders
BATCH_SIZE = 32

xlm_loader = torch.utils.data.DataLoader(
    xlm_test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

btwt_loader = torch.utils.data.DataLoader(
    btwt_test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

In [32]:
def test_eval(model_path, data_loader, device):
    
    # Load the saved model
    model = torch.load(model_path)
    model = model.to(device)
    
    # Set the model to evaluation mode
    model.eval()
    
    # store predicted probs
    class_probs = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']

            # Convert logits to probabilities
            probabilities = F.softmax(logits, dim=1)
            class_probs.extend(probabilities.cpu().numpy().tolist())

            
    # Compute the metrics
    metrics = compute_metrics(model, data_loader, device)
    
    return metrics, np.array(class_probs)


In [20]:
path_xlm =  os.path.join(model_dir ,'XLM_age.pt')
path_btwt  = os.path.join(model_dir ,'bertweet_italian_age.pt')

In [43]:
# XLM
metrics, xlm_probs = test_eval(path_xlm, xlm_loader, DEVICE)

# save probs
path = os.path.join(uc_dir, 'trained_models', 'age', 'XLM_probs_age.npy')
np.save(path, xlm_probs)

print(metrics)

{'accuracy': 0.6133928571428572, 'f1': 0.5507302799158431}


In [44]:
# BerTweet
metrics, btwt_probs = test_eval(path_btwt, btwt_loader, DEVICE)

path = os.path.join(uc_dir, 'trained_models', 'age', 'BerTweet_probs_age.npy')
np.save(path, btwt_probs)

print(metrics)

{'accuracy': 0.60625, 'f1': 0.5330589059575411}


In [30]:
# save models in work directory
path = os.path.join(uc_dir, 'trained_models')
if not os.path.exists(path):
    os.mkdir(path)
    
sub_path = os.path.join(path, 'age')
if not os.path.exists(sub_path):
        os.mkdir(sub_path)

In [28]:
# # XLM
# import shutil
# source_path =  os.path.join(model_dir ,'XLM_gender.pt')
# dest_path = os.path.join(uc_dir, 'trained_models', 'gender', 'XLM_gender.pt')
# shutil.copy(source_path, dest_path)

'/g100_work/IscrC_mental/data/user_classification/trained_models/gender/XLM_gender.pt'

In [32]:
model_dir

'/g100/home/userexternal/mhabibi0/Models/Age'

In [36]:
# # BERTweet
# import shutil
# source_path =  os.path.join(model_dir ,'bertweet_italian_age.pt')
# dest_path = os.path.join(sub_path, 'BerTweet_age.pt')
# shutil.copy(source_path, dest_path)

'/g100_work/IscrC_mental/data/user_classification/trained_models/age/BerTweet_age.pt'

In [34]:
dest_path

'/g100_work/IscrC_mental/data/user_classification/trained_models/age/BerTweet_age.pt'