# Installation and Importing

In [1]:
# dependencies
import re
import os
import gc
import time
import random
import csv
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datetime import datetime
from transformers import AutoModel, AutoTokenizer
from google.colab import drive, userdata

# file management
drive.mount('/content/drive')
WORK_DIR = '/content/drive/MyDrive/Projects/skillextraction'

# work dir shortcut function
def work_dir(*args):
    return os.path.join(WORK_DIR, *args)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Configuration

In [2]:
# config container
class C:

    # architecture
    BASE_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
    SEQ_LENGTH = 256
    IS_SKILL_DIM = 8
    ATP_TEMPERATURE = 0.005

    # training
    N_LAYERS = 5
    LR = 1e-6
    LR_INITIAL = 1e-8
    LR_LAYER_FACTOR = 0.5
    LR_REDUCE_FACTOR = 0.1
    LR_WARMUP_FACTOR = 1.0005
    TRAIN_METHOD = 'mnr' # 'direct', 'mnr'
    EPOCHS = 100
    PER_EPOCH = 10 # training samples per epoch per skill
    BATCH_SIZE = 96
    PATIENCE = 3 # early stopping
    IS_SKILL_WEIGHT = 0.25 # loss coefficient
    IS_SKILL_FP_PENALTY = 0.0005 # false positives loss multipler
    SKILL_ID_TEMP = 0.05 # loss temperature
    SKILL_ID_TEMP_INITIAL = 1.0
    SKILL_ID_TEMP_FACTOR = 0.9999

    # regularization
    DROPOUT_RATE = 0.1
    WEIGHT_DECAY_RATE = 0.1

    # system
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    PREFETCH_FACTOR = 1

    # export path
    def PATH(postfix=''):
        return work_dir('experiments', 'llm_data_model' + postfix)

# check config-aggregated path
C.PATH('.csv')

'/content/drive/MyDrive/Projects/skillextraction/experiments/llm_data_model.csv'

# Load Data

In [3]:
# get esco english/danish
esco_en = pd.read_csv(work_dir('ESCO', 'ESCO dataset - v1.1.2 - classification - en - csv', 'skills_en.csv'))
esco_da = pd.read_csv(work_dir('ESCO', 'ESCO dataset - v1.1.2 - classification - da - csv', 'skills_da.csv'))

# get proxies (english labels)
proxies = esco_en[['conceptUri', 'preferredLabel']].rename(columns={'preferredLabel': 'sentence'}).sort_values('conceptUri').reset_index(drop=True)

# get real sentences
reals = pd.concat([pd.read_csv(work_dir('Annotated_data', s))
                   for s in os.listdir(work_dir('Annotated_data')) if re.match(r'^sentences\_[0-9]+\.csv$', s)])

# filter real sentences conservatively in relation to llm instructions
skills = reals[reals['conceptUri'].notna() & (reals['conceptUri'] != '') & (reals['is_skill'] == 1)].copy()
nonskills = reals[(reals['conceptUri'].isna() | (reals['conceptUri'] == '')) & (reals['is_skill'] == 0)].copy()

# sanity
nonskills['conceptUri'] = ''

# get synthetic multi label sentences
synth = pd.concat([pd.read_csv(work_dir('Data', s))
                   for s in os.listdir(work_dir('Data')) if re.match(r'^multi\_sentences\_[0-9]+\.csv$', s)])

# gather all samples
train_samples = pd.concat([
    skills[['conceptUri', 'sentence']].assign(group=1),
    nonskills[['conceptUri', 'sentence']].assign(group=1),
    synth[['conceptUriPrimary', 'sentence']].rename(columns={'conceptUriPrimary': 'conceptUri'}).assign(group=2),
    synth[['conceptUriSecondary', 'sentence']].rename(columns={'conceptUriSecondary': 'conceptUri'}).assign(group=3),
    esco_en[['conceptUri', 'description']].rename(columns={'description': 'sentence'}).assign(group=4),
    esco_da[['conceptUri', 'preferredLabel']].rename(columns={'preferredLabel': 'sentence'}).assign(group=5),
    esco_da[['conceptUri', 'description']].rename(columns={'description': 'sentence'}).assign(group=6),
], ignore_index=True).dropna().reset_index(drop=True)

