# Installation and Importing

In [1]:
# dependencies
import re
import os
import gc
import time
import random
import csv
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datetime import datetime
from transformers import AutoModel, AutoTokenizer
from google.colab import drive, userdata

# file management
drive.mount('/content/drive')
WORK_DIR = '/content/drive/MyDrive/Projects/skillextraction'

# work dir shortcut function
def work_dir(*args):
    return os.path.join(WORK_DIR, *args)

Mounted at /content/drive


# Configuration

In [2]:
# config container
class C:

    #BASE_MODEL = 'sentence-transformers/all-mpnet-base-v2'
    BASE_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
    #BASE_MODEL = 'sentence-transformers/labse'
    #BASE_MODEL = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'

    PROXY_SETS = ['label_en']
    TRAIN_SETS = ['train']
    VAL_SETS = ['val']
    TEST_SETS = ['val', 'test', 'tech', 'house', 'techwolf']
    SEQ_LENGTH = 256
    CAP_TEMP = 0.01
    AVERAGE_EMBEDDINGS = False # whether to average proxy embeddings per skill (relevant for > 1 proxy per skill)
    UPDATE_EMBEDDINGS = True # whether to refresh proxy embeddings each epoch or freeze entirely
    TRAIN_METHOD = 'direct' # 'direct', 'mnr'

    N_LAYERS = 5 # [-n:] layers unfrozen in base model
    LR = 1e-5 # skill_id learning rate
    LR_INITIAL = 1e-6
    LR_LAYER_FACTOR = 0.5
    LR_REDUCE_FACTOR = 0.5
    LR_WARMUP_FACTOR = 1.001
    LR_NUGGET = 1e-3 # nugget learning rate, static

    EPOCHS = 100
    PER_EPOCH = 5 # training samples per epoch per skill
    BATCH_SIZE = 96

    VAL_STATS = ['skill_id_cap', 'skill_id_rp5', 'skill_id_mrr'] # which metrics to test for improvement
    PATIENCE = 0 # early stopping

    FP_PENALTY = 0.01 # is_skill false positives loss multipler
    TEMP = 0.1 # skill_id loss temperature
    TEMP_INITIAL = 1.0
    TEMP_FACTOR = 0.99999

    DROPOUT_RATE = 0.25 # is_skill only
    AUGMENT_RATE = 0.0 # skill_id only
    WEIGHT_DECAY_RATE = 0.05 # skill_id only

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    PREFETCH_FACTOR = 1

    LOG_PATH = work_dir('trials', 'final.csv')
    SAVE_PATH = work_dir('trials', 'model' + str(int(time.time())) + '.pth')

# check config-aggregated path
C.SAVE_PATH

'/content/drive/MyDrive/Projects/skillextraction/trials/model1735657541.pth'

# Load Data

In [3]:
# get esco english/danish
esco_en = pd.read_csv(work_dir('ESCO', 'ESCO dataset - v1.1.2 - classification - en - csv', 'skills_en.csv'))
esco_da = pd.read_csv(work_dir('ESCO', 'ESCO dataset - v1.1.2 - classification - da - csv', 'skills_da.csv'))

# prepare alts
alt_en = esco_en.set_index('conceptUri')['altLabels'].str.split('\n').explode().reset_index().replace('', pd.NA).dropna()
alt_da = esco_da.set_index('conceptUri')['altLabels'].str.split('\n').explode().reset_index().replace('', pd.NA).dropna()

# get synthetic data from article
article = pd.read_csv("hf://datasets/jensjorisdecorte/Synthetic-ESCO-skill-sentences/dataset.csv")
article['conceptUri'] = article['skill'].map(esco_en.set_index('preferredLabel')['conceptUri'])
article = article.loc[article['conceptUri'].notna()].reset_index(drop=True)

# get benchmark data from article
bench = pd.concat([
    pd.read_csv(f'hf://datasets/jensjorisdecorte/skill-extraction-{v[0]}/{v[1]}.csv').assign(dataset=v[0])
    for v in (('tech', 'validation'), ('tech', 'test'), ('house', 'validation'), ('house', 'test'), ('techwolf', 'test'))
], ignore_index=True)

