# M3 Skills Comparator Training (Colab + T4 GPU)

Trains a 256-dim projection head on top of **TechWolf/JobBERT-v2** using contrastive learning with `MultipleNegativesRankingLoss`.

## Setup
1. Upload `m3_training_data.zip` (created by the packaging cell below, or pre-packaged locally)
2. Run all cells
3. Download the trained model from `training/models/m3_skills_comparator/`

In [None]:
# Cell 1: Install dependencies
!pip install -q sentence-transformers datasets pandas pyyaml numpy

In [None]:
# Cell 2: Check GPU
import torch
print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
# Cell 3: Upload and extract training data
# Option A: Upload from Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# !cp /content/drive/MyDrive/m3_training_data.zip /content/

# Option B: Upload directly
from google.colab import files
print("Upload m3_training_data.zip...")
uploaded = files.upload()

!unzip -o m3_training_data.zip -d /content/m3_data

In [None]:
# Cell 4: Data preparation - load all skill data sources
import csv
import gzip
import json
import logging
import random
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

DATA_DIR = Path("/content/m3_data")


def load_esco_sentences():
    parquet_path = DATA_DIR / "techwolf_esco_sentences" / "train.parquet"
    if not parquet_path.exists():
        logger.warning("TechWolf ESCO sentences not found -- skipping.")
        return []
    df = pd.read_parquet(parquet_path)
    skill_col, sentence_col = None, None
    for col in df.columns:
        cl = col.lower()
        if "skill" in cl and "sent" not in cl:
            skill_col = col
        elif "sent" in cl or "text" in cl or "description" in cl:
            sentence_col = col
    if skill_col is None or sentence_col is None:
        cols = df.columns.tolist()
        if len(cols) >= 2:
            skill_col, sentence_col = cols[0], cols[1]
        else:
            return []
    records = []
    for _, row in df.iterrows():
        skill = str(row[skill_col]).strip()
        sentence = str(row[sentence_col]).strip()
        if skill and sentence and len(skill) > 1:
            records.append({"skill": skill, "sentence": sentence})
    logger.info("TechWolf ESCO: loaded %d skill-sentence pairs.", len(records))
    return records


def load_tabiya_synonyms():
    csv_dir = DATA_DIR / "tabiya_esco" / "tabiya-esco-v1.1.1" / "csv"
    skills_csv = csv_dir / "skills.csv"
    if not skills_csv.exists():
        # Try alternative paths
        for p in DATA_DIR.rglob("skills.csv"):
            skills_csv = p
            break
    if not skills_csv.exists():
        logger.warning("Tabiya skills.csv not found -- skipping.")
        return {}
    df = pd.read_csv(skills_csv)
    synonyms = defaultdict(list)
    preferred_col, alt_col = None, None
    for col in df.columns:
        cl = col.lower()
        if "preferred" in cl and "label" in cl:
            preferred_col = col
        elif "alt" in cl and "label" in cl:
            alt_col = col
    if preferred_col is None:
        for col in df.columns:
            if "label" in col.lower() or "name" in col.lower():
                preferred_col = col
                break
    if preferred_col is None:
        return {}
    for _, row in df.iterrows():
        pref = str(row[preferred_col]).strip()
        if not pref or pref == "nan":
            continue
        alts = []
        if alt_col and pd.notna(row.get(alt_col)):
            alt_str = str(row[alt_col])
            for sep in ["|", "\n", ";"]:
                if sep in alt_str:
                    alts = [a.strip() for a in alt_str.split(sep) if a.strip()]
                    break
            if not alts and alt_str.strip():
                alts = [alt_str.strip()]
        synonyms[pref] = alts
    logger.info("Tabiya ESCO: loaded %d skills with synonyms.", len(synonyms))
    return synonyms


