# Installation and Importing

In [1]:
# dependencies
import os
import gc
import time
import random
import csv
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datetime import datetime
from transformers import AutoModel, AutoTokenizer
from google.colab import drive, userdata

# file management
drive.mount('/content/drive')
WORK_DIR = '/content/drive/MyDrive/Projects/skillextraction'

# work dir shortcut function
def work_dir(*args):
    return os.path.join(WORK_DIR, *args)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Configuration

In [2]:
# config container
class C:

    # architecture
    BASE_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
    PROXY_GROUPS = [1,2,3,4,5,6,7,8,9,10,11] # label_en = [1], label_da = [2], desc_en = [3], desc_da = [4]
    SAMPLE_GROUPS = [1,2,3,4,5,6,7,8,9,10,11] # article_en = [5], article_da = [6], extra_da = [7], alts_en = [8], alts_da = [9], multi_1 = [10], multi_2 = [11]
    AVERAGE_EMBEDDINGS = True
    SEQ_LENGTH = 256 # 99.9th percentile of bench, 100% of others
    IS_SKILL_DIM = 8 # how many dimensions are read out for is_skill
    ATP_TEMPERATURE = 0.005 # average true probability metric temperature

    # training
    N_LAYERS = 5
    LR = 1e-6
    LR_INITIAL = 1e-8
    LR_LAYER_FACTOR = 0.5
    LR_REDUCE_FACTOR = 0.1
    LR_WARMUP_FACTOR = 1.0005
    TRAIN_METHOD = 'mnr' # 'direct', 'mnr'
    EPOCHS = 100
    PER_EPOCH = 50 # training samples per epoch per skill
    BATCH_SIZE = 128
    PATIENCE = 3 # early stopping
    IS_SKILL_WEIGHT = 0.25 # loss coefficient
    IS_SKILL_FP_PENALTY = 0.0005 # false positives loss multipler
    SKILL_ID_TEMP = 0.05 # loss temperature
    SKILL_ID_TEMP_INITIAL = 1.0
    SKILL_ID_TEMP_FACTOR = 0.9999

    # regularization
    AUGMENT_RATE = 1.0
    DROPOUT_RATE = 0.1
    WEIGHT_DECAY_RATE = 0.1

    # system
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    PREFETCH_FACTOR = 1

    # export path
    def PATH(postfix=''):
        return work_dir('experiments', '-'.join(str(v).replace('/', '-') for k, v in vars(C).items() if k.isupper() and k != 'PATH') + postfix)

# check config-aggregated path
C.PATH('.csv')

'/content/drive/MyDrive/Projects/skillextraction/experiments/sentence-transformers-paraphrase-multilingual-mpnet-base-v2-[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]-[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]-True-256-8-0.005-5-1e-06-1e-08-0.5-0.1-1.0005-mnr-100-50-128-3-0.25-0.0005-0.05-1.0-0.9999-1.0-0.1-0.1-cuda-2-1.csv'

# Dataframes

In [3]:
# load pre-organized data
skills = pd.read_json(work_dir('Data', 'skills.json'), orient='records', lines=True)
nonskills = pd.read_json(work_dir('Data', 'nonskills.json'), orient='records', lines=True)
bench = pd.read_json(work_dir('Data', 'bench.json'), orient='records', lines=True)

# assign id's to conceptUri's
uri_ids = skills['conceptUri'].unique()
uri_ids = dict(zip(uri_ids, range(len(uri_ids))))

# map id's for skills and bench
skills['id'] = skills['conceptUri'].map(uri_ids)
bench['id'] = bench['conceptUri'].map(uri_ids)

# sort for convenience
skills = skills.sort_values('id').reset_index(drop=True)
bench = bench.sort_values('id').reset_index(drop=True)

# padding id for nonskills
nonskills['id'] = -1

# extract skills from bench for validation and test
validation = bench[bench['group'].isin([1, 3, 6, 8])].reset_index(drop=True)
test = bench[bench['group'].isin([2, 4, 5, 7, 9, 10])].reset_index(drop=True)

# extract and append nonskills for validation from manually annotated
val_nonskills = nonskills[nonskills['group'].isin([2, 3])].sample(len(validation), random_state=7)
nonskills.drop(val_nonskills.index, inplace=True)
validation = pd.concat([validation, val_nonskills], ignore_index=True).reset_index(drop=True)