# map bench label to conceptUri and drop non-labeled
bench['conceptUri'] = bench['label'].map(esco_en.set_index('preferredLabel')['conceptUri'])
bench = bench.loc[bench['conceptUri'].notna()].reset_index(drop=True)

# get llm-annotated real sentences
reals = pd.read_csv(work_dir('Data', 'reals.csv'), low_memory=False).fillna('')

# get hand-annotated real nonskill sentences
nonskills = pd.read_csv(work_dir('Data', 'nonskills.csv'), low_memory=False).fillna('')

# gather all data
all_data = pd.concat([
    esco_en[['conceptUri', 'preferredLabel']].rename(columns={'preferredLabel': 'sentence'}).assign(dataset='label_en'),
    esco_en[['conceptUri', 'description']].rename(columns={'description': 'sentence'}).assign(dataset='desc_en'),
    esco_da[['conceptUri', 'preferredLabel']].rename(columns={'preferredLabel': 'sentence'}).assign(dataset='label_da'),
    esco_da[['conceptUri', 'description']].rename(columns={'description': 'sentence'}).assign(dataset='desc_da'),
    alt_en.rename(columns={'altLabels': 'sentence'}).assign(dataset='alt_en'),
    alt_da.rename(columns={'altLabels': 'sentence'}).assign(dataset='alt_da'),
    alt_en.rename(columns={'altLabels': 'sentence'}).groupby('conceptUri').sample(1, random_state=7).assign(dataset='alt_en_eval'),
    alt_da.rename(columns={'altLabels': 'sentence'}).groupby('conceptUri').sample(1, random_state=7).assign(dataset='alt_da_eval'),
    article[['conceptUri', 'sentence']].assign(dataset='article'),
    article[['conceptUri', 'sentence']].groupby('conceptUri').sample(1, random_state=7).assign(dataset='article_eval'),
    bench[['conceptUri', 'sentence', 'dataset']],
    reals[['conceptUri', 'sentence']].assign(dataset=reals['split']),
    reals.loc[reals['conceptUri'] == '', ['conceptUri', 'sentence']].assign(dataset='llm_nonskills'),
    nonskills[['sentence']].assign(conceptUri='', dataset='nonskills')
], ignore_index=True).dropna().sort_values('conceptUri').reset_index(drop=True)

# check
all_data['dataset'].value_counts()

Unnamed: 0_level_0,count
dataset,Unnamed: 1_level_1
train,992104
llm_nonskills,241219
article,138240
alt_en,97483
test,27148
val,25601
desc_en,13896
label_en,13896
desc_da,13896
label_da,13895


In [4]:
# create splits
proxy_df = all_data[all_data['dataset'].isin(C.PROXY_SETS)].copy()
train_df = all_data[all_data['dataset'].isin(C.TRAIN_SETS)].copy()
val_df = all_data[all_data['dataset'].isin(C.VAL_SETS)].copy()
test_df = all_data[all_data['dataset'].isin(C.TEST_SETS)].copy()

# assign id's to conceptUri's (-1 for empty, i.e. nonskill)
uri_ids = [''] + proxy_df['conceptUri'].unique().tolist()
uri_ids = dict(zip(uri_ids, range(-1, len(uri_ids) - 1)))

# map ids
proxy_df['id'] = proxy_df['conceptUri'].map(uri_ids)
train_df['id'] = train_df['conceptUri'].map(uri_ids)
val_df['id'] = val_df['conceptUri'].map(uri_ids)
test_df['id'] = test_df['conceptUri'].map(uri_ids)

# adjust splits to proxies and sort for convenience
train_df = train_df[train_df['id'].isin([-1] + proxy_df['id'].tolist())].sort_values('id').reset_index(drop=True)
val_df = val_df[val_df['id'].isin([-1] + proxy_df['id'].tolist())].sort_values('id').reset_index(drop=True)
test_df = test_df[test_df['id'].isin([-1] + proxy_df['id'].tolist())].sort_values('id').reset_index(drop=True)