# assign id's to conceptUri's (-1 for empty, i.e. nonskill)
uri_ids = [''] + proxies['conceptUri'].unique().tolist()
uri_ids = dict(zip(uri_ids, range(-1, len(uri_ids)-1)))

# map id's
proxies['id'] = proxies['conceptUri'].map(uri_ids)
train_samples['id'] = train_samples['conceptUri'].map(uri_ids)

# sort by id for convenience
train_samples.sort_values('id', inplace=True)

# test split (with respect to multi labels represented as single labels with multi samples)
test_samples = train_samples[train_samples['group'] == 1].groupby('id').sample(frac=0.1, random_state=7)
test_samples = train_samples[train_samples['sentence'].isin(test_samples['sentence'])]
train_samples = train_samples[~train_samples['sentence'].isin(test_samples['sentence'])].reset_index()

# validation split
val_samples = train_samples[train_samples['group'] == 1].groupby('id').sample(frac=0.1, random_state=7)
val_samples = train_samples[train_samples['sentence'].isin(val_samples['sentence'])]
train_samples = train_samples[~train_samples['sentence'].isin(val_samples['sentence'])]

# check
print('Sizes:', proxies.shape, train_samples.shape, val_samples.shape, test_samples.shape)
print('Average labels per sentences:',
      (train_samples['id'] > -1).sum() / train_samples['sentence'].nunique(),
      (val_samples['id'] > -1).sum() / val_samples['sentence'].nunique(),
      (test_samples['id'] > -1).sum() / test_samples['sentence'].nunique())

Sizes: (13896, 3) (398574, 5) (18779, 5) (25017, 4)
Average labels per sentences: 1.4659312885361129 1.3400981052226604 1.5327504098680615


# Tokenizer

In [4]:
# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(C.BASE_MODEL)

# Datasets

