## Setup

In [1]:
import numpy as np
import pandas as pd
import os
from transformers import pipeline, set_seed
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import time
from sklearn.model_selection import train_test_split

In [2]:
# random seed
# SEED = 1
# torch.backends.cudnn.deterministic = True
# torch.manual_seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
torch.cuda.device_count()

2

In [4]:
DEVICE

device(type='cuda')

In [5]:
home_dir = '/g100/home/userexternal/mhabibi0/'
work_dir = '/g100_work/IscrC_mental'
scratch = '/g100_scratch/userexternal/mhabibi0'

hdata_dir = os.path.join(home_dir, 'Data')
wdata_dir = os.path.join(work_dir, 'data')
uc_dir = os.path.join(wdata_dir, 'user_classification')
model_dir = os.path.join(scratch, 'Models', 'Age')

## Utils

In [6]:
from sklearn.metrics import f1_score, accuracy_score

def compute_metrics(model, data_loader, device):
    with torch.no_grad():
        all_predictions = []
        all_labels = []

        for batch_idx, batch in enumerate(data_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']

            predicted_labels = torch.argmax(logits, 1)

            all_predictions.extend(predicted_labels.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        accuracy = accuracy_score(all_labels, all_predictions)
        macro_f1 = f1_score(all_labels, all_predictions, average='macro')

        metrics = {
            'accuracy': accuracy,
            'f1': macro_f1
        }

        return metrics


In [7]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [8]:

def batch_tokenize(X_text, tokenizer, max_length=512, batch_size=64):

    # Dictionary to hold tokenized batches
    encodings = {}

    # Calculate the number of batches needed
    num_batches = len(X_text) // batch_size + int(len(X_text) % batch_size > 0)

    # Iterate over the data in batches
    for i in range(num_batches):
        batch_start = i * batch_size
        batch_end = min(len(X_text), (i + 1) * batch_size)

        # Tokenize the current batch of texts
        batch_encodings = tokenizer.batch_encode_plus(
            list(X_text[batch_start:batch_end]),
            padding='max_length',
            truncation=True,
            max_length=max_length
        )

        # Merge the batch tokenizations into the main dictionary
        for key, val in batch_encodings.items():
            if key not in encodings:
                encodings[key] = []
            encodings[key].extend(val)

    return encodings


##  Training

In [9]:
# user age training data 
path  = os.path.join(uc_dir, 'data_for_models_train.pkl')
df = pd.read_pickle(path)


# user info data
path = os.path.join(wdata_dir, 'database', 'user_geocoded.parquet' )
df_users = pd.read_parquet(path)
df_users = df_users[['user_id', 'username', 'full_name', 'location',
          'join_year', 'tweets', 'following', 'followers']]

df_users = df_users[df_users['user_id'].isin(df['user_id'].values)]


# merge 
df = df.merge(df_users, on='user_id', how='left')




# Discretize the 'age' column into six classes
age_intervals = [0, 18, 30, 40, 60, 100]
age_labels = [0, 1, 2, 3, 4]
df = df[df['age']<=99]
df['age_class'] = pd.cut(df['age'], bins=age_intervals, labels=age_labels, right=False).astype(int)

# create input text
# Separating text and numbers with a space
df['username_sep'] = df['username'].str.replace(r'([a-zA-Z])(\d)', r'\1 \2').\
                    str.replace(r'(\d)([a-zA-Z])', r'\1 \2')
# concat info
df['text']  = 'NAME:' + ' "' + df['full_name'] + '". ' +\
                'USERNAME:' + ' "'+  df['username_sep'] + '". ' + \
                'JOINED:' + ' "' + df['join_year'].astype(str) + '". ' +\
                'TWEETS:' + ' "' + df['tweets'].astype(str) + '". ' + \
                'FOLLOWING:' + ' "' + df['following'].astype(str) + '". ' +\
                'FOLLOWERS:' + ' "' + df['followers'].astype(str) + '". ' + \
                'BIO:' + ' "' + df['masked_bio'] + '". ' + \
                'TEXT:' + ' "' + df['long_text'] + '".'

# train valid split
df_train, df_valid = train_test_split(df[['user_id', 'text', 'age_class']], test_size=0.1, random_state=42)

  df['username_sep'] = df['username'].str.replace(r'([a-zA-Z])(\d)', r'\1 \2').\
  str.replace(r'(\d)([a-zA-Z])', r'\1 \2')


In [10]:
X_train = df_train['text'].values
y_train = df_train['age_class'].values

X_valid = df_valid['text'].values
y_valid = df_valid['age_class'].values

In [11]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-large-2022")

train_encodings = batch_tokenize(X_train, tokenizer)
valid_encodings = batch_tokenize(X_valid, tokenizer)

train_dataset = TweetDataset(train_encodings, y_train)
valid_dataset = TweetDataset(valid_encodings, y_valid)

In [12]:
# data loaders
BATCH_SIZE = 16

train_loader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    pin_memory=True
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

In [13]:
from transformers import get_linear_schedule_with_warmup
from torch.optim.lr_scheduler import ReduceLROnPlateau

MODEL = "cardiffnlp/twitter-xlm-roberta-large-2022"
num_labels = 5
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=num_labels)
model = torch.nn.DataParallel(model)
model.to(DEVICE)

print('')

Some weights of the model checkpoint at cardiffnlp/twitter-xlm-roberta-large-2022 were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at cardiffnlp/twitter-xlm-roberta-large-2022 and are newly initialized: ['classifier.out_proj.bias', 'classifier.out_proj.weig




In [14]:
# train
start_time = time.time()

# Freeze all layers except the classifier 
freeze_epochs = 4
frozen = True
for name, param in model.named_parameters():
    if 'classifier' not in name: # classifier layer
        param.requires_grad = False
        
# Initialize best accuracy and epochs since improvement
best_f1 = 0.0
epochs_since_improvement = 0


# opimizer
LR_freezed = 8e-5
LR_unfreezed = 4e-6
optim = torch.optim.AdamW(model.parameters(), lr=LR_freezed)

# learning rate scheduler.
scheduler = ReduceLROnPlateau(optim, mode='max', factor=0.2, patience=2)

                
                          
NUM_EPOCHS = 100
for epoch in range(NUM_EPOCHS):

    model.train()
    
    # Unfreeze all layers after 2 epochs
    if epoch  == freeze_epochs:
        print('Unfreezing the internal layers.')
        for param in model.parameters():
            param.requires_grad = True   
            
        optim = torch.optim.AdamW(model.parameters(), lr=LR_unfreezed)
        scheduler = ReduceLROnPlateau(optim, mode='max', factor=0.2, patience=2)
        
        print('Optimizer and scheduler are reset')

            
    for batch_idx, batch in enumerate(train_loader):
        ### Prepare data
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        ### Forward
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss, logits = outputs['loss'], outputs['logits']

        ### Backward
        optim.zero_grad()
        loss = loss.mean()
        loss.backward()
        optim.step()
        #scheduler.step()

        ### Logging
        if not batch_idx % 50:
            print (f'Epoch: {epoch+1:04d}/{NUM_EPOCHS:04d} | '
                   f'Batch {batch_idx:04d}/{len(train_loader):04d} | '
                   f'Loss: {loss:.4f}')
    
    
    model.eval()
    with torch.set_grad_enabled(False):
        
        print(f'Training metrics: '
              f'{compute_metrics(model, train_loader, DEVICE)}%'
              f'\nValid metrics: '
              f'{compute_metrics(model, valid_loader, DEVICE)}%')

        current_f1 = compute_metrics(model, valid_loader, DEVICE)['f1']
        
        # step the scheduler ReduceLROnPlateau
        scheduler.step(current_f1) 
        
        # check if improved
        if current_f1 > best_f1:
            best_f1 = current_f1
            epochs_since_improvement = 0
            
            # Save the new best model
            path = os.path.join(model_dir ,'XLM2022_age_mentalism_extra_features.pt')
            torch.save(model, path) 
                
        else:
            epochs_since_improvement += 1
            
        # Early stopping
        if epochs_since_improvement >= 3:
            print('Early stopping triggered.')
            break

    print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')




Epoch: 0001/0100 | Batch 0000/1079 | Loss: 1.5291
Epoch: 0001/0100 | Batch 0050/1079 | Loss: 1.4623
Epoch: 0001/0100 | Batch 0100/1079 | Loss: 1.4347
Epoch: 0001/0100 | Batch 0150/1079 | Loss: 1.4833
Epoch: 0001/0100 | Batch 0200/1079 | Loss: 1.6650
Epoch: 0001/0100 | Batch 0250/1079 | Loss: 1.8413
Epoch: 0001/0100 | Batch 0300/1079 | Loss: 1.3416
Epoch: 0001/0100 | Batch 0350/1079 | Loss: 1.4661
Epoch: 0001/0100 | Batch 0400/1079 | Loss: 1.4353
Epoch: 0001/0100 | Batch 0450/1079 | Loss: 1.4625
Epoch: 0001/0100 | Batch 0500/1079 | Loss: 1.3396
Epoch: 0001/0100 | Batch 0550/1079 | Loss: 1.3023
Epoch: 0001/0100 | Batch 0600/1079 | Loss: 1.5569
Epoch: 0001/0100 | Batch 0650/1079 | Loss: 1.4079
Epoch: 0001/0100 | Batch 0700/1079 | Loss: 1.3391
Epoch: 0001/0100 | Batch 0750/1079 | Loss: 1.6926
Epoch: 0001/0100 | Batch 0800/1079 | Loss: 1.3321
Epoch: 0001/0100 | Batch 0850/1079 | Loss: 1.3401
Epoch: 0001/0100 | Batch 0900/1079 | Loss: 1.3735
Epoch: 0001/0100 | Batch 0950/1079 | Loss: 1.2432




Epoch: 0002/0100 | Batch 0000/1079 | Loss: 1.4736
Epoch: 0002/0100 | Batch 0050/1079 | Loss: 1.3340
Epoch: 0002/0100 | Batch 0100/1079 | Loss: 1.2751
Epoch: 0002/0100 | Batch 0150/1079 | Loss: 1.2036
Epoch: 0002/0100 | Batch 0200/1079 | Loss: 1.3686
Epoch: 0002/0100 | Batch 0250/1079 | Loss: 1.3449
Epoch: 0002/0100 | Batch 0300/1079 | Loss: 1.1732
Epoch: 0002/0100 | Batch 0350/1079 | Loss: 1.3347
Epoch: 0002/0100 | Batch 0400/1079 | Loss: 1.2917
Epoch: 0002/0100 | Batch 0450/1079 | Loss: 1.5500
Epoch: 0002/0100 | Batch 0500/1079 | Loss: 1.3027
Epoch: 0002/0100 | Batch 0550/1079 | Loss: 1.2749
Epoch: 0002/0100 | Batch 0600/1079 | Loss: 1.6974
Epoch: 0002/0100 | Batch 0650/1079 | Loss: 1.3625
Epoch: 0002/0100 | Batch 0700/1079 | Loss: 1.2334
Epoch: 0002/0100 | Batch 0750/1079 | Loss: 1.4139
Epoch: 0002/0100 | Batch 0800/1079 | Loss: 1.2639
Epoch: 0002/0100 | Batch 0850/1079 | Loss: 1.5293
Epoch: 0002/0100 | Batch 0900/1079 | Loss: 1.1033
Epoch: 0002/0100 | Batch 0950/1079 | Loss: 1.2402




Epoch: 0003/0100 | Batch 0000/1079 | Loss: 1.3334
Epoch: 0003/0100 | Batch 0050/1079 | Loss: 1.1631
Epoch: 0003/0100 | Batch 0100/1079 | Loss: 1.1263
Epoch: 0003/0100 | Batch 0150/1079 | Loss: 1.2818
Epoch: 0003/0100 | Batch 0200/1079 | Loss: 1.3980
Epoch: 0003/0100 | Batch 0250/1079 | Loss: 1.3213
Epoch: 0003/0100 | Batch 0300/1079 | Loss: 1.3912
Epoch: 0003/0100 | Batch 0350/1079 | Loss: 1.3200
Epoch: 0003/0100 | Batch 0400/1079 | Loss: 1.3359
Epoch: 0003/0100 | Batch 0450/1079 | Loss: 1.2214
Epoch: 0003/0100 | Batch 0500/1079 | Loss: 1.5438
Epoch: 0003/0100 | Batch 0550/1079 | Loss: 1.3795
Epoch: 0003/0100 | Batch 0600/1079 | Loss: 1.5804
Epoch: 0003/0100 | Batch 0650/1079 | Loss: 1.4827
Epoch: 0003/0100 | Batch 0700/1079 | Loss: 1.0985
Epoch: 0003/0100 | Batch 0750/1079 | Loss: 1.0486
Epoch: 0003/0100 | Batch 0800/1079 | Loss: 1.1911
Epoch: 0003/0100 | Batch 0850/1079 | Loss: 1.2790
Epoch: 0003/0100 | Batch 0900/1079 | Loss: 1.2670
Epoch: 0003/0100 | Batch 0950/1079 | Loss: 1.4424




Epoch: 0004/0100 | Batch 0000/1079 | Loss: 1.2360
Epoch: 0004/0100 | Batch 0050/1079 | Loss: 1.3883
Epoch: 0004/0100 | Batch 0100/1079 | Loss: 1.3417
Epoch: 0004/0100 | Batch 0150/1079 | Loss: 1.2971
Epoch: 0004/0100 | Batch 0200/1079 | Loss: 1.2402
Epoch: 0004/0100 | Batch 0250/1079 | Loss: 1.3491
Epoch: 0004/0100 | Batch 0300/1079 | Loss: 1.3716
Epoch: 0004/0100 | Batch 0350/1079 | Loss: 1.3475
Epoch: 0004/0100 | Batch 0400/1079 | Loss: 1.4718
Epoch: 0004/0100 | Batch 0450/1079 | Loss: 1.2090
Epoch: 0004/0100 | Batch 0500/1079 | Loss: 1.5349
Epoch: 0004/0100 | Batch 0550/1079 | Loss: 1.3738
Epoch: 0004/0100 | Batch 0600/1079 | Loss: 1.3432
Epoch: 0004/0100 | Batch 0650/1079 | Loss: 1.3861
Epoch: 0004/0100 | Batch 0700/1079 | Loss: 1.3107
Epoch: 0004/0100 | Batch 0750/1079 | Loss: 1.1243
Epoch: 0004/0100 | Batch 0800/1079 | Loss: 1.0966
Epoch: 0004/0100 | Batch 0850/1079 | Loss: 1.3035
Epoch: 0004/0100 | Batch 0900/1079 | Loss: 1.3273
Epoch: 0004/0100 | Batch 0950/1079 | Loss: 1.2107




Epoch: 0005/0100 | Batch 0000/1079 | Loss: 1.2428
Epoch: 0005/0100 | Batch 0050/1079 | Loss: 1.2353
Epoch: 0005/0100 | Batch 0100/1079 | Loss: 1.1145
Epoch: 0005/0100 | Batch 0150/1079 | Loss: 1.3703
Epoch: 0005/0100 | Batch 0200/1079 | Loss: 1.1818
Epoch: 0005/0100 | Batch 0250/1079 | Loss: 0.9971
Epoch: 0005/0100 | Batch 0300/1079 | Loss: 1.1260
Epoch: 0005/0100 | Batch 0350/1079 | Loss: 1.3950
Epoch: 0005/0100 | Batch 0400/1079 | Loss: 1.4524
Epoch: 0005/0100 | Batch 0450/1079 | Loss: 1.1353
Epoch: 0005/0100 | Batch 0500/1079 | Loss: 1.4976
Epoch: 0005/0100 | Batch 0550/1079 | Loss: 1.1397
Epoch: 0005/0100 | Batch 0600/1079 | Loss: 1.2999
Epoch: 0005/0100 | Batch 0650/1079 | Loss: 1.1949
Epoch: 0005/0100 | Batch 0700/1079 | Loss: 1.1962
Epoch: 0005/0100 | Batch 0750/1079 | Loss: 0.8304
Epoch: 0005/0100 | Batch 0800/1079 | Loss: 1.6045
Epoch: 0005/0100 | Batch 0850/1079 | Loss: 1.7314
Epoch: 0005/0100 | Batch 0900/1079 | Loss: 1.3485
Epoch: 0005/0100 | Batch 0950/1079 | Loss: 1.1982




Epoch: 0006/0100 | Batch 0000/1079 | Loss: 0.9895
Epoch: 0006/0100 | Batch 0050/1079 | Loss: 1.3431
Epoch: 0006/0100 | Batch 0100/1079 | Loss: 0.8798
Epoch: 0006/0100 | Batch 0150/1079 | Loss: 1.0577
Epoch: 0006/0100 | Batch 0200/1079 | Loss: 0.9373
Epoch: 0006/0100 | Batch 0250/1079 | Loss: 1.2596
Epoch: 0006/0100 | Batch 0300/1079 | Loss: 1.4112
Epoch: 0006/0100 | Batch 0350/1079 | Loss: 0.8772
Epoch: 0006/0100 | Batch 0400/1079 | Loss: 0.8310
Epoch: 0006/0100 | Batch 0450/1079 | Loss: 1.2238
Epoch: 0006/0100 | Batch 0500/1079 | Loss: 1.4033
Epoch: 0006/0100 | Batch 0550/1079 | Loss: 1.4515
Epoch: 0006/0100 | Batch 0600/1079 | Loss: 1.4788
Epoch: 0006/0100 | Batch 0650/1079 | Loss: 0.9864
Epoch: 0006/0100 | Batch 0700/1079 | Loss: 0.9317
Epoch: 0006/0100 | Batch 0750/1079 | Loss: 0.9378
Epoch: 0006/0100 | Batch 0800/1079 | Loss: 1.0006
Epoch: 0006/0100 | Batch 0850/1079 | Loss: 0.8337
Epoch: 0006/0100 | Batch 0900/1079 | Loss: 1.3156
Epoch: 0006/0100 | Batch 0950/1079 | Loss: 1.2576




Epoch: 0007/0100 | Batch 0000/1079 | Loss: 0.8725
Epoch: 0007/0100 | Batch 0050/1079 | Loss: 0.6471
Epoch: 0007/0100 | Batch 0100/1079 | Loss: 0.8418
Epoch: 0007/0100 | Batch 0150/1079 | Loss: 1.3544
Epoch: 0007/0100 | Batch 0200/1079 | Loss: 0.9458
Epoch: 0007/0100 | Batch 0250/1079 | Loss: 0.8585
Epoch: 0007/0100 | Batch 0300/1079 | Loss: 1.3754
Epoch: 0007/0100 | Batch 0350/1079 | Loss: 1.0617
Epoch: 0007/0100 | Batch 0400/1079 | Loss: 1.3040
Epoch: 0007/0100 | Batch 0450/1079 | Loss: 2.0273
Epoch: 0007/0100 | Batch 0500/1079 | Loss: 1.2222
Epoch: 0007/0100 | Batch 0550/1079 | Loss: 0.9451
Epoch: 0007/0100 | Batch 0600/1079 | Loss: 1.2362
Epoch: 0007/0100 | Batch 0650/1079 | Loss: 1.0123
Epoch: 0007/0100 | Batch 0700/1079 | Loss: 0.8055
Epoch: 0007/0100 | Batch 0750/1079 | Loss: 0.7398
Epoch: 0007/0100 | Batch 0800/1079 | Loss: 1.1428
Epoch: 0007/0100 | Batch 0850/1079 | Loss: 1.2208
Epoch: 0007/0100 | Batch 0900/1079 | Loss: 1.2469
Epoch: 0007/0100 | Batch 0950/1079 | Loss: 0.8186




Epoch: 0008/0100 | Batch 0000/1079 | Loss: 1.4612
Epoch: 0008/0100 | Batch 0050/1079 | Loss: 1.0421
Epoch: 0008/0100 | Batch 0100/1079 | Loss: 0.8888
Epoch: 0008/0100 | Batch 0150/1079 | Loss: 0.9538
Epoch: 0008/0100 | Batch 0200/1079 | Loss: 0.7105
Epoch: 0008/0100 | Batch 0250/1079 | Loss: 0.7820
Epoch: 0008/0100 | Batch 0300/1079 | Loss: 1.0625
Epoch: 0008/0100 | Batch 0350/1079 | Loss: 0.8905
Epoch: 0008/0100 | Batch 0400/1079 | Loss: 1.0420
Epoch: 0008/0100 | Batch 0450/1079 | Loss: 0.9520
Epoch: 0008/0100 | Batch 0500/1079 | Loss: 0.8210
Epoch: 0008/0100 | Batch 0550/1079 | Loss: 1.1957
Epoch: 0008/0100 | Batch 0600/1079 | Loss: 0.8798
Epoch: 0008/0100 | Batch 0650/1079 | Loss: 1.3651
Epoch: 0008/0100 | Batch 0700/1079 | Loss: 1.2335
Epoch: 0008/0100 | Batch 0750/1079 | Loss: 0.8597
Epoch: 0008/0100 | Batch 0800/1079 | Loss: 0.9508
Epoch: 0008/0100 | Batch 0850/1079 | Loss: 0.8525
Epoch: 0008/0100 | Batch 0900/1079 | Loss: 0.9042
Epoch: 0008/0100 | Batch 0950/1079 | Loss: 1.0948




Epoch: 0009/0100 | Batch 0000/1079 | Loss: 0.9810
Epoch: 0009/0100 | Batch 0050/1079 | Loss: 0.8935
Epoch: 0009/0100 | Batch 0100/1079 | Loss: 0.9552
Epoch: 0009/0100 | Batch 0150/1079 | Loss: 0.9739
Epoch: 0009/0100 | Batch 0200/1079 | Loss: 0.6786
Epoch: 0009/0100 | Batch 0250/1079 | Loss: 1.1057
Epoch: 0009/0100 | Batch 0300/1079 | Loss: 1.1765
Epoch: 0009/0100 | Batch 0350/1079 | Loss: 1.0007
Epoch: 0009/0100 | Batch 0400/1079 | Loss: 1.0148
Epoch: 0009/0100 | Batch 0450/1079 | Loss: 0.8195
Epoch: 0009/0100 | Batch 0500/1079 | Loss: 0.9676
Epoch: 0009/0100 | Batch 0550/1079 | Loss: 0.6957
Epoch: 0009/0100 | Batch 0600/1079 | Loss: 1.0698
Epoch: 0009/0100 | Batch 0650/1079 | Loss: 0.9683
Epoch: 0009/0100 | Batch 0700/1079 | Loss: 0.7005
Epoch: 0009/0100 | Batch 0750/1079 | Loss: 1.1549
Epoch: 0009/0100 | Batch 0800/1079 | Loss: 1.0840
Epoch: 0009/0100 | Batch 0850/1079 | Loss: 1.1707
Epoch: 0009/0100 | Batch 0900/1079 | Loss: 0.6795
Epoch: 0009/0100 | Batch 0950/1079 | Loss: 0.6739




Epoch: 0010/0100 | Batch 0000/1079 | Loss: 1.1125
Epoch: 0010/0100 | Batch 0050/1079 | Loss: 0.8207
Epoch: 0010/0100 | Batch 0100/1079 | Loss: 0.9390
Epoch: 0010/0100 | Batch 0150/1079 | Loss: 0.8335
Epoch: 0010/0100 | Batch 0200/1079 | Loss: 0.7403
Epoch: 0010/0100 | Batch 0250/1079 | Loss: 0.7540
Epoch: 0010/0100 | Batch 0300/1079 | Loss: 0.6373
Epoch: 0010/0100 | Batch 0350/1079 | Loss: 0.6400
Epoch: 0010/0100 | Batch 0400/1079 | Loss: 0.7170
Epoch: 0010/0100 | Batch 0450/1079 | Loss: 1.2640
Epoch: 0010/0100 | Batch 0500/1079 | Loss: 0.3986
Epoch: 0010/0100 | Batch 0550/1079 | Loss: 0.8150
Epoch: 0010/0100 | Batch 0600/1079 | Loss: 1.0462
Epoch: 0010/0100 | Batch 0650/1079 | Loss: 0.8416
Epoch: 0010/0100 | Batch 0700/1079 | Loss: 0.5135
Epoch: 0010/0100 | Batch 0750/1079 | Loss: 1.0651
Epoch: 0010/0100 | Batch 0800/1079 | Loss: 0.6235
Epoch: 0010/0100 | Batch 0850/1079 | Loss: 1.0132
Epoch: 0010/0100 | Batch 0900/1079 | Loss: 0.9822
Epoch: 0010/0100 | Batch 0950/1079 | Loss: 1.0965




Epoch: 0011/0100 | Batch 0000/1079 | Loss: 0.8565
Epoch: 0011/0100 | Batch 0050/1079 | Loss: 0.4793
Epoch: 0011/0100 | Batch 0100/1079 | Loss: 0.6685
Epoch: 0011/0100 | Batch 0150/1079 | Loss: 0.5238
Epoch: 0011/0100 | Batch 0200/1079 | Loss: 1.1236
Epoch: 0011/0100 | Batch 0250/1079 | Loss: 0.6991
Epoch: 0011/0100 | Batch 0300/1079 | Loss: 0.5967
Epoch: 0011/0100 | Batch 0350/1079 | Loss: 0.7529
Epoch: 0011/0100 | Batch 0400/1079 | Loss: 0.7039
Epoch: 0011/0100 | Batch 0450/1079 | Loss: 0.5951
Epoch: 0011/0100 | Batch 0500/1079 | Loss: 0.5359
Epoch: 0011/0100 | Batch 0550/1079 | Loss: 0.7180
Epoch: 0011/0100 | Batch 0600/1079 | Loss: 0.6007
Epoch: 0011/0100 | Batch 0650/1079 | Loss: 1.1209
Epoch: 0011/0100 | Batch 0700/1079 | Loss: 1.0081
Epoch: 0011/0100 | Batch 0750/1079 | Loss: 0.5195
Epoch: 0011/0100 | Batch 0800/1079 | Loss: 0.5510
Epoch: 0011/0100 | Batch 0850/1079 | Loss: 0.7092
Epoch: 0011/0100 | Batch 0900/1079 | Loss: 0.7039
Epoch: 0011/0100 | Batch 0950/1079 | Loss: 0.8103




Epoch: 0012/0100 | Batch 0000/1079 | Loss: 0.5369
Epoch: 0012/0100 | Batch 0050/1079 | Loss: 0.7127
Epoch: 0012/0100 | Batch 0100/1079 | Loss: 0.5432
Epoch: 0012/0100 | Batch 0150/1079 | Loss: 0.6468
Epoch: 0012/0100 | Batch 0200/1079 | Loss: 0.4184
Epoch: 0012/0100 | Batch 0250/1079 | Loss: 0.4716
Epoch: 0012/0100 | Batch 0300/1079 | Loss: 0.8530
Epoch: 0012/0100 | Batch 0350/1079 | Loss: 0.5283
Epoch: 0012/0100 | Batch 0400/1079 | Loss: 1.1281
Epoch: 0012/0100 | Batch 0450/1079 | Loss: 0.5559
Epoch: 0012/0100 | Batch 0500/1079 | Loss: 0.7871
Epoch: 0012/0100 | Batch 0550/1079 | Loss: 0.8743
Epoch: 0012/0100 | Batch 0600/1079 | Loss: 0.5598
Epoch: 0012/0100 | Batch 0650/1079 | Loss: 0.6864
Epoch: 0012/0100 | Batch 0700/1079 | Loss: 0.8918
Epoch: 0012/0100 | Batch 0750/1079 | Loss: 0.6095
Epoch: 0012/0100 | Batch 0800/1079 | Loss: 0.6292
Epoch: 0012/0100 | Batch 0850/1079 | Loss: 0.6359
Epoch: 0012/0100 | Batch 0900/1079 | Loss: 0.5284
Epoch: 0012/0100 | Batch 0950/1079 | Loss: 0.9516




Epoch: 0013/0100 | Batch 0000/1079 | Loss: 0.4223
Epoch: 0013/0100 | Batch 0050/1079 | Loss: 0.3904
Epoch: 0013/0100 | Batch 0100/1079 | Loss: 0.6942
Epoch: 0013/0100 | Batch 0150/1079 | Loss: 0.8111
Epoch: 0013/0100 | Batch 0200/1079 | Loss: 0.8359
Epoch: 0013/0100 | Batch 0250/1079 | Loss: 0.9845
Epoch: 0013/0100 | Batch 0300/1079 | Loss: 0.5932
Epoch: 0013/0100 | Batch 0350/1079 | Loss: 0.3291
Epoch: 0013/0100 | Batch 0400/1079 | Loss: 0.5883
Epoch: 0013/0100 | Batch 0450/1079 | Loss: 0.5227
Epoch: 0013/0100 | Batch 0500/1079 | Loss: 0.9105
Epoch: 0013/0100 | Batch 0550/1079 | Loss: 0.5423
Epoch: 0013/0100 | Batch 0600/1079 | Loss: 0.7736
Epoch: 0013/0100 | Batch 0650/1079 | Loss: 0.6702
Epoch: 0013/0100 | Batch 0700/1079 | Loss: 0.6926
Epoch: 0013/0100 | Batch 0750/1079 | Loss: 0.4628
Epoch: 0013/0100 | Batch 0800/1079 | Loss: 0.8514
Epoch: 0013/0100 | Batch 0850/1079 | Loss: 0.4652
Epoch: 0013/0100 | Batch 0900/1079 | Loss: 0.2387
Epoch: 0013/0100 | Batch 0950/1079 | Loss: 1.0069




Epoch: 0014/0100 | Batch 0000/1079 | Loss: 0.6850
Epoch: 0014/0100 | Batch 0050/1079 | Loss: 0.8247
Epoch: 0014/0100 | Batch 0100/1079 | Loss: 0.6380
Epoch: 0014/0100 | Batch 0150/1079 | Loss: 0.6657
Epoch: 0014/0100 | Batch 0200/1079 | Loss: 0.5309
Epoch: 0014/0100 | Batch 0250/1079 | Loss: 0.2366
Epoch: 0014/0100 | Batch 0300/1079 | Loss: 0.4294
Epoch: 0014/0100 | Batch 0350/1079 | Loss: 0.4415
Epoch: 0014/0100 | Batch 0400/1079 | Loss: 0.7213
Epoch: 0014/0100 | Batch 0450/1079 | Loss: 0.4210
Epoch: 0014/0100 | Batch 0500/1079 | Loss: 0.4092
Epoch: 0014/0100 | Batch 0550/1079 | Loss: 0.3034
Epoch: 0014/0100 | Batch 0600/1079 | Loss: 0.4706
Epoch: 0014/0100 | Batch 0650/1079 | Loss: 0.5994
Epoch: 0014/0100 | Batch 0700/1079 | Loss: 0.5055
Epoch: 0014/0100 | Batch 0750/1079 | Loss: 0.5773
Epoch: 0014/0100 | Batch 0800/1079 | Loss: 0.8204
Epoch: 0014/0100 | Batch 0850/1079 | Loss: 0.6365
Epoch: 0014/0100 | Batch 0900/1079 | Loss: 0.6776
Epoch: 0014/0100 | Batch 0950/1079 | Loss: 0.4643


In [15]:
# # save models in work directory
# path = os.path.join(uc_dir, 'trained_models', 'mentalism')
# if not os.path.exists(path):
#     os.mkdir(path)
    
# sub_path = os.path.join(path, 'age')
# if not os.path.exists(sub_path):
#         os.mkdir(sub_path)

# # XLM
# import shutil
# source_path =  os.path.join(model_dir ,'XLM_age_mentalism.pt')
# dest_path = os.path.join(sub_path, 'XLM_age.pt')
# shutil.copy(source_path, dest_path)

## Test Set Performance

In [16]:
# user age data
path  = os.path.join(uc_dir, 'data_for_models_test.pkl')
df_test = pd.read_pickle(path)

# Discretize the 'age' column into five classes
age_intervals = [0, 18, 30, 40, 60, 100]
age_labels = [0, 1, 2, 3, 4]
df = df_test[df_test['age']<=99]
df_test['age_class'] = pd.cut(df_test['age'], bins=age_intervals, labels=age_labels, right=False).astype(int)


df_test['text']  = 'bio: ' + df_test['masked_bio'] + '. ' + 'tweets: ' + df_test['long_text'] 
df_test['text'] = df_test['text'].str.replace('\r|\n', ' ', regex=True)

X_test = df_test['text'].values
y_test = df_test['age_class'].values

In [17]:
# setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-large-2022")

# test encodings and dataset
xlm_test_encodings = batch_tokenize(X_test, tokenizer)

xlm_test_dataset = TweetDataset(xlm_test_encodings, y_test)


In [18]:
# data loaders
BATCH_SIZE = 16

xlm_loader = torch.utils.data.DataLoader(
    xlm_test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    pin_memory=True
)

In [19]:
def test_eval(model_path, data_loader, device):
    
    # Load the saved model
    model = torch.load(model_path)
    model = model.to(device)
    
    # Set the model to evaluation mode
    model.eval()
    
    # store predicted probs
    class_probs = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs['logits']

            # Convert logits to probabilities
            probabilities = F.softmax(logits, dim=1)
            class_probs.extend(probabilities.cpu().numpy().tolist())

            
    # Compute the metrics
    metrics = compute_metrics(model, data_loader, device)
    
    return metrics, np.array(class_probs)


In [None]:
# XLM
path = os.path.join(model_dir ,'XLM2022_age_mentalism_extra_features.pt')
metrics, xlm_probs = test_eval(path, xlm_loader, DEVICE)

# # save probs
# path = os.path.join(uc_dir, 'trained_models', 'age', 'XLM_probs_age.npy')
# np.save(path, xlm_probs)

print(metrics)