def load_nesta_clusters():
    nesta_base = DATA_DIR / "nesta_skills_taxonomy"
    clusters = defaultdict(list)
    json_files = list(nesta_base.rglob("*.json")) if nesta_base.exists() else []
    csv_files = list(nesta_base.rglob("*.csv")) if nesta_base.exists() else []
    if not json_files and not csv_files:
        logger.warning("Nesta: no data files found -- skipping.")
        return {}
    for jf in json_files:
        if "cluster" in jf.name.lower() or "taxonomy" in jf.name.lower():
            try:
                with open(jf, "r", encoding="utf-8") as f:
                    data = json.load(f)
                if isinstance(data, dict):
                    for key, val in data.items():
                        if isinstance(val, list):
                            clusters[hash(key) % 1000] = [str(v) for v in val if v]
                        elif isinstance(val, dict) and "skills" in val:
                            clusters[hash(key) % 1000] = [str(s) for s in val["skills"] if s]
                break
            except:
                continue
    if not clusters:
        for cf in csv_files:
            if "skill" in cf.name.lower() or "cluster" in cf.name.lower():
                try:
                    df = pd.read_csv(cf)
                    skill_col, cluster_col = None, None
                    for col in df.columns:
                        cl = col.lower()
                        if "skill" in cl or "name" in cl or "label" in cl:
                            skill_col = col
                        elif "cluster" in cl or "group" in cl or "category" in cl:
                            cluster_col = col
                    if skill_col and cluster_col:
                        for _, row in df.iterrows():
                            skill = str(row[skill_col]).strip()
                            cid = row[cluster_col]
                            if skill and skill != "nan":
                                clusters[int(hash(str(cid)) % 1000)].append(skill)
                        break
                except:
                    continue
    logger.info("Nesta: loaded %d clusters with %d total skills.", len(clusters), sum(len(v) for v in clusters.values()))
    return clusters