# extract and append nonskills for test from manually annotated
test_nonskills = nonskills[nonskills['group'].isin([2, 3])].sample(len(test), random_state=7)
nonskills.drop(test_nonskills.index, inplace=True)
test = pd.concat([test, test_nonskills], ignore_index=True).reset_index(drop=True)

# check
print(skills.shape, nonskills.shape, validation.shape, test.shape)
skills.columns, nonskills.columns, validation.columns, test.columns

(990700, 4) (95818, 3) (1110, 4) (7254, 4)


(Index(['conceptUri', 'sentence', 'group', 'id'], dtype='object'),
 Index(['group', 'sentence', 'id'], dtype='object'),
 Index(['conceptUri', 'group', 'sentence', 'id'], dtype='object'),
 Index(['conceptUri', 'group', 'sentence', 'id'], dtype='object'))

# Tokenizer

In [4]:
# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(C.BASE_MODEL)

# Datasets

In [5]:
# define multi purpose dataset, e.g. proxies for predictor, samples for training, bench for evaluating
class SkillData(Dataset):

    # init that handles different usages, calculates length based on usage (num unique id for proxies)
    def __init__(self, tokenizer, proxies=None, samples=None, nonskills=None, augment_rate=1.0, seq_length=None, per_epoch=1):
        super().__init__()
        self.tokenizer = tokenizer
        self.proxies = proxies.reset_index(drop=True) if proxies is not None else None
        self.samples = samples.reset_index(drop=True) if samples is not None else None
        self.nonskills = nonskills.reset_index(drop=True) if nonskills is not None else None
        self.augment_rate = augment_rate
        self.seq_length = seq_length
        self.per_epoch = per_epoch
        # calculate number of proxies per skill and number of samples (unique for train, all for eval)
        self.n_proxies = self.proxies['id'].value_counts().min() if proxies is not None else 0
        self.n = self.proxies['id'].nunique() if proxies is not None else len(self.samples)

    def __len__(self):
        return self.n * self.per_epoch

    def __getitem__(self, i):
        return i % self.n

    # shortcut tokenize
    def tokenize(self, sentences):
        return tokenizer(sentences,
                         padding='max_length' if self.seq_length else 'longest',
                         truncation=True,
                         max_length=self.seq_length,
                         return_tensors='pt')

    # augment into pairs of sentences, even / odd
    def augment(self, samples):
        if random.random() >= self.augment_rate:
            return self.samples
        s1 = samples['sentence'].iloc[::2].reset_index(drop=True)
        s2 = samples['sentence'].iloc[1::2].reset_index(drop=True)
        s3 = pd.Series()
        if len(s1) > len(s2):
            s3 = s1.iloc[-1:]
            s1 = s1.iloc[:-1]
        return samples.reset_index(drop=True).assign(
            sentence=pd.concat([np.repeat(s1 + ' ' + s2, 2), s3], ignore_index=True).reset_index(drop=True)
        )

    # collate proxy loading
    def proxy_collate(self, ids):
        proxies = self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').sample(self.n_proxies)
        tokens = self.tokenize(proxies['sentence'].tolist())
        return {'id': torch.tensor(proxies['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask']}

    # collate mnr training
    def mnr_collate(self, ids):
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        samples = self.augment(samples)
        samples = pd.concat([samples, self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').sample(1)], ignore_index=True) # only one proxy for mnr
        samples = pd.concat([samples, self.nonskills.sample(len(samples), random_state=7)], ignore_index=True)
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence.tolist(), dtype=torch.long)}

    # collate direct training
    def direct_collate(self, ids):
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        samples = self.augment(samples)
        samples = pd.concat([samples, self.nonskills.sample(len(samples), random_state=7)], ignore_index=True)
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence.tolist(), dtype=torch.long)}

    # collate evaluation
    def eval_collate(self, idx):
        samples = self.samples.loc[self.samples.index.isin(idx)]
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(self.samples['sentence'].unique(), range(len(self.samples['sentence'].unique()))))
        sentence = samples['sentence'].map(sentence_ids)
        group = ((samples['id'] > -1) * samples['group']).tolist()
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence.tolist(), dtype=torch.long),
                'group': torch.tensor(group, dtype=torch.long)}

# initialize datasets

proxy_data = SkillData(tokenizer=tokenizer,
                       proxies=skills.loc[skills['group'].isin(C.PROXY_GROUPS)],
                       seq_length=C.SEQ_LENGTH)