# check
len(proxy_df), len(train_df), len(val_df), len(test_df)

(13896, 992104, 25601, 54868)

# Tokenizer

In [5]:
# initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(C.BASE_MODEL)

tokenizer_config.json:   0%|          | 0.00/402 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

# Datasets

In [6]:
# define multi purpose dataset, e.g. proxies for predictor, samples for training, bench for evaluating
class SkillData(Dataset):

    # init that handles different usages, calculates length based on usage (num unique id for proxies)
    def __init__(self, tokenizer, proxies=None, samples=None, seq_length=None, augment_rate=0.0, per_epoch=1):
        super().__init__()
        self.tokenizer = tokenizer
        self.proxies = proxies.reset_index(drop=True) if proxies is not None else None
        self.samples = samples.reset_index(drop=True) if samples is not None else None
        self.seq_length = seq_length
        self.augment_rate = augment_rate
        self.per_epoch = per_epoch
        self.n = self.proxies['id'].nunique() if proxies is not None else len(self.samples)
        self.n_proxies = self.proxies['conceptUri'].value_counts().min() if proxies is not None else None

    def __len__(self):
        return self.n * self.per_epoch

    def __getitem__(self, i):
        return i % self.n

    # tokenize
    def tokenize(self, sentences):
        return tokenizer(sentences,
                         padding='max_length' if self.seq_length else 'longest',
                         truncation=True,
                         max_length=self.seq_length,
                         return_tensors='pt')

    # augment into pairs of sentences, even / odd
    def augment(self, samples):
        if random.random() >= self.augment_rate:
            return samples
        s1 = samples['sentence'].iloc[::2].reset_index(drop=True)
        s2 = samples['sentence'].iloc[1::2].reset_index(drop=True)
        s3 = pd.Series()
        if len(s1) > len(s2):
            s3 = s1.iloc[-1:]
            s1 = s1.iloc[:-1]
        return samples.reset_index(drop=True).assign(
            sentence=pd.concat([np.repeat(s1 + ', ' + s2, 2), s3], ignore_index=True).reset_index(drop=True)
        )

    # collate proxy loading
    def proxy_collate(self, ids):
        proxies = self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').head(self.n_proxies)
        tokens = self.tokenize(proxies['sentence'].tolist())
        return {'id': torch.tensor(proxies['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask']}

    # collate mnr training
    def mnr_collate(self, ids):
        proxies = self.proxies.loc[self.proxies['id'].isin(ids)].groupby('id').head(1)
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        samples = self.augment(samples)
        nonskills = self.samples.loc[self.samples['id'] == -1].sample(len(samples))
        nonskills.iloc[::2] = self.augment(nonskills.copy()).iloc[::2]
        samples = pd.concat([proxies, samples, nonskills], ignore_index=True)
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

    # collate direct training
    def direct_collate(self, ids):
        samples = self.samples.loc[self.samples['id'].isin(ids)].groupby('id').sample(1)
        samples = self.augment(samples)
        nonskills = self.samples.loc[self.samples['id'] == -1].sample(len(samples))
        nonskills.iloc[::2] = self.augment(nonskills.copy()).iloc[::2]
        samples = pd.concat([samples, nonskills], ignore_index=True)
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(samples['sentence'].unique(), range(len(samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

    # collate evaluation
    def eval_collate(self, idx):
        samples = self.samples.loc[self.samples.index.isin(idx)]
        tokens = self.tokenize(samples['sentence'].tolist())
        sentence_ids = dict(zip(self.samples['sentence'].unique(), range(len(self.samples['sentence'].unique()))))
        sentence_ids = samples['sentence'].map(sentence_ids)
        return {'id': torch.tensor(samples['id'].tolist(), dtype=torch.long),
                'input_ids': tokens['input_ids'],
                'attention_mask': tokens['attention_mask'],
                'sentence': torch.tensor(sentence_ids.tolist(), dtype=torch.long)}

# initialize datasets

proxy_data = SkillData(tokenizer=tokenizer,
                       proxies=proxy_df,
                       seq_length=C.SEQ_LENGTH)

train_data = SkillData(tokenizer=tokenizer,
                       proxies=proxy_df,
                       samples=train_df,
                       seq_length=None,
                       augment_rate=C.AUGMENT_RATE,
                       per_epoch=C.PER_EPOCH)

val_data = SkillData(tokenizer=tokenizer,
                     samples=val_df,
                     seq_length=C.SEQ_LENGTH)

test_data = {dataset: SkillData(tokenizer=tokenizer,
                                samples=test_df[test_df['dataset'] == dataset],
                                seq_length=C.SEQ_LENGTH)
             for dataset in test_df['dataset'].unique()}

# initialize dataloaders

proxy_loader = DataLoader(proxy_data,
                          batch_size=1 if C.AVERAGE_EMBEDDINGS else C.BATCH_SIZE - (C.BATCH_SIZE % proxy_data.n_proxies),
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=proxy_data.proxy_collate,
                          shuffle=False,
                          pin_memory=True,
                          persistent_workers=True)

train_loader = DataLoader(train_data,
                          batch_size=C.BATCH_SIZE,
                          num_workers=C.NUM_WORKERS,
                          prefetch_factor=C.PREFETCH_FACTOR,
                          collate_fn=train_data.mnr_collate if C.TRAIN_METHOD == 'mnr' else train_data.direct_collate,
                          shuffle=True,
                          pin_memory=True,
                          persistent_workers=True)

val_loader = DataLoader(val_data,
                        batch_size=C.BATCH_SIZE,
                        num_workers=C.NUM_WORKERS,
                        prefetch_factor=C.PREFETCH_FACTOR,
                        collate_fn=val_data.eval_collate,
                        shuffle=False,
                        pin_memory=True,
                        persistent_workers=True)

test_loader = {dataset: DataLoader(test_data[dataset],
                                   batch_size=C.BATCH_SIZE,
                                   num_workers=C.NUM_WORKERS,
                                   prefetch_factor=C.PREFETCH_FACTOR,
                                   collate_fn=test_data[dataset].eval_collate,
                                   shuffle=False,
                                   pin_memory=True,
                                   persistent_workers=True)
               for dataset in test_data.keys()}
# check
proxy_loader.dataset.n, train_loader.dataset.n, val_loader.dataset.n, {k: v.dataset.n for k, v in test_loader.items()}

(13896,
 13896,
 25601,
 {'val': 25601, 'test': 27148, 'techwolf': 588, 'house': 704, 'tech': 827})

# Base Model

In [7]:
# initialize base model
base_model = AutoModel.from_pretrained(C.BASE_MODEL).to(C.DEVICE)

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

# Embedder

In [8]:
# define embedder
class SkillEmbedder(nn.Module):

    # initialize with base model and dropout rate
    def __init__(self, base_model, dropout_rate):
        super().__init__()
        self.base_model = base_model
        self.nugget = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(base_model.config.hidden_size, base_model.config.hidden_size // 2),
            nn.GELU(),
            nn.Linear(base_model.config.hidden_size // 2, 1)
        )

    # embed with mean pooling using batch input_ids and attention_mask
    def forward(self, input_ids, attention_mask):
        embeddings = self.base_model(input_ids, attention_mask).last_hidden_state
        embeddings = (embeddings * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        embeddings = F.normalize(embeddings, p=2, dim=-1)
        return torch.cat([embeddings, self.nugget(embeddings.detach())], dim=-1)

# init embedder
embedder = SkillEmbedder(base_model=base_model, dropout_rate=C.DROPOUT_RATE).to(C.DEVICE)

# Predictor

In [9]:
# define predictor
class SkillPredictor(nn.Module):

    # initialize embedder, proxy_loader and proxy embeddings
    def __init__(self, embedder, proxy_loader, average):

        super().__init__()

        self.embedder = embedder
        self.proxy_loader = proxy_loader
        self.average = average
        self.embeddings = nn.Parameter(torch.zeros(self.proxy_loader.dataset.n,
                                                   self.proxy_loader.dataset.n_proxies,
                                                   self.embedder.base_model.config.hidden_size + 1,
                                                   dtype=torch.half),
                                       requires_grad=False)

    # update proxy embeddings
    def update_embeddings(self):

        training = embedder.training
        pbar = tqdm(proxy_loader, desc=f'Updating proxy embeddings', unit='batch')

        with torch.no_grad():
            for batch in pbar:
                with torch.amp.autocast(str(self.embeddings.device)):
                    batch = {k: v.to(self.embeddings.device) for k, v in batch.items()}
                    embedder.eval()
                    embeddings = embedder(batch['input_ids'], batch['attention_mask']).to(torch.half)
                    for id in batch['id'].unique():
                        skill_embeddings = embeddings[batch['id'] == id]
                        self.embeddings[id] = skill_embeddings.mean(dim=0, keepdim=True) if self.average else skill_embeddings

        embedder.train(training)

    # predict is_skill from n'th dimension and skill_id from proxy embedding similarity
    def forward(self, embeddings, include='both', logits=False):

        if include in ('both', 'all', 'is_skill'):
            is_skill = embeddings[:, -1]
            is_skill = is_skill if logits else F.sigmoid(is_skill)
            if include == 'is_skill':
                return is_skill

        if include in ('both', 'all', 'skill_id'):
            sims = (embeddings[:, :-1].unsqueeze(1).unsqueeze(1) * self.embeddings[:, :, :-1]).sum(dim=-1)
            sims = sims.max(dim=-1)[0]
            skill_id = sims if logits else F.softmax(sims, dim=-1)
            if include == 'skill_id':
                return skill_id

        return is_skill, skill_id

# init predictor
predictor = SkillPredictor(embedder=embedder, proxy_loader=proxy_loader, average=C.AVERAGE_EMBEDDINGS).to(C.DEVICE)
predictor.update_embeddings()

Updating proxy embeddings: 100%|██████████| 145/145 [00:09<00:00, 15.76batch/s]


# Criterion

In [10]:
# define criterion
class SkillCriterion(nn.Module):
    def __init__(self, fp_penalty=0.0, temperature=1.0):
        super().__init__()
        self.fp_penalty = fp_penalty
        self.temperature = temperature

    # calculate loss for is_skill (with false positives penalty) and skill_id (with temperature)
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels):
        is_skill_loss = F.binary_cross_entropy_with_logits(is_skill_logits.float(),
                                                           is_skill_labels.float(),
                                                           reduction='none')
        is_skill_loss *= 1 + ((is_skill_logits > 0.0) & (~is_skill_labels.bool())).float() * self.fp_penalty
        is_skill_loss = is_skill_loss.mean()
        if (skill_id_labels > -1).sum() > 0:
            skill_id_loss = F.cross_entropy(skill_id_logits / self.temperature, skill_id_labels, ignore_index=-1)
        else:
            skill_id_loss = torch.tensor(0.0, device=is_skill_logits.device)
        return is_skill_loss + skill_id_loss, {'is_skill_loss': is_skill_loss.item(), 'skill_id_loss': skill_id_loss.item()}

# init criterion
criterion = SkillCriterion(fp_penalty=C.FP_PENALTY, temperature=C.TEMP_INITIAL).to(C.DEVICE)

# Metrics

In [11]:
# define metrics
class SkillMetrics(nn.Module):
    def __init__(self, cap_temp=1.0):
        super().__init__()
        self.cap_temp = cap_temp

    # calculate metrics
    def forward(self, is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels, sentences):
        with torch.no_grad():

            # group entries by unique sentences
            sentence_to_indices = {}
            for i, sid in enumerate(sentences.tolist()):
                sentence_to_indices.setdefault(sid, []).append(i)

            # mask out duplicate skill sentences from is_skill
            mask = [idxs[0] for idxs in sentence_to_indices.values()]
            is_skill_logits = is_skill_logits[mask]
            is_skill_labels = is_skill_labels[mask]

            # calculate is_skill precision and recall
            tp = ((is_skill_logits > 0.0) & is_skill_labels).sum().item()
            fp = ((is_skill_logits > 0.0) & ~is_skill_labels).sum().item()
            fn = ((is_skill_logits <= 0.0) & is_skill_labels).sum().item()
            is_skill_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            is_skill_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0

            # initialize skill_id metrics
            mrr_sum, rp5_sum, cap_sum, count = 0.0, 0.0, 0.0, 0

            # softmax probabilities and ranking
            probs = F.softmax(skill_id_logits / self.cap_temp, dim=1)
            ranks = torch.argsort(skill_id_logits, dim=1, descending=True)

            # process metrics per unique sentence
            for sid, indices in sentence_to_indices.items():

                # collect skill labels for this sentence
                sentence_labels = {skill_id_labels[i].item() for i in indices if skill_id_labels[i] > -1}

                # check for any labels or skip
                if len(sentence_labels) == 0:
                    continue

                # CAP calculation (Correct Assigned Probability)
                sentence_cap = probs[indices[0], list(sentence_labels)].sum().item()
                cap_sum += sentence_cap

                # MRR calculation (use only max ranking):
                sentence_mrr = 0.0
                for pos, pred in enumerate(ranks[indices[0]].tolist(), start=1):
                    if pred in sentence_labels:
                        sentence_mrr = 1.0 / pos
                        break
                mrr_sum += sentence_mrr

                # RP@5 calculation
                top_5_preds = set(ranks[indices[0], :5].tolist())
                top_k_correct = len(sentence_labels & top_5_preds)
                rp5_sum += top_k_correct / min(5, len(sentence_labels))

                count += 1

            # finalize metrics
            if count > 0:
                skill_id_cap = cap_sum / count
                skill_id_rp5 = rp5_sum / count
                skill_id_mrr = mrr_sum / count
            else:
                skill_id_cap, skill_id_rp5, skill_id_mrr = None, None, None

        return {
            'is_skill_pre': is_skill_precision,
            'is_skill_rec': is_skill_recall,
            'skill_id_cap': skill_id_cap,
            'skill_id_rp5': skill_id_rp5,
            'skill_id_mrr': skill_id_mrr
        }

# init metrics
metrics = SkillMetrics(cap_temp=C.CAP_TEMP).to(C.DEVICE)

# Optimizer

In [12]:
# set n last layers trainable
for idx, layer in enumerate(base_model.encoder.layer):
    for param in layer.parameters():
        param.requires_grad = idx >= len(base_model.encoder.layer) - C.N_LAYERS

# calculate layers to optimize learning for
n_layers = C.N_LAYERS - len(base_model.encoder.layer)

# optimizer for skill_id with layer wise decay rate
optimizer = torch.optim.AdamW(list(reversed([
    {'params': base_model.encoder.layer[i].parameters(), 'lr': C.LR_INITIAL * C.LR_LAYER_FACTOR**-(i + 1)} for i in range(n_layers, 0)
])), weight_decay=C.WEIGHT_DECAY_RATE)

# optimizer for is_skill
nugget_optimizer = torch.optim.AdamW([
    {'params': embedder.nugget.parameters(), 'lr': C.LR_NUGGET}
])

# mixed precision scaler
scaler = torch.amp.GradScaler(str(C.DEVICE))

# collection of mixed precision backward pass calls
def backward(loss, optimizer, nugget_optimizer, scaler, embedder):
    optimizer.zero_grad()
    nugget_optimizer.zero_grad()
    scaler.scale(loss).backward()
    clip_grad_norm_(embedder.base_model.parameters(), max_norm=1.0)
    clip_grad_norm_(embedder.nugget.parameters(), max_norm=2.0)
    scaler.unscale_(optimizer)
    scaler.unscale_(nugget_optimizer)
    scaler.step(optimizer)
    scaler.step(nugget_optimizer)
    scaler.update()

# Logger

In [13]:
# define module for logging
class Logger(nn.Module):

    def __init__(self, path):
        super().__init__()
        self.path = path

    def forward(self, epoch, cardinality, data):

        # init log data fields
        log_data = {'datetime': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    'epoch': epoch,
                    'cardinality': cardinality}

        # add settings
        log_data |= {k:v for k, v in vars(C).items() if k.isupper()}

        # gather log data
        log_data |= data

        # init log with header
        if not os.path.exists(self.path):
            with open(self.path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(log_data.keys())

        # log data for epoch
        with open(self.path, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(log_data.values())
            f.flush()

# initialize logging
logger = Logger(path=C.LOG_PATH)

# Training

In [None]:
# decision variables
best_stats = dict()
patience_counter = 0

# run through epochs
for epoch in range(C.EPOCHS):

    # init stats and progress bar
    train_stats = dict()
    pbar = tqdm(train_loader, desc=f'Training (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # train
    for num_batch, batch in enumerate(pbar):
        with torch.amp.autocast(str(C.DEVICE)):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # training mode
            embedder.train()
            predictor.train()

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            if C.TRAIN_METHOD == 'mnr':

                # separate
                proxies, samples, nonskills = torch.chunk(embeddings, 3)

                # predictions
                is_skill_logits = predictor(torch.cat([samples, nonskills], dim=0), include='is_skill', logits=True)
                skill_id_logits = F.cosine_similarity(samples.unsqueeze(1),
                                                      proxies.unsqueeze(0),
                                                      dim=-1)

                # truth
                is_skill_labels = (batch['id'] > -1)[len(proxies):]
                skill_id_labels = torch.arange(len(samples), device=C.DEVICE)

            else:

                # predictions
                is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

                # truth
                is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # run batch
            loss, stats = criterion(is_skill_logits, is_skill_labels, skill_id_logits, skill_id_labels)

        # backward pass
        backward(loss, optimizer, nugget_optimizer, scaler, embedder)

        # update progress
        train_stats = {
            'train_loss': [loss.item()] + train_stats.setdefault('loss', []),
            'train_is_skill_loss': [stats['is_skill_loss']] + train_stats.setdefault('is_skill_loss', []),
            'train_skill_id_loss': [stats['skill_id_loss']] + train_stats.setdefault('skill_id_loss', []),
            'lr': [optimizer.param_groups[0]['lr']] + train_stats.setdefault('lr', []),
            'patience': [C.PATIENCE - patience_counter],
            'temperature': [criterion.temperature] + train_stats.setdefault('temperature', [])
        }

        pbar.set_postfix({k: sum(v) / len(v) for k, v in train_stats.items()})

        # decay temperature and weight
        criterion.temperature = max(criterion.temperature * C.TEMP_FACTOR, C.TEMP)

        # warmup lr
        for param_group in optimizer.param_groups:
            if param_group['lr'] >= C.LR:# or patience_counter > 0:
                break
            param_group['lr'] = min(param_group['lr'] * C.LR_WARMUP_FACTOR, C.LR)

    # clean up
    del batch, embeddings, is_skill_logits, skill_id_logits, is_skill_labels, skill_id_labels, loss, stats
    torch.cuda.empty_cache()

    # update proxy embeddings after training / before validation
    if C.UPDATE_EMBEDDINGS:
        predictor.update_embeddings()

    # init stats and progress bar
    val_stats = dict()
    pbar = tqdm(val_loader, desc=f'Validation (epoch {epoch+1}/{C.EPOCHS})', unit='batch')

    # results warehouse
    logits_and_labels = dict()

    # validate
    with torch.no_grad():
        for num_batch, batch in enumerate(pbar):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # validation mode
            embedder.eval()
            predictor.eval()

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            # predictions
            is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

            # truth
            is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # save logits and labels
            logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
            logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
            logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
            logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
            logits_and_labels.setdefault('sentence', []).append(batch['sentence'])

        # run batch
        loss, stats = criterion(*[torch.cat(v) for v in list(logits_and_labels.values())[:4]])
        val_metrics = metrics(*[torch.cat(v) for v in logits_and_labels.values()])
        stats = {'loss': loss.item()} | stats | val_metrics

        # update progress
        val_stats = {k: [v] + val_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
        print('Validation:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in val_stats.items()]))

    # clean up
    del batch, embeddings, is_skill_logits, skill_id_logits, is_skill_labels, skill_id_labels, logits_and_labels, loss, stats
    torch.cuda.empty_cache()

    # finalize stats
    train_stats = {k: sum(v) / len(v) for k, v in train_stats.items()}
    val_stats = {k: sum(v) / len(v) for k, v in val_stats.items()}

    # logging
    logger(epoch=epoch + 1, cardinality=val_loader.dataset.n, data=val_stats | train_stats)

    # force stat keys
    best_stats = {k: best_stats.setdefault(k, float('inf') if k[-4:] == 'loss' else 0.0) for k in val_stats.keys()}

    # count improved
    improved_stats = np.array([val_stats[k] < best_stats[k] if k[-4:] == 'loss' else val_stats[k] > best_stats[k] for k in C.VAL_STATS])

    # check overall improvement or not
    if improved_stats.mean().round():
        patience_counter = 0
        best_stats = {k:val_stats[k] if (val_stats[k] < best_stats[k] and k[-4:] == 'loss') or (val_stats[k] > best_stats[k] and k[-4:] != 'loss') else best_stats[k] for k in val_stats.keys()}
        torch.save({'embedder_state_dict': embedder.state_dict(),
                    'predictor_state_dict': predictor.state_dict()}, C.SAVE_PATH)
    else:
        patience_counter += 1
        if patience_counter > C.PATIENCE:
            state_dicts = torch.load(C.SAVE_PATH, weights_only=False)
            embedder.load_state_dict(state_dicts['embedder_state_dict'])
            predictor.load_state_dict(state_dicts['predictor_state_dict'])
            break
        for param_group in optimizer.param_groups + nugget_optimizer.param_groups:
            param_group['lr'] *= C.LR_REDUCE_FACTOR

# Evaluation

In [None]:
# test set + bench orig + bench llm translated
for dataset, loader in test_loader.items():

    # init stats and progress bar
    test_stats = dict()
    pbar = tqdm(loader, desc=dataset, unit='batch')

    # results warehouse
    logits_and_labels = dict()

    # test
    with torch.no_grad():
        for num_batch, batch in enumerate(pbar):

            # send batch to GPU
            batch = {k: v.to(C.DEVICE) for k, v in batch.items()}

            # test mode
            embedder.eval()
            predictor.eval()

            # generate embeddings
            embeddings = embedder(batch['input_ids'], batch['attention_mask'])

            # predictions
            is_skill_logits, skill_id_logits = predictor(embeddings, logits=True)

            # truth
            is_skill_labels, skill_id_labels = batch['id'] > -1, batch['id']

            # save logits and labels
            logits_and_labels.setdefault('is_skill_logits', []).append(is_skill_logits)
            logits_and_labels.setdefault('is_skill_labels', []).append(is_skill_labels)
            logits_and_labels.setdefault('skill_id_logits', []).append(skill_id_logits)
            logits_and_labels.setdefault('skill_id_labels', []).append(skill_id_labels)
            logits_and_labels.setdefault('sentence', []).append(batch['sentence'])

        # concatenate
        logits_and_labels = {k: torch.cat(v) for k, v in logits_and_labels.items()}

        # run batch
        loss, stats = criterion(*list(logits_and_labels.values())[:4])
        stats = {'loss': loss.item()} | stats | metrics(*list(logits_and_labels.values())[:5])

        # update progress
        test_stats = {k: [v] + test_stats.setdefault(k, []) for k, v in stats.items() if v is not None}
        print(f'{dataset}:', ', '.join([f'{k} = {str(sum(v) / len(v))}' for k, v in test_stats.items()]))

    # finalize stats
    test_stats = {k: sum(v) / len(v) for k, v in test_stats.items()}

    # logging
    logger(epoch=dataset, cardinality=loader.dataset.n, data=test_stats)

In [17]:
# empty
None

In [18]:
# turn off costly GPU
from google.colab import runtime
runtime.unassign()