def load_stacklite_cooccurrence():
    tags_path = DATA_DIR / "stacklite" / "question_tags.csv.gz"
    if not tags_path.exists():
        logger.warning("StackLite tags not found -- skipping.")
        return {}
    question_tags = defaultdict(list)
    try:
        with gzip.open(str(tags_path), "rt", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for i, row in enumerate(reader):
                qid = int(row.get("Id", row.get("id", 0)))
                tag = row.get("Tag", row.get("tag", "")).strip()
                if qid and tag:
                    question_tags[qid].append(tag)
                if i >= 2_000_000:
                    break
    except:
        return {}
    cooccur = defaultdict(set)
    for qid, tags in question_tags.items():
        for tag in tags:
            for other in tags:
                if other != tag:
                    cooccur[tag].add(other)
    logger.info("StackLite: built co-occurrence for %d tags.", len(cooccur))
    return cooccur


# Load all sources
esco_pairs = load_esco_sentences()
synonyms = load_tabiya_synonyms()
clusters = load_nesta_clusters()
cooccur = load_stacklite_cooccurrence()

In [None]:
# Cell 5: Build triplets
def build_triplets(esco_pairs, synonyms, clusters, cooccur, max_triplets=500000, seed=42):
    rng = random.Random(seed)
    triplets = []
    all_skills = list(set(
        list(synonyms.keys())
        + [alt for alts in synonyms.values() for alt in alts]
        + [skill for skills in clusters.values() for skill in skills]
        + list(cooccur.keys())
    ))
    if len(all_skills) < 10:
        logger.warning("Insufficient skill vocabulary (%d)", len(all_skills))
        return []

    # Synonym-based
    for pref, alts in synonyms.items():
        for alt in alts:
            if alt == pref:
                continue
            neg = rng.choice(all_skills)
            while neg == pref or neg == alt:
                neg = rng.choice(all_skills)
            triplets.append({"anchor": pref, "positive": alt, "negative": neg})
            if len(triplets) >= max_triplets:
                break
        if len(triplets) >= max_triplets:
            break

    # Cluster-based
    cluster_list = list(clusters.items())
    all_cluster_skills = [s for skills in clusters.values() for s in skills]
    if cluster_list and all_cluster_skills:
        for cid, skills in cluster_list:
            if len(skills) < 2:
                continue
            other_cluster_skills = [s for oid, oss in cluster_list if oid != cid for s in oss]
            if not other_cluster_skills:
                other_cluster_skills = all_skills
            for i in range(len(skills)):
                for j in range(i + 1, min(i + 3, len(skills))):
                    anchor = skills[i]
                    positive = skills[j]
                    negative = rng.choice(other_cluster_skills)
                    triplets.append({"anchor": anchor, "positive": positive, "negative": negative})
                    if len(triplets) >= max_triplets:
                        break
                if len(triplets) >= max_triplets:
                    break
            if len(triplets) >= max_triplets:
                break

    # ESCO sentence-based
    skill_to_sentences = defaultdict(list)
    for pair in esco_pairs:
        skill_to_sentences[pair["skill"]].append(pair["sentence"])
    skill_keys = list(skill_to_sentences.keys())
    if len(skill_keys) >= 2 and len(triplets) < max_triplets:
        for skill in skill_keys:
            sentences = skill_to_sentences[skill]
            if len(sentences) < 2:
                continue
            for i in range(min(len(sentences), 3)):
                for j in range(i + 1, min(len(sentences), 4)):
                    neg_skill = rng.choice(skill_keys)
                    while neg_skill == skill:
                        neg_skill = rng.choice(skill_keys)
                    neg_sentences = skill_to_sentences[neg_skill]
                    neg_sentence = rng.choice(neg_sentences) if neg_sentences else neg_skill
                    triplets.append({"anchor": sentences[i], "positive": sentences[j], "negative": neg_sentence})
                    if len(triplets) >= max_triplets:
                        break
                if len(triplets) >= max_triplets:
                    break
            if len(triplets) >= max_triplets:
                break

    rng.shuffle(triplets)
    triplets = triplets[:max_triplets]
    logger.info("Built %d triplets.", len(triplets))
    return triplets


def create_hard_negatives(triplets, cooccur, fraction=0.3, seed=42):
    rng = random.Random(seed)
    if not cooccur or not triplets:
        return triplets
    n_replace = int(len(triplets) * fraction)
    indices = rng.sample(range(len(triplets)), min(n_replace, len(triplets)))
    replaced = 0
    for idx in indices:
        anchor = triplets[idx]["anchor"]
        positive = triplets[idx]["positive"]
        cooccurring = cooccur.get(anchor, set()) - {positive, anchor}
        if not cooccurring:
            cooccurring = cooccur.get(anchor.lower(), set()) - {positive.lower(), anchor.lower()}
        if cooccurring:
            triplets[idx]["negative"] = rng.choice(list(cooccurring))
            replaced += 1
    logger.info("Replaced %d / %d negatives with hard negatives.", replaced, len(triplets))
    return triplets


triplets = build_triplets(esco_pairs, synonyms, clusters, cooccur, max_triplets=500000)
triplets = create_hard_negatives(triplets, cooccur)
print(f"Total triplets: {len(triplets)}")
print(f"Sample: {triplets[0]}")

In [None]:
# Cell 6: Train the model
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from sentence_transformers.evaluation import TripletEvaluator
from torch.utils.data import DataLoader

# Config
BASE_MODEL = "TechWolf/JobBERT-v2"
PROJECTION_DIM = 256
EPOCHS = 5
BATCH_SIZE = 64
WARMUP_RATIO = 0.1
FP16 = True
OUTPUT_DIR = "/content/m3_skills_comparator"

# Load model + add projection head
model = SentenceTransformer(BASE_MODEL)
projection = models.Dense(
    in_features=model.get_sentence_embedding_dimension(),
    out_features=PROJECTION_DIM,
    activation_function=None,
)
model.add_module("projection", projection)
print(f"Model loaded. Embedding dim: {model.get_sentence_embedding_dimension()}")

# Split triplets
np.random.shuffle(triplets)
split = int(len(triplets) * 0.9)
train_triplets = triplets[:split]
eval_triplets = triplets[split:]

# DataLoader
train_examples = [
    InputExample(texts=[t["anchor"], t["positive"], t["negative"]])
    for t in train_triplets
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

# Loss & evaluator
train_loss = losses.MultipleNegativesRankingLoss(model)
evaluator = TripletEvaluator(
    anchors=[t["anchor"] for t in eval_triplets],
    positives=[t["positive"] for t in eval_triplets],
    negatives=[t["negative"] for t in eval_triplets],
    name="skill_triplet_eval",
)

warmup_steps = int(len(train_dataloader) * EPOCHS * WARMUP_RATIO)

print(f"Training: {len(train_triplets)} train, {len(eval_triplets)} eval")
print(f"Batches/epoch: {len(train_dataloader)}, warmup: {warmup_steps} steps")
print(f"Starting training...")

model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path=OUTPUT_DIR,
    evaluation_steps=len(train_dataloader) // 2,
    save_best_model=True,
    use_amp=FP16,
)
print("Training complete!")

In [None]:
# Cell 7: Export projection weights & evaluate
projection_weights = {}
for name, param in model.named_parameters():
    if "projection" in name:
        projection_weights[name.split(".")[-1]] = param.detach().cpu().numpy()

projection_path = Path(OUTPUT_DIR) / "projection.npy"
np.save(str(projection_path), projection_weights)
print(f"Projection weights saved to {projection_path}")

# Final evaluation
eval_score = evaluator(model, output_path=OUTPUT_DIR)
TARGET = 0.80
if eval_score >= TARGET:
    print(f"Target accuracy {TARGET:.2f} ACHIEVED (got {eval_score:.4f})")
else:
    print(f"Target accuracy {TARGET:.2f} NOT MET (got {eval_score:.4f})")

In [None]:
# Cell 8: Download trained model
import shutil

# Zip the model directory
shutil.make_archive("/content/m3_skills_comparator_trained", "zip", OUTPUT_DIR)
print(f"Model zipped to /content/m3_skills_comparator_trained.zip")

# Download
from google.colab import files
files.download("/content/m3_skills_comparator_trained.zip")