train_data = SkillData(tokenizer=tokenizer,
                       proxies=proxy_data.proxies,
                       samples=skills.loc[skills['group'].isin(C.SAMPLE_GROUPS)],
                       nonskills=nonskills,
                       per_epoch=C.PER_EPOCH,
                       augment_rate=C.AUGMENT_RATE)

val_data = SkillData(tokenizer=tokenizer,
                     samples=validation,
                     seq_length=C.SEQ_LENGTH)

test_data = SkillData(tokenizer=tokenizer,
                      samples=test,
                      seq_length=C.SEQ_LENGTH)

# initialize dataloaders

proxy_loader = DataLoader(proxy_data,
                          batch_size=64, # C.BATCH_SIZE,
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=proxy_data.proxy_collate,
                          shuffle=False,
                          pin_memory=True,
                          persistent_workers=True)

train_loader = DataLoader(train_data,
                          batch_size=C.BATCH_SIZE // 2 if C.TRAIN_METHOD == 'mnr' else C.BATCH_SIZE,
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=train_data.mnr_collate if C.TRAIN_METHOD == 'mnr' else train_data.direct_collate,
                          shuffle=True,
                          pin_memory=True,
                          persistent_workers=True)

val_loader = DataLoader(val_data,
                        batch_size=C.BATCH_SIZE,
                        num_workers=C.NUM_WORKERS,
                        prefetch_factor=C.PREFETCH_FACTOR,
                        collate_fn=val_data.eval_collate,
                        shuffle=False,
                        pin_memory=True,
                        persistent_workers=True)

test_loader = DataLoader(test_data,
                         batch_size=C.BATCH_SIZE,
                         num_workers=C.NUM_WORKERS,
                         prefetch_factor=C.PREFETCH_FACTOR,
                         collate_fn=test_data.eval_collate,
                         shuffle=False,
                         pin_memory=True,
                         persistent_workers=True)

# check
proxy_data.n, train_data.n, val_data.n, test_data.n, proxy_data.n_proxies

(13813, 13813, 1110, 7254, 39)

# Base Model

In [6]:
# initialize base model
base_model = AutoModel.from_pretrained(C.BASE_MODEL).to(C.DEVICE)

# Embedder

In [7]:
# define embedder
class SkillEmbedder(nn.Module):

    # initialize with base model and dropout rate
    def __init__(self, base_model, dropout_rate):
        super().__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)

    # embed using batch input_ids and attention_mask (including attention mean pooling!)
    def forward(self, input_ids, attention_mask):
        embeddings = self.base_model(input_ids, attention_mask).last_hidden_state#.mean(dim=1)
        embeddings = (embeddings * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        return self.dropout(embeddings)

# init embedder
embedder = SkillEmbedder(base_model=base_model, dropout_rate=C.DROPOUT_RATE).to(C.DEVICE)

# Predictor

In [8]:
# define predictor
class SkillPredictor(nn.Module):

    # initialize embedder, proxy_loader and proxy embeddings
    def __init__(self, embedder, proxy_loader, average, is_skill_dim):

        super().__init__()

        self.embedder = embedder
        self.proxy_loader = proxy_loader
        self.average = average
        self.is_skill_dim = is_skill_dim
        self.embeddings = nn.Parameter(torch.zeros(self.proxy_loader.dataset.n,
                                                  1 if self.average else self.proxy_loader.dataset.n_proxies,
                                                  self.embedder.base_model.config.hidden_size,
                                                  dtype=torch.half),
                                       requires_grad=False)

    # update proxy embeddings
    def update_embeddings(self):

        training = embedder.training
        embedder.eval()
        pbar = tqdm(proxy_loader, desc=f'Updating proxy embeddings', unit='batch')

        with torch.no_grad():
            for batch in pbar:
                with torch.amp.autocast(str(self.embeddings.device)):
                    batch = {k: v.to(self.embeddings.device) for k, v in batch.items()}
                    embeddings = embedder(batch['input_ids'], batch['attention_mask']).to(torch.half)
                    for id in batch['id'].unique():
                        skill_embeddings = embeddings[batch['id'] == id]
                        self.embeddings[id] = skill_embeddings.mean(dim=0, keepdim=True) if self.average else skill_embeddings

        embedder.train(training)

    # predict is_skill from n'th dimension(s) and skill_id from proxy embedding similarity
    def forward(self, embeddings, include='both', logits=False):

        if include in ('both', 'all', 'is_skill'):
            is_skill = embeddings[:, -self.is_skill_dim:].mean(dim=-1)
            is_skill = is_skill if logits else F.sigmoid(is_skill)
            if include == 'is_skill':
                return is_skill

        if include in ('both', 'all', 'skill_id'):
            sims = F.cosine_similarity(embeddings.unsqueeze(1).unsqueeze(1),
                                       self.embeddings,
                                       dim=-1).max(dim=-1)[0]
            skill_id = sims if logits else F.softmax(sims, dim=-1)
            if include == 'skill_id':
                return skill_id

        return is_skill, skill_id

# init predictor
predictor = SkillPredictor(embedder=embedder, proxy_loader=proxy_loader, average=C.AVERAGE_EMBEDDINGS, is_skill_dim=C.IS_SKILL_DIM).to(C.DEVICE)
predictor.update_embeddings()

Updating proxy embeddings: 100%|██████████| 216/216 [03:28<00:00,  1.03batch/s]


# Criterion

In [9]:
# define criterion
class SkillCriterion(nn.Module):
    def __init__(self, is_skill_fp_penalty=0.0, skill_id_temperature=1.0):
        super().__init__()
        self.is_skill_fp_penalty = is_skill_fp_penalty
        self.skill_id_temperature = skill_id_temperature

    # calculate loss for is_skill (with false positives penalty) and skill_id (with temperature)
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels):
        is_skill_loss = F.binary_cross_entropy_with_logits(is_skill_logits.float(),
                                                           is_skill_labels.float(),
                                                           reduction='none')
        is_skill_loss *= 1 + ((is_skill_logits > 0.0) & (~is_skill_labels.bool())).float() * self.is_skill_fp_penalty
        is_skill_loss = is_skill_loss.mean()
        if (skill_id_labels > -1).sum() > 0:
            skill_id_loss = F.cross_entropy(skill_id_logits / self.skill_id_temperature, skill_id_labels, ignore_index=-1)
        else:
            skill_id_loss = torch.tensor(0.0, device=is_skill_logits.device)
        return C.IS_SKILL_WEIGHT * is_skill_loss + skill_id_loss, {'is_skill_loss': is_skill_loss.item(), 'skill_id_loss': skill_id_loss.item()}