In [5]:
# define multi purpose dataset, e.g. proxies for predictor, samples for training, bench for evaluating
class SkillData(Dataset):

    # init that handles different usages, calculates length based on usage (num unique id for proxies)
    def __init__(self, tokenizer, proxies=None, samples=None, seq_length=None, per_epoch=1):
        super().__init__()
        self.tokenizer = tokenizer
        self.proxies = proxies.reset_index(drop=True) if proxies is not None else None
        self.samples = samples.reset_index(drop=True) if samples is not None else None
        self.seq_length = seq_length
        self.per_epoch = per_epoch
        self.n = self.proxies['id'].nunique() if proxies is not None else len(self.samples)

    def __len__(self):
        return self.n * self.per_epoch

    def __getitem__(self, i):
        return i % self.n

    # tokenize
    def tokenize(self, sentences):
        return tokenizer(sentences,
                         padding='max_length' if self.seq_length else 'longest',
                         truncation=True,
                         max_length=self.seq_length,
                         return_tensors='pt')

    # collate proxy loading
    def proxy_collate(self, ids):
        proxies = self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').head(1)
        tokens = self.tokenize(proxies['sentence'].tolist())
        return {'id': torch.tensor(proxies['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask']}

    # collate mnr training
    def mnr_collate(self, ids):
        proxies = self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').head(1)
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        nonskills = self.samples.loc[self.samples['id'] == -1].sample(len(samples))
        samples = pd.concat([proxies, samples, nonskills], ignore_index=True)
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

    # collate direct training
    def direct_collate(self, ids):
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        if (samples['id'] > -1).sum() > 1:
            samples.loc[samples['id'] > -1] = self.augment(samples.loc[samples['id'] > -1])
        if (samples['id'] == -1).sum() == 0:
            samples.loc[samples['id'] == -1] = self.augment(samples.loc[samples['id'] == -1])
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

    # collate evaluation
    def eval_collate(self, idx):
        samples = self.samples.loc[self.samples.index.isin(idx)]
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(self.samples['sentence'].unique(), range(len(self.samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

# initialize datasets

proxy_data = SkillData(tokenizer=tokenizer,
                       proxies=proxies,
                       seq_length=C.SEQ_LENGTH)

train_data = SkillData(tokenizer=tokenizer,
                       proxies=proxies,
                       samples=train_samples,
                       seq_length=C.SEQ_LENGTH,
                       per_epoch=C.PER_EPOCH)

val_data = SkillData(tokenizer=tokenizer,
                     samples=val_samples,
                     seq_length=C.SEQ_LENGTH)

test_data = SkillData(tokenizer=tokenizer,
                      samples=test_samples,
                      seq_length=C.SEQ_LENGTH)

# initialize dataloaders

proxy_loader = DataLoader(proxy_data,
                          batch_size=256,
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=proxy_data.proxy_collate,
                          shuffle=False,
                          pin_memory=True,
                          persistent_workers=True)

train_loader = DataLoader(train_data,
                          batch_size=C.BATCH_SIZE,
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=train_data.mnr_collate if C.TRAIN_METHOD == 'mnr' else train_data.direct_collate,
                          shuffle=True,
                          pin_memory=True,
                          persistent_workers=True)

val_loader = DataLoader(val_data,
                        batch_size=256,
                        num_workers=C.NUM_WORKERS,
                        prefetch_factor=C.PREFETCH_FACTOR,
                        collate_fn=val_data.eval_collate,
                        shuffle=False,
                        pin_memory=True,
                        persistent_workers=True)

test_loader = DataLoader(test_data,
                         batch_size=256,
                         num_workers=C.NUM_WORKERS,
                         prefetch_factor=C.PREFETCH_FACTOR,
                         collate_fn=test_data.eval_collate,
                         shuffle=False,
                         pin_memory=True,
                         persistent_workers=True)

# check
proxy_data.n, train_data.n, val_data.n, test_data.n

(13896, 13896, 18779, 25017)

# Base Model

In [6]:
# initialize base model
base_model = AutoModel.from_pretrained(C.BASE_MODEL).to(C.DEVICE)

# Embedder

In [7]:
# define embedder
class SkillEmbedder(nn.Module):

    # initialize with base model and dropout rate
    def __init__(self, base_model, dropout_rate):
        super().__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)

    # embed using batch input_ids and attention_mask (including attention mean pooling!)
    def forward(self, input_ids, attention_mask):
        embeddings = self.base_model(input_ids, attention_mask).last_hidden_state#.mean(dim=1)
        embeddings = (embeddings * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        return self.dropout(embeddings)

# init embedder
embedder = SkillEmbedder(base_model=base_model, dropout_rate=C.DROPOUT_RATE).to(C.DEVICE)

In [8]:
# loading embedder state
state_dicts = torch.load(C.PATH('.pth'), weights_only=False)
embedder.load_state_dict(state_dicts['embedder_state_dict'])

<All keys matched successfully>

# Predictor

In [9]:
# define predictor
class SkillPredictor(nn.Module):

    # initialize embedder, proxy_loader and proxy embeddings
    def __init__(self, embedder, proxy_loader, is_skill_dim):

        super().__init__()

        self.embedder = embedder
        self.proxy_loader = proxy_loader
        self.is_skill_dim = is_skill_dim
        self.embeddings = nn.Parameter(torch.zeros(self.proxy_loader.dataset.n,
                                                  1,
                                                  self.embedder.base_model.config.hidden_size,
                                                  dtype=torch.half),
                                       requires_grad=False)

    # update proxy embeddings
    def update_embeddings(self):

        training = embedder.training
        embedder.eval()
        pbar = tqdm(proxy_loader, desc=f'Updating proxy embeddings', unit='batch')

        with torch.no_grad():
            for batch in pbar:
                with torch.amp.autocast(str(self.embeddings.device)):
                    batch = {k: v.to(self.embeddings.device) for k, v in batch.items()}
                    embeddings = embedder(batch['input_ids'], batch['attention_mask']).to(torch.half)
                    for id in batch['id'].unique():
                        skill_embeddings = embeddings[batch['id'] == id]
                        self.embeddings[id] = skill_embeddings

        embedder.train(training)

    # predict is_skill from n'th dimension(s) and skill_id from proxy embedding similarity
    def forward(self, embeddings, include='both', logits=False):

        if include in ('both', 'all', 'is_skill'):
            is_skill = embeddings[:, -self.is_skill_dim:].mean(dim=-1)
            is_skill = is_skill if logits else F.sigmoid(is_skill)
            if include == 'is_skill':
                return is_skill

        if include in ('both', 'all', 'skill_id'):
            sims = F.cosine_similarity(embeddings.unsqueeze(1).unsqueeze(1),
                                       self.embeddings,
                                       dim=-1).max(dim=-1)[0]
            skill_id = sims if logits else F.softmax(sims, dim=-1)
            if include == 'skill_id':
                return skill_id

        return is_skill, skill_id

# init predictor
predictor = SkillPredictor(embedder=embedder, proxy_loader=proxy_loader, is_skill_dim=C.IS_SKILL_DIM).to(C.DEVICE)
predictor.update_embeddings()

Updating proxy embeddings: 100%|██████████| 55/55 [00:08<00:00,  6.79batch/s]


# Criterion

In [10]:
# define criterion
class SkillCriterion(nn.Module):
    def __init__(self, is_skill_fp_penalty=0.0, skill_id_temperature=1.0):
        super().__init__()
        self.is_skill_fp_penalty = is_skill_fp_penalty
        self.skill_id_temperature = skill_id_temperature

    # calculate loss for is_skill (with false positives penalty) and skill_id (with temperature)
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels):
        is_skill_loss = F.binary_cross_entropy_with_logits(is_skill_logits.float(),
                                                           is_skill_labels.float(),
                                                           reduction='none')
        is_skill_loss *= 1 + ((is_skill_logits > 0.0) & (~is_skill_labels.bool())).float() * self.is_skill_fp_penalty
        is_skill_loss = is_skill_loss.mean()
        if (skill_id_labels > -1).sum() > 0:
            skill_id_loss = F.cross_entropy(skill_id_logits / self.skill_id_temperature, skill_id_labels, ignore_index=-1)
        else:
            skill_id_loss = torch.tensor(0.0, device=is_skill_logits.device)
        return C.IS_SKILL_WEIGHT * is_skill_loss + skill_id_loss, {'is_skill_loss': is_skill_loss.item(), 'skill_id_loss': skill_id_loss.item()}

# init criterion
criterion = SkillCriterion(is_skill_fp_penalty=C.IS_SKILL_FP_PENALTY, skill_id_temperature=C.SKILL_ID_TEMP_INITIAL).to(C.DEVICE)

# Metrics

In [11]:
# define metrics
class SkillMetrics(nn.Module):
    def __init__(self, atp_temperature=1.0):
        super().__init__()
        self.atp_temperature = atp_temperature

    # calculate metrics
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels, sentences):
        with torch.no_grad():

            # calculate is_skill precision and recall
            tp = ((is_skill_logits > 0.0) & is_skill_labels).sum().item()
            fp = ((is_skill_logits > 0.0) & ~is_skill_labels).sum().item()
            fn = ((is_skill_logits <= 0.0) & is_skill_labels).sum().item()
            is_skill_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            is_skill_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

            # group entries by unique sentences
            sentence_to_indices = {}
            sentence_list = sentences.tolist()
            for i, sid in enumerate(sentence_list):
                sentence_to_indices.setdefault(sid, []).append(i)

            # initialize skill_id metrics
            mrr_sum, rp5_sum, atp_sum, count = 0.0, 0.0, 0.0, 0

            # softmax probabilities and ranking
            probs = F.softmax(skill_id_logits / self.atp_temperature, dim=1)
            ranks = torch.argsort(skill_id_logits, dim=1, descending=True)

            # process metrics per unique sentence
            for sid, indices in sentence_to_indices.items():

                # collect skill labels for this sentence
                sentence_labels = {skill_id_labels[i].item() for i in indices if skill_id_labels[i] > -1}

                # check for any labels or skip
                if len(sentence_labels) == 0:
                    continue

                # ATP calculation (Average True Probability)
                sentence_atp = probs[indices[0], list(sentence_labels)].sum().item()
                atp_sum += sentence_atp / len(sentence_labels)

                # MRR calculation (use only max ranking):
                sentence_mrr = 0.0
                for pos, pred in enumerate(ranks[indices[0]].tolist(), start=1):
                    if pred in sentence_labels:
                        sentence_mrr = 1.0 / pos
                        break
                mrr_sum += sentence_mrr

                # RP@5 calculation
                top_5_preds = set(ranks[indices[0], :5].tolist())
                top_k_correct = len(sentence_labels & top_5_preds)
                rp5_sum += top_k_correct / min(5, len(sentence_labels))

                count += 1

            # finalize metrics
            if count > 0:
                skill_id_atp = atp_sum / count
                skill_id_rp5 = rp5_sum / count
                skill_id_mrr = mrr_sum / count
            else:
                skill_id_atp, skill_id_rp5, skill_id_mrr = None, None, None

        return {
            'is_skill_pre': is_skill_precision,
            'is_skill_rec': is_skill_recall,
            'skill_id_atp': skill_id_atp,
            'skill_id_rp5': skill_id_rp5,
            'skill_id_mrr': skill_id_mrr
        }

# init metrics
metrics = SkillMetrics(atp_temperature=C.ATP_TEMPERATURE).to(C.DEVICE)

# Optimizer

In [12]:
# set n last layers trainable
for idx, layer in enumerate(base_model.encoder.layer):
    for param in layer.parameters():
        param.requires_grad = idx >= len(base_model.encoder.layer) - C.N_LAYERS

# calculate layers to optimize learning for
n_layers = C.N_LAYERS - len(base_model.encoder.layer)

# create param groups for optimizer with layer-wise learning rate  (factor applied per layer)
param_groups = reversed([
    {'params': base_model.encoder.layer[i].parameters(), 'lr': C.LR_INITIAL * C.LR_LAYER_FACTOR**-(i + 1)} for i in range(n_layers, 0)
])

# adam with weight decay, default settings
optimizer = torch.optim.AdamW(param_groups, weight_decay=C.WEIGHT_DECAY_RATE)

# mixed precision scaler
scaler = torch.amp.GradScaler(str(C.DEVICE))

# collection of mixed precision backward pass calls
def backward(loss, optimizer, scaler):
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    clip_grad_norm_(embedder.parameters(), max_norm=1.0)
    scaler.unscale_(optimizer)
    scaler.step(optimizer)
    scaler.update()

# Logger

In [13]:
# define module for logging
class Logger(nn.Module):

    def __init__(self, optimizer, path):
        super().__init__()
        self.optimizer = optimizer
        self.path = path

    def forward(self, epoch, data):

        # init log data fields
        log_data = {'datetime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    'epoch': epoch}

        # gather log data
        log_data |= data

        # init log with header
        if not os.path.exists(self.path):
            with open(self.path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(log_data.keys())

        # log data for epoch
        with open(self.path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(log_data.values())
            f.flush()

# initialize logging
logger = Logger(optimizer, path=C.PATH('.csv'))

# Training

In [None]:
# decision variables
max_metrics = 0.0
patience_counter = 0

# run through epochs
for epoch in range(C.EPOCHS):

    # training mode
    embedder.train()

    # init stats and progress bar
    train_stats = {}
    model_stats = {}
    pbar = tqdm(train_loader, desc=f'Training (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # train
    for num_batch, batch in enumerate(pbar):
        with torch.amp.autocast(str(C.DEVICE)):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            if C.TRAIN_METHOD == 'mnr':

                # separate
                proxies, samples, nonskills = torch.chunk(embeddings, 3)

                # predictions
                is_skill_logits = predictor(torch.cat([samples, nonskills], dim=0), include='is_skill', logits=True)
                skill_id_logits = F.cosine_similarity(samples.unsqueeze(1),
                                                      proxies.unsqueeze(0),
                                                      dim=-1)

                # truth
                is_skill_labels = (batch['id'] > -1)[len(proxies):]
                skill_id_labels = torch.arange(len(samples), device=C.DEVICE)

            else:

                # predictions
                is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

                # truth
                is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # run batch
            loss, stats = criterion(is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels)

        # backward pass
        backward(loss, optimizer, scaler)

        # update progress
        train_stats = {
            'train_loss': [loss.item()] + train_stats.setdefault('loss', []),
            'train_is_skill_loss': [stats['is_skill_loss']] + train_stats.setdefault('is_skill_loss', []),
            'train_skill_id_loss': [stats['skill_id_loss']] + train_stats.setdefault('skill_id_loss', []),
            'lr': [optimizer.param_groups[0]['lr']] + model_stats.setdefault('lr', []),
            'patience': [C.PATIENCE - patience_counter],
            'temperature': [criterion.skill_id_temperature] + model_stats.setdefault('temperature', []),
        }

        pbar.set_postfix({k: sum(v) / len(v) for k, v in train_stats.items()})

        # decay temperature
        criterion.skill_id_temperature = max(criterion.skill_id_temperature * C.SKILL_ID_TEMP_FACTOR, C.SKILL_ID_TEMP)

        # warmup lr
        for param_group in optimizer.param_groups:
            if param_group['lr'] >= C.LR or patience_counter > 0:
                break
            param_group['lr'] = min(param_group['lr'] * C.LR_WARMUP_FACTOR, C.LR)

    # clean up
    del batch, embeddings, is_skill_logits, skill_id_logits, is_skill_labels, skill_id_labels, loss, stats
    torch.cuda.empty_cache()

    # update proxy embeddings after training / before validation
    predictor.update_embeddings()

    # validation mode
    embedder.eval()

    # init stats and progress bar
    val_stats = {}
    pbar = tqdm(val_loader, desc=f'Validation (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # results warehouse
    logits_and_labels = dict()

    # validate
    with torch.no_grad():
        for num_batch, batch in enumerate(pbar):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            # predictions
            is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

            # truth
            is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # save logits and labels
            logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
            logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
            logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
            logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
            logits_and_labels.setdefault('sentence', []).append(batch['sentence'])

        # run batch
        loss, stats = criterion(*[torch.cat(v) for v in list(logits_and_labels.values())[:4]])
        val_metrics = metrics(*[torch.cat(v) for v in logits_and_labels.values()])
        stats = {'loss': loss.item()} | stats | val_metrics

        # update progress
        val_stats = {k: [v] + val_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
        print('Validation:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in val_stats.items()]))

    # clean up
    del batch, embeddings, is_skill_logits, skill_id_logits, is_skill_labels, skill_id_labels, logits_and_labels, loss, stats
    torch.cuda.empty_cache()

    # finalize stats
    train_stats = {k: sum(v) / len(v) for k, v in train_stats.items()}
    val_stats = {k: sum(v) / len(v) for k, v in val_stats.items()}

    # logging
    logger(epoch=epoch + 1, data=val_stats | train_stats)

    # save or break (if break, load best weights and restore embeddings)
    avg_metrics = sum([v for k, v in val_stats.items() if k in val_metrics.keys()])
    if avg_metrics > max_metrics:
        patience_counter = 0
        max_metrics = avg_metrics
        torch.save({'embedder_state_dict': embedder.state_dict(),
                    'predictor_state_dict': predictor.state_dict()}, C.PATH('.pth'))
    else:
        patience_counter += 1
        if patience_counter >= C.PATIENCE:
            state_dicts = torch.load(C.PATH('.pth'), weights_only=False)
            embedder.load_state_dict(state_dicts['embedder_state_dict'])
            predictor.load_state_dict(state_dicts['predictor_state_dict'])
            break
        for param_group in optimizer.param_groups:
            param_group['lr'] *= C.LR_REDUCE_FACTOR

Training (epoch 1/100):  69%|██████▉   | 1005/1448 [05:45<02:32,  2.91batch/s, train_loss=4.17, train_is_skill_loss=0.692, train_skill_id_loss=4, lr=1.65e-8, patience=3, temperature=0.904]

# Evaluation

In [None]:
# init stats and progress bar
test_stats = {}
pbar = tqdm(test_loader, desc=f'Test', unit='batch')

# results warehouse
logits_and_labels = dict()

# test
with torch.no_grad():
    for num_batch, batch in enumerate(pbar):

        # test mode
        embedder.eval()

        # send batch to GPU
        batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

        # generate embeddings
        embeddings = embedder(batch['input_ids'], batch['attention_mask'])

        # predictions
        is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

        # truth
        is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

        # save logits and labels
        logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
        logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
        logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
        logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
        logits_and_labels.setdefault('sentence', []).append(batch['sentence'])

    # concatenate
    logits_and_labels = {k: torch.cat(v) for k, v in logits_and_labels.items()}

    # run batch
    loss, stats = criterion(*list(logits_and_labels.values())[:4])
    stats = {'loss': loss.item()} | stats | metrics(*list(logits_and_labels.values())[:5])

    # update progress
    test_stats = {k: [v] + test_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
    print('Test:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in test_stats.items()]))

# finalize stats
test_stats = {k: sum(v) / len(v) for k, v in test_stats.items()}

# logging
logger(epoch='test', data=test_stats)