# init criterion
criterion = SkillCriterion(is_skill_fp_penalty=C.IS_SKILL_FP_PENALTY, skill_id_temperature=C.SKILL_ID_TEMP_INITIAL).to(C.DEVICE)

# Metrics

In [10]:
# define metrics (needs vectorizing)
class SkillMetrics(nn.Module):
    def __init__(self, atp_temperature=1.0):
        super().__init__()
        self.atp_temperature = atp_temperature

    # calculate metrics
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels, sentences):

        with torch.no_grad():

            # is_skill precision and recall
            tp = ((is_skill_logits > 0.0) & is_skill_labels).sum().item()
            fp = ((is_skill_logits > 0.0) & ~is_skill_labels).sum().item()
            fn = ((is_skill_logits <= 0.0) & is_skill_labels).sum().item()
            is_skill_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            is_skill_recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            # group entries by unique sentences
            sentence_to_indices = {}
            for i, sentence in enumerate(sentences.tolist()):
                sentence_to_indices.setdefault(sentence, []).append(i)

            # init skill_id metrics
            mrr_sum, rp5_sum, atp_sum, count = 0, 0, 0, 0

            # get probs and ranks
            probs = F.softmax(skill_id_logits / self.atp_temperature, dim=1)
            ranks = torch.argsort(skill_id_logits, dim=1, descending=True)

            # process metrics per unique sentence
            for sentence, indices in sentence_to_indices.items():

                # aggregate skill labels for the sentence, get softmax probs and sorted ranks
                sentence_labels = set(skill_id_labels[i].item() for i in indices if skill_id_labels[i] > -1)

                # check for any labels
                if len(sentence_labels) == 0:
                    continue

                # ATP calculation (average true probability)
                sentence_atp = probs[indices[0], list(sentence_labels)].sum().item()
                atp_sum += sentence_atp / len(sentence_labels)

                # MRR calculation
                sentence_mrr = 0.0
                found_labels = set()
                for pos, pred in enumerate(ranks[indices[0]].tolist(), 1):
                    if pred in sentence_labels:
                        sentence_mrr += 1.0 / (pos - len(found_labels))
                        found_labels.add(pred)
                        if len(found_labels) == len(sentence_labels):
                            break
                mrr_sum += sentence_mrr / len(sentence_labels)

                # RP@5 calculation
                top_k_correct = len(sentence_labels & set(ranks[indices[0], :5].tolist()))
                rp5_sum += top_k_correct / min(5, len(sentence_labels))

                count += 1

        # finalize metrics
        if count:
            skill_id_atp = atp_sum / count
            skill_id_rp5 = rp5_sum / count
            skill_id_mrr = mrr_sum / count
        else:
            skill_id_atp, skill_id_rp5, skill_id_mrr = None, None, None

        return {
            'is_skill_pre': is_skill_precision,
            'is_skill_rec': is_skill_recall,
            'skill_id_atp': skill_id_atp,
            'skill_id_rp5': skill_id_rp5,
            'skill_id_mrr': skill_id_mrr
        }

# init metrics
metrics = SkillMetrics(atp_temperature=C.ATP_TEMPERATURE).to(C.DEVICE)

# Optimizer

In [11]:
# set n last layers trainable
for idx, layer in enumerate(base_model.encoder.layer):
    for param in layer.parameters():
        param.requires_grad = idx >= len(base_model.encoder.layer) - C.N_LAYERS

# calculate layers to optimize learning for
n_layers = C.N_LAYERS - len(base_model.encoder.layer)

# create param groups for optimizer with layer-wise learning rate  (factor applied per layer)
param_groups = reversed([
    {'params': base_model.encoder.layer[i].parameters(), 'lr': C.LR_INITIAL * C.LR_LAYER_FACTOR**-(i + 1)} for i in range(n_layers, 0)
])

# adam with weight decay, default settings
optimizer = torch.optim.AdamW(param_groups, weight_decay=C.WEIGHT_DECAY_RATE)

# mixed precision scaler
scaler = torch.amp.GradScaler(str(C.DEVICE))

# collection of mixed precision backward pass calls
def backward(loss, optimizer, scaler):
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    clip_grad_norm_(embedder.parameters(), max_norm=1.0)
    scaler.unscale_(optimizer)
    scaler.step(optimizer)
    scaler.update()

# Logger

In [12]:
# define module for logging
class Logger(nn.Module):

    def __init__(self, optimizer, path):
        super().__init__()
        self.optimizer = optimizer
        self.path = path

    def forward(self, epoch, data):

        # init log data fields
        log_data = {'datetime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    'epoch': epoch}

        # gather log data
        log_data |= data

        # init log with header
        if not os.path.exists(self.path):
            with open(self.path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(log_data.keys())

        # log data for epoch
        with open(self.path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(log_data.values())
            f.flush()

# initialize logging
logger = Logger(optimizer, path=C.PATH('.csv'))

# Training

In [None]:
# decision variables
max_metrics = 0.0
patience_counter = 0

# run through epochs
for epoch in range(C.EPOCHS):

    # training mode
    embedder.train()

    # init stats and progress bar
    train_stats = {}
    model_stats = {}
    pbar = tqdm(train_loader, desc=f'Training (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # train
    for num_batch, batch in enumerate(pbar):
        with torch.amp.autocast(str(C.DEVICE)):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            if C.TRAIN_METHOD == 'mnr':

                # predictions
                is_skill_logits = predictor(embeddings, include='is_skill', logits=True)
                skill_id_logits = F.cosine_similarity(embeddings[:len(embeddings)//4].unsqueeze(1), # 1st 4th
                                                      embeddings[len(embeddings)//4:len(embeddings)//2].unsqueeze(0), # 2nd 4th
                                                      dim=-1)

                # truth
                is_skill_labels = batch['id'] > -1
                skill_id_labels = torch.arange(len(embeddings)//4, device=C.DEVICE)

            else:

                # predictions
                is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

                # truth
                is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # run batch
            loss, stats = criterion(is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels)

        # backward pass
        backward(loss, optimizer, scaler)

        # update progress
        train_stats = {
            'train_loss': [loss.item()] + train_stats.setdefault('loss', []),
            'train_is_skill_loss': [stats['is_skill_loss']] + train_stats.setdefault('is_skill_loss', []),
            'train_skill_id_loss': [stats['skill_id_loss']] + train_stats.setdefault('skill_id_loss', []),
            'lr': [optimizer.param_groups[0]['lr']] + model_stats.setdefault('lr', []),
            'patience': [C.PATIENCE - patience_counter],
            'temperature': [criterion.skill_id_temperature] + model_stats.setdefault('temperature', []),
        }

        pbar.set_postfix({k: sum(v) / len(v) for k, v in train_stats.items()})

        # decay temperature
        criterion.skill_id_temperature = max(criterion.skill_id_temperature * C.SKILL_ID_TEMP_FACTOR, C.SKILL_ID_TEMP)

        # warmup lr
        for param_group in optimizer.param_groups:
            if param_group['lr'] >= C.LR or patience_counter > 0:
                break
            param_group['lr'] = min(param_group['lr'] * C.LR_WARMUP_FACTOR, C.LR)

    # clean up
    del batch, embeddings, is_skill_logits, skill_id_logits, is_skill_labels, skill_id_labels, loss, stats
    torch.cuda.empty_cache()

    # update proxy embeddings after training / before validation
    predictor.update_embeddings()

    # validation mode
    embedder.eval()

    # init stats and progress bar
    val_stats = {}
    pbar = tqdm(val_loader, desc=f'Validation (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # results warehouse
    logits_and_labels = dict()

    # validate
    with torch.no_grad():
        for num_batch, batch in enumerate(pbar):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            # predictions
            is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

            # truth
            is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # save logits and labels
            logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
            logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
            logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
            logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
            logits_and_labels.setdefault('sentence', []).append(batch['sentence'])

        # run batch
        loss, stats = criterion(*[torch.cat(v) for v in list(logits_and_labels.values())[:4]])
        val_metrics = metrics(*[torch.cat(v) for v in logits_and_labels.values()])
        stats = {'loss': loss.item()} | stats | val_metrics

        # update progress
        val_stats = {k: [v] + val_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
        print('Validation:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in val_stats.items()]))

    # finalize stats
    train_stats = {k: sum(v) / len(v) for k, v in train_stats.items()}
    val_stats = {k: sum(v) / len(v) for k, v in val_stats.items()}

    # logging
    logger(epoch=epoch + 1, data=val_stats | train_stats)

    # save or break (if break, load best weights and restore embeddings)
    avg_metrics = sum([v for k, v in val_stats.items() if k in val_metrics.keys()])
    if avg_metrics > max_metrics:
        patience_counter = 0
        max_metrics = avg_metrics
        torch.save({'embedder_state_dict': embedder.state_dict()}, C.PATH('.pth'))
    else:
        patience_counter += 1
        if patience_counter >= C.PATIENCE:
            embedder.load_state_dict(torch.load(C.PATH('.pth'), weights_only=False)['embedder_state_dict'])
            predictor.update_embeddings()
            break
        for param_group in optimizer.param_groups:
            param_group['lr'] *= C.LR_REDUCE_FACTOR

Training (epoch 1/100): 100%|██████████| 10792/10792 [18:20<00:00,  9.81batch/s, train_loss=2.77, train_is_skill_loss=0.688, train_skill_id_loss=2.6, lr=1e-6, patience=3, temperature=0.34]
Updating proxy embeddings: 100%|██████████| 216/216 [03:29<00:00,  1.03batch/s]
Validation (epoch 1/100): 100%|██████████| 9/9 [00:03<00:00,  2.90batch/s]


Validation: loss = 8.310223579406738, is_skill_loss = 0.6870371699333191, skill_id_loss = 8.138463973999023, is_skill_pre = 0.6977152899824253, is_skill_rec = 0.7153153153153153, skill_id_atp = 0.308489127721702, skill_id_rp5 = 0.5512722646310433, skill_id_mrr = 0.4307749980083098


Training (epoch 2/100): 100%|██████████| 10792/10792 [18:20<00:00,  9.80batch/s, train_loss=2.01, train_is_skill_loss=0.681, train_skill_id_loss=1.84, lr=1e-6, patience=3, temperature=0.116]
Updating proxy embeddings: 100%|██████████| 216/216 [03:28<00:00,  1.04batch/s]
Validation (epoch 2/100): 100%|██████████| 9/9 [00:02<00:00,  3.03batch/s]


Validation: loss = 6.012263298034668, is_skill_loss = 0.6795756816864014, skill_id_loss = 5.842369556427002, is_skill_pre = 0.9210526315789473, is_skill_rec = 0.6936936936936937, skill_id_atp = 0.28346767709384885, skill_id_rp5 = 0.5299618320610686, skill_id_mrr = 0.4082731737982301


Training (epoch 3/100): 100%|██████████| 10792/10792 [18:21<00:00,  9.80batch/s, train_loss=2.34, train_is_skill_loss=0.678, train_skill_id_loss=2.17, lr=1e-6, patience=3, temperature=0.05]
Updating proxy embeddings: 100%|██████████| 216/216 [03:28<00:00,  1.04batch/s]
Validation (epoch 3/100): 100%|██████████| 9/9 [00:02<00:00,  3.03batch/s]


Validation: loss = 5.294389724731445, is_skill_loss = 0.6746786236763, skill_id_loss = 5.125720024108887, is_skill_pre = 0.8465189873417721, is_skill_rec = 0.963963963963964, skill_id_atp = 0.2918497152561288, skill_id_rp5 = 0.5414122137404579, skill_id_mrr = 0.4304522463582193


Training (epoch 4/100): 100%|██████████| 10792/10792 [18:22<00:00,  9.79batch/s, train_loss=2.09, train_is_skill_loss=0.671, train_skill_id_loss=1.93, lr=1e-6, patience=3, temperature=0.05]
Updating proxy embeddings: 100%|██████████| 216/216 [03:28<00:00,  1.04batch/s]
Validation (epoch 4/100): 100%|██████████| 9/9 [00:02<00:00,  3.03batch/s]


Validation: loss = 5.413917541503906, is_skill_loss = 0.6703407764434814, skill_id_loss = 5.246332168579102, is_skill_pre = 0.9142367066895368, is_skill_rec = 0.9603603603603603, skill_id_atp = 0.2803674158300044, skill_id_rp5 = 0.5538167938931298, skill_id_mrr = 0.42351996794072916


Training (epoch 5/100): 100%|██████████| 10792/10792 [18:19<00:00,  9.81batch/s, train_loss=2.05, train_is_skill_loss=0.663, train_skill_id_loss=1.88, lr=1e-6, patience=3, temperature=0.05]
Updating proxy embeddings: 100%|██████████| 216/216 [03:28<00:00,  1.04batch/s]
Validation (epoch 5/100): 100%|██████████| 9/9 [00:02<00:00,  3.03batch/s]


Validation: loss = 5.420886039733887, is_skill_loss = 0.665419340133667, skill_id_loss = 5.254531383514404, is_skill_pre = 0.9492753623188406, is_skill_rec = 0.9441441441441442, skill_id_atp = 0.2749599505463674, skill_id_rp5 = 0.5365139949109414, skill_id_mrr = 0.4150293934712245


Training (epoch 6/100):  80%|████████  | 8679/10792 [14:46<03:28, 10.14batch/s, train_loss=2.14, train_is_skill_loss=0.667, train_skill_id_loss=1.97, lr=1e-7, patience=2, temperature=0.05]

# Evaluation

In [None]:
# init stats and progress bar
test_stats = {}
pbar = tqdm(test_loader, desc=f'Test', unit='batch')

# results warehouse
logits_and_labels = dict()

# test
with torch.no_grad():
    for num_batch, batch in enumerate(pbar):

        # test mode
        embedder.eval()

        # send batch to GPU
        batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

        # generate embeddings
        embeddings = embedder(batch['input_ids'], batch['attention_mask'])

        # predictions
        is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

        # truth
        is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

        # save logits and labels
        logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
        logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
        logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
        logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
        logits_and_labels.setdefault('sentence', []).append(batch['sentence'])
        logits_and_labels.setdefault('group', []).append(batch['group'])

    # concatenate
    logits_and_labels = {k: torch.cat(v) for k, v in logits_and_labels.items()}

    # run batch
    loss, stats = criterion(*list(logits_and_labels.values())[:4])
    stats = {'loss': loss.item()} | stats | metrics(*list(logits_and_labels.values())[:5])

    # update progress
    test_stats = {k: [v] + test_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
    print('Test:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in test_stats.items()]))

# finalize stats
test_stats = {k: sum(v) / len(v) for k, v in test_stats.items()}

# logging
logger(epoch='test', data=test_stats)

# group logging
for g in logits_and_labels['group'].unique():
    loss, stats = criterion(*[v[g == logits_and_labels['group']] for v in list(logits_and_labels.values())[:4]])
    stats = {'loss': loss.item()} | stats | metrics(*[v[g == logits_and_labels['group']] for v in list(logits_and_labels.values())[:5]])
    logger(epoch=f'group{g}', data